This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch vm_rpc_support in repository https://gitbox.apache.org/repos/asf/tvm.git
commit de8be3464236a8e4161b6a6e3a2414e5166e5442 Author: Jared Roesch <[email protected]> AuthorDate: Thu Mar 25 16:45:06 2021 -0700 Get basic verison of VM RPC working --- include/tvm/runtime/vm/executable.h | 10 ++++- python/tvm/runtime/module.py | 9 ++-- src/relay/backend/vm/compiler.cc | 1 + src/runtime/library_module.cc | 43 +++++++++++-------- src/runtime/library_module.h | 3 ++ src/runtime/vm/executable.cc | 86 +++++++++++++++++++++++++++++++++++++ src/runtime/vm/vm.cc | 19 ++++++-- 7 files changed, 145 insertions(+), 26 deletions(-) diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index 8d3f651..1b300b5 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -64,6 +64,14 @@ class Executable : public ModuleNode { PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final; /*! + * \brief Save the entire executable to a binary stream. + * \param stream The binary stream to save to. + */ + void SaveToBinary(dmlc::Stream* stream) final; + + void SaveToFile(const std::string& path, const std::string& format) final; + + /*! * \brief Serialize the executable into global section, constant section, and * code section. * @@ -125,7 +133,7 @@ class Executable : public ModuleNode { * \brief Get the `lib` module in an executable. Users have the flexibility to call * `export_library` from the frontend to save the library to disk. * - * \return The runtime module that contains the hardwre dependent code. + * \return The runtime module that contains the hardware dependent code. */ runtime::Module GetLib() const { return lib; } diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 09bef9e..712d728 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -269,10 +269,13 @@ class Module(object): return self._collect_from_import_tree(is_dso_exportable) def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs): - """Export the module and its imported device code one library. + """ + Export the module and all imported modules into a single device library. - This function only works on host llvm modules. - It will pack all the imported modules + This function only works on host LLVM modules, other runtime::Module + subclasses DO NOT work with this API. If you do in fact have an LLVM + module, this API will pack the module with all imported modules into + a single binary library which can be used with TVM. Parameters ---------- diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 9d3ffc5..c4f8ba6 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1167,6 +1167,7 @@ void VMCompiler::Codegen() { exec_->lib = codegen::CSourceModuleCreate(";", "", Array<String>{}); } exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods, target_host_); + exec_->Import(exec_->lib); } ExprDeviceMap VMCompiler::AnalyzeContext() const { diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 30ef214..9699a94 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -99,6 +99,29 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) { #undef TVM_INIT_CONTEXT_FUNC } +Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { + std::string loadkey = "runtime.module.loadbinary_"; + std::string fkey = loadkey + type_key; + const PackedFunc* f = Registry::Get(fkey); + if (f == nullptr) { + std::string loaders = ""; + for (auto name : Registry::ListNames()) { + if (name.rfind(loadkey, 0) == 0) { + if (loaders.size() > 0) { + loaders += ", "; + } + loaders += name.substr(loadkey.size()); + } + } + ICHECK(f != nullptr) + << "Binary was created using " << type_key + << " but a loader of that name is not registered. Available loaders are " << loaders + << ". Perhaps you need to recompile with this runtime enabled."; + } + + return (*f)(static_cast<void*>(stream)); +} + /*! * \brief Load and append module blob to module list * \param mblob The module blob. @@ -133,25 +156,7 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) { ICHECK(stream->Read(&import_tree_row_ptr)); ICHECK(stream->Read(&import_tree_child_indices)); } else { - std::string loadkey = "runtime.module.loadbinary_"; - std::string fkey = loadkey + tkey; - const PackedFunc* f = Registry::Get(fkey); - if (f == nullptr) { - std::string loaders = ""; - for (auto name : Registry::ListNames()) { - if (name.rfind(loadkey, 0) == 0) { - if (loaders.size() > 0) { - loaders += ", "; - } - loaders += name.substr(loadkey.size()); - } - } - ICHECK(f != nullptr) - << "Binary was created using " << tkey - << " but a loader of that name is not registered. Available loaders are " << loaders - << ". Perhaps you need to recompile with this runtime enabled."; - } - Module m = (*f)(static_cast<void*>(stream)); + auto m = LoadModuleFromBinary(tkey, stream); modules.emplace_back(m); } } diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 91918c1..75bd287 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -32,6 +32,9 @@ namespace tvm { namespace runtime { + +Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream); + /*! * \brief Library is the common interface * for storing data in the form of shared libaries. diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 6992097..3246f19 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -38,6 +38,8 @@ #include <vector> #include "serialize_utils.h" +#include "../library_module.h" +#include "../file_utils.h" namespace tvm { namespace runtime { @@ -74,6 +76,12 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Obje int index = args[1]; *rv = this->GetFunctionParameterName(func_name, index); }); + } else if (name == "create_vm") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + auto vm = make_object<VirtualMachine>(); + vm->LoadExecutable(this); + *rv = Module(vm); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(nullptr); @@ -476,11 +484,17 @@ void LoadHeader(dmlc::Stream* strm) { } runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { + std::cout << "code: " << code.size() << std::endl; auto exec = make_object<Executable>(); exec->lib = lib; exec->code_ = code; dmlc::MemoryStringStream strm(&exec->code_); + if (lib.defined()) { + std::cout << "Importing: " << std::endl; + exec->Import(lib); + } + // Load header. LoadHeader(&strm); @@ -765,6 +779,78 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { } } +void Executable::SaveToBinary(dmlc::Stream* stream) { + auto code_bytes = this->Save(); + std::string code(code_bytes.data, code_bytes.size); + stream->Write(code); + + CHECK(this->lib.defined()) + << "the library must be defined before serialization"; + + // this->lib->SaveToBinary(stream); + // std::vector<std::string> names; + // std::vector<DLTensor*> arrays; + // for (const auto& v : params_) { + // names.emplace_back(v.first); + // arrays.emplace_back(const_cast<DLTensor*>(v.second.operator->())); + // } + // uint64_t sz = arrays.size(); + // ICHECK(sz == names.size()); + // stream->Write(sz); + // stream->Write(names); + // for (size_t i = 0; i < sz; ++i) { + // tvm::runtime::SaveDLTensor(stream, arrays[i]); + // } +} + +Module ExecutableLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm); + std::string code; + stream->Read(&code); + auto exec = Executable::Load(code, Module()); + auto exec_node = exec.as<Executable>(); + std::cout << exec_node->primitive_map.size() << std::endl; + + // // std::unordered_map<std::string, tvm::runtime::NDArray> params; + // std::string module_name; + // ICHECK(stream->Read(&graph_json)); + // uint64_t sz; + // ICHECK(stream->Read(&sz)); + // std::vector<std::string> names; + // ICHECK(stream->Read(&names)); + // ICHECK(sz == names.size()); + // for (size_t i = 0; i < sz; ++i) { + // tvm::runtime::NDArray temp; + // temp.Load(stream); + // params[names[i]] = temp; + // } + + return exec; +} + +void Executable::SaveToFile(const std::string& path, const std::string& format) { + std::string data; + dmlc::MemoryStringStream writer(&data); + dmlc::SeekStream* strm = &writer; + SaveToBinary(strm); + SaveBinaryToFile(path, data); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable") + .set_body_typed(ExecutableLoadBinary); + + // Load module from module. +Module ExecutableLoadFile(const std::string& file_name, const std::string& format) { + std::string data; + LoadBinaryFromFile(file_name, &data); + dmlc::MemoryStringStream reader(&data); + dmlc::Stream* strm = &reader; + auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(strm)); + return exec; +} + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_VMExecutable").set_body_typed(ExecutableLoadFile); + TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast<Executable*>(mod.operator->()); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 4683398..03bc20b 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -281,11 +281,24 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { ICHECK(exec) << "The executable is not created yet."; exec_ = exec; - runtime::Module lib = exec_->lib; + runtime::Module lib; + if (exec_->lib.defined()) { + lib = exec_->lib; + } else { + ICHECK(exec_->imports().size() > 0) + << "fix"; + lib = exec_->imports()[0]; + } + // Get the list of packed functions. - ICHECK(exec->primitive_map.empty() || lib.operator->()) - << "runtime module should have been built for primitive functions" + ICHECK(!exec->primitive_map.empty()) + << "runtime module primitive map is empty" << "\n"; + + ICHECK(lib.operator->()) + << "library is null" + << "\n"; + for (const auto& it : exec_->primitive_map) { const auto& packed_name = it.first; auto packed_index = static_cast<size_t>(it.second);
