This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fd18751  Add support for using the VM across the RPC boundary.  (#7746)
fd18751 is described below

commit fd18751e684f96df888184ab1796b929b32c86f2
Author: Jared Roesch <[email protected]>
AuthorDate: Tue Mar 30 02:01:40 2021 -0700

    Add support for using the VM across the RPC boundary.  (#7746)
    
    * Get basic verison of VM RPC working
    
    * Test case passes
    
    * Clean up PR
    
    * Lint
    
    * Format
    
    * Address Andrew R and TK feedback
    
    * Add comment for Andrew
    
    * Address Zhi's comment
    
    * Format
    
    * Fix broken test
---
 include/tvm/runtime/vm/executable.h | 34 ++++++++++++++---
 python/tvm/runtime/module.py        | 23 ++++++++---
 python/tvm/runtime/vm.py            | 37 +++++++++++++++++-
 src/relay/backend/vm/compiler.cc    |  8 ++--
 src/runtime/library_module.cc       | 42 ++++++++++----------
 src/runtime/library_module.h        | 12 ++++++
 src/runtime/vm/executable.cc        | 76 ++++++++++++++++++++++++++++++++++++-
 src/runtime/vm/vm.cc                |  9 +++--
 tests/python/relay/test_vm.py       | 46 +++++++++++++++++++++-
 9 files changed, 245 insertions(+), 42 deletions(-)

diff --git a/include/tvm/runtime/vm/executable.h 
b/include/tvm/runtime/vm/executable.h
index 8d3f651..95c6d6f 100644
--- a/include/tvm/runtime/vm/executable.h
+++ b/include/tvm/runtime/vm/executable.h
@@ -64,6 +64,19 @@ class Executable : public ModuleNode {
   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) final;
 
   /*!
+   * \brief Write the Executable to the binary stream in serialized form.
+   * \param stream The binary stream to save the executable to.
+   */
+  void SaveToBinary(dmlc::Stream* stream) final;
+
+  /*!
+   * \brief Write the Executable to the provided path as a file contianing its 
serialized content.
+   * \param path The path to write the serialized data to.
+   * \param format The format of the serialized blob.
+   */
+  void SaveToFile(const std::string& path, const std::string& format) final;
+
+  /*!
    * \brief Serialize the executable into global section, constant section, and
    * code section.
    *
@@ -125,12 +138,24 @@ class Executable : public ModuleNode {
    * \brief Get the `lib` module in an executable. Users have the flexibility 
to call
    * `export_library` from the frontend to save the library to disk.
    *
-   * \return The runtime module that contains the hardwre dependent code.
+   * \return The runtime module that contains the hardware dependent code.
+   */
+  runtime::Module GetLib() const;
+
+  /*!
+   * \brief Set the `lib` module in an executable.
+   *
+   * This allows us to do partial initialization in the case of 
(de|ser)ialization cases.
+   * This method also ensures correct initialization of library ensuring we 
only Import a
+   * single library.
+   *
+   * NB: This also provides some abstraction over how libraries are stored as 
there are plans
+   * to iterate on the way runtime::Module works in the backend of the 
compiler.
    */
-  runtime::Module GetLib() const { return lib; }
+  void SetLib(const runtime::Module& lib);
 
   /*!
-   * \brief Get the arity of the VM Fucntion.
+   * \brief Get the arity of the VMFunction.
    * \param func Function name.
    * \return The number of parameters.
    */
@@ -148,9 +173,6 @@ class Executable : public ModuleNode {
 
   const char* type_key() const final { return "VMExecutable"; }
 
-  /*! \brief The runtime module/library that contains both the host and also 
the device
-   * code when executing on non-CPU devices. */
-  runtime::Module lib;
   /*! \brief The global constant pool. */
   std::vector<ObjectRef> constants;
   /*! \brief A map from globals (as strings) to their index in the function 
map. */
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 5165ae0..d36554b 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -269,10 +269,14 @@ class Module(object):
         return self._collect_from_import_tree(is_dso_exportable)
 
     def export_library(self, file_name, fcompile=None, addons=None, 
workspace_dir=None, **kwargs):
-        """Export the module and its imported device code one library.
+        """
+        Export the module and all imported modules into a single device 
library.
 
-        This function only works on host llvm modules.
-        It will pack all the imported modules
+        This function only works on host LLVM modules, other runtime::Module
+        subclasses will work with this API but they must support implement
+        the save and load mechanisms of modules completely including saving
+        from streams and files. This will pack your non-shared library module
+        into a single shared library which can later be loaded by TVM.
 
         Parameters
         ----------
@@ -280,13 +284,20 @@ class Module(object):
             The name of the shared library.
 
         fcompile : function(target, file_list, kwargs), optional
-            Compilation function to use create dynamic library.
+            The compilation function to use create the final library object 
during
+            export.
+
+            For example, when fcompile=_cc.create_shared, or when it is not 
supplied but
+            module is "llvm," this is used to link all produced artifacts
+            into a final dynamic library.
+
+            This behavior is controlled by the type of object exported.
             If fcompile has attribute object_format, will compile host library
             to that format. Otherwise, will use default format "o".
 
         workspace_dir : str, optional
-            the path to a directory used to create intermediary
-            artifacts for the process exporting of the library.
+            The path of the directory used to create the intermediate
+            artifacts when exporting the module.
             If this is not provided a temporary dir will be created.
 
         kwargs : dict, optional
diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py
index a503da5..d0de052 100644
--- a/python/tvm/runtime/vm.py
+++ b/python/tvm/runtime/vm.py
@@ -23,6 +23,7 @@ Implements a Python interface to executing the compiled VM 
object.
 import numpy as np
 
 import tvm
+from tvm.runtime import Module
 from tvm._ffi.runtime_ctypes import TVMByteArray
 from tvm._ffi import base as _base
 from .object import Object
@@ -299,12 +300,44 @@ class VirtualMachine(object):
     POOLED_ALLOCATOR = 2
 
     def __init__(self, exe, device, memory_cfg=None):
-        if not isinstance(exe, Executable):
+        """
+        Construct a VirtualMachine wrapper class which provides a simple
+        interface over the raw C++ Module based API.
+
+        Parameters
+        ----------
+        exe: Union[Executable, Module]
+            The executable either with the wrapper Python type or the raw 
runtime.Module.
+
+            In most cases this will be the Python wrapper class 
tvm.runtime.vm.Executable but
+            if you instead get the underlying runtime.Module subclass (i.e 
`exe.mod`) you
+            can directly pass it to this method.
+
+            This case can occur when doing things such as RPC where TVM's 
module APIs
+            return the raw modules, not the wrapped modules. This constructor 
will
+            handle this internally.
+
+        device: Union[Device, List[Device]]
+            The device, or devices on which to execute the VM code.
+
+        memory_cfg: Optional[str]
+            The allocator behavior to use for the VM.
+
+        Returns
+        -------
+        vm: VirtualMachine
+            A VM wrapper object.
+        """
+        if not isinstance(exe, Executable) and not isinstance(exe, Module):
             raise TypeError(
                 "exe is expected to be the type of Executable, "
                 + "but received {}".format(type(exe))
             )
-        self.module = _ffi_api._VirtualMachine(exe.module)
+
+        if not isinstance(exe, Executable):
+            exe = Executable(exe)
+
+        self.module = exe.mod["vm_load_executable"]()
         self._exec = exe
         self._init = self.module["init"]
         self._invoke = self.module["invoke"]
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index dafaed1..906250c 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1155,18 +1155,20 @@ void VMCompiler::Codegen() {
 
   auto compile_engine = CompileEngine::Global();
   auto ext_mods = compile_engine->LowerExternalFunctions();
+  runtime::Module lib;
   if (funcs.size() > 0) {
     Map<String, IRModule> build_funcs;
     for (const auto& i : funcs) {
       build_funcs.Set(i.first, i.second);
     }
-    exec_->lib = tvm::build(build_funcs, target_host_);
+    lib = tvm::build(build_funcs, target_host_);
   } else {
     // There is no function handled by TVM. We create a virtual main module
     // to make sure a DSO module will be also available.
-    exec_->lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
+    lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
   }
-  exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods, 
target_host_);
+  lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_);
+  exec_->SetLib(lib);
 }
 
 ExprDeviceMap VMCompiler::AnalyzeContext() const {
diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc
index 30ef214..370dc83 100644
--- a/src/runtime/library_module.cc
+++ b/src/runtime/library_module.cc
@@ -99,6 +99,28 @@ void InitContextFunctions(std::function<void*(const char*)> 
fgetsymbol) {
 #undef TVM_INIT_CONTEXT_FUNC
 }
 
+Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) 
{
+  std::string loadkey = "runtime.module.loadbinary_";
+  std::string fkey = loadkey + type_key;
+  const PackedFunc* f = Registry::Get(fkey);
+  if (f == nullptr) {
+    std::string loaders = "";
+    for (auto name : Registry::ListNames()) {
+      if (name.find(loadkey, 0) == 0) {
+        if (loaders.size() > 0) {
+          loaders += ", ";
+        }
+        loaders += name.substr(loadkey.size());
+      }
+    }
+    LOG(FATAL) << "Binary was created using " << type_key
+               << " but a loader of that name is not registered. Available 
loaders are " << loaders
+               << ". Perhaps you need to recompile with this runtime enabled.";
+  }
+
+  return (*f)(static_cast<void*>(stream));
+}
+
 /*!
  * \brief Load and append module blob to module list
  * \param mblob The module blob.
@@ -133,25 +155,7 @@ runtime::Module ProcessModuleBlob(const char* mblob, 
ObjectPtr<Library> lib) {
       ICHECK(stream->Read(&import_tree_row_ptr));
       ICHECK(stream->Read(&import_tree_child_indices));
     } else {
-      std::string loadkey = "runtime.module.loadbinary_";
-      std::string fkey = loadkey + tkey;
-      const PackedFunc* f = Registry::Get(fkey);
-      if (f == nullptr) {
-        std::string loaders = "";
-        for (auto name : Registry::ListNames()) {
-          if (name.rfind(loadkey, 0) == 0) {
-            if (loaders.size() > 0) {
-              loaders += ", ";
-            }
-            loaders += name.substr(loadkey.size());
-          }
-        }
-        ICHECK(f != nullptr)
-            << "Binary was created using " << tkey
-            << " but a loader of that name is not registered. Available 
loaders are " << loaders
-            << ". Perhaps you need to recompile with this runtime enabled.";
-      }
-      Module m = (*f)(static_cast<void*>(stream));
+      auto m = LoadModuleFromBinary(tkey, stream);
       modules.emplace_back(m);
     }
   }
diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h
index 91918c1..00c79e8 100644
--- a/src/runtime/library_module.h
+++ b/src/runtime/library_module.h
@@ -29,9 +29,21 @@
 #include <tvm/runtime/module.h>
 
 #include <functional>
+#include <string>
 
 namespace tvm {
 namespace runtime {
+
+/*! \brief Load a module with the given type key directly from the stream.
+ *  This function wraps the registry mechanism used to store type based 
deserializers
+ *  for each runtime::Module sub-class.
+ *
+ * \param type_key The type key of the serialized module.
+ * \param stream A pointer to the stream containing the serialized module.
+ * \return module The deserialized module.
+ */
+Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream);
+
 /*!
  * \brief Library is the common interface
  *  for storing data in the form of shared libaries.
diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc
index 6992097..e8b948d 100644
--- a/src/runtime/vm/executable.cc
+++ b/src/runtime/vm/executable.cc
@@ -37,6 +37,8 @@
 #include <utility>
 #include <vector>
 
+#include "../file_utils.h"
+#include "../library_module.h"
 #include "serialize_utils.h"
 
 namespace tvm {
@@ -74,6 +76,12 @@ PackedFunc Executable::GetFunction(const std::string& name, 
const ObjectPtr<Obje
       int index = args[1];
       *rv = this->GetFunctionParameterName(func_name, index);
     });
+  } else if (name == "vm_load_executable") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      auto vm = make_object<VirtualMachine>();
+      vm->LoadExecutable(this);
+      *rv = Module(vm);
+    });
   } else {
     LOG(FATAL) << "Unknown packed function: " << name;
     return PackedFunc(nullptr);
@@ -475,9 +483,37 @@ void LoadHeader(dmlc::Stream* strm) {
   STREAM_CHECK(version == TVM_VERSION, "version");
 }
 
+runtime::Module Executable::GetLib() const {
+  ICHECK_LE(this->imports_.size(), 1)
+      << "The kernel library must be imported as the only module in an 
Executable";
+
+  if (this->imports().size() == 0) {
+    return Module(nullptr);
+  } else {
+    return this->imports_[0];
+  }
+}
+
+void Executable::SetLib(const runtime::Module& lib) {
+  ICHECK(lib.defined()) << "the provided library can not be null";
+
+  ICHECK_EQ(this->imports_.size(), 0)
+      << "A VMExecutable should never have more than one import inside an the 
executable, \n"
+      << "the first import should *always* be the library containing"
+      << "the platform specific kernel code";
+
+  this->Import(lib);
+}
+
 runtime::Module Executable::Load(const std::string& code, const 
runtime::Module lib) {
   auto exec = make_object<Executable>();
-  exec->lib = lib;
+
+  // Support null-initialization of lib, to enable initialization during
+  // deserialization before we have we have deserialized the imports.
+  if (lib.defined()) {
+    exec->SetLib(lib);
+  }
+
   exec->code_ = code;
   dmlc::MemoryStringStream strm(&exec->code_);
 
@@ -765,6 +801,44 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) {
   }
 }
 
+void Executable::SaveToBinary(dmlc::Stream* stream) {
+  auto code_bytes = this->Save();
+  std::string code(code_bytes.data, code_bytes.size);
+  stream->Write(code);
+
+  ICHECK(this->imports()[0].defined()) << "the library must be imported before 
serialization";
+}
+
+Module ExecutableLoadBinary(void* strm) {
+  dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+  std::string code;
+  stream->Read(&code);
+  auto exec = Executable::Load(code, Module());
+  return exec;
+}
+
+void Executable::SaveToFile(const std::string& path, const std::string& 
format) {
+  std::string data;
+  dmlc::MemoryStringStream writer(&data);
+  dmlc::SeekStream* strm = &writer;
+  SaveToBinary(strm);
+  SaveBinaryToFile(path, data);
+}
+
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable").set_body_typed(ExecutableLoadBinary);
+
+// Load module from module.
+Module ExecutableLoadFile(const std::string& file_name, const std::string& 
format) {
+  std::string data;
+  LoadBinaryFromFile(file_name, &data);
+  dmlc::MemoryStringStream reader(&data);
+  dmlc::Stream* strm = &reader;
+  auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(strm));
+  return exec;
+}
+
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_VMExecutable").set_body_typed(ExecutableLoadFile);
+
 TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, 
TVMRetValue* rv) {
   runtime::Module mod = args[0];
   const auto* exec = dynamic_cast<Executable*>(mod.operator->());
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index ee06da8..76ca009 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -281,11 +281,12 @@ void VirtualMachine::LoadExecutable(const Executable* 
exec) {
   ICHECK(exec) << "The executable is not created yet.";
   exec_ = exec;
 
-  runtime::Module lib = exec_->lib;
-  // Get the list of packed functions.
+  runtime::Module lib = exec_->GetLib();
+
   ICHECK(exec->primitive_map.empty() || lib.operator->())
-      << "runtime module should have been built for primitive functions"
-      << "\n";
+      << "If the executable has declared primitive functions, the"
+      << "generated kernel library must non-be null.";
+
   for (const auto& it : exec_->primitive_map) {
     const auto& packed_name = it.first;
     auto packed_index = static_cast<size_t>(it.second);
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 4ecd0d9..c1bdc3f 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -19,11 +19,14 @@ import pytest
 
 import tvm
 from tvm import runtime
-from tvm import relay
+from tvm import relay, IRModule
+from tvm.relay.backend import vm
 from tvm.relay.scope_builder import ScopeBuilder
 from tvm.relay.prelude import Prelude
 from tvm.relay.loops import while_loop
 from tvm.relay import testing
+from tvm.contrib import utils
+from tvm import rpc
 import tvm.testing
 
 
@@ -799,5 +802,46 @@ def test_constant_shape_with_external_codegen():
     assert "shape_func" in opt_mod.astext(False)
 
 
+def test_vm_rpc():
+    """
+    This test checks to make sure you can export a VMExecutable,
+    upload it to a remote machine using RPC and then execute it
+    on the other machine.
+    """
+    target = "llvm"
+    target_host = "llvm"
+
+    # Build a IRModule.
+    x = relay.var("x", shape=(10, 1))
+    f = relay.Function([x], x + x)
+    mod = IRModule.from_expr(f)
+
+    # Compile to VMExecutable.
+    vm_exec = vm.compile(mod, target=target, target_host=target_host)
+
+    # Export to Disk
+    temp = utils.tempdir()
+    path = temp.relpath("vm_library.so")
+    vm_exec.mod.export_library(path)
+
+    # Use LocalRPC for testing.
+    remote = rpc.LocalSession()
+
+    # Upload the serialized Executable.
+    remote.upload(path)
+    # Get a handle to remote Executable.
+    rexec = remote.load_module("vm_library.so")
+
+    ctx = remote.cpu()
+    # Build a VM out of the executable and context.
+    vm_factory = runtime.vm.VirtualMachine(rexec, ctx)
+    np_input = np.random.uniform(size=(10, 1)).astype("float32")
+    input_tensor = tvm.nd.array(np_input, ctx)
+    # Invoke its "main" function.
+    out = vm_factory.invoke("main", [input_tensor])
+    # Check the result.
+    np.testing.assert_allclose(out.asnumpy(), np_input + np_input)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to