This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new db5f4fe65c [Runtime] Add 'static_library' runtime::Module (#11442)
db5f4fe65c is described below
commit db5f4fe65cb01ff50dfd05d2b2d59a66b78079c7
Author: Mark Shields <[email protected]>
AuthorDate: Thu May 26 09:26:05 2022 -0700
[Runtime] Add 'static_library' runtime::Module (#11442)
(See
https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6
for
context, which in turn is part of Collage
(https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md).
This adds a new 'DSO exportable' runtime module representing the contents
of a .o file. It
allows external codegen toolchains to yield a result which:
- Like CSource modules, can be conveyed directly to the final
export_library compilation
step for linking into the final .so and saved to a know location without
risk the
underlying code artifact will be lost.
- Like DSOLibrary modules, are self contained so that no additional
compile-time arguments
need be conveyed from the CSource module to the final export_library
command line
Since this is the third flavor of 'DSO exportable' module, add a
Module::IsDSOExportable.
Since adding the above, can't resist also adding a
Module::ImplementsFunction virtual and
calling it from TEComplier to check if an external codegen function
actually provided the
implementation it promised.
Note:
- I've left the existing implementation of runtime.load_module alone which
relinks .o files to .so files.
- Though also contained in the .o metadata, I require static libraries to
always
carry their list of exported function names.
This is all pretty stop gap pending a good rework of TVM to supoprt the
notion of artifacts
and, perhaps, build rules.
---
include/tvm/runtime/module.h | 28 ++++++
python/tvm/contrib/cc.py | 3 +
python/tvm/contrib/nvcc.py | 2 +
python/tvm/runtime/__init__.py | 2 +-
python/tvm/runtime/module.py | 52 ++++++++--
src/printer/model_library_format_printer.cc | 2 +-
src/relay/backend/contrib/ethosu/source_module.cc | 11 ++-
src/relay/backend/te_compiler.cc | 29 +++---
src/relay/backend/vm/compiler.h | 2 +-
src/runtime/aot_executor/aot_executor_factory.h | 2 +-
src/runtime/const_loader_module.cc | 2 +-
src/runtime/contrib/json/json_runtime.h | 2 +-
src/runtime/contrib/tensorrt/tensorrt_runtime.cc | 2 +-
.../graph_executor/graph_executor_factory.h | 2 +-
src/runtime/metadata.cc | 2 +-
src/runtime/module.cc | 18 +++-
src/runtime/stackvm/stackvm_module.cc | 2 +-
src/runtime/static_library.cc | 106 +++++++++++++++++++++
src/runtime/static_library.h | 50 ++++++++++
src/support/ffi_testing.cc | 2 +-
src/target/codegen.cc | 12 +--
src/target/llvm/llvm_module.cc | 8 +-
src/target/metadata_module.cc | 6 +-
src/target/source/interface_c.cc | 2 +-
src/target/source/source_module.cc | 20 +++-
.../python/unittest/test_runtime_module_export.py | 48 ++++++++--
26 files changed, 356 insertions(+), 61 deletions(-)
diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index 076172a8b5..875d999c64 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -28,6 +28,7 @@
#include <dmlc/io.h>
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/container/string.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
@@ -190,6 +191,33 @@ class TVM_DLL ModuleNode : public Object {
/*! \return The module it imports from */
const std::vector<Module>& imports() const { return imports_; }
+ /*!
+ * \brief Returns true if this module is 'DSO exportable'.
+ *
+ * A DSO exportable module (eg a CSourceModuleNode of type_key 'c') can be
incorporated into the
+ * final runtime artifact (ie shared library) by compilation and/or linking
using the external
+ * compiler (llvm, nvcc, etc). DSO exportable modules must implement
SaveToFile.
+ *
+ * By contrast, non-DSO exportable modules (eg CUDAModuleNode of type_key
'cuda') typically must
+ * be incorporated into the final runtime artifact by being serialized as
data into the
+ * artifact, then deserialized at runtime. Non-DSO exportable modules must
implement SaveToBinary,
+ * and have a matching deserializer registered as
'runtime.module.loadbinary_<type_key>'.
+ *
+ * The default implementation returns false.
+ */
+ virtual bool IsDSOExportable() const;
+
+ /*!
+ * \brief Returns true if this module has a definition for a function of \p
name. If
+ * \p query_imports is true, also search in any imported modules.
+ *
+ * Note that even if this function returns true the corresponding \p
GetFunction result may be
+ * nullptr if the function is not yet callable without further compilation.
+ *
+ * The default implementation just checkis if \p GetFunction is non-null.
+ */
+ virtual bool ImplementsFunction(const String& name, bool query_imports =
false);
+
// integration with the existing components.
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
static constexpr const char* _type_key = "runtime.Module";
diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py
index 867cbd6012..ec40ef3189 100644
--- a/python/tvm/contrib/cc.py
+++ b/python/tvm/contrib/cc.py
@@ -19,6 +19,7 @@
import sys
import os
import subprocess
+import logging
from .._ffi.base import py_str
@@ -238,6 +239,7 @@ def _linux_compile(output, objects, options, compile_cmd,
compile_shared=False):
cmd += objects
if options:
cmd += options
+ logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
@@ -264,6 +266,7 @@ def _windows_compile(output, objects, options):
cmd += options
try:
+ logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
except FileNotFoundError:
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index 5a104be996..33a32c9c00 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
import subprocess
import os
import warnings
+import logging
import tvm._ffi
from tvm.target import Target
@@ -102,6 +103,7 @@ def compile_cuda(code, target_format="ptx", arch=None,
options=None, path_target
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path]
+ logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index e0da680a24..114f01dd0e 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -28,7 +28,7 @@ from .profiling import Report
from .object_generic import convert_to_object, convert, const
from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, ext_dev
-from .module import load_module, enabled, system_lib
+from .module import load_module, enabled, system_lib, load_static_library
from .container import String, ShapeTuple
from .params import save_param_dict, load_param_dict
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 64b3d506b6..c614e5d757 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -127,6 +127,28 @@ class Module(object):
self._entry = self.get_function(self.entry_name)
return self._entry
+ def implements_function(self, name, query_imports=False):
+ """Returns True if the module has a definition for the global function
with name. Note
+ that has_function(name) does not imply get_function(name) is non-null
since the module
+ may be, eg, a CSourceModule which cannot supply a packed-func
implementation of the function
+ without further compilation. However, get_function(name) non null
should always imply
+ has_function(name).
+
+ Parameters
+ ----------
+ name : str
+ The name of the function
+
+ query_imports : bool
+ Whether to also query modules imported by this module.
+
+ Returns
+ -------
+ b : Bool
+ True if module (or one of its imports) has a definition for name.
+ """
+ return _ffi_api.ModuleImplementsFunction(self, name, query_imports)
+
def get_function(self, name, query_imports=False):
"""Get function from the module.
@@ -217,6 +239,18 @@ class Module(object):
nmod = _ffi_api.ModuleImportsSize(self)
return [_ffi_api.ModuleGetImport(self, i) for i in range(nmod)]
+ @property
+ def is_dso_exportable(self):
+ """Returns true if module is 'DSO exportable', ie can be included in
result of
+ export_library by the external compiler directly.
+
+ Returns
+ -------
+ b : Bool
+ True if the module is DSO exportable.
+ """
+ return _ffi_api.ModuleIsDSOExportable(self)
+
def save(self, file_name, fmt=""):
"""Save the module to file.
@@ -332,8 +366,7 @@ class Module(object):
return dso_modules
def _collect_dso_modules(self):
- is_dso_exportable = lambda m: (m.type_key == "llvm" or m.type_key ==
"c")
- return self._collect_from_import_tree(is_dso_exportable)
+ return self._collect_from_import_tree(lambda m: m.is_dso_exportable)
def export_library(self, file_name, fcompile=None, addons=None,
workspace_dir=None, **kwargs):
"""
@@ -418,10 +451,7 @@ class Module(object):
else:
object_format = fcompile.object_format
else:
- if module.type_key == "llvm":
- object_format = "o"
- else:
- assert module.type_key == "c"
+ if module.type_key == "c":
if len(module.format) > 0:
assert module.format in [
"c",
@@ -436,6 +466,9 @@ class Module(object):
if kwargs["cc"] == "nvcc":
object_format = "cu"
has_c_module = True
+ else:
+ assert module.type_key == "llvm" or module.type_key ==
"static_library"
+ object_format = "o"
path_obj = os.path.join(workspace_dir,
f"lib{index}.{object_format}")
module.save(path_obj)
files.append(path_obj)
@@ -552,6 +585,13 @@ def load_module(path, fmt=""):
return _ffi_api.ModuleLoadFromFile(path, fmt)
+def load_static_library(path, func_names):
+ """Load the .o library at path which implements functions with func_names.
+ Unlike the generic load_module the result will remain as a static_library
+ and will not be relinked on-the-fly into a .so library."""
+ return _ffi_api.ModuleLoadStaticLibrary(path, func_names)
+
+
def enabled(target):
"""Whether module runtime is enabled for target
diff --git a/src/printer/model_library_format_printer.cc
b/src/printer/model_library_format_printer.cc
index 17ba84e68d..f6ac39ce79 100644
--- a/src/printer/model_library_format_printer.cc
+++ b/src/printer/model_library_format_printer.cc
@@ -35,7 +35,7 @@ class ModelLibraryFormatPrinter : public
::tvm::runtime::ModuleNode {
bool show_warning)
: text_printer_{show_meta_data, annotate, show_warning} {}
- const char* type_key() const override { return
"model_library_format_printer"; }
+ const char* type_key() const final { return "model_library_format_printer"; }
std::string Print(const ObjectRef& node) {
Doc doc;
diff --git a/src/relay/backend/contrib/ethosu/source_module.cc
b/src/relay/backend/contrib/ethosu/source_module.cc
index c79785c869..eb4b779ecd 100644
--- a/src/relay/backend/contrib/ethosu/source_module.cc
+++ b/src/relay/backend/contrib/ethosu/source_module.cc
@@ -114,13 +114,22 @@ class EthosUModuleNode : public ModuleNode {
return PackedFunc();
}
- const char* type_key() const override { return "c"; }
+ const char* type_key() const final { return "c"; }
static Module Create(Array<CompilationArtifact> compilation_artifacts) {
auto n = make_object<EthosUModuleNode>(compilation_artifacts);
return Module(n);
}
+ bool IsDSOExportable() const final { return true; }
+
+ bool ImplementsFunction(const String& name, bool query_imports) final {
+ return std::find_if(compilation_artifacts_.begin(),
compilation_artifacts_.end(),
+ [&name](const CompilationArtifact& artifact) {
+ return artifact->function_name == name;
+ }) != compilation_artifacts_.end();
+ }
+
private:
std::string c_source;
Array<CompilationArtifact> compilation_artifacts_;
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 5a0502d175..76dbfef538 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -155,6 +155,7 @@ class TECompilerImpl : public TECompilerNode {
}
}
for (const auto& global_var : to_be_deleted) {
+ VLOG(1) << "Removing definition for external codegened '" <<
global_var->name_hint << "'";
module->Remove(global_var);
}
// HOWEVER we still need a Relay definition to go with those now external
functions, so
@@ -203,27 +204,29 @@ class TECompilerImpl : public TECompilerNode {
std::string ext_name = "relay.ext." + opt_compiler.value();
auto pf = tvm::runtime::Registry::Get(ext_name);
- ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
+ ICHECK(pf) << "Failed to find the external codegen tool for " <<
ext_name;
// No need to keep compiler attribute at this point, functions have
been
// extracted for specific codegen.
src_func = WithAttr(std::move(src_func), attr::kCompiler,
NullValue<ObjectRef>());
- VLOG_CONTEXT << ext_name;
+ VLOG_CONTEXT << opt_compiler.value();
+ With<Target> with_target(it.first->target);
runtime::Module ext_mod = (*pf)(src_func);
if (ext_mod.defined()) {
- if (ext_mod->GetFunction(opt_symbol_name.value(),
/*query_imports=*/true) == nullptr) {
- // It's possible the codegen yielded C or C++ tracked separately
and thus the
- // returned runtime module can be empty.
- VLOG(1) << "Unable to find definition for the external function '"
- << opt_symbol_name.value()
- << "' in the runtime module generated by external codegen
'"
- << opt_compiler.value() << "'";
+ // TODO(mbs): Can this be an ICHECKs?
+ if (!ext_mod->ImplementsFunction(opt_symbol_name.value())) {
+ VLOG(1) << "Note that the external codegen for '" <<
opt_compiler.value()
+ << "' returned a runtime module which does not appear to
implement '"
+ << opt_symbol_name.value() << "'";
}
ret.push_back(ext_mod);
} else {
- // A warning only so that we can write unit tests which can return
an empty runtime
- // module.
- LOG(WARNING) << "No external runtime module was generated by
external codegen '"
- << opt_compiler.value() << "'";
+ // It is valid for the external codegen function to return null:
+ // - Unit tests can use it.
+ // - The true compilation may have already been handled by a
RelayToTIR custom hook pass
+ // on the Target's kind. The original Relay functions will be
left in place so
+ // that we can capture that their function names are now
externally defined.
+ VLOG(1) << "Note that no external runtime module was generated by
external codegen '"
+ << opt_compiler.value() << "'";
}
}
}
diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h
index b1c977e526..a65bdc5ab3 100644
--- a/src/relay/backend/vm/compiler.h
+++ b/src/relay/backend/vm/compiler.h
@@ -91,7 +91,7 @@ class VMCompiler : public runtime::ModuleNode {
virtual PackedFunc GetFunction(const std::string& name, const
ObjectPtr<Object>& sptr_to_self);
- const char* type_key() const { return "VMCompiler"; }
+ const char* type_key() const final { return "VMCompiler"; }
/*!
* \brief Set the parameters
diff --git a/src/runtime/aot_executor/aot_executor_factory.h
b/src/runtime/aot_executor/aot_executor_factory.h
index 1d6a0a6277..ada63f0ba8 100644
--- a/src/runtime/aot_executor/aot_executor_factory.h
+++ b/src/runtime/aot_executor/aot_executor_factory.h
@@ -63,7 +63,7 @@ class TVM_DLL AotExecutorFactory : public runtime::ModuleNode
{
/*!
* \return The type key of the executor.
*/
- const char* type_key() const override { return "AotExecutorFactory"; }
+ const char* type_key() const final { return "AotExecutorFactory"; }
/*!
* \brief Save the module to binary stream.
diff --git a/src/runtime/const_loader_module.cc
b/src/runtime/const_loader_module.cc
index 5496e161e5..2e91d26d5f 100644
--- a/src/runtime/const_loader_module.cc
+++ b/src/runtime/const_loader_module.cc
@@ -79,7 +79,7 @@ class ConstLoaderModuleNode : public ModuleNode {
return PackedFunc(nullptr);
}
- const char* type_key() const { return "const_loader"; }
+ const char* type_key() const final { return "const_loader"; }
/*!
* \brief Get the list of constants that is required by the given module.
diff --git a/src/runtime/contrib/json/json_runtime.h
b/src/runtime/contrib/json/json_runtime.h
index 374a440e29..355390765d 100644
--- a/src/runtime/contrib/json/json_runtime.h
+++ b/src/runtime/contrib/json/json_runtime.h
@@ -54,7 +54,7 @@ class JSONRuntimeBase : public ModuleNode {
LoadGraph(graph_json_);
}
- const char* type_key() const override { return "json"; }
+ const char* type_key() const override { return "json"; } // May be
overridden
/*! \brief Initialize a specific json runtime. */
virtual void Init(const Array<NDArray>& consts) = 0;
diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
index b60074e66d..554515c456 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
@@ -95,7 +95,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
*
* \return module type key.
*/
- const char* type_key() const override { return "tensorrt"; }
+ const char* type_key() const final { return "tensorrt"; }
/*!
* \brief Initialize runtime. Create TensorRT layer from JSON
diff --git a/src/runtime/graph_executor/graph_executor_factory.h
b/src/runtime/graph_executor/graph_executor_factory.h
index 1ee74c3547..d8ebe44bb9 100644
--- a/src/runtime/graph_executor/graph_executor_factory.h
+++ b/src/runtime/graph_executor/graph_executor_factory.h
@@ -65,7 +65,7 @@ class TVM_DLL GraphExecutorFactory : public
runtime::ModuleNode {
/*!
* \return The type key of the executor.
*/
- const char* type_key() const override { return "GraphExecutorFactory"; }
+ const char* type_key() const final { return "GraphExecutorFactory"; }
/*!
* \brief Save the module to binary stream.
diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc
index c08f2872fe..8e034cc94d 100644
--- a/src/runtime/metadata.cc
+++ b/src/runtime/metadata.cc
@@ -75,7 +75,7 @@ class MetadataModuleNode : public ::tvm::runtime::ModuleNode {
explicit MetadataModuleNode(runtime::metadata::Metadata metadata)
: metadata_{::std::move(metadata)} {}
- const char* type_key() const { return "metadata_module"; }
+ const char* type_key() const final { return "metadata_module"; }
static Module LoadFromBinary() {
return
Module(make_object<MetadataModuleNode>(runtime::metadata::Metadata()));
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index 097d6a2f53..57fe575689 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -83,6 +83,7 @@ Module Module::LoadFromFile(const std::string& file_name,
const std::string& for
fmt = "so";
}
std::string load_f_name = "runtime.module.loadfile_" + fmt;
+ VLOG(1) << "Loading module from '" << file_name << "' of format '" << fmt <<
"'";
const PackedFunc* f = Registry::Get(load_f_name);
ICHECK(f != nullptr) << "Loader for `." << format << "` files is not
registered,"
<< " resolved to (" << load_f_name << ") in the global
registry."
@@ -132,6 +133,12 @@ std::string ModuleNode::GetFormat() {
return "";
}
+bool ModuleNode::IsDSOExportable() const { return false; }
+
+bool ModuleNode::ImplementsFunction(const String& name, bool query_imports) {
+ return GetFunction(name, query_imports) != nullptr;
+}
+
bool RuntimeEnabled(const std::string& target) {
std::string f_name;
if (target == "cpu") {
@@ -191,8 +198,15 @@
TVM_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) {
TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);
TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
- .set_body_typed([](Module mod, tvm::String name, tvm::String fmt) {
- mod->SaveToFile(name, fmt);
+ .set_body_typed([](Module mod, String name, tvm::String fmt) {
mod->SaveToFile(name, fmt); });
+
+TVM_REGISTER_GLOBAL("runtime.ModuleIsDSOExportable").set_body_typed([](Module
mod) {
+ return mod->IsDSOExportable();
+});
+
+TVM_REGISTER_GLOBAL("runtime.ModuleImplementsFunction")
+ .set_body_typed([](Module mod, String name, bool query_imports) {
+ return mod->ImplementsFunction(std::move(name), query_imports);
});
TVM_REGISTER_OBJECT_TYPE(ModuleNode);
diff --git a/src/runtime/stackvm/stackvm_module.cc
b/src/runtime/stackvm/stackvm_module.cc
index c784a9d048..bbcadd21b4 100644
--- a/src/runtime/stackvm/stackvm_module.cc
+++ b/src/runtime/stackvm/stackvm_module.cc
@@ -37,7 +37,7 @@ namespace runtime {
class StackVMModuleNode : public runtime::ModuleNode {
public:
- const char* type_key() const { return "stackvm"; }
+ const char* type_key() const final { return "stackvm"; }
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) final {
if (name == runtime::symbol::tvm_module_main) {
diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc
new file mode 100644
index 0000000000..e845d0fac2
--- /dev/null
+++ b/src/runtime/static_library.cc
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file runtime/static_library.cc
+ * \brief Represents a generic '.o' static library which can be linked into
the final output
+ * dynamic library by export_library.
+ */
+#include "./static_library.h"
+
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <iostream>
+
+#include "file_utils.h"
+
+namespace tvm {
+namespace runtime {
+
+namespace {
+
+/*!
+ * \brief A '.o' library which can be linked into the final output library by
export_library.
+ * Can be used by external codegen tools which can produce a ready-to-link
artifact.
+ */
+class StaticLibraryNode final : public runtime::ModuleNode {
+ public:
+ ~StaticLibraryNode() override = default;
+
+ const char* type_key() const final { return "static_library"; }
+
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) final {
+ if (name == "get_func_names") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = func_names_; });
+ } else {
+ return {};
+ }
+ }
+
+ void SaveToFile(const std::string& file_name, const std::string& format)
final {
+ VLOG(0) << "Saving static library of " << data_.size() << " bytes
implementing " << FuncNames()
+ << " to '" << file_name << "'";
+ SaveBinaryToFile(file_name, data_);
+ }
+
+ bool IsDSOExportable() const final { return true; }
+
+ bool ImplementsFunction(const String& name, bool query_imports) final {
+ return std::find(func_names_.begin(), func_names_.end(), name) !=
func_names_.end();
+ }
+
+ std::string FuncNames() {
+ std::ostringstream os;
+ os << "[";
+ bool first = true;
+ for (const auto& func_name : func_names_) {
+ if (first) {
+ first = false;
+ } else {
+ os << ", ";
+ }
+ os << "'" << func_name << "'";
+ }
+ os << "]";
+ return os.str();
+ }
+
+ /*! \brief Contents of the object file. */
+ std::string data_;
+ /*! \brief Function names exported by the above. */
+ Array<String> func_names_;
+};
+
+} // namespace
+
+Module LoadStaticLibrary(const std::string& filename, Array<String>
func_names) {
+ auto node = make_object<StaticLibraryNode>();
+ LoadBinaryFromFile(filename, &node->data_);
+ node->func_names_ = std::move(func_names);
+ VLOG(0) << "Loaded static library from '" << filename << "' implementing "
<< node->FuncNames();
+ return Module(node);
+}
+
+TVM_REGISTER_GLOBAL("runtime.ModuleLoadStaticLibrary").set_body_typed(LoadStaticLibrary);
+
+} // namespace runtime
+} // namespace tvm
diff --git a/src/runtime/static_library.h b/src/runtime/static_library.h
new file mode 100644
index 0000000000..352891f6fb
--- /dev/null
+++ b/src/runtime/static_library.h
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file runtime/static_library.h
+ * \brief Represents a generic '.o' static library which can be linked into
the final output
+ * dynamic library by export_library.
+ */
+
+#ifndef TVM_RUNTIME_STATIC_LIBRARY_H_
+#define TVM_RUNTIME_STATIC_LIBRARY_H_
+
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/module.h>
+
+#include <array>
+#include <memory>
+#include <set>
+#include <string>
+#include <unordered_map>
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief Returns a static library with the contents loaded from filename
which exports
+ * func_names with the usual packed-func calling convention.
+ */
+Module LoadStaticLibrary(const std::string& filename, Array<String>
func_names);
+
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_STATIC_LIBRARY_H_
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index c7cec21508..26d7a4b70f 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -121,7 +121,7 @@
TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRet
class FrontendTestModuleNode : public runtime::ModuleNode {
public:
- virtual const char* type_key() const { return "frontend_test"; }
+ const char* type_key() const final { return "frontend_test"; }
static constexpr const char* kAddFunctionName = "__add_function";
diff --git a/src/target/codegen.cc b/src/target/codegen.cc
index 41221ad8a3..3c4866be1b 100644
--- a/src/target/codegen.cc
+++ b/src/target/codegen.cc
@@ -68,7 +68,7 @@ class ModuleSerializer {
// Only have one DSO module and it is in the root, then
// we will not produce import_tree_.
bool has_import_tree = true;
- if (DSOExportable(mod_.operator->()) && mod_->imports().empty()) {
+ if (mod_->IsDSOExportable() && mod_->imports().empty()) {
has_import_tree = false;
}
uint64_t sz = 0;
@@ -84,7 +84,7 @@ class ModuleSerializer {
for (const auto& group : mod_group_vec_) {
ICHECK_NE(group.size(), 0) << "Every allocated group must have at least
one module";
- if (!DSOExportable(group[0])) {
+ if (!group[0]->IsDSOExportable()) {
ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged";
std::string mod_type_key = group[0]->type_key();
stream->Write(mod_type_key);
@@ -147,7 +147,7 @@ class ModuleSerializer {
while (!stack.empty()) {
runtime::ModuleNode* n = stack.back();
stack.pop_back();
- if (DSOExportable(n)) {
+ if (n->IsDSOExportable()) {
// do not recursively expand dso modules
// we will expand in phase 1
dso_exportable_boundary.emplace_back(n);
@@ -174,7 +174,7 @@ class ModuleSerializer {
runtime::ModuleNode* n = stack.back();
stack.pop_back();
- if (DSOExportable(n)) {
+ if (n->IsDSOExportable()) {
mod_group_vec_[dso_module_index].emplace_back(n);
mod2index_[n] = dso_module_index;
} else {
@@ -219,10 +219,6 @@ class ModuleSerializer {
}
}
- bool DSOExportable(const runtime::ModuleNode* mod) {
- return !std::strcmp(mod->type_key(), "llvm") ||
!std::strcmp(mod->type_key(), "c");
- }
-
runtime::Module mod_;
// construct module to index
std::unordered_map<runtime::ModuleNode*, size_t> mod2index_;
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index b2dc4c81f9..c7aea3dc19 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -56,7 +56,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
}
}
- const char* type_key() const { return "llvm"; }
+ const char* type_key() const final { return "llvm"; }
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) final {
if (name == "__tvm_is_system_module") {
@@ -357,6 +357,12 @@ class LLVMModuleNode final : public runtime::ModuleNode {
Init(std::move(module), ctx);
}
+ bool IsDSOExportable() const final { return true; }
+
+ bool ImplementsFunction(const String& name, bool query_imports) final {
+ return std::find(function_names_.begin(), function_names_.end(), name) !=
function_names_.end();
+ }
+
private:
void LazyInitJIT() {
std::lock_guard<std::mutex> lock(mutex_);
diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc
index 840ba5cab2..97299c6375 100644
--- a/src/target/metadata_module.cc
+++ b/src/target/metadata_module.cc
@@ -190,10 +190,6 @@ runtime::Module CreateMetadataModule(
Array<runtime::Module> crt_exportable_modules;
Array<runtime::Module> non_crt_exportable_modules;
- auto DSOExportable = [](tvm::runtime::Module& mod) {
- return !std::strcmp(mod->type_key(), "llvm") ||
!std::strcmp(mod->type_key(), "c");
- };
-
bool is_targeting_crt = runtime->name == "crt";
// Wrap all submodules in the initialization wrapper.
@@ -219,7 +215,7 @@ runtime::Module CreateMetadataModule(
// TODO(@manupa-arm) : we should be able to use csource_metadata
// if the variables are empty when all the runtime modules implement
get_func_names
- if (symbol_const_vars.empty() && is_targeting_crt && DSOExportable(mod) &&
+ if (symbol_const_vars.empty() && is_targeting_crt &&
mod->IsDSOExportable() &&
(target->kind->name == "c" || target->kind->name == "llvm")) {
crt_exportable_modules.push_back(mod);
} else {
diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc
index 12d930d8f8..1bb567d148 100644
--- a/src/target/source/interface_c.cc
+++ b/src/target/source/interface_c.cc
@@ -52,7 +52,7 @@ class InterfaceCNode : public runtime::ModuleNode {
pools_(FilterExternalPools(pools)),
io_pool_allocations_(io_pool_allocations),
workspace_size_(workspace_size) {}
- const char* type_key() const { return "h"; }
+ const char* type_key() const final { return "h"; }
std::string GetSource(const std::string& format) final {
std::stringstream code;
diff --git a/src/target/source/source_module.cc
b/src/target/source/source_module.cc
index 8f581f4cbb..2c4993419f 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -59,7 +59,7 @@ using runtime::SaveBinaryToFile;
class SourceModuleNode : public runtime::ModuleNode {
public:
SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt)
{}
- const char* type_key() const { return "source"; }
+ const char* type_key() const final { return "source"; }
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module"
@@ -87,7 +87,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
CSourceModuleNode(const std::string& code, const std::string& fmt,
const Array<String>& func_names, const Array<String>&
const_vars)
: code_(code), fmt_(fmt), const_vars_(const_vars),
func_names_(func_names) {}
- const char* type_key() const { return "c"; }
+ const char* type_key() const final { return "c"; }
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) final {
// Currently c-source module is used as demonstration purposes with binary
metadata module
@@ -123,6 +123,12 @@ class CSourceModuleNode : public runtime::ModuleNode {
}
}
+ bool IsDSOExportable() const final { return true; }
+
+ bool ImplementsFunction(const String& name, bool query_imports) final {
+ return std::find(func_names_.begin(), func_names_.end(), name) !=
func_names_.end();
+ }
+
protected:
std::string code_;
std::string fmt_;
@@ -166,7 +172,7 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
metadata_(metadata) {
CreateSource();
}
- const char* type_key() const { return "c"; }
+ const char* type_key() const final { return "c"; }
std::string GetSource(const std::string& format) final { return code_.str();
}
@@ -187,6 +193,12 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
}
}
+ bool IsDSOExportable() const final { return true; }
+
+ bool ImplementsFunction(const String& name, bool query_imports) final {
+ return std::find(func_names_.begin(), func_names_.end(), name) !=
func_names_.end();
+ }
+
protected:
std::stringstream code_;
std::string fmt_;
@@ -908,7 +920,7 @@ class DeviceSourceModuleNode final : public
runtime::ModuleNode {
}
}
- const char* type_key() const { return type_key_.c_str(); }
+ const char* type_key() const final { return type_key_.c_str(); }
void SaveToFile(const std::string& file_name, const std::string& format)
final {
std::string fmt = GetFileFormat(file_name, format);
diff --git a/tests/python/unittest/test_runtime_module_export.py
b/tests/python/unittest/test_runtime_module_export.py
index 9ea1ff437f..57fcaea03d 100644
--- a/tests/python/unittest/test_runtime_module_export.py
+++ b/tests/python/unittest/test_runtime_module_export.py
@@ -18,7 +18,6 @@ from tvm import relay
from tvm.relay import testing
import tvm
from tvm import te
-
import tvm.testing
from tvm.contrib import utils
@@ -80,8 +79,6 @@ def test_mod_export():
synthetic_llvm_mod, "llvm", params=synthetic_llvm_params,
mod_name="llvmlib"
)
- from tvm.contrib import utils
-
temp = utils.tempdir()
if obj_format == ".so":
file_name = "deploy_lib.so"
@@ -109,8 +106,6 @@ def test_mod_export():
mod0 = tvm.build(s, [A, B], "llvm", name="myadd0")
mod1 = tvm.build(s, [A, B], "llvm", name="myadd1")
- from tvm.contrib import utils
-
temp = utils.tempdir()
if obj_format == ".so":
file_name = "deploy_lib.so"
@@ -152,8 +147,6 @@ def test_mod_export():
+ "mul 6 inputs: 5 3 shape: 10 10"
)
- from tvm.contrib import utils
-
temp = utils.tempdir()
subgraph_path = temp.relpath("subgraph.examplejson")
with open(subgraph_path, "w") as f:
@@ -203,7 +196,6 @@ def test_mod_export():
s = te.create_schedule(B.op)
f = tvm.build(s, [A, B], "c", name="myadd")
engine_module = generate_engine_module()
- from tvm.contrib import utils
temp = utils.tempdir()
file_name = "deploy_lib.so"
@@ -225,5 +217,43 @@ def test_mod_export():
verify_multi_c_mod_export()
[email protected]_llvm
+def test_import_static_library():
+ # Generate two LLVM modules.
+ A = te.placeholder((1024,), name="A")
+ B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
+ s = te.create_schedule(B.op)
+ mod0 = tvm.build(s, [A, B], "llvm", name="myadd0")
+ mod1 = tvm.build(s, [A, B], "llvm", name="myadd1")
+
+ assert mod0.implements_function("myadd0")
+ assert mod1.implements_function("myadd1")
+ assert mod1.is_dso_exportable
+
+ # mod1 is currently an 'llvm' module.
+ # Save and reload it as a vanilla 'static_library'.
+ temp = utils.tempdir()
+ mod1_o_path = temp.relpath("mod1.o")
+ mod1.save(mod1_o_path)
+ mod1_o = tvm.runtime.load_static_library(mod1_o_path, ["myadd1"])
+ assert mod1_o.implements_function("myadd1")
+ assert mod1_o.is_dso_exportable
+
+ # Import mod1 as a static library into mod0 and compile to its own DSO.
+ mod0.import_module(mod1_o)
+ mod0_dso_path = temp.relpath("mod0.so")
+ mod0.export_library(mod0_dso_path)
+
+ # The imported mod1 is statically linked into mod0.
+ loaded_lib = tvm.runtime.load_module(mod0_dso_path)
+ assert loaded_lib.type_key == "library"
+ assert len(loaded_lib.imported_modules) == 0
+ assert loaded_lib.implements_function("myadd0")
+ assert loaded_lib.get_function("myadd0")
+ assert loaded_lib.implements_function("myadd1")
+ assert loaded_lib.get_function("myadd1")
+ assert not loaded_lib.is_dso_exportable
+
+
if __name__ == "__main__":
- test_mod_export()
+ tvm.testing.main()