huajsj commented on a change in pull request #9494:
URL: https://github.com/apache/tvm/pull/9494#discussion_r748003421



##########
File path: src/runtime/pipeline/pipeline_struct.h
##########
@@ -134,37 +220,316 @@ struct OutputMap {
     }
   }
 };
+
+/*!
+ * \brief A map of the global module input interfaces and the graph modudles 
input interfaces.
+ */
+struct InputConnectionConfig {
+  /*!\brief The key is the name of global module input interfaces. the value 
is the pair of
+   * the index of a graph module and the name of a graph module input 
interface.
+   */
+  std::unordered_map<std::string, std::pair<int, std::string>> 
input_connection;
+  bool Empty() { return input_connection.empty(); }
+  std::pair<int, std::string> operator[](const std::string key) {
+    if (input_connection.find(key) == input_connection.end()) {
+      LOG(FATAL) << "Not find the key " << key;
+    }
+    return input_connection[key];
+  }
+
+  size_t size() const { return input_connection.size(); }
+  /*!
+   * \brief Create a input connection config from JSONReader.
+   * \param reader Json reader.
+   */
+  void Load(dmlc::JSONReader* reader) {
+    reader->BeginArray();
+    while (reader->NextArrayItem()) {
+      reader->BeginObject();
+      std::string key;
+      std::string global_interface_name;
+      std::string module_interface_name;
+      int mod_idx = -1;
+      while (reader->NextObjectItem(&key)) {
+        if (key == "global_interface_name") {
+          reader->Read(&global_interface_name);
+        } else if (key == "mod_idx") {
+          reader->Read(&mod_idx);
+        } else if (key == "module_interface_name") {
+          reader->Read(&module_interface_name);
+        } else {
+          LOG(FATAL) << "do not support key " << key;
+        }
+      }
+      ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx;
+      ICHECK(!global_interface_name.empty()) << "Invalid global interface name 
value";
+      ICHECK(!module_interface_name.empty()) << "Invalid module interface name 
value";
+      input_connection[global_interface_name] = make_pair(mod_idx, 
module_interface_name);
+    }
+  }
+};
+/*!
+ * \brief A map of the global module param interfaces and the graph modudles 
param.
+ */
+struct ParamConnectionConfig {
+  /*!\brief The key is the name of global module param interfaces. the value 
is the
+   * index of a graph module.
+   */
+  std::unordered_map<std::string, int> param_connection;
+  bool Empty() { return param_connection.empty(); }
+  int operator[](const std::string key) {
+    if (param_connection.find(key) == param_connection.end()) {
+      LOG(FATAL) << "do not support key " << key;
+    }
+    return param_connection[key];
+  }
+  /*!
+   * \brief Create a param connection config from JSONReader.
+   * \param reader Json reader.
+   */
+  void Load(dmlc::JSONReader* reader) {
+    reader->BeginArray();
+    while (reader->NextArrayItem()) {
+      reader->BeginObject();
+      std::string key;
+      std::string global_param_name;
+      int mod_idx = -1;
+      while (reader->NextObjectItem(&key)) {
+        if (key == "global_param_name") {
+          reader->Read(&global_param_name);
+        } else if (key == "mod_idx") {
+          reader->Read(&mod_idx);
+        } else {
+          LOG(FATAL) << "do not support key " << key;
+        }
+      }
+      ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx;
+      ICHECK(!global_param_name.empty()) << "Invalid global param name value";
+      param_connection[global_param_name] = mod_idx;
+    }
+  }
+};
 /*!
  * \brief The binding or dependency information of each module output 
interface.
  */
-struct PipelineConfig {
-  /*!\brief The key is the module index, this variable records all module 
pipeline configuration
+class ConfigPipelineExecution {
+ private:
+  /*
+   *!\brief The key is the module index, this variable records all module 
pipeline configuration
    * information.
    */
-  std::unordered_map<int, OutputMap> config;
-  OutputMap& operator[](int key) {
+  std::unordered_map<int, ConfigOutputBindings> config;
+  /*
+   *\brief The key is the global output index, this variable records the 
mapping of global output
+   * and the module output.
+   */
+  std::unordered_map<int, ModuleOutputPair> global_output_map;
+  /*
+   *\brief The number of binding of module outputs and inputs.
+   */
+  size_t module_input_output_binding_total_num;
+
+ public:
+  ConfigOutputBindings& operator[](int key) {
     ICHECK(config.find(key) != config.end());
     return config[key];
   }
+  /*!
+   *\brief Check if the module index existing in the "config".
+   */
+  bool FindModuleInConfig(int mod_idx) { return config.find(mod_idx) != 
config.end(); }
+  /*!
+   *\brief Build the mapping of key and "ConfigOutputBindings", key is module 
index.
+   */
+  void Insert(int key, const ConfigOutputBindings& map) { config[key] = map; }
 
-  void Insert(int key, const OutputMap& map) { config[key] = map; }
-
-  /*!\brief This function is used to verify whether config is loaded 
successfully.
+  /*
+   *!\brief This function is used to verify whether config is loaded 
successfully.
    * \return Return true to indicate that this class has not been successfully 
loaded.
    */
   bool Empty() { return config.empty(); }
-
   /*!
    * \brief Get the number of global outputs.
    * \return The number of outputs the entire pipeline has.
    */
   size_t GetGlobalOutputNum() const {
-    size_t num_output = 0;
+    // The number of pipeline outputs is the size of "global_output_map";
+    return global_output_map.size();
+  }
+  /*
+   *!\brief Get the map of global outputs and module outputs.
+   */
+  std::unordered_map<int, ModuleOutputPair>& 
GetGlobalConfigOutputBindings(void) {
+    return global_output_map;
+  }
+  /*
+   *!\brief Get the number of module output and module input bindings.
+   */
+  size_t GetInputOutputBindingNum() const { return 
module_input_output_binding_total_num; }
+  /*
+   *!\brief Parse the config to construct data struct using in pipeline 
execution.
+   */
+  void ParseConfiguration(const std::unordered_map<int, ConfigOutputBindings>& 
config) {
+    if (config.empty()) {
+      LOG(FATAL) << "The Configuration loading not finish yet.";
+    }
+    module_input_output_binding_total_num = 0;
     for (auto mod_output : config) {
-      num_output += mod_output.second.GetGlobalOutputNum();
+      // Get the numbers of binding of input and output.
+      module_input_output_binding_total_num += 
mod_output.second.GetInputOutputBindingNum();
+      // Use global output index as key to create a mapping of global index 
and module output.
+      const std::vector<GlobalOutputPair>& global_output =
+          mod_output.second.GetGlobalConfigOutputBindings();
+
+      for (auto output : global_output) {
+        global_output_map[output.global_output_idx] =
+            ModuleOutputPair(mod_output.first, output.mod_output_idx);
+      }
     }
-    return num_output;
+    return;
+  }
+  /*!
+   * \brief Create a pipeline config from JSONReader.
+   * \param reader Json reader.
+   */
+  void Load(dmlc::JSONReader* reader) {
+    reader->BeginArray();
+    while (reader->NextArrayItem()) {
+      std::string key;
+      reader->BeginObject();
+      int mod_idx = -1;
+      ConfigOutputBindings output;
+      std::string dev;
+      while (reader->NextObjectItem(&key)) {
+        if (key == "mod_idx") {
+          reader->Read(&mod_idx);
+        } else if (key == "dev") {
+          reader->Read(&dev);
+        } else if (key == "output") {
+          reader->Read(&output);
+        } else {
+          LOG(FATAL) << "do not support key " << key;
+        }
+      }
+      ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx;
+      // Check if the output is successfully read.
+      ICHECK(!output.Empty()) << "Invalid output binding result.";
+      Insert(mod_idx, output);
+    }
+    // Call this function after "config" loading finished.
+    ParseConfiguration(config);
+  }
+};
+/*
+ *\brief Runtime of backend.
+ */
+class BackendRuntime {
+ private:
+  /*\brief The index of runtime indicate the position in the pipeline.*/
+  int runtime_idx;
+  /*\brief The Runtime module of a backedn graph executor.*/
+  Module module;
+  /*!
+   *\brief To transfer data between two different backends, we need a local
+   * tensor variable as a medium. This variable is a mapping of input data and 
local
+   * data.
+   */
+  std::unordered_map<DLTensor*, DLTensor*> input_tensor_local_copy;
+  /*!\brief The packed functions.*/
+  tvm::runtime::PackedFunc run;
+  tvm::runtime::PackedFunc set_input;
+  tvm::runtime::PackedFunc get_input;
+  tvm::runtime::PackedFunc get_output;
+  tvm::runtime::PackedFunc get_num_output;
+  tvm::runtime::PackedFunc get_num_inputs;
+  tvm::runtime::PackedFunc get_input_index;
+  /*!\brief The new DLTensor have same shape, data type with a existing 
DLTensor.*/
+  DLTensor* CreateFromDLTensor(const DLTensor* from) {
+    DLTensor* ret = NULL;
+    TVMArrayAlloc(from->shape, from->ndim, from->dtype.code, from->dtype.bits, 
from->dtype.lanes,
+                  kDLCPU, 0, &ret);

Review comment:
       fixed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to