ANSHUMAN87 commented on a change in pull request #5753:
URL: https://github.com/apache/incubator-tvm/pull/5753#discussion_r449782535



##########
File path: python/tvm/runtime/module.py
##########
@@ -282,7 +292,19 @@ def export_library(self,
             self.save(file_name)
             return
 
-        modules = self._collect_dso_modules()
+        graph_runtime_factory_modules = 
self._collect_modules("GraphRuntimeFactory")

Review comment:
       Nit: As "GraphRuntimeFactory" is a specific type of key used for this 
purpose, may be we maintain a func which returns this key in 
graph_runtime_factory.py, so that later it can be easily up-gradable.

##########
File path: python/tvm/runtime/module.py
##########
@@ -282,7 +292,19 @@ def export_library(self,
             self.save(file_name)
             return
 
-        modules = self._collect_dso_modules()
+        graph_runtime_factory_modules = 
self._collect_modules("GraphRuntimeFactory")
+        for index, module in enumerate(graph_runtime_factory_modules):
+            if not package_params:
+                module.get_function("diable_package_params")()
+                path_params = os.path.join(os.path.dirname(file_name), 
"deploy_" + str(index) + ".params")

Review comment:
       I think "deploy_1.params" will be difficult for user to relate to 
specific module based on keys as the indexing is internal to TVM implementation.
   
   May be we use "deploy_key.params" ?
   In real world example it will be like "deploy_resnet18.params".

##########
File path: src/runtime/graph/graph_runtime_factory.cc
##########
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph_runtime_factory.cc
+ * \brief Graph runtime factory implementations
+ */
+
+#include "./graph_runtime_factory.h"
+
+#include <tvm/node/container.h>
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+#include <iterator>
+#include <vector>
+
+#include "./graph_runtime.h"
+
+namespace tvm {
+namespace runtime {
+
+void GraphRuntimeFactory::Init(const std::string& kind, const std::string& 
graph_json,
+                               const std::unordered_map<std::string, 
tvm::runtime::NDArray>& params,
+                               const std::string& module_name) {
+  kind_ = kind;
+  graph_json_ = graph_json;
+  params_ = params;
+  module_name_ = module_name;
+  graph_runtime_factory_module_list_.push_back(module_name_);

Review comment:
       Should we clear this list before inserting ?

##########
File path: python/tvm/runtime/module.py
##########
@@ -222,29 +223,31 @@ def evaluator(*args):
         except NameError:
             raise NameError("time_evaluate is only supported when RPC is 
enabled")
 
-    def _collect_dso_modules(self):
-        """Helper function to collect dso modules, then return it."""
-        visited, stack, dso_modules = set(), [], []
+    def _collect_modules(self, module_type_keys):
+        """Helper function to collect specifit modules, then return it."""
+        visited, stack, modules = set(), [], []
+        type_keys = module_type_keys if isinstance(module_type_keys, (list, 
tuple)) else [module_type_keys]
         # append root module
         visited.add(self)
         stack.append(self)
         while stack:
             module = stack.pop()
-            if module._dso_exportable():
-                dso_modules.append(module)
+            if module.type_key in type_keys:
+                modules.append(module)
             for m in module.imported_modules:
                 if m not in visited:
                     visited.add(m)
                     stack.append(m)
-        return dso_modules
+        return modules
 
     def _dso_exportable(self):

Review comment:
       May be we change the func name now, as the behavior also has changed 
from boolean --> [type_keys] ?

##########
File path: python/tvm/runtime/module.py
##########
@@ -222,29 +223,31 @@ def evaluator(*args):
         except NameError:
             raise NameError("time_evaluate is only supported when RPC is 
enabled")
 
-    def _collect_dso_modules(self):
-        """Helper function to collect dso modules, then return it."""
-        visited, stack, dso_modules = set(), [], []
+    def _collect_modules(self, module_type_keys):
+        """Helper function to collect specifit modules, then return it."""

Review comment:
       specifit -> specific

##########
File path: src/runtime/graph/graph_runtime.h
##########
@@ -94,6 +98,7 @@ class TVM_DLL GraphRuntime : public ModuleNode {
    *  processor.
    * \param ctxs The context of the host and devices where graph nodes will be
    *  executed on.
+   * \param params The params of graph.

Review comment:
       Should we add params here?

##########
File path: python/tvm/runtime/module.py
##########
@@ -282,7 +292,19 @@ def export_library(self,
             self.save(file_name)
             return
 
-        modules = self._collect_dso_modules()
+        graph_runtime_factory_modules = 
self._collect_modules("GraphRuntimeFactory")
+        for index, module in enumerate(graph_runtime_factory_modules):
+            if not package_params:
+                module.get_function("diable_package_params")()
+                path_params = os.path.join(os.path.dirname(file_name), 
"deploy_" + str(index) + ".params")
+                from tvm import relay
+                with open(path_params, "wb") as fo:
+                    graph_params = {}
+                    for k, v in module.get_function("get_params")().items():
+                        graph_params[k] = v
+                    fo.write(relay.save_param_dict(graph_params))

Review comment:
       Should we give info to user about the files written, because 
package_params is disabled ?

##########
File path: src/runtime/graph/graph_runtime_factory.cc
##########
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph_runtime_factory.cc
+ * \brief Graph runtime factory implementations
+ */
+
+#include "./graph_runtime_factory.h"
+
+#include <tvm/node/container.h>
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+#include <iterator>
+#include <vector>
+
+#include "./graph_runtime.h"
+
+namespace tvm {
+namespace runtime {
+
+void GraphRuntimeFactory::Init(const std::string& kind, const std::string& 
graph_json,
+                               const std::unordered_map<std::string, 
tvm::runtime::NDArray>& params,
+                               const std::string& module_name) {
+  kind_ = kind;
+  graph_json_ = graph_json;
+  params_ = params;
+  module_name_ = module_name;
+  graph_runtime_factory_module_list_.push_back(module_name_);
+}
+
+void GraphRuntimeFactory::ImportModule(Module other) {
+  this->Import(other);
+  auto module = other.as<GraphRuntimeFactory>();
+  CHECK(module) << "should only import graph runtime factory module";
+  graph_runtime_factory_module_list_.push_back(module->GetModuleName());
+}
+
+PackedFunc GraphRuntimeFactory::GetFunction(
+    const std::string& name, const 
tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) {
+  if (name == "runtime_create") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      std::vector<TVMContext> contexts;
+      TVMContext ctx;
+      // arg is: module, ctxs
+      CHECK_EQ((args.size() - 1) % 2, 0);
+      for (int i = 1; i < args.num_args; i += 2) {
+        int dev_type = args[i];
+        ctx.device_type = static_cast<DLDeviceType>(dev_type);
+        ctx.device_id = args[i + 1];
+        contexts.push_back(ctx);
+      }
+      *rv = this->RuntimeCreate(args[0], contexts);
+    });
+  } else if (name == "import_module") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK_EQ(args.size(), 1);
+      this->ImportModule(args[0]);
+    });
+  } else if (name == "select_module") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK_EQ(args.size(), 1);
+      *rv = this->SelectModule(args[0]);
+    });
+  } else if (name == "get_json") {
+    return PackedFunc(
+        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->GetJson(); });
+  } else if (name == "get_lib") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK_GT(this->imports().size(), 0);
+      *rv = this->GetLib();
+    });
+  } else if (name == "get_params") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      Map<String, tvm::runtime::NDArray> ret;
+      for (const auto& kv : this->GetParams()) {
+        ret.Set(kv.first, kv.second);
+      }
+      *rv = ret;
+    });
+  } else if (name == "diable_package_params") {
+    return PackedFunc(
+        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { 
this->package_params_ = false; });
+  } else {
+    return PackedFunc();
+  }
+}
+
+void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) {
+  stream->Write(graph_runtime_factory_module_list_);
+  stream->Write(kind_);
+  stream->Write(graph_json_);
+  stream->Write(package_params_);
+  if (package_params_) {
+    std::vector<std::string> names;
+    std::vector<DLTensor*> arrays;
+    for (const auto& v : params_) {
+      names.emplace_back(v.first);
+      arrays.emplace_back(const_cast<DLTensor*>(v.second.operator->()));
+    }
+    uint64_t sz = arrays.size();
+    CHECK(sz == names.size());
+    stream->Write(sz);
+    stream->Write(names);
+    for (size_t i = 0; i < sz; ++i) {
+      tvm::runtime::SaveDLTensor(stream, arrays[i]);
+    }
+  }
+}
+
+Module GraphRuntimeFactory::RuntimeCreate(Module module, const 
std::vector<TVMContext>& ctxs) {
+  auto factory_module = module.as<GraphRuntimeFactory>();
+  CHECK(factory_module != nullptr);
+  if (factory_module->GetKind() == "graph") {
+    auto exec = make_object<GraphRuntime>();
+    exec->Init(factory_module->GetJson(), factory_module->GetLib(), ctxs);
+    exec->SetParams(factory_module->GetParams());
+    return Module(exec);
+  }
+
+  return Module();
+}
+
+Module GraphRuntimeFactory::SelectModule(const std::string& name) {
+  auto iter = std::find(graph_runtime_factory_module_list_.begin(),
+                        graph_runtime_factory_module_list_.end(), name);
+  CHECK(iter != graph_runtime_factory_module_list_.end());
+  if (iter == graph_runtime_factory_module_list_.begin()) {
+    auto exec = make_object<GraphRuntimeFactory>();
+    exec->Init(this->GetKind(), this->GetJson(), this->GetParams());
+    exec->Import(this->GetLib());
+    return Module(exec);
+  } else {
+    return 
this->imports()[std::distance(graph_runtime_factory_module_list_.begin(), 
iter)];
+  }
+}
+
+Module GraphRuntimeFactoryModuleLoadBinary(void* strm) {
+  dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+  std::vector<std::string> graph_runtime_factory_module_list;
+  std::string kind;
+  std::string graph_json;
+  bool package_params;
+  std::unordered_map<std::string, tvm::runtime::NDArray> params;
+  CHECK(stream->Read(&graph_runtime_factory_module_list));
+  CHECK(stream->Read(&kind));
+  CHECK(stream->Read(&graph_json));
+  CHECK(stream->Read(&package_params));
+  if (package_params) {
+    uint64_t sz;
+    CHECK(stream->Read(&sz));
+    std::vector<std::string> names;
+    CHECK(stream->Read(&names));
+    CHECK(sz == names.size());
+    for (size_t i = 0; i < sz; ++i) {
+      tvm::runtime::NDArray temp;
+      temp.Load(stream);
+      params[names[i]] = temp;
+    }
+  }
+  auto exec = make_object<GraphRuntimeFactory>();
+  exec->Init(kind, graph_json, params);
+  exec->SetGraphRuntimeFactoryModuleList(graph_runtime_factory_module_list);
+  return Module(exec);
+}
+
+Module RuntimeCreate(Module module, const std::vector<TVMContext>& ctxs) {
+  auto mod = module.as<GraphRuntimeFactory>();
+  CHECK(mod != nullptr);
+  if (mod->GetKind() == "graph") {
+    auto exec = make_object<GraphRuntime>();
+    exec->Init(mod->GetJson(), mod->GetLib(), ctxs);
+    exec->SetParams(mod->GetParams());
+    return Module(exec);
+  } else {
+    LOG(ERROR) << "Doesn't support graph kind of " << mod->GetKind();
+  }
+
+  return Module();
+}
+
+TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs 
args, TVMRetValue* rv) {
+  CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
+                                "graph_runtime_factory.create needs at least 
3, "

Review comment:
       The check is for 4 arguments, but the error message is 3, please check 
once.

##########
File path: src/runtime/graph/graph_runtime.h
##########
@@ -64,7 +68,7 @@ struct TVMOpParam {
  *  This runtime can be acccesibly in various language via
  *  TVM runtime PackedFunc API.
  */
-class TVM_DLL GraphRuntime : public ModuleNode {
+class TVM_DLL GraphRuntime : public GraphRuntimeFactory {

Review comment:
       I am sorry! I am little confused here. Do we really need properties of 
GraphRuntimeFactory in GraphRuntime ?
   As i understand  GraphRuntime should be emitted by GraphRuntimeFactory, not 
the reverse. May be i did not get it clearly. Would you please help me 
understand it better. Thanks!

##########
File path: src/runtime/graph/graph_runtime_factory.cc
##########
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph_runtime_factory.cc
+ * \brief Graph runtime factory implementations
+ */
+
+#include "./graph_runtime_factory.h"
+
+#include <tvm/node/container.h>
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+#include <iterator>
+#include <vector>
+
+#include "./graph_runtime.h"
+
+namespace tvm {
+namespace runtime {
+
+void GraphRuntimeFactory::Init(const std::string& kind, const std::string& 
graph_json,
+                               const std::unordered_map<std::string, 
tvm::runtime::NDArray>& params,
+                               const std::string& module_name) {
+  kind_ = kind;
+  graph_json_ = graph_json;
+  params_ = params;
+  module_name_ = module_name;
+  graph_runtime_factory_module_list_.push_back(module_name_);
+}
+
+void GraphRuntimeFactory::ImportModule(Module other) {
+  this->Import(other);
+  auto module = other.as<GraphRuntimeFactory>();
+  CHECK(module) << "should only import graph runtime factory module";
+  graph_runtime_factory_module_list_.push_back(module->GetModuleName());
+}
+
+PackedFunc GraphRuntimeFactory::GetFunction(
+    const std::string& name, const 
tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) {
+  if (name == "runtime_create") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      std::vector<TVMContext> contexts;
+      TVMContext ctx;
+      // arg is: module, ctxs
+      CHECK_EQ((args.size() - 1) % 2, 0);
+      for (int i = 1; i < args.num_args; i += 2) {
+        int dev_type = args[i];
+        ctx.device_type = static_cast<DLDeviceType>(dev_type);
+        ctx.device_id = args[i + 1];
+        contexts.push_back(ctx);
+      }
+      *rv = this->RuntimeCreate(args[0], contexts);
+    });
+  } else if (name == "import_module") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK_EQ(args.size(), 1);
+      this->ImportModule(args[0]);
+    });
+  } else if (name == "select_module") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK_EQ(args.size(), 1);
+      *rv = this->SelectModule(args[0]);
+    });
+  } else if (name == "get_json") {
+    return PackedFunc(
+        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->GetJson(); });
+  } else if (name == "get_lib") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK_GT(this->imports().size(), 0);
+      *rv = this->GetLib();
+    });
+  } else if (name == "get_params") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      Map<String, tvm::runtime::NDArray> ret;
+      for (const auto& kv : this->GetParams()) {
+        ret.Set(kv.first, kv.second);
+      }
+      *rv = ret;
+    });
+  } else if (name == "diable_package_params") {
+    return PackedFunc(
+        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { 
this->package_params_ = false; });
+  } else {
+    return PackedFunc();
+  }
+}
+
+void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) {
+  stream->Write(graph_runtime_factory_module_list_);

Review comment:
       Should we add the module list at the end, after meta data of 
GraphRuntimeFactory? May be we can get other's opinion too.

##########
File path: src/runtime/graph/graph_runtime_factory.h
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/graph_runtime_factory.h
+ * \brief Graph runtime factory creating graph runtime.
+ */
+
+#ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_
+#define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_
+
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+
+class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode {
+ public:
+  /*!
+   * \brief Initialize the GraphRuntimeFactory with graph and context.
+   * \param graph_json The execution graph.
+   * \param params The params of graph.
+   * \param kind The runtime kind to be created.
+   */
+  void Init(const std::string& kind, const std::string& graph_json,
+            const std::unordered_map<std::string, tvm::runtime::NDArray>& 
params,
+            const std::string& module_name = "default");
+
+  /*!
+   * \brief Import other GraphRuntimeFactory module.
+   * \param other The GraphRuntimeFactory module we want to import.
+   */
+  void ImportModule(Module other);
+
+  /*!
+   * \brief Get member function to front-end
+   * \param name The name of the function.
+   * \param sptr_to_self The pointer to the module node.
+   * \return The corresponding member function.
+   */
+  virtual PackedFunc GetFunction(const std::string& name, const 
ObjectPtr<Object>& sptr_to_self);
+
+  /*!
+   * \return The type key of the executor.
+   */
+  const char* type_key() const override { return "GraphRuntimeFactory"; }
+
+  /*!
+   * \brief Save the module to binary stream.
+   * \param stream The binary stream to save to.
+   */
+  void SaveToBinary(dmlc::Stream* stream) override;
+
+  /*!
+   * \brief Create a specific runtime module
+   * \param module The module we will be used for creating runtime
+   * \param ctxs The context of the host and devices where graph nodes will be
+   *  executed on.
+   * \return created runtime module
+   */
+  Module RuntimeCreate(Module module, const std::vector<TVMContext>& ctxs);
+
+  /*!
+   * \brief Select the specific module
+   * \param name The name of the module
+   * \return selected module
+   */
+  Module SelectModule(const std::string& name);
+
+  const std::string& GetJson() const { return graph_json_; }
+
+  std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() const { 
return params_; }
+
+  Module GetLib() const {
+    CHECK_GT(this->imports().size(), 0);
+    return this->imports_[0];
+  }
+
+  const std::string& GetKind() const { return kind_; }
+
+  const std::string& GetModuleName() const { return module_name_; }
+
+  const std::vector<std::string>& GetGraphRuntimeFactoryModuleList() const {
+    return graph_runtime_factory_module_list_;
+  }
+
+  void SetGraphRuntimeFactoryModuleList(
+      const std::vector<std::string>& graph_runtime_factory_module_list) {
+    graph_runtime_factory_module_list_ = graph_runtime_factory_module_list;
+  }
+
+ protected:
+  /*! \brief The execution graph. */
+  std::string graph_json_;
+  /*! \brief The params. */
+  std::unordered_map<std::string, tvm::runtime::NDArray> params_;

Review comment:
       I have one small query here. Do these params contains for all the 
modules imported or only for the first one ?

##########
File path: src/runtime/graph/graph_runtime_factory.cc
##########
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph_runtime_factory.cc
+ * \brief Graph runtime factory implementations
+ */
+
+#include "./graph_runtime_factory.h"
+
+#include <tvm/node/container.h>
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+#include <iterator>
+#include <vector>
+
+#include "./graph_runtime.h"
+
+namespace tvm {
+namespace runtime {
+
+void GraphRuntimeFactory::Init(const std::string& kind, const std::string& 
graph_json,
+                               const std::unordered_map<std::string, 
tvm::runtime::NDArray>& params,
+                               const std::string& module_name) {
+  kind_ = kind;
+  graph_json_ = graph_json;
+  params_ = params;
+  module_name_ = module_name;
+  graph_runtime_factory_module_list_.push_back(module_name_);
+}
+
+void GraphRuntimeFactory::ImportModule(Module other) {
+  this->Import(other);
+  auto module = other.as<GraphRuntimeFactory>();
+  CHECK(module) << "should only import graph runtime factory module";
+  graph_runtime_factory_module_list_.push_back(module->GetModuleName());
+}
+
+PackedFunc GraphRuntimeFactory::GetFunction(
+    const std::string& name, const 
tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) {
+  if (name == "runtime_create") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      std::vector<TVMContext> contexts;
+      TVMContext ctx;
+      // arg is: module, ctxs
+      CHECK_EQ((args.size() - 1) % 2, 0);
+      for (int i = 1; i < args.num_args; i += 2) {
+        int dev_type = args[i];
+        ctx.device_type = static_cast<DLDeviceType>(dev_type);
+        ctx.device_id = args[i + 1];
+        contexts.push_back(ctx);
+      }
+      *rv = this->RuntimeCreate(args[0], contexts);
+    });
+  } else if (name == "import_module") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK_EQ(args.size(), 1);
+      this->ImportModule(args[0]);
+    });
+  } else if (name == "select_module") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK_EQ(args.size(), 1);
+      *rv = this->SelectModule(args[0]);
+    });
+  } else if (name == "get_json") {
+    return PackedFunc(
+        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->GetJson(); });
+  } else if (name == "get_lib") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK_GT(this->imports().size(), 0);

Review comment:
       Same check exists in GetLib() too, is it required?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to