huajsj commented on a change in pull request #9494:
URL: https://github.com/apache/tvm/pull/9494#discussion_r747999857



##########
File path: src/runtime/pipeline/pipeline_scheduler.cc
##########
@@ -26,12 +27,97 @@ namespace runtime {
  * \brief Initialize the pipeline.
  * \param modules The list of graph executor modules.
  * \param pipeline_conf The dependency information of each graph executor 
module.
+ * \return Return a list of backend runtime module.
  */
-size_t PipelineScheduler::PipelineInit(const std::vector<Module>& modules,
-                                       const PipelineConfig& pipeline_config) {
+std::vector<std::shared_ptr<BackendRuntime>> PipelineScheduler::PipelineInit(
+    const std::vector<Module>& modules, ConfigPipelineExecution 
pipeline_config) {
+  std::vector<std::shared_ptr<BackendRuntime>> runtimes;
   graph_modules_ = modules;
-  int num_output = pipeline_config.GetGlobalOutputNum();
-  return num_output;
+  for (size_t i = 0; i < graph_modules_.size(); i++) {
+    auto runItem = std::make_shared<BackendRuntime>(graph_modules_[i], i);
+    runtimes.push_back(runItem);
+  }
+  // Initialize the outputs array.
+  auto& global_output_map = pipeline_config.GetGlobalConfigOutputBindings();
+  for (size_t i = 0; i < global_output_map.size(); i++) {
+    if (global_output_map.find(i) == global_output_map.end()) {
+      LOG(FATAL) << "Not find global output index " << i;
+    }
+    ModuleOutputPair& output_pair = global_output_map[i];
+    NDArray output = 
runtimes[output_pair.mod_idx]->CreateFromOutput(output_pair.output_idx);
+    output_array.push_back(output);
+  }
+  return runtimes;
 }
+
+/*!
+ * \brief Exeute in the serialized mode.
+ * \param runtimes A list of backend runtimes module.
+ * \param pipeline_config The dependency information of each graph executor 
module.
+ */
+void PipelineScheduler::PipelineRunSerial(
+    const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
+    ConfigPipelineExecution pipeline_config) {
+  for (size_t i = 0; i < runtimes.size(); i++) {
+    // The offset in vector is the runtime execution order, this offset value 
should
+    // be same with the the value of "runtime_idx" in runtime.
+    if (static_cast<int>(i) != runtimes[i]->GetModuleIndex()) {
+      LOG(FATAL) << "runtime index " << runtimes[i]->GetModuleIndex()
+                 << " is not same as vector offset value " << i;
+    }
+
+    if (!pipeline_config.FindModuleInConfig(i)) {
+      LOG(FATAL) << "Not find the configuration for the module " << i;
+    }
+
+    runtimes[i]->Run();
+    // Check if there is any output need to be forward to other graph module 
or to be as
+    // global output.
+    int outputs_num = runtimes[i]->NumOutputs();
+    for (int j = 0; j < outputs_num; j++) {
+      ConfigBindings& out_binding = pipeline_config[i][j];
+      std::unordered_map<int, std::string>& input_connections = 
out_binding.Get();
+      NDArray output = runtimes[i]->GetOutput(j);
+      for (auto bind : input_connections) {
+        // If the value of "bind.first" less then 0 then this is not a graph 
module binding.
+        if (bind.first < 0) continue;
+        // Set input data for the graph module.
+        runtimes[bind.first]->SetInput(bind.second, 
const_cast<DLTensor*>(output.operator->()));
+      }
+      // Store the output.
+      if (out_binding.IsGlobalOutput()) {
+        int global_idx = out_binding.GetGlobalOutputIndex();
+        TVMArrayCopyFromTo(const_cast<DLTensor*>(output.operator->()),
+                           
const_cast<DLTensor*>(output_array[global_idx].operator->()), nullptr);
+      }
+    }
+  }
+}
+/*!
+ * \brief Execute pipeline.
+ * \param runtimes A list of backend runtimes module.
+ * \param pipeline_config The dependency information of each graph executor 
module.
+ * \param serialize_mode If the execution is serialized.
+ */
+void PipelineScheduler::PipelineRun(const 
std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
+                                    ConfigPipelineExecution pipeline_config, 
bool serialize_mode) {
+  if (!serialize_mode) {
+    // TODO(huajsj) remove this check after all of pipeline features in.
+    LOG(FATAL) << "Currently Only supports serialized mode.";
+  } else {
+    PipelineRunSerial(runtimes, pipeline_config);
+  }
+}
+/*!
+ * \brief Stop the pipeline exection.
+ */
+void PipelineScheduler::PipelineStop() {
+  // TODO(huajsj) Remove this.
+  std::cout << __FUNCTION__ << std::endl;

Review comment:
       removed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to