huajsj commented on a change in pull request #9494:
URL: https://github.com/apache/tvm/pull/9494#discussion_r747999521



##########
File path: src/runtime/pipeline/pipeline_executor.cc
##########
@@ -34,13 +36,134 @@ PackedFunc PipelineExecutor::GetFunction(const 
std::string& name,
   if (name == "get_num_outputs") {
     return PackedFunc(
         [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->NumOutputs(); });
+  } else if (name == "get_num_inputs") {
+    return PackedFunc(
+        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->NumInputs(); });
+  } else if (name == "set_input") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      if (String::CanConvertFrom(args[0])) {
+        this->SetInput(args[0].operator String(), args[1]);
+      } else {
+        LOG(FATAL) << "Function only support the input name value in the form 
of string";
+      }
+    });
+  } else if (name == "set_param") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      if (String::CanConvertFrom(args[0]) && String::CanConvertFrom(args[1])) {
+        this->SetParam(args[0].operator String(), args[1].operator String(), 
args[2]);
+      } else {
+        LOG(FATAL) << "Function only support the params name and keyin the 
form of string";
+      }
+    });
+  } else if (name == "get_output") {
+    return PackedFunc(
+        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->GetOutput(); });
+  } else if (name == "get_input") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      if (String::CanConvertFrom(args[0])) {
+        *rv = this->GetInput(args[0].operator String());
+      } else {
+        LOG(FATAL) << "Function only support the input name value in the form 
of string";
+      }
+    });
+  } else if (name == "run") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { 
this->Run(args[0]); });
+  } else if (name == "stop") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { 
this->Stop(); });
   } else {
     LOG(FATAL) << "Unknown packed function: " << name;
     return PackedFunc();
   }
   return nullptr;
 }
+/*!
+ * \brief There are some input called pipeline global input that user need to 
use function
+    "set_input" to set the data for it, this function return the number of 
such global input.
+   \return Return the number of pipeline global input.
+ */
+
+int PipelineExecutor::NumInputs() const {
+  // The number of inputs obtained from the input configuration.
+  size_t config_inputs_num = input_connection_config.size(), ret = 0;
+  // The number of inputs obtained from the graph runtime and pipeline 
configuration.
+  size_t internal_inputs_num = pipeline_config_.GetInputOutputBindingNum();
+  for (auto runtime : runtimes_) {
+    ret += runtime->NumInputs();
+  }
+  // Use the summary of all backend runtime module input number to minus the 
internal inputs
+  // number, then we will get the pipeline global input number
+  ret -= internal_inputs_num;
+  // Check whether these two numbers are equal.
+  if (config_inputs_num != ret) {
+    LOG(FATAL) << "The number of inputs from the configuration file is 
inconsistent!";
+  }
+  return ret;
+}
+/*!
+ * \brief Return the input index and module index for a given input name.
+ * \param name The input name.
+ * \return std::pair<int, int> The module index and the input index.
+ */
+std::pair<int, int> PipelineExecutor::GetInputIndex(const std::string& name) {
+  std::pair<int, std::string> index = input_connection_config[name];
+  auto gruntime = runtimes_[index.first];
+  return std::make_pair(index.first, gruntime->GetInputIndex(index.second));
+}
+/*!
+ * \brief Return the module index for a given input param name.
+ * \param name The params name.
+ * \return int The module index.
+ */
+int PipelineExecutor::GetParamModuleIndex(const std::string& name) {
+  return param_connection_config[name];
+}
+/*!
+ * \brief set input to the graph module.
+ * \param input_name The input name.
+ * \param data_in The input data.
+ */
+void PipelineExecutor::SetInput(std::string input_name, DLTensor* data_in) {
+  std::pair<int, int> indexs = this->GetInputIndex(input_name);
+  runtimes_[indexs.first]->SetInput(indexs.second, data_in);

Review comment:
       add indexs.first check




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to