This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fd18751 Add support for using the VM across the RPC boundary. (#7746)
fd18751 is described below
commit fd18751e684f96df888184ab1796b929b32c86f2
Author: Jared Roesch <[email protected]>
AuthorDate: Tue Mar 30 02:01:40 2021 -0700
Add support for using the VM across the RPC boundary. (#7746)
* Get basic verison of VM RPC working
* Test case passes
* Clean up PR
* Lint
* Format
* Address Andrew R and TK feedback
* Add comment for Andrew
* Address Zhi's comment
* Format
* Fix broken test
---
include/tvm/runtime/vm/executable.h | 34 ++++++++++++++---
python/tvm/runtime/module.py | 23 ++++++++---
python/tvm/runtime/vm.py | 37 +++++++++++++++++-
src/relay/backend/vm/compiler.cc | 8 ++--
src/runtime/library_module.cc | 42 ++++++++++----------
src/runtime/library_module.h | 12 ++++++
src/runtime/vm/executable.cc | 76 ++++++++++++++++++++++++++++++++++++-
src/runtime/vm/vm.cc | 9 +++--
tests/python/relay/test_vm.py | 46 +++++++++++++++++++++-
9 files changed, 245 insertions(+), 42 deletions(-)
diff --git a/include/tvm/runtime/vm/executable.h
b/include/tvm/runtime/vm/executable.h
index 8d3f651..95c6d6f 100644
--- a/include/tvm/runtime/vm/executable.h
+++ b/include/tvm/runtime/vm/executable.h
@@ -64,6 +64,19 @@ class Executable : public ModuleNode {
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) final;
/*!
+ * \brief Write the Executable to the binary stream in serialized form.
+ * \param stream The binary stream to save the executable to.
+ */
+ void SaveToBinary(dmlc::Stream* stream) final;
+
+ /*!
+ * \brief Write the Executable to the provided path as a file contianing its
serialized content.
+ * \param path The path to write the serialized data to.
+ * \param format The format of the serialized blob.
+ */
+ void SaveToFile(const std::string& path, const std::string& format) final;
+
+ /*!
* \brief Serialize the executable into global section, constant section, and
* code section.
*
@@ -125,12 +138,24 @@ class Executable : public ModuleNode {
* \brief Get the `lib` module in an executable. Users have the flexibility
to call
* `export_library` from the frontend to save the library to disk.
*
- * \return The runtime module that contains the hardwre dependent code.
+ * \return The runtime module that contains the hardware dependent code.
+ */
+ runtime::Module GetLib() const;
+
+ /*!
+ * \brief Set the `lib` module in an executable.
+ *
+ * This allows us to do partial initialization in the case of
(de|ser)ialization cases.
+ * This method also ensures correct initialization of library ensuring we
only Import a
+ * single library.
+ *
+ * NB: This also provides some abstraction over how libraries are stored as
there are plans
+ * to iterate on the way runtime::Module works in the backend of the
compiler.
*/
- runtime::Module GetLib() const { return lib; }
+ void SetLib(const runtime::Module& lib);
/*!
- * \brief Get the arity of the VM Fucntion.
+ * \brief Get the arity of the VMFunction.
* \param func Function name.
* \return The number of parameters.
*/
@@ -148,9 +173,6 @@ class Executable : public ModuleNode {
const char* type_key() const final { return "VMExecutable"; }
- /*! \brief The runtime module/library that contains both the host and also
the device
- * code when executing on non-CPU devices. */
- runtime::Module lib;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function
map. */
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 5165ae0..d36554b 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -269,10 +269,14 @@ class Module(object):
return self._collect_from_import_tree(is_dso_exportable)
def export_library(self, file_name, fcompile=None, addons=None,
workspace_dir=None, **kwargs):
- """Export the module and its imported device code one library.
+ """
+ Export the module and all imported modules into a single device
library.
- This function only works on host llvm modules.
- It will pack all the imported modules
+ This function only works on host LLVM modules, other runtime::Module
+ subclasses will work with this API but they must support implement
+ the save and load mechanisms of modules completely including saving
+ from streams and files. This will pack your non-shared library module
+ into a single shared library which can later be loaded by TVM.
Parameters
----------
@@ -280,13 +284,20 @@ class Module(object):
The name of the shared library.
fcompile : function(target, file_list, kwargs), optional
- Compilation function to use create dynamic library.
+ The compilation function to use create the final library object
during
+ export.
+
+ For example, when fcompile=_cc.create_shared, or when it is not
supplied but
+ module is "llvm," this is used to link all produced artifacts
+ into a final dynamic library.
+
+ This behavior is controlled by the type of object exported.
If fcompile has attribute object_format, will compile host library
to that format. Otherwise, will use default format "o".
workspace_dir : str, optional
- the path to a directory used to create intermediary
- artifacts for the process exporting of the library.
+ The path of the directory used to create the intermediate
+ artifacts when exporting the module.
If this is not provided a temporary dir will be created.
kwargs : dict, optional
diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py
index a503da5..d0de052 100644
--- a/python/tvm/runtime/vm.py
+++ b/python/tvm/runtime/vm.py
@@ -23,6 +23,7 @@ Implements a Python interface to executing the compiled VM
object.
import numpy as np
import tvm
+from tvm.runtime import Module
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from .object import Object
@@ -299,12 +300,44 @@ class VirtualMachine(object):
POOLED_ALLOCATOR = 2
def __init__(self, exe, device, memory_cfg=None):
- if not isinstance(exe, Executable):
+ """
+ Construct a VirtualMachine wrapper class which provides a simple
+ interface over the raw C++ Module based API.
+
+ Parameters
+ ----------
+ exe: Union[Executable, Module]
+ The executable either with the wrapper Python type or the raw
runtime.Module.
+
+ In most cases this will be the Python wrapper class
tvm.runtime.vm.Executable but
+ if you instead get the underlying runtime.Module subclass (i.e
`exe.mod`) you
+ can directly pass it to this method.
+
+ This case can occur when doing things such as RPC where TVM's
module APIs
+ return the raw modules, not the wrapped modules. This constructor
will
+ handle this internally.
+
+ device: Union[Device, List[Device]]
+ The device, or devices on which to execute the VM code.
+
+ memory_cfg: Optional[str]
+ The allocator behavior to use for the VM.
+
+ Returns
+ -------
+ vm: VirtualMachine
+ A VM wrapper object.
+ """
+ if not isinstance(exe, Executable) and not isinstance(exe, Module):
raise TypeError(
"exe is expected to be the type of Executable, "
+ "but received {}".format(type(exe))
)
- self.module = _ffi_api._VirtualMachine(exe.module)
+
+ if not isinstance(exe, Executable):
+ exe = Executable(exe)
+
+ self.module = exe.mod["vm_load_executable"]()
self._exec = exe
self._init = self.module["init"]
self._invoke = self.module["invoke"]
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index dafaed1..906250c 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1155,18 +1155,20 @@ void VMCompiler::Codegen() {
auto compile_engine = CompileEngine::Global();
auto ext_mods = compile_engine->LowerExternalFunctions();
+ runtime::Module lib;
if (funcs.size() > 0) {
Map<String, IRModule> build_funcs;
for (const auto& i : funcs) {
build_funcs.Set(i.first, i.second);
}
- exec_->lib = tvm::build(build_funcs, target_host_);
+ lib = tvm::build(build_funcs, target_host_);
} else {
// There is no function handled by TVM. We create a virtual main module
// to make sure a DSO module will be also available.
- exec_->lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
+ lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
}
- exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods,
target_host_);
+ lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_);
+ exec_->SetLib(lib);
}
ExprDeviceMap VMCompiler::AnalyzeContext() const {
diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc
index 30ef214..370dc83 100644
--- a/src/runtime/library_module.cc
+++ b/src/runtime/library_module.cc
@@ -99,6 +99,28 @@ void InitContextFunctions(std::function<void*(const char*)>
fgetsymbol) {
#undef TVM_INIT_CONTEXT_FUNC
}
+Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream)
{
+ std::string loadkey = "runtime.module.loadbinary_";
+ std::string fkey = loadkey + type_key;
+ const PackedFunc* f = Registry::Get(fkey);
+ if (f == nullptr) {
+ std::string loaders = "";
+ for (auto name : Registry::ListNames()) {
+ if (name.find(loadkey, 0) == 0) {
+ if (loaders.size() > 0) {
+ loaders += ", ";
+ }
+ loaders += name.substr(loadkey.size());
+ }
+ }
+ LOG(FATAL) << "Binary was created using " << type_key
+ << " but a loader of that name is not registered. Available
loaders are " << loaders
+ << ". Perhaps you need to recompile with this runtime enabled.";
+ }
+
+ return (*f)(static_cast<void*>(stream));
+}
+
/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
@@ -133,25 +155,7 @@ runtime::Module ProcessModuleBlob(const char* mblob,
ObjectPtr<Library> lib) {
ICHECK(stream->Read(&import_tree_row_ptr));
ICHECK(stream->Read(&import_tree_child_indices));
} else {
- std::string loadkey = "runtime.module.loadbinary_";
- std::string fkey = loadkey + tkey;
- const PackedFunc* f = Registry::Get(fkey);
- if (f == nullptr) {
- std::string loaders = "";
- for (auto name : Registry::ListNames()) {
- if (name.rfind(loadkey, 0) == 0) {
- if (loaders.size() > 0) {
- loaders += ", ";
- }
- loaders += name.substr(loadkey.size());
- }
- }
- ICHECK(f != nullptr)
- << "Binary was created using " << tkey
- << " but a loader of that name is not registered. Available
loaders are " << loaders
- << ". Perhaps you need to recompile with this runtime enabled.";
- }
- Module m = (*f)(static_cast<void*>(stream));
+ auto m = LoadModuleFromBinary(tkey, stream);
modules.emplace_back(m);
}
}
diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h
index 91918c1..00c79e8 100644
--- a/src/runtime/library_module.h
+++ b/src/runtime/library_module.h
@@ -29,9 +29,21 @@
#include <tvm/runtime/module.h>
#include <functional>
+#include <string>
namespace tvm {
namespace runtime {
+
+/*! \brief Load a module with the given type key directly from the stream.
+ * This function wraps the registry mechanism used to store type based
deserializers
+ * for each runtime::Module sub-class.
+ *
+ * \param type_key The type key of the serialized module.
+ * \param stream A pointer to the stream containing the serialized module.
+ * \return module The deserialized module.
+ */
+Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream);
+
/*!
* \brief Library is the common interface
* for storing data in the form of shared libaries.
diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc
index 6992097..e8b948d 100644
--- a/src/runtime/vm/executable.cc
+++ b/src/runtime/vm/executable.cc
@@ -37,6 +37,8 @@
#include <utility>
#include <vector>
+#include "../file_utils.h"
+#include "../library_module.h"
#include "serialize_utils.h"
namespace tvm {
@@ -74,6 +76,12 @@ PackedFunc Executable::GetFunction(const std::string& name,
const ObjectPtr<Obje
int index = args[1];
*rv = this->GetFunctionParameterName(func_name, index);
});
+ } else if (name == "vm_load_executable") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ auto vm = make_object<VirtualMachine>();
+ vm->LoadExecutable(this);
+ *rv = Module(vm);
+ });
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc(nullptr);
@@ -475,9 +483,37 @@ void LoadHeader(dmlc::Stream* strm) {
STREAM_CHECK(version == TVM_VERSION, "version");
}
+runtime::Module Executable::GetLib() const {
+ ICHECK_LE(this->imports_.size(), 1)
+ << "The kernel library must be imported as the only module in an
Executable";
+
+ if (this->imports().size() == 0) {
+ return Module(nullptr);
+ } else {
+ return this->imports_[0];
+ }
+}
+
+void Executable::SetLib(const runtime::Module& lib) {
+ ICHECK(lib.defined()) << "the provided library can not be null";
+
+ ICHECK_EQ(this->imports_.size(), 0)
+ << "A VMExecutable should never have more than one import inside an the
executable, \n"
+ << "the first import should *always* be the library containing"
+ << "the platform specific kernel code";
+
+ this->Import(lib);
+}
+
runtime::Module Executable::Load(const std::string& code, const
runtime::Module lib) {
auto exec = make_object<Executable>();
- exec->lib = lib;
+
+ // Support null-initialization of lib, to enable initialization during
+ // deserialization before we have we have deserialized the imports.
+ if (lib.defined()) {
+ exec->SetLib(lib);
+ }
+
exec->code_ = code;
dmlc::MemoryStringStream strm(&exec->code_);
@@ -765,6 +801,44 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) {
}
}
+void Executable::SaveToBinary(dmlc::Stream* stream) {
+ auto code_bytes = this->Save();
+ std::string code(code_bytes.data, code_bytes.size);
+ stream->Write(code);
+
+ ICHECK(this->imports()[0].defined()) << "the library must be imported before
serialization";
+}
+
+Module ExecutableLoadBinary(void* strm) {
+ dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+ std::string code;
+ stream->Read(&code);
+ auto exec = Executable::Load(code, Module());
+ return exec;
+}
+
+void Executable::SaveToFile(const std::string& path, const std::string&
format) {
+ std::string data;
+ dmlc::MemoryStringStream writer(&data);
+ dmlc::SeekStream* strm = &writer;
+ SaveToBinary(strm);
+ SaveBinaryToFile(path, data);
+}
+
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable").set_body_typed(ExecutableLoadBinary);
+
+// Load module from module.
+Module ExecutableLoadFile(const std::string& file_name, const std::string&
format) {
+ std::string data;
+ LoadBinaryFromFile(file_name, &data);
+ dmlc::MemoryStringStream reader(&data);
+ dmlc::Stream* strm = &reader;
+ auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(strm));
+ return exec;
+}
+
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_VMExecutable").set_body_typed(ExecutableLoadFile);
+
TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args,
TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index ee06da8..76ca009 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -281,11 +281,12 @@ void VirtualMachine::LoadExecutable(const Executable*
exec) {
ICHECK(exec) << "The executable is not created yet.";
exec_ = exec;
- runtime::Module lib = exec_->lib;
- // Get the list of packed functions.
+ runtime::Module lib = exec_->GetLib();
+
ICHECK(exec->primitive_map.empty() || lib.operator->())
- << "runtime module should have been built for primitive functions"
- << "\n";
+ << "If the executable has declared primitive functions, the"
+ << "generated kernel library must non-be null.";
+
for (const auto& it : exec_->primitive_map) {
const auto& packed_name = it.first;
auto packed_index = static_cast<size_t>(it.second);
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 4ecd0d9..c1bdc3f 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -19,11 +19,14 @@ import pytest
import tvm
from tvm import runtime
-from tvm import relay
+from tvm import relay, IRModule
+from tvm.relay.backend import vm
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude
from tvm.relay.loops import while_loop
from tvm.relay import testing
+from tvm.contrib import utils
+from tvm import rpc
import tvm.testing
@@ -799,5 +802,46 @@ def test_constant_shape_with_external_codegen():
assert "shape_func" in opt_mod.astext(False)
+def test_vm_rpc():
+ """
+ This test checks to make sure you can export a VMExecutable,
+ upload it to a remote machine using RPC and then execute it
+ on the other machine.
+ """
+ target = "llvm"
+ target_host = "llvm"
+
+ # Build a IRModule.
+ x = relay.var("x", shape=(10, 1))
+ f = relay.Function([x], x + x)
+ mod = IRModule.from_expr(f)
+
+ # Compile to VMExecutable.
+ vm_exec = vm.compile(mod, target=target, target_host=target_host)
+
+ # Export to Disk
+ temp = utils.tempdir()
+ path = temp.relpath("vm_library.so")
+ vm_exec.mod.export_library(path)
+
+ # Use LocalRPC for testing.
+ remote = rpc.LocalSession()
+
+ # Upload the serialized Executable.
+ remote.upload(path)
+ # Get a handle to remote Executable.
+ rexec = remote.load_module("vm_library.so")
+
+ ctx = remote.cpu()
+ # Build a VM out of the executable and context.
+ vm_factory = runtime.vm.VirtualMachine(rexec, ctx)
+ np_input = np.random.uniform(size=(10, 1)).astype("float32")
+ input_tensor = tvm.nd.array(np_input, ctx)
+ # Invoke its "main" function.
+ out = vm_factory.invoke("main", [input_tensor])
+ # Check the result.
+ np.testing.assert_allclose(out.asnumpy(), np_input + np_input)
+
+
if __name__ == "__main__":
pytest.main([__file__])