This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9713d67 Created CSourceMetaData module for model metadata (#7002)
9713d67 is described below
commit 9713d675c64ae3075e10be5acadeef1328a44bb5
Author: manupa-arm <[email protected]>
AuthorDate: Mon Dec 21 22:07:33 2020 +0000
Created CSourceMetaData module for model metadata (#7002)
* Created CSourceMetaData module for model metadata
* Currently, there is a MetaData module to capture constants
conditionaly if the runtime modules implement const init
PackedFuncs. However, this one relies on a load process
in which the metadata is created on volatile memory that
may be not usable in uTVM environments.
* There is a need for model level metadata that is valid
across all runtime modules such as the func registry
when creating a system-lib.
* This commit implements a CSoureMetaData module to hold
func registry that collects function names from the
runtime module and generates a c source file to be
linked with final artifact.
* Modified and added export_library for utvm
Change-Id: Ie2e8e2aea1a66520f03fe8af7cc5bdf27339ea10
* Created CSourceMetaData module for model metadata
* fixed llvm_module to return null pfs for
get_symbol and get_const_vars
Change-Id: I84810e0695d4d6fb314af2469117f965eed71b51
* Created CSourceMetaData module for model metadata
*fixed bundle_deploy tests
Change-Id: I0d1332a4abbb6830531784c59264021bbbd7148a
* Created CSourceMetaData module for model metadata
*fixed export_library not to insert "options" when targeting tar
*fixed unit tests
Change-Id: Ia1686889498b71af66f1a0311a059154ad3c2c3e
* Created CSourceMetaData module for model metadata
* enable wasm to support csource metadata module
* disabled non DSOExportables from using csource metadata module
Change-Id: Ie09beaad35cbc2ef738d1d24d91e249b5e099569
* Created CSourceMetaData module for model metadata
* changed const pfs to be called only on external modules
or DSOExportable modules
Change-Id: I6ad28f166c0fc27a2548c851bf9287ec805550d1
* Created CSourceMetaData module for model metadata
* CSourceMetadata module wrapper is only created for c/llvm targets
Change-Id: I13cb4140c17e2e1f91d495b15a1ff7eeab9fb14d
* Created CSourceMetaData module for model metadata
*target should be defined to use csourcemetdata module
Change-Id: Id8e55b23d0007a79c550334de2c0fec63d40171f
* Created CSourceMetaData module for model metadata
* reinstate llvm func registry
Change-Id: I53e0754b6fb533637f08b25e98064d8c04092de4
* Created CSourceMetaData module for model metadata
* addressed comments and fixed bugs
Change-Id: I26401685dc803aeaf7642c865df88d683419e859
* Created CSourceMetaData module for model metadata
* addressed a missed comment
Change-Id: I65e65c30bc780a946f3f1b8372c40a49a5c20582
* Created CSourceMetaData module for model metadata
* te build interface should only include c-source metadata if
targetting "c"
Change-Id: Ie23cb8c6231c1f2de6d2827084774e3510288098
* Created CSourceMetaData module for model metadata
* c_source modules should be created only if they are
non-DSO exportable
Change-Id: I53f2f8e9caa41f133446f8881b9dc541ebeee8cc
* Created CSourceMetaData module for model metadata
* documetation misalignment in source_module.cc
Change-Id: I83e2c29b1f2980ca65a694304720dc58a5cb7879
* Created CSourceMetaData module for model metadata
* typo : same object file written as a dependency in the Makefile
Change-Id: I8becc4196d286cfb6372768687b3c836799dcb78
* Created CSourceMetaData module for model metadata
* removed unused param from a brief
Change-Id: Ie4db2aca3b7ea147bd8c65ef5d1cc2146f530e76
* Created CSourceMetaData module for model metadata
* made export library use c as the format for c source modules
Change-Id: Ie2fd6204414f0fa43988a8082d18af7a3225e237
* Created CSourceMetaData module for model metadata
*addressed a nit
Change-Id: I6084b8c06ddfaaece295439dbab589e6e202b664
---
apps/bundle_deploy/build_model.py | 2 -
python/tvm/driver/build_module.py | 12 ++
python/tvm/micro/build.py | 14 +-
python/tvm/runtime/module.py | 40 +++--
src/relay/backend/build_module.cc | 10 +-
src/relay/backend/contrib/codegen_c/codegen.cc | 26 ++-
src/relay/backend/contrib/dnnl/codegen.cc | 3 +-
src/relay/backend/vm/compiler.cc | 6 +-
src/target/func_registry_generator.cc | 2 +-
src/target/func_registry_generator.h | 7 +-
src/target/llvm/codegen_cpu.cc | 4 +-
src/target/llvm/llvm_module.cc | 14 +-
src/target/source/codegen_c_host.cc | 31 +---
src/target/source/codegen_c_host.h | 8 +-
src/target/source/codegen_source_base.h | 22 ++-
src/target/source/source_module.cc | 183 ++++++++++++++++++---
tests/micro/qemu/test_zephyr.py | 30 +++-
tests/python/relay/test_pass_partition_graph.py | 28 ++--
tests/python/unittest/test_crt.py | 1 -
tests/python/unittest/test_link_params.py | 2 +-
.../python/unittest/test_runtime_module_export.py | 2 +-
21 files changed, 316 insertions(+), 131 deletions(-)
diff --git a/apps/bundle_deploy/build_model.py
b/apps/bundle_deploy/build_model.py
index 623d246..a2513c8 100644
--- a/apps/bundle_deploy/build_model.py
+++ b/apps/bundle_deploy/build_model.py
@@ -51,7 +51,6 @@ def build_module(opts):
build_dir = os.path.abspath(opts.out_dir)
if not os.path.isdir(build_dir):
os.makedirs(build_dir)
-
lib.save(os.path.join(build_dir, file_format_str.format(name="model",
ext="o")))
with open(
os.path.join(build_dir, file_format_str.format(name="graph",
ext="json")), "w"
@@ -85,7 +84,6 @@ def build_test_module(opts):
build_dir = os.path.abspath(opts.out_dir)
if not os.path.isdir(build_dir):
os.makedirs(build_dir)
-
lib.save(os.path.join(build_dir,
file_format_str.format(name="test_model", ext="o")))
with open(
os.path.join(build_dir, file_format_str.format(name="test_graph",
ext="json")), "w"
diff --git a/python/tvm/driver/build_module.py
b/python/tvm/driver/build_module.py
index dc9d741..7ad48e1 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -424,4 +424,16 @@ def build(inputs, args=None, target=None,
target_host=None, name="default_functi
for mdev in device_modules:
if mdev:
rt_mod_host.import_module(mdev)
+
+ if not isinstance(target_host, Target):
+ target_host = Target(target_host)
+ if (
+ "system-lib" in target_host.attrs
+ and target_host.attrs["system-lib"].value == 1
+ and target_host.kind.name == "c"
+ ):
+ create_csource_metadata_module = tvm._ffi.get_global_func(
+ "runtime.CreateCSourceMetadataModule"
+ )
+ return create_csource_metadata_module([rt_mod_host], target_host)
return rt_mod_host
diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py
index 4aec9ea..cad385b 100644
--- a/python/tvm/micro/build.py
+++ b/python/tvm/micro/build.py
@@ -95,6 +95,7 @@ _CRT_GENERATED_LIB_OPTIONS = copy.copy(_CRT_DEFAULT_OPTIONS)
# void* arg0 = (((TVMValue*)args)[0].v_handle);
# int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
_CRT_GENERATED_LIB_OPTIONS["cflags"].append("-Wno-unused-variable")
+_CRT_GENERATED_LIB_OPTIONS["ccflags"].append("-Wno-unused-variable")
# Many TVM-intrinsic operators (i.e. expf, in particular)
@@ -159,9 +160,6 @@ def build_static_runtime(
mod_build_dir = workspace.relpath(os.path.join("build", "module"))
os.makedirs(mod_build_dir)
mod_src_dir = workspace.relpath(os.path.join("src", "module"))
- os.makedirs(mod_src_dir)
- mod_src_path = os.path.join(mod_src_dir, "module.c")
- module.save(mod_src_path, "cc")
libs = []
for mod_or_src_dir in (extra_libs or []) + RUNTIME_LIB_SRC_DIRS:
@@ -181,7 +179,15 @@ def build_static_runtime(
libs.append(compiler.library(lib_build_dir, lib_srcs, lib_opts))
- libs.append(compiler.library(mod_build_dir, [mod_src_path],
generated_lib_opts))
+ mod_src_dir = workspace.relpath(os.path.join("src", "module"))
+ os.makedirs(mod_src_dir)
+ libs.append(
+ module.export_library(
+ mod_build_dir,
+ workspace_dir=mod_src_dir,
+ fcompile=lambda bdir, srcs, **kwargs: compiler.library(bdir, srcs,
generated_lib_opts),
+ )
+ )
runtime_build_dir = workspace.relpath(f"build/runtime")
os.makedirs(runtime_build_dir)
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index cef6173..6326796 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name, unused-import, import-outside-toplevel
+# pylint: disable=invalid-name, unused-import, import-outside-toplevel,
inconsistent-return-statements
"""Runtime Module namespace."""
import os
import ctypes
@@ -252,7 +252,7 @@ class Module(object):
def _dso_exportable(self):
return self.type_key == "llvm" or self.type_key == "c"
- def export_library(self, file_name, fcompile=None, addons=None, **kwargs):
+ def export_library(self, file_name, fcompile=None, addons=None,
workspace_dir=None, **kwargs):
"""Export the module and its imported device code one library.
This function only works on host llvm modules.
@@ -268,8 +268,19 @@ class Module(object):
If fcompile has attribute object_format, will compile host library
to that format. Otherwise, will use default format "o".
+ workspace_dir : str, optional
+ the path to a directory used to create intermediary
+ artifacts for the process exporting of the library.
+ If this is not provided a temporary dir will be created.
+
kwargs : dict, optional
Additional arguments passed to fcompile
+
+ Returns
+ -------
+ result of fcompile() : unknown, optional
+ If the compilation function returns an artifact it would be
returned via
+ export_library, if any.
"""
# NOTE: this function depends on contrib library features
# which are only available in when TVM function is available.
@@ -292,22 +303,28 @@ class Module(object):
return
modules = self._collect_dso_modules()
- temp = _utils.tempdir()
+ if workspace_dir is None:
+ temp = _utils.tempdir()
+ workspace_dir = temp.temp_dir
files = addons if addons else []
is_system_lib = False
has_c_module = False
llvm_target_triple = None
for index, module in enumerate(modules):
if fcompile is not None and hasattr(fcompile, "object_format"):
- object_format = fcompile.object_format
+ if module.type_key == "c":
+ object_format = "c"
+ has_c_module = True
+ else:
+ object_format = fcompile.object_format
else:
if module.type_key == "llvm":
object_format = "o"
else:
assert module.type_key == "c"
- object_format = "cc"
+ object_format = "c"
has_c_module = True
- path_obj = temp.relpath("lib" + str(index) + "." + object_format)
+ path_obj = os.path.join(workspace_dir,
f"lib{index}.{object_format}")
module.save(path_obj)
files.append(path_obj)
is_system_lib = (
@@ -330,17 +347,20 @@ class Module(object):
if self.imported_modules:
if enabled("llvm") and llvm_target_triple:
- path_obj = temp.relpath("devc." + object_format)
+ path_obj = os.path.join(workspace_dir, f"devc.{object_format}")
m = _ffi_api.ModulePackImportsToLLVM(self, is_system_lib,
llvm_target_triple)
m.save(path_obj)
files.append(path_obj)
else:
- path_cc = temp.relpath("devc.cc")
+ path_cc = os.path.join(workspace_dir, "devc.c")
with open(path_cc, "w") as f:
f.write(_ffi_api.ModulePackImportsToC(self, is_system_lib))
files.append(path_cc)
- if has_c_module:
+ # The imports could contain a c module but the object format could be
tar
+ # Thus, it would not recognize the following include paths as options
+ # which are there assuming a c compiler is the fcompile.
+ if has_c_module and not file_name.endswith(".tar"):
options = []
if "options" in kwargs:
opts = kwargs["options"]
@@ -348,7 +368,7 @@ class Module(object):
opts = options + ["-I" + path for path in find_include_path()]
kwargs.update({"options": opts})
- fcompile(file_name, files, **kwargs)
+ return fcompile(file_name, files, **kwargs)
def system_lib():
diff --git a/src/relay/backend/build_module.cc
b/src/relay/backend/build_module.cc
index a0828d1..09b0966 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -510,18 +510,14 @@ class RelayBuildModule : public runtime::ModuleNode {
// If we cannot decide the target is LLVM, we create an empty
CSourceModule.
// The code content is initialized with ";" to prevent complaining
// from CSourceModuleNode::SaveToFile.
- ret_.mod = tvm::codegen::CSourceModuleCreate(";", "");
+ ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array<String>{});
}
} else {
ret_.mod = tvm::build(lowered_funcs, target_host_);
}
- Array<tvm::runtime::Module> ext_mods =
graph_codegen_->GetExternalModules();
- // TODO(zhiics) We should be able to completely switch to MetadataModule no
- // matter whether there are external modules or not.
- if (!ext_mods.empty()) {
- ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod,
ext_mods);
- }
+ auto ext_mods = graph_codegen_->GetExternalModules();
+ ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod,
ext_mods, GetTargetHost());
}
private:
diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc
b/src/relay/backend/contrib/codegen_c/codegen.cc
index 935ac16..998393d 100644
--- a/src/relay/backend/contrib/codegen_c/codegen.cc
+++ b/src/relay/backend/contrib/codegen_c/codegen.cc
@@ -215,20 +215,19 @@ class CodegenC : public
MemoizedExprTranslator<std::vector<Output>>, public Code
class CSourceCodegen : public CSourceModuleCodegenBase {
public:
- std::pair<std::string, Array<String>> GenCFunc(const Function& func) {
+ std::tuple<Array<String>, String, String> GenCFunc(const Function& func) {
ICHECK(func.defined()) << "Input error: expect a Relay function.";
-
- // Record the external symbol for runtime lookup.
- auto sid = GetExtSymbol(func);
-
- CodegenC builder(sid);
+ CodegenC builder(GetExtSymbol(func));
auto out = builder.VisitExpr(func->body);
- code_stream_ << builder.JIT(out);
-
- return {sid, builder.const_vars_};
+ return std::make_tuple(builder.const_vars_, builder.ext_func_id_,
builder.JIT(out));
}
runtime::Module CreateCSourceModule(const ObjectRef& ref) override {
+ ICHECK(ref->IsInstance<FunctionNode>());
+ auto res = GenCFunc(Downcast<Function>(ref));
+ Array<String> variables = std::get<0>(res);
+ String func_name = std::get<1>(res);
+
// Create headers
code_stream_ << "#include <cstring>\n";
code_stream_ << "#include <vector>\n";
@@ -259,18 +258,13 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
)op_macro";
code_stream_ << operator_macro << "\n\n";
-
- ICHECK(ref->IsInstance<FunctionNode>());
- auto res = GenCFunc(Downcast<Function>(ref));
+ code_stream_ << std::get<2>(res);
std::string code = code_stream_.str();
- String sym = std::get<0>(res);
- Array<String> variables = std::get<1>(res);
-
// Create a CSource module
const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
ICHECK(pf != nullptr) << "Cannot find csource module to create the
external runtime module";
- return (*pf)(code, "c", sym, variables);
+ return (*pf)(code, "c", Array<String>{func_name}, variables);
}
private:
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc
b/src/relay/backend/contrib/dnnl/codegen.cc
index bfc5c77..c9a5828 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -413,7 +413,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
// Create a CSource module
const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
ICHECK(pf != nullptr) << "Cannot find csource module to create the
external runtime module";
- return (*pf)(code, "c", sym, variables);
+ // TODO(@manupa-arm): pass the function names to enable system-lib creation
+ return (*pf)(code, "c", Array<String>{sym}, variables);
}
private:
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index bed2510..8fbe31e 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1146,11 +1146,9 @@ void VMCompiler::Codegen() {
} else {
// There is no function handled by TVM. We create a virtual main module
// to make sure a DSO module will be also available.
- exec_->lib = codegen::CSourceModuleCreate(";", "");
- }
- if (!ext_mods.empty()) {
- exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods);
+ exec_->lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
}
+ exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods,
target_host_);
}
ExprDeviceMap VMCompiler::AnalyzeContext() const {
diff --git a/src/target/func_registry_generator.cc
b/src/target/func_registry_generator.cc
index 402d0f8..7c948d5 100644
--- a/src/target/func_registry_generator.cc
+++ b/src/target/func_registry_generator.cc
@@ -29,7 +29,7 @@
namespace tvm {
namespace target {
-std::string GenerateFuncRegistryNames(const std::vector<std::string>&
function_names) {
+std::string GenerateFuncRegistryNames(const Array<String>& function_names) {
std::stringstream ss;
ss << (unsigned char)(function_names.size());
for (auto f : function_names) {
diff --git a/src/target/func_registry_generator.h
b/src/target/func_registry_generator.h
index 362fca8..fb59648 100644
--- a/src/target/func_registry_generator.h
+++ b/src/target/func_registry_generator.h
@@ -24,13 +24,18 @@
#ifndef TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_
#define TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_
+#include <tvm/runtime/container.h>
+
#include <string>
#include <vector>
+using tvm::runtime::Array;
+using tvm::runtime::String;
+
namespace tvm {
namespace target {
-std::string GenerateFuncRegistryNames(const std::vector<std::string>&
function_names);
+std::string GenerateFuncRegistryNames(const Array<String>& function_names);
} // namespace target
} // namespace tvm
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index fea5f80..6143e70 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -794,10 +794,10 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
void CodeGenCPU::AddStartupFunction() {
if (registry_functions_.size() != 0) {
ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be
defined for C runtime";
- std::vector<std::string> symbols;
+ Array<String> symbols;
std::vector<llvm::Constant*> funcs;
for (auto sym : registry_functions_) {
- symbols.emplace_back(sym.first);
+ symbols.push_back(sym.first);
funcs.emplace_back(llvm::ConstantExpr::getBitCast(
sym.second, ftype_tvm_backend_packed_c_func_->getPointerTo()));
}
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index a0ab49d..43d2097 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -60,6 +60,13 @@ class LLVMModuleNode final : public runtime::ModuleNode {
if (name == "__tvm_is_system_module") {
bool flag = (mptr_->getFunction("__tvm_module_startup") != nullptr);
return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; });
+ } else if (name == "get_func_names") {
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
this->function_names_; });
+ } else if (name == "get_symbol") {
+ return PackedFunc(nullptr);
+ } else if (name == "get_const_vars") {
+ return PackedFunc(nullptr);
} else if (name == "_get_target_triple") {
std::string target_triple = tm_->getTargetTriple().str();
// getTargetTriple() doesn't include other flags besides the triple. Add
back flags which are
@@ -218,9 +225,10 @@ class LLVMModuleNode final : public runtime::ModuleNode {
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs, but got " <<
kv.second->GetTypeKey();
auto f = Downcast<PrimFunc>(kv.second);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(global_symbol.defined());
+ function_names_.push_back(global_symbol.value());
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
- auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined());
entry_func = global_symbol.value();
}
funcs.push_back(f);
@@ -377,6 +385,8 @@ class LLVMModuleNode final : public runtime::ModuleNode {
std::unique_ptr<llvm::Module> module_;
// the context.
std::shared_ptr<llvm::LLVMContext> ctx_;
+ /* \brief names of the functions declared in this module */
+ Array<String> function_names_;
};
TVM_REGISTER_GLOBAL("target.build.llvm")
diff --git a/src/target/source/codegen_c_host.cc
b/src/target/source/codegen_c_host.cc
index 0a19fc1..bee5441 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -55,7 +55,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute";
- function_names_.emplace_back(global_symbol.value());
+ function_names_.push_back(global_symbol.value());
CodeGenC::AddFunction(f);
}
@@ -73,7 +73,7 @@ void CodeGenCHost::LinkParameters(Map<String, LinkedParam>
params) {
<< " out_ret_tcode[0] = " << kTVMNullptr << ";\n"
<< " return 0;\n";
- function_names_.emplace_back(tvm::runtime::symbol::tvm_lookup_linked_param);
+ function_names_.push_back(tvm::runtime::symbol::tvm_lookup_linked_param);
for (auto kv : params) {
decl_stream << "\n"
<< "#ifdef __cplusplus\n"
@@ -322,29 +322,6 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T*
op, const char* compare,
<< "? (" << a_id << ") : (" << b_id << "))";
}
-void CodeGenCHost::GenerateFuncRegistry() {
- decl_stream << "#include <tvm/runtime/crt/module.h>\n";
- stream << "static TVMBackendPackedCFunc _tvm_func_array[] = {\n";
- for (auto f : function_names_) {
- stream << " (TVMBackendPackedCFunc)" << f << ",\n";
- }
- stream << "};\n";
- auto registry = target::GenerateFuncRegistryNames(function_names_);
- stream << "static const TVMFuncRegistry _tvm_func_registry = {\n"
- << " \"" << ::tvm::support::StrEscape(registry.data(),
registry.size(), true) << "\","
- << " _tvm_func_array,\n"
- << "};\n";
-}
-
-void CodeGenCHost::GenerateCrtSystemLib() {
- stream << "static const TVMModule _tvm_system_lib = {\n"
- << " &_tvm_func_registry,\n"
- << "};\n"
- << "const TVMModule* TVMSystemLibEntryPoint(void) {\n"
- << " return &_tvm_system_lib;\n"
- << "}\n";
-}
-
runtime::Module BuildCHost(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
@@ -380,12 +357,10 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
if (target->GetAttr<Bool>("system-lib").value_or(Bool(false))) {
ICHECK_EQ(target->GetAttr<String>("runtime").value_or(""), "c")
<< "c target only supports generating C runtime SystemLibs";
- cg.GenerateFuncRegistry();
- cg.GenerateCrtSystemLib();
}
std::string code = cg.Finish();
- return CSourceModuleCreate(code, "c");
+ return CSourceModuleCreate(code, "c", cg.GetFunctionNames());
}
TVM_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost);
diff --git a/src/target/source/codegen_c_host.h
b/src/target/source/codegen_c_host.h
index b54b6fb..97fe7ab 100644
--- a/src/target/source/codegen_c_host.h
+++ b/src/target/source/codegen_c_host.h
@@ -59,18 +59,14 @@ class CodeGenCHost final : public CodeGenC {
void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*)
- /*! \brief Generate C runtime FuncRegistry global constant. */
- void GenerateFuncRegistry();
-
- /*! \brief Generate C runtime SystemLib entry point. */
- void GenerateCrtSystemLib();
+ Array<String> GetFunctionNames() { return function_names_; }
private:
std::string module_name_;
/* \brief tracks declared global variables which live despite GetUniqueName
*/
std::set<std::string> declared_globals_;
/* \brief names of the functions declared in this module */
- std::vector<std::string> function_names_;
+ Array<String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;
diff --git a/src/target/source/codegen_source_base.h
b/src/target/source/codegen_source_base.h
index 7e5e403..ed838f8 100644
--- a/src/target/source/codegen_source_base.h
+++ b/src/target/source/codegen_source_base.h
@@ -136,25 +136,26 @@ runtime::Module SourceModuleCreate(std::string code,
std::string fmt);
* \brief Create a C source module for viewing and compiling GCC code.
* \param code The code to be viewed.
* \param fmt The code format.
- * \param symbol The symbol that the c source module represents.
+ * \param func_names The name of functions inside the runtime module.
* \param const_vars. The constant variables that the c source module needs.
* \return The created module.
*/
runtime::Module CSourceModuleCreate(const String& code, const String& fmt,
- const String& symbol = "",
+ const Array<String>& func_names,
const Array<String>& const_vars = {});
/*!
* \brief Wrap the submodules in a metadata module.
* \param params The variable to constant mapping that is collected by the host
* module.
- * \param dso_module The host module to be wrapped.
- * \param modules The modules to be wrapped.
+ * \param target_module The main TIR-lowered internal runtime module
+ * \param modules All the external modules that needs to be imported inside
the metadata module(s).
+ * \param target The target that all the modules are compiled for
* \return The wrapped module.
*/
runtime::Module CreateMetadataModule(
- const std::unordered_map<std::string, runtime::NDArray>& params,
- const runtime::Module& dso_module, const Array<runtime::Module>& modules);
+ const std::unordered_map<std::string, runtime::NDArray>& params,
runtime::Module target_module,
+ const Array<runtime::Module>& ext_modules, Target target);
/*!
* \brief Create a source module for viewing and limited saving for device.
@@ -167,6 +168,15 @@ runtime::Module CreateMetadataModule(
runtime::Module DeviceSourceModuleCreate(
std::string data, std::string fmt, std::unordered_map<std::string,
runtime::FunctionInfo> fmap,
std::string type_key, std::function<std::string(const std::string&)>
fget_source = nullptr);
+
+/*!
+ * \brief Wrap the submodules that are to be wrapped in a c-source metadata
module.
+ * \param modules The modules to be wrapped.
+ * \param target the target the modules are compiled for.
+ * \return The wrapped module.
+ */
+runtime::Module CreateCSourceMetadataModule(const Array<runtime::Module>&
modules, Target target);
+
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
diff --git a/src/target/source/source_module.cc
b/src/target/source/source_module.cc
index 3be658a..4b4770a 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -27,6 +27,8 @@
#include "../../runtime/file_utils.h"
#include "../../runtime/meta_data.h"
+#include "../../support/str_escape.h"
+#include "../func_registry_generator.h"
#include "codegen_source_base.h"
namespace tvm {
@@ -46,40 +48,66 @@ using runtime::SaveBinaryToFile;
* codegens, such as graph runtime codegen and the vm compiler.
*
* \param params The metadata for initialization of all modules.
- * \param dso_module The DSO module that contains TVM primitives.
- * \param modules The submodules that will be wrapped, e.g. CSource modules
that
- * contain vendor library calls or customized runtime modules.
- *
+ * \param target_module the internal module that is compiled by tvm.
+ * \param ext_modules The external modules that needs to be imported inside
the metadata
+ * module(s).
+ * \param target The target that all the modules are compiled for
* \return The created metadata module that manages initialization of metadata.
*/
runtime::Module CreateMetadataModule(
const std::unordered_map<std::string, runtime::NDArray>& params,
- const runtime::Module& dso_module, const Array<runtime::Module>& modules) {
+ tvm::runtime::Module target_module, const Array<runtime::Module>&
ext_modules, Target target) {
+ Array<tvm::runtime::Module> csource_modules;
+ Array<tvm::runtime::Module> binary_modules;
+
+ auto DSOExportable = [](tvm::runtime::Module& mod) {
+ return !std::strcmp(mod->type_key(), "llvm") ||
!std::strcmp(mod->type_key(), "c");
+ };
+
// Wrap all submodules in the initialization wrapper.
std::unordered_map<std::string, std::vector<std::string>> sym_metadata;
- for (runtime::Module it : modules) {
- auto pf_sym = it.GetFunction("get_symbol");
- auto pf_var = it.GetFunction("get_const_vars");
+ for (tvm::runtime::Module mod : ext_modules) {
+ auto pf_sym = mod.GetFunction("get_symbol");
+ auto pf_var = mod.GetFunction("get_const_vars");
+ std::vector<std::string> arrays;
if (pf_sym != nullptr && pf_var != nullptr) {
String symbol = pf_sym();
Array<String> variables = pf_var();
- std::vector<std::string> arrays;
for (size_t i = 0; i < variables.size(); i++) {
arrays.push_back(variables[i].operator std::string());
}
ICHECK_EQ(sym_metadata.count(symbol), 0U) << "Found duplicated symbol: "
<< symbol;
sym_metadata[symbol] = arrays;
}
+ // We only need loading of serialized constant data
+ // if there are constants present and required by the
+ // runtime module to be initialized by the binary
+ // metadata module. If not rest of the modules are
+ // wrapped in c-source metadata module.
+
+ // TODO(@manupa-arm) : we should be able to use csource_metadata
+ // if the variables are empty when all the runtime modules implement
get_func_names
+ if (arrays.empty() && DSOExportable(mod) && target->kind->name == "c") {
+ csource_modules.push_back(mod);
+ } else {
+ binary_modules.push_back(mod);
+ }
}
- // Wrap the modules.
- runtime::Module init_m = runtime::MetadataModuleCreate(params, sym_metadata);
- init_m.Import(dso_module);
- for (const auto& it : modules) {
- init_m.Import(it);
+ if (target.defined() && target->kind->name == "c") {
+ csource_modules.push_back(target_module);
+ target_module = CreateCSourceMetadataModule(csource_modules, target);
}
- return init_m;
+ if (!binary_modules.empty()) {
+ runtime::Module binary_meta_mod = runtime::MetadataModuleCreate(params,
sym_metadata);
+ binary_meta_mod.Import(target_module);
+ for (const auto& it : binary_modules) {
+ binary_meta_mod.Import(it);
+ }
+ return binary_meta_mod;
+ }
+ return target_module;
}
// Simulator function
@@ -109,18 +137,25 @@ runtime::Module SourceModuleCreate(std::string code,
std::string fmt) {
// Simulator function
class CSourceModuleNode : public runtime::ModuleNode {
public:
- CSourceModuleNode(const std::string& code, const std::string& fmt, const
std::string& symbol,
- const Array<String>& const_vars)
- : code_(code), fmt_(fmt), symbol_(symbol), const_vars_(const_vars) {}
+ CSourceModuleNode(const std::string& code, const std::string& fmt,
+ const Array<String>& func_names, const Array<String>&
const_vars)
+ : code_(code), fmt_(fmt), const_vars_(const_vars),
func_names_(func_names) {}
const char* type_key() const { return "c"; }
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) final {
+ // Currently c-source module is used as demonstration purposes with binary
metadata module
+ // that expects get_symbol interface. When c-source module is used as
external module, it
+ // will only contain one function. However, when its used as an internal
module (e.g., target
+ // "c") it can have many functions.
if (name == "get_symbol") {
return PackedFunc(
- [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
this->symbol_; });
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
this->func_names_[0]; });
} else if (name == "get_const_vars") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
this->const_vars_; });
+ } else if (name == "get_func_names") {
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
this->func_names_; });
} else {
return PackedFunc(nullptr);
}
@@ -131,7 +166,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
void SaveToFile(const std::string& file_name, const std::string& format)
final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
- if (fmt == "cc") {
+ if (fmt == "c") {
ICHECK_NE(code_.length(), 0);
SaveBinaryToFile(file_name, code_);
} else {
@@ -142,17 +177,109 @@ class CSourceModuleNode : public runtime::ModuleNode {
protected:
std::string code_;
std::string fmt_;
- std::string symbol_;
Array<String> const_vars_;
+ Array<String> func_names_;
};
-runtime::Module CSourceModuleCreate(const String& code, const String& fmt,
const String& symbol,
+runtime::Module CSourceModuleCreate(const String& code, const String& fmt,
+ const Array<String>& func_names,
const Array<String>& const_vars) {
auto n = make_object<CSourceModuleNode>(code.operator std::string(),
fmt.operator std::string(),
- symbol.operator std::string(),
const_vars);
+ func_names, const_vars);
return runtime::Module(n);
}
+class CSourceMetadataModuleNode : public runtime::ModuleNode {
+ public:
+ CSourceMetadataModuleNode(const Array<String>& func_names, const
std::string& fmt, Target target)
+ : fmt_(fmt), func_names_(func_names), target_(target) {
+ CreateSource();
+ }
+ const char* type_key() const { return "c"; }
+
+ std::string GetSource(const std::string& format) final { return code_.str();
}
+
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) final {
+ return PackedFunc(nullptr);
+ }
+
+ void SaveToFile(const std::string& file_name, const std::string& format)
final {
+ std::string fmt = GetFileFormat(file_name, format);
+ std::string meta_file = GetMetaFilePath(file_name);
+ if (fmt == "c") {
+ auto code_str = code_.str();
+ ICHECK_NE(code_str.length(), 0);
+ SaveBinaryToFile(file_name, code_str);
+ } else {
+ ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
+ }
+ }
+
+ protected:
+ std::stringstream code_;
+ std::string fmt_;
+ Array<String> func_names_;
+ Target target_;
+
+ void CreateFuncRegistry() {
+ code_ << "#include <tvm/runtime/crt/module.h>\n";
+ for (const auto& fname : func_names_) {
+ code_ << "#ifdef __cplusplus\n";
+ code_ << "extern \"C\"\n";
+ code_ << "#endif\n";
+ code_ << "TVM_DLL int32_t " << fname.data();
+ code_ << "(TVMValue* args, int* type_code, int num_args, TVMValue*
out_value, int* "
+ "out_type_code);\n";
+ }
+ code_ << "static TVMBackendPackedCFunc _tvm_func_array[] = {\n";
+ for (auto f : func_names_) {
+ code_ << " (TVMBackendPackedCFunc)" << f << ",\n";
+ }
+ code_ << "};\n";
+ auto registry = target::GenerateFuncRegistryNames(func_names_);
+ code_ << "static const TVMFuncRegistry _tvm_func_registry = {\n"
+ << " \"" << ::tvm::support::StrEscape(registry.data(),
registry.size(), true) << "\","
+ << " _tvm_func_array,\n"
+ << "};\n";
+ }
+
+ void GenerateCrtSystemLib() {
+ code_ << "static const TVMModule _tvm_system_lib = {\n"
+ << " &_tvm_func_registry,\n"
+ << "};\n"
+ << "const TVMModule* TVMSystemLibEntryPoint(void) {\n"
+ << " return &_tvm_system_lib;\n"
+ << "}\n";
+ }
+
+ void CreateSource() {
+ if (target_->GetAttr<Bool>("system-lib").value_or(Bool(false)) &&
!func_names_.empty()) {
+ CreateFuncRegistry();
+ GenerateCrtSystemLib();
+ }
+ code_ << ";";
+ }
+};
+
+runtime::Module CreateCSourceMetadataModule(const Array<runtime::Module>&
modules, Target target) {
+ Array<String> func_names;
+ for (runtime::Module mod : modules) {
+ auto pf_funcs = mod.GetFunction("get_func_names");
+ if (pf_funcs != nullptr) {
+ Array<String> func_names_ = pf_funcs();
+ for (const auto& fname : func_names_) {
+ func_names.push_back(fname);
+ }
+ }
+ }
+ auto n = make_object<CSourceMetadataModuleNode>(func_names, "cc", target);
+ auto csrc_metadata_module = runtime::Module(n);
+ for (const auto& mod : modules) {
+ csrc_metadata_module.Import(mod);
+ }
+ return std::move(csrc_metadata_module);
+}
+
// supports limited save without cross compile
class DeviceSourceModuleNode final : public runtime::ModuleNode {
public:
@@ -209,8 +336,14 @@ runtime::Module DeviceSourceModuleCreate(
TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate);
TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate")
- .set_body_typed([](String code, String fmt, String symbol, Array<String>
const_vars) {
- return CSourceModuleCreate(code, fmt, symbol, const_vars);
+ .set_body_typed([](String code, String fmt, Array<String> func_names,
+ Array<String> const_vars) {
+ return CSourceModuleCreate(code, fmt, func_names, const_vars);
+ });
+
+TVM_REGISTER_GLOBAL("runtime.CreateCSourceMetadataModule")
+ .set_body_typed([](const Array<runtime::Module>& modules, Target target) {
+ return CreateCSourceMetadataModule(modules, target);
});
} // namespace codegen
diff --git a/tests/micro/qemu/test_zephyr.py b/tests/micro/qemu/test_zephyr.py
index 2213203..3e73307 100644
--- a/tests/micro/qemu/test_zephyr.py
+++ b/tests/micro/qemu/test_zephyr.py
@@ -29,7 +29,7 @@ import numpy as np
import tvm
import tvm.rpc
import tvm.micro
-import tvm.relay
+import tvm.relay as relay
from tvm.micro.contrib import zephyr
from tvm.contrib import utils
@@ -143,5 +143,33 @@ def test_compile_runtime(platform):
test_basic_add(sess)
+def test_relay(platform):
+ """Testing a simple relay graph"""
+ model, zephyr_board = PLATFORMS[platform]
+ shape = (10,)
+ dtype = "int8"
+
+ # Construct Relay program.
+ x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype))
+ xx = relay.multiply(x, x)
+ z = relay.add(xx, relay.const(np.ones(shape=shape, dtype=dtype)))
+ func = relay.Function([x], z)
+
+ target = tvm.target.target.micro(model)
+ with tvm.transform.PassContext(opt_level=3,
config={"tir.disable_vectorize": True}):
+ graph, mod, params = tvm.relay.build(func, target=target)
+
+ with _make_session(model, target, zephyr_board, mod) as session:
+ graph_mod = tvm.micro.create_local_graph_runtime(
+ graph, session.get_system_lib(), session.context
+ )
+ graph_mod.set_input(**params)
+ x_in = np.random.randint(10, size=shape[0], dtype=dtype)
+ graph_mod.run(x=x_in)
+ result = graph_mod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(graph_mod.get_input(0).asnumpy(), x_in)
+ tvm.testing.assert_allclose(result, x_in * x_in + 1)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([os.path.dirname(__file__)] + sys.argv[1:]))
diff --git a/tests/python/relay/test_pass_partition_graph.py
b/tests/python/relay/test_pass_partition_graph.py
index 059d0b4..d8f674e 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -273,19 +273,23 @@ def test_multi_node_compiler():
map_inputs = {"w{}".format(i): w_data[i] for i in range(8)}
map_inputs["x"] = x_data
- check_result(
- mod,
- map_inputs,
- (30, 10),
- np.concatenate(
- (
- ((x_data + w_data[0]) - w_data[1]) * w_data[2],
- ((x_data + w_data[3]) - w_data[4]) * w_data[5],
- x_data + w_data[6] - w_data[7],
+
+ targets = ["llvm", "c -runtime=c --system-lib"]
+ for tgt in targets:
+ check_result(
+ mod,
+ map_inputs,
+ (30, 10),
+ np.concatenate(
+ (
+ ((x_data + w_data[0]) - w_data[1]) * w_data[2],
+ ((x_data + w_data[3]) - w_data[4]) * w_data[5],
+ x_data + w_data[6] - w_data[7],
+ ),
+ axis=0,
),
- axis=0,
- ),
- )
+ target=tgt,
+ )
def test_extern_ccompiler_single_op():
diff --git a/tests/python/unittest/test_crt.py
b/tests/python/unittest/test_crt.py
index 07a4cfc..1d84d4e 100644
--- a/tests/python/unittest/test_crt.py
+++ b/tests/python/unittest/test_crt.py
@@ -49,7 +49,6 @@ def _make_sess_from_op(workspace, op_name, sched, arg_bufs):
def _make_session(workspace, mod):
compiler = tvm.micro.DefaultCompiler(target=TARGET)
opts = tvm.micro.default_options(os.path.join(tvm.micro.CRT_ROOT_DIR,
"host"))
-
micro_binary = tvm.micro.build_static_runtime(
# the x86 compiler *expects* you to give the exact same dictionary for
both
# lib_opts and bin_opts. so the library compiler is mutating lib_opts
and
diff --git a/tests/python/unittest/test_link_params.py
b/tests/python/unittest/test_link_params.py
index e3bd634..da87a31 100644
--- a/tests/python/unittest/test_link_params.py
+++ b/tests/python/unittest/test_link_params.py
@@ -266,7 +266,7 @@ def test_c_link_params():
lib = tvm.relay.build(mod, target, params=param_init)
assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded
- src = lib.lib.get_source()
+ src = lib.lib.imported_modules[0].get_source()
lib.lib.save("test.c", "cc")
c_dtype = _get_c_datatype(dtype)
src_lines = src.split("\n")
diff --git a/tests/python/unittest/test_runtime_module_export.py
b/tests/python/unittest/test_runtime_module_export.py
index 88b7af9..af9a8ab 100644
--- a/tests/python/unittest/test_runtime_module_export.py
+++ b/tests/python/unittest/test_runtime_module_export.py
@@ -58,7 +58,7 @@ def generate_engine_module():
import tvm.runtime._ffi_api
gen_engine_header()
- csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "",
None)
+ csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", [],
None)
return csource_module