This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8a94b6699a [Runtime][PipelineExecutor] Added Interface to Track Number
of Global Inputs (#11315)
8a94b6699a is described below
commit 8a94b6699a16e688c2da26c5e83bf52e671d94fc
Author: Raghav Chakravarthy <[email protected]>
AuthorDate: Fri Jun 17 15:36:31 2022 -0400
[Runtime][PipelineExecutor] Added Interface to Track Number of Global
Inputs (#11315)
* [Runtime][PipleineExecutor] Added Interface to Track Number of Global
Inputs
Added a feature to PipelineExecutor to track number of Global Inputs.
* Fixed CI Error
* Fixed remaining CI Error
---
python/tvm/contrib/pipeline_executor.py | 11 +++++++++++
src/runtime/pipeline/pipeline_executor.cc | 8 +++++++-
src/runtime/pipeline/pipeline_executor.h | 1 +
src/runtime/pipeline/pipeline_struct.h | 3 +++
tests/python/relay/test_pipeline_executor.py | 2 ++
5 files changed, 24 insertions(+), 1 deletion(-)
diff --git a/python/tvm/contrib/pipeline_executor.py
b/python/tvm/contrib/pipeline_executor.py
index a50fffaa2b..5ef309bb28 100644
--- a/python/tvm/contrib/pipeline_executor.py
+++ b/python/tvm/contrib/pipeline_executor.py
@@ -55,6 +55,7 @@ class PipelineModule(object):
self._get_input = self.module["get_input"]
self._get_output = self.module["get_output"]
self._get_num_outputs = self.module["get_num_outputs"]
+ self._get_num_inputs = self.module["get_num_inputs"]
self._get_input_pipeline_map = self.module["get_input_pipeline_map"]
self._get_pipe_execute_count = self.module["get_execute_count"]
@@ -159,6 +160,16 @@ class PipelineModule(object):
"""
return self._get_num_outputs()
+ @property
+ def num_inputs(self):
+ """Get the number of inputs
+ Returns
+ -------
+ count : int
+ The number of inputs
+ """
+ return self._get_num_inputs()
+
@staticmethod
def load_library(config_file_name):
"""Import files to create a pipeline executor.
diff --git a/src/runtime/pipeline/pipeline_executor.cc
b/src/runtime/pipeline/pipeline_executor.cc
index a191f816f7..b5c560e255 100644
--- a/src/runtime/pipeline/pipeline_executor.cc
+++ b/src/runtime/pipeline/pipeline_executor.cc
@@ -34,6 +34,9 @@ PackedFunc PipelineExecutor::GetFunction(const std::string&
name,
if (name == "get_num_outputs") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
this->NumOutputs(); });
+ } else if (name == "get_num_inputs") {
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
this->NumInputs(); });
} else if (name == "get_input_pipeline_map") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
@@ -87,7 +90,10 @@ PackedFunc PipelineExecutor::GetFunction(const std::string&
name,
return PackedFunc();
}
}
-
+/*!
+ * brief Returns number of global inputs.
+ */
+int PipelineExecutor::NumInputs(void) { return
input_connection_config_.GetInputNum(); }
/*!
* \brief set input to the runtime module.
* \param input_name The input name.
diff --git a/src/runtime/pipeline/pipeline_executor.h
b/src/runtime/pipeline/pipeline_executor.h
index 9f9b24bdf0..87b50ed3a1 100644
--- a/src/runtime/pipeline/pipeline_executor.h
+++ b/src/runtime/pipeline/pipeline_executor.h
@@ -115,6 +115,7 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
int NumOutputs() const { return num_outputs_; }
/*!\brief Run the pipeline executor.*/
void Run();
+ int NumInputs();
/*!
* \brief Get a list output data.
* \return A list of output data.
diff --git a/src/runtime/pipeline/pipeline_struct.h
b/src/runtime/pipeline/pipeline_struct.h
index 2cb7b4a6d2..540103d018 100644
--- a/src/runtime/pipeline/pipeline_struct.h
+++ b/src/runtime/pipeline/pipeline_struct.h
@@ -560,6 +560,9 @@ struct InputConnectionConfig {
}
return input_connection[key];
}
+ /*!\brief Returns the number of global inputs through the input_runtime_map
list size.*/
+ int GetInputNum() { return input_runtime_map.size(); }
+
/*!
* \brief Getting the global input index through the input name.
* \param input_name The global input name.
diff --git a/tests/python/relay/test_pipeline_executor.py
b/tests/python/relay/test_pipeline_executor.py
index 541f3bba13..06614977d4 100644
--- a/tests/python/relay/test_pipeline_executor.py
+++ b/tests/python/relay/test_pipeline_executor.py
@@ -595,6 +595,8 @@ def test_pipeline():
if input_map[0] == "0":
input_data = pipeline_module_test.get_input("data_a")
tvm.testing.assert_allclose(data, input_data.numpy())
+
+ assert pipeline_module_test.num_inputs == 2
# Running the pipeline executor in the pipeline mode.
pipeline_module_test.run()