This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch vm_rpc_support
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit de8be3464236a8e4161b6a6e3a2414e5166e5442
Author: Jared Roesch <[email protected]>
AuthorDate: Thu Mar 25 16:45:06 2021 -0700

    Get basic verison of VM RPC working
---
 include/tvm/runtime/vm/executable.h | 10 ++++-
 python/tvm/runtime/module.py        |  9 ++--
 src/relay/backend/vm/compiler.cc    |  1 +
 src/runtime/library_module.cc       | 43 +++++++++++--------
 src/runtime/library_module.h        |  3 ++
 src/runtime/vm/executable.cc        | 86 +++++++++++++++++++++++++++++++++++++
 src/runtime/vm/vm.cc                | 19 ++++++--
 7 files changed, 145 insertions(+), 26 deletions(-)

diff --git a/include/tvm/runtime/vm/executable.h 
b/include/tvm/runtime/vm/executable.h
index 8d3f651..1b300b5 100644
--- a/include/tvm/runtime/vm/executable.h
+++ b/include/tvm/runtime/vm/executable.h
@@ -64,6 +64,14 @@ class Executable : public ModuleNode {
   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) final;
 
   /*!
+   * \brief Save the entire executable to a binary stream.
+   * \param stream The binary stream to save to.
+   */
+  void SaveToBinary(dmlc::Stream* stream) final;
+
+  void SaveToFile(const std::string& path, const std::string& format) final;
+
+  /*!
    * \brief Serialize the executable into global section, constant section, and
    * code section.
    *
@@ -125,7 +133,7 @@ class Executable : public ModuleNode {
    * \brief Get the `lib` module in an executable. Users have the flexibility 
to call
    * `export_library` from the frontend to save the library to disk.
    *
-   * \return The runtime module that contains the hardwre dependent code.
+   * \return The runtime module that contains the hardware dependent code.
    */
   runtime::Module GetLib() const { return lib; }
 
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 09bef9e..712d728 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -269,10 +269,13 @@ class Module(object):
         return self._collect_from_import_tree(is_dso_exportable)
 
     def export_library(self, file_name, fcompile=None, addons=None, 
workspace_dir=None, **kwargs):
-        """Export the module and its imported device code one library.
+        """
+        Export the module and all imported modules into a single device 
library.
 
-        This function only works on host llvm modules.
-        It will pack all the imported modules
+        This function only works on host LLVM modules, other runtime::Module
+        subclasses DO NOT work with this API. If you do in fact have an LLVM
+        module, this API will pack the module with all imported modules into
+        a single binary library which can be used with TVM.
 
         Parameters
         ----------
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 9d3ffc5..c4f8ba6 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1167,6 +1167,7 @@ void VMCompiler::Codegen() {
     exec_->lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
   }
   exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods, 
target_host_);
+  exec_->Import(exec_->lib);
 }
 
 ExprDeviceMap VMCompiler::AnalyzeContext() const {
diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc
index 30ef214..9699a94 100644
--- a/src/runtime/library_module.cc
+++ b/src/runtime/library_module.cc
@@ -99,6 +99,29 @@ void InitContextFunctions(std::function<void*(const char*)> 
fgetsymbol) {
 #undef TVM_INIT_CONTEXT_FUNC
 }
 
+Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) 
{
+  std::string loadkey = "runtime.module.loadbinary_";
+  std::string fkey = loadkey + type_key;
+  const PackedFunc* f = Registry::Get(fkey);
+  if (f == nullptr) {
+    std::string loaders = "";
+    for (auto name : Registry::ListNames()) {
+      if (name.rfind(loadkey, 0) == 0) {
+        if (loaders.size() > 0) {
+          loaders += ", ";
+        }
+        loaders += name.substr(loadkey.size());
+      }
+    }
+    ICHECK(f != nullptr)
+        << "Binary was created using " << type_key
+        << " but a loader of that name is not registered. Available loaders 
are " << loaders
+        << ". Perhaps you need to recompile with this runtime enabled.";
+  }
+
+  return (*f)(static_cast<void*>(stream));
+}
+
 /*!
  * \brief Load and append module blob to module list
  * \param mblob The module blob.
@@ -133,25 +156,7 @@ runtime::Module ProcessModuleBlob(const char* mblob, 
ObjectPtr<Library> lib) {
       ICHECK(stream->Read(&import_tree_row_ptr));
       ICHECK(stream->Read(&import_tree_child_indices));
     } else {
-      std::string loadkey = "runtime.module.loadbinary_";
-      std::string fkey = loadkey + tkey;
-      const PackedFunc* f = Registry::Get(fkey);
-      if (f == nullptr) {
-        std::string loaders = "";
-        for (auto name : Registry::ListNames()) {
-          if (name.rfind(loadkey, 0) == 0) {
-            if (loaders.size() > 0) {
-              loaders += ", ";
-            }
-            loaders += name.substr(loadkey.size());
-          }
-        }
-        ICHECK(f != nullptr)
-            << "Binary was created using " << tkey
-            << " but a loader of that name is not registered. Available 
loaders are " << loaders
-            << ". Perhaps you need to recompile with this runtime enabled.";
-      }
-      Module m = (*f)(static_cast<void*>(stream));
+      auto m = LoadModuleFromBinary(tkey, stream);
       modules.emplace_back(m);
     }
   }
diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h
index 91918c1..75bd287 100644
--- a/src/runtime/library_module.h
+++ b/src/runtime/library_module.h
@@ -32,6 +32,9 @@
 
 namespace tvm {
 namespace runtime {
+
+Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream);
+
 /*!
  * \brief Library is the common interface
  *  for storing data in the form of shared libaries.
diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc
index 6992097..3246f19 100644
--- a/src/runtime/vm/executable.cc
+++ b/src/runtime/vm/executable.cc
@@ -38,6 +38,8 @@
 #include <vector>
 
 #include "serialize_utils.h"
+#include "../library_module.h"
+#include "../file_utils.h"
 
 namespace tvm {
 namespace runtime {
@@ -74,6 +76,12 @@ PackedFunc Executable::GetFunction(const std::string& name, 
const ObjectPtr<Obje
       int index = args[1];
       *rv = this->GetFunctionParameterName(func_name, index);
     });
+  } else if (name == "create_vm") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      auto vm = make_object<VirtualMachine>();
+      vm->LoadExecutable(this);
+      *rv = Module(vm);
+    });
   } else {
     LOG(FATAL) << "Unknown packed function: " << name;
     return PackedFunc(nullptr);
@@ -476,11 +484,17 @@ void LoadHeader(dmlc::Stream* strm) {
 }
 
 runtime::Module Executable::Load(const std::string& code, const 
runtime::Module lib) {
+  std::cout << "code: " << code.size() << std::endl;
   auto exec = make_object<Executable>();
   exec->lib = lib;
   exec->code_ = code;
   dmlc::MemoryStringStream strm(&exec->code_);
 
+  if (lib.defined()) {
+    std::cout << "Importing: " << std::endl;
+    exec->Import(lib);
+  }
+
   // Load header.
   LoadHeader(&strm);
 
@@ -765,6 +779,78 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) {
   }
 }
 
+void Executable::SaveToBinary(dmlc::Stream* stream) {
+  auto code_bytes = this->Save();
+  std::string code(code_bytes.data, code_bytes.size);
+  stream->Write(code);
+
+  CHECK(this->lib.defined())
+    << "the library must be defined before serialization";
+
+  // this->lib->SaveToBinary(stream);
+  // std::vector<std::string> names;
+  // std::vector<DLTensor*> arrays;
+  // for (const auto& v : params_) {
+  //   names.emplace_back(v.first);
+  //   arrays.emplace_back(const_cast<DLTensor*>(v.second.operator->()));
+  // }
+  // uint64_t sz = arrays.size();
+  // ICHECK(sz == names.size());
+  // stream->Write(sz);
+  // stream->Write(names);
+  // for (size_t i = 0; i < sz; ++i) {
+  //   tvm::runtime::SaveDLTensor(stream, arrays[i]);
+  // }
+}
+
+Module ExecutableLoadBinary(void* strm) {
+  dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+  std::string code;
+  stream->Read(&code);
+  auto exec = Executable::Load(code, Module());
+  auto exec_node = exec.as<Executable>();
+  std::cout << exec_node->primitive_map.size() << std::endl;
+
+  // // std::unordered_map<std::string, tvm::runtime::NDArray> params;
+  // std::string module_name;
+  // ICHECK(stream->Read(&graph_json));
+  // uint64_t sz;
+  // ICHECK(stream->Read(&sz));
+  // std::vector<std::string> names;
+  // ICHECK(stream->Read(&names));
+  // ICHECK(sz == names.size());
+  // for (size_t i = 0; i < sz; ++i) {
+  //   tvm::runtime::NDArray temp;
+  //   temp.Load(stream);
+  //   params[names[i]] = temp;
+  // }
+
+  return exec;
+}
+
+void Executable::SaveToFile(const std::string& path, const std::string& 
format) {
+  std::string data;
+  dmlc::MemoryStringStream writer(&data);
+  dmlc::SeekStream* strm = &writer;
+  SaveToBinary(strm);
+  SaveBinaryToFile(path, data);
+}
+
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable")
+    .set_body_typed(ExecutableLoadBinary);
+
+  // Load module from module.
+Module ExecutableLoadFile(const std::string& file_name, const std::string& 
format) {
+  std::string data;
+  LoadBinaryFromFile(file_name, &data);
+  dmlc::MemoryStringStream reader(&data);
+  dmlc::Stream* strm = &reader;
+  auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(strm));
+  return exec;
+}
+
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_VMExecutable").set_body_typed(ExecutableLoadFile);
+
 TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, 
TVMRetValue* rv) {
   runtime::Module mod = args[0];
   const auto* exec = dynamic_cast<Executable*>(mod.operator->());
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 4683398..03bc20b 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -281,11 +281,24 @@ void VirtualMachine::LoadExecutable(const Executable* 
exec) {
   ICHECK(exec) << "The executable is not created yet.";
   exec_ = exec;
 
-  runtime::Module lib = exec_->lib;
+  runtime::Module lib;
+  if (exec_->lib.defined()) {
+    lib = exec_->lib;
+  } else {
+    ICHECK(exec_->imports().size() > 0)
+      << "fix";
+    lib = exec_->imports()[0];
+  }
+
   // Get the list of packed functions.
-  ICHECK(exec->primitive_map.empty() || lib.operator->())
-      << "runtime module should have been built for primitive functions"
+  ICHECK(!exec->primitive_map.empty())
+      << "runtime module primitive map is empty"
       << "\n";
+
+  ICHECK(lib.operator->())
+      << "library is null"
+      << "\n";
+
   for (const auto& it : exec_->primitive_map) {
     const auto& packed_name = it.first;
     auto packed_index = static_cast<size_t>(it.second);

Reply via email to