This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8a94b6699a [Runtime][PipelineExecutor] Added Interface to Track Number 
of Global Inputs (#11315)
8a94b6699a is described below

commit 8a94b6699a16e688c2da26c5e83bf52e671d94fc
Author: Raghav Chakravarthy <[email protected]>
AuthorDate: Fri Jun 17 15:36:31 2022 -0400

    [Runtime][PipelineExecutor] Added Interface to Track Number of Global 
Inputs (#11315)
    
    * [Runtime][PipleineExecutor] Added Interface to Track Number of Global 
Inputs
    
    Added a feature to PipelineExecutor to track number of Global Inputs.
    
    * Fixed CI Error
    
    * Fixed remaining CI Error
---
 python/tvm/contrib/pipeline_executor.py      | 11 +++++++++++
 src/runtime/pipeline/pipeline_executor.cc    |  8 +++++++-
 src/runtime/pipeline/pipeline_executor.h     |  1 +
 src/runtime/pipeline/pipeline_struct.h       |  3 +++
 tests/python/relay/test_pipeline_executor.py |  2 ++
 5 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/python/tvm/contrib/pipeline_executor.py 
b/python/tvm/contrib/pipeline_executor.py
index a50fffaa2b..5ef309bb28 100644
--- a/python/tvm/contrib/pipeline_executor.py
+++ b/python/tvm/contrib/pipeline_executor.py
@@ -55,6 +55,7 @@ class PipelineModule(object):
         self._get_input = self.module["get_input"]
         self._get_output = self.module["get_output"]
         self._get_num_outputs = self.module["get_num_outputs"]
+        self._get_num_inputs = self.module["get_num_inputs"]
         self._get_input_pipeline_map = self.module["get_input_pipeline_map"]
         self._get_pipe_execute_count = self.module["get_execute_count"]
 
@@ -159,6 +160,16 @@ class PipelineModule(object):
         """
         return self._get_num_outputs()
 
+    @property
+    def num_inputs(self):
+        """Get the number of inputs
+        Returns
+        -------
+        count : int
+            The number of inputs
+        """
+        return self._get_num_inputs()
+
     @staticmethod
     def load_library(config_file_name):
         """Import files to create a pipeline executor.
diff --git a/src/runtime/pipeline/pipeline_executor.cc 
b/src/runtime/pipeline/pipeline_executor.cc
index a191f816f7..b5c560e255 100644
--- a/src/runtime/pipeline/pipeline_executor.cc
+++ b/src/runtime/pipeline/pipeline_executor.cc
@@ -34,6 +34,9 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& 
name,
   if (name == "get_num_outputs") {
     return PackedFunc(
         [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->NumOutputs(); });
+  } else if (name == "get_num_inputs") {
+    return PackedFunc(
+        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->NumInputs(); });
   } else if (name == "get_input_pipeline_map") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       if (String::CanConvertFrom(args[0])) {
@@ -87,7 +90,10 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& 
name,
     return PackedFunc();
   }
 }
-
+/*!
+ * brief Returns number of global inputs.
+ */
+int PipelineExecutor::NumInputs(void) { return 
input_connection_config_.GetInputNum(); }
 /*!
  * \brief set input to the runtime module.
  * \param input_name The input name.
diff --git a/src/runtime/pipeline/pipeline_executor.h 
b/src/runtime/pipeline/pipeline_executor.h
index 9f9b24bdf0..87b50ed3a1 100644
--- a/src/runtime/pipeline/pipeline_executor.h
+++ b/src/runtime/pipeline/pipeline_executor.h
@@ -115,6 +115,7 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
   int NumOutputs() const { return num_outputs_; }
   /*!\brief Run the pipeline executor.*/
   void Run();
+  int NumInputs();
   /*!
    * \brief Get a list output data.
    * \return A list of output data.
diff --git a/src/runtime/pipeline/pipeline_struct.h 
b/src/runtime/pipeline/pipeline_struct.h
index 2cb7b4a6d2..540103d018 100644
--- a/src/runtime/pipeline/pipeline_struct.h
+++ b/src/runtime/pipeline/pipeline_struct.h
@@ -560,6 +560,9 @@ struct InputConnectionConfig {
     }
     return input_connection[key];
   }
+  /*!\brief Returns the number of global inputs through the input_runtime_map 
list size.*/
+  int GetInputNum() { return input_runtime_map.size(); }
+
   /*!
    * \brief Getting the global input index through the input name.
    * \param input_name The global input name.
diff --git a/tests/python/relay/test_pipeline_executor.py 
b/tests/python/relay/test_pipeline_executor.py
index 541f3bba13..06614977d4 100644
--- a/tests/python/relay/test_pipeline_executor.py
+++ b/tests/python/relay/test_pipeline_executor.py
@@ -595,6 +595,8 @@ def test_pipeline():
                 if input_map[0] == "0":
                     input_data = pipeline_module_test.get_input("data_a")
                     tvm.testing.assert_allclose(data, input_data.numpy())
+
+                assert pipeline_module_test.num_inputs == 2
                 # Running the pipeline executor in the pipeline mode.
                 pipeline_module_test.run()
 

Reply via email to