This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new db5f4fe65c [Runtime] Add 'static_library' runtime::Module (#11442)
db5f4fe65c is described below

commit db5f4fe65cb01ff50dfd05d2b2d59a66b78079c7
Author: Mark Shields <[email protected]>
AuthorDate: Thu May 26 09:26:05 2022 -0700

    [Runtime] Add 'static_library' runtime::Module (#11442)
    
    (See 
https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6
 for
    context, which in turn is part of Collage 
(https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md).
    
    This adds a new 'DSO exportable' runtime module representing the contents 
of a .o file. It
    allows external codegen toolchains to yield a result which:
     - Like CSource modules, can be conveyed directly to the final 
export_library compilation
       step for linking into the final .so and saved to a know location without 
risk the
       underlying code artifact will be lost.
     - Like DSOLibrary modules, are self contained so that no additional 
compile-time arguments
       need be conveyed from the CSource module to the final export_library 
command line
    
    Since this is the third flavor of 'DSO exportable' module, add a 
Module::IsDSOExportable.
    
    Since adding the above, can't resist also adding a 
Module::ImplementsFunction virtual and
    calling it from TEComplier to check if an external codegen function 
actually provided the
    implementation it promised.
    
    Note:
     - I've left the existing implementation of runtime.load_module alone which
       relinks .o files to .so files.
     - Though also contained in the .o metadata, I require static libraries to 
always
       carry their list of exported function names.
    
    This is all pretty stop gap pending a good rework of TVM to supoprt the 
notion of artifacts
    and, perhaps, build rules.
---
 include/tvm/runtime/module.h                       |  28 ++++++
 python/tvm/contrib/cc.py                           |   3 +
 python/tvm/contrib/nvcc.py                         |   2 +
 python/tvm/runtime/__init__.py                     |   2 +-
 python/tvm/runtime/module.py                       |  52 ++++++++--
 src/printer/model_library_format_printer.cc        |   2 +-
 src/relay/backend/contrib/ethosu/source_module.cc  |  11 ++-
 src/relay/backend/te_compiler.cc                   |  29 +++---
 src/relay/backend/vm/compiler.h                    |   2 +-
 src/runtime/aot_executor/aot_executor_factory.h    |   2 +-
 src/runtime/const_loader_module.cc                 |   2 +-
 src/runtime/contrib/json/json_runtime.h            |   2 +-
 src/runtime/contrib/tensorrt/tensorrt_runtime.cc   |   2 +-
 .../graph_executor/graph_executor_factory.h        |   2 +-
 src/runtime/metadata.cc                            |   2 +-
 src/runtime/module.cc                              |  18 +++-
 src/runtime/stackvm/stackvm_module.cc              |   2 +-
 src/runtime/static_library.cc                      | 106 +++++++++++++++++++++
 src/runtime/static_library.h                       |  50 ++++++++++
 src/support/ffi_testing.cc                         |   2 +-
 src/target/codegen.cc                              |  12 +--
 src/target/llvm/llvm_module.cc                     |   8 +-
 src/target/metadata_module.cc                      |   6 +-
 src/target/source/interface_c.cc                   |   2 +-
 src/target/source/source_module.cc                 |  20 +++-
 .../python/unittest/test_runtime_module_export.py  |  48 ++++++++--
 26 files changed, 356 insertions(+), 61 deletions(-)

diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index 076172a8b5..875d999c64 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -28,6 +28,7 @@
 
 #include <dmlc/io.h>
 #include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/container/string.h>
 #include <tvm/runtime/memory.h>
 #include <tvm/runtime/object.h>
 
@@ -190,6 +191,33 @@ class TVM_DLL ModuleNode : public Object {
   /*! \return The module it imports from */
   const std::vector<Module>& imports() const { return imports_; }
 
+  /*!
+   * \brief Returns true if this module is 'DSO exportable'.
+   *
+   * A DSO exportable module (eg a CSourceModuleNode of type_key 'c') can be 
incorporated into the
+   * final runtime artifact (ie shared library) by compilation and/or linking 
using the external
+   * compiler (llvm, nvcc, etc). DSO exportable modules must implement 
SaveToFile.
+   *
+   * By contrast, non-DSO exportable modules (eg CUDAModuleNode of type_key 
'cuda') typically must
+   * be incorporated into the final runtime artifact by being serialized as 
data into the
+   * artifact, then deserialized at runtime. Non-DSO exportable modules must 
implement SaveToBinary,
+   * and have a matching deserializer registered as 
'runtime.module.loadbinary_<type_key>'.
+   *
+   * The default implementation returns false.
+   */
+  virtual bool IsDSOExportable() const;
+
+  /*!
+   * \brief Returns true if this module has a definition for a function of \p 
name. If
+   * \p query_imports is true, also search in any imported modules.
+   *
+   * Note that even if this function returns true the corresponding \p 
GetFunction result may be
+   * nullptr if the function is not yet callable without further compilation.
+   *
+   * The default implementation just checkis if \p GetFunction is non-null.
+   */
+  virtual bool ImplementsFunction(const String& name, bool query_imports = 
false);
+
   // integration with the existing components.
   static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
   static constexpr const char* _type_key = "runtime.Module";
diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py
index 867cbd6012..ec40ef3189 100644
--- a/python/tvm/contrib/cc.py
+++ b/python/tvm/contrib/cc.py
@@ -19,6 +19,7 @@
 import sys
 import os
 import subprocess
+import logging
 
 from .._ffi.base import py_str
 
@@ -238,6 +239,7 @@ def _linux_compile(output, objects, options, compile_cmd, 
compile_shared=False):
         cmd += objects
     if options:
         cmd += options
+    logging.info("invoking '%s'", cmd)
     proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, 
stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
     if proc.returncode != 0:
@@ -264,6 +266,7 @@ def _windows_compile(output, objects, options):
         cmd += options
 
     try:
+        logging.info("invoking '%s'", cmd)
         proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, 
stderr=subprocess.STDOUT)
         (out, _) = proc.communicate()
     except FileNotFoundError:
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index 5a104be996..33a32c9c00 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
 import subprocess
 import os
 import warnings
+import logging
 
 import tvm._ffi
 from tvm.target import Target
@@ -102,6 +103,7 @@ def compile_cuda(code, target_format="ptx", arch=None, 
options=None, path_target
     # if cxx_compiler_path != "":
     #    cmd += ["-ccbin", cxx_compiler_path]
 
+    logging.info("invoking '%s'", cmd)
     proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, 
stderr=subprocess.STDOUT)
 
     (out, _) = proc.communicate()
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index e0da680a24..114f01dd0e 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -28,7 +28,7 @@ from .profiling import Report
 from .object_generic import convert_to_object, convert, const
 from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
 from .ndarray import vpi, rocm, ext_dev
-from .module import load_module, enabled, system_lib
+from .module import load_module, enabled, system_lib, load_static_library
 from .container import String, ShapeTuple
 from .params import save_param_dict, load_param_dict
 
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 64b3d506b6..c614e5d757 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -127,6 +127,28 @@ class Module(object):
         self._entry = self.get_function(self.entry_name)
         return self._entry
 
+    def implements_function(self, name, query_imports=False):
+        """Returns True if the module has a definition for the global function 
with name. Note
+        that has_function(name) does not imply get_function(name) is non-null 
since the module
+        may be, eg, a CSourceModule which cannot supply a packed-func 
implementation of the function
+        without further compilation. However, get_function(name) non null 
should always imply
+        has_function(name).
+
+        Parameters
+        ----------
+        name : str
+            The name of the function
+
+        query_imports : bool
+            Whether to also query modules imported by this module.
+
+        Returns
+        -------
+        b : Bool
+            True if module (or one of its imports) has a definition for name.
+        """
+        return _ffi_api.ModuleImplementsFunction(self, name, query_imports)
+
     def get_function(self, name, query_imports=False):
         """Get function from the module.
 
@@ -217,6 +239,18 @@ class Module(object):
         nmod = _ffi_api.ModuleImportsSize(self)
         return [_ffi_api.ModuleGetImport(self, i) for i in range(nmod)]
 
+    @property
+    def is_dso_exportable(self):
+        """Returns true if module is 'DSO exportable', ie can be included in 
result of
+        export_library by the external compiler directly.
+
+        Returns
+        -------
+        b : Bool
+            True if the module is DSO exportable.
+        """
+        return _ffi_api.ModuleIsDSOExportable(self)
+
     def save(self, file_name, fmt=""):
         """Save the module to file.
 
@@ -332,8 +366,7 @@ class Module(object):
         return dso_modules
 
     def _collect_dso_modules(self):
-        is_dso_exportable = lambda m: (m.type_key == "llvm" or m.type_key == 
"c")
-        return self._collect_from_import_tree(is_dso_exportable)
+        return self._collect_from_import_tree(lambda m: m.is_dso_exportable)
 
     def export_library(self, file_name, fcompile=None, addons=None, 
workspace_dir=None, **kwargs):
         """
@@ -418,10 +451,7 @@ class Module(object):
                 else:
                     object_format = fcompile.object_format
             else:
-                if module.type_key == "llvm":
-                    object_format = "o"
-                else:
-                    assert module.type_key == "c"
+                if module.type_key == "c":
                     if len(module.format) > 0:
                         assert module.format in [
                             "c",
@@ -436,6 +466,9 @@ class Module(object):
                         if kwargs["cc"] == "nvcc":
                             object_format = "cu"
                     has_c_module = True
+                else:
+                    assert module.type_key == "llvm" or module.type_key == 
"static_library"
+                    object_format = "o"
             path_obj = os.path.join(workspace_dir, 
f"lib{index}.{object_format}")
             module.save(path_obj)
             files.append(path_obj)
@@ -552,6 +585,13 @@ def load_module(path, fmt=""):
     return _ffi_api.ModuleLoadFromFile(path, fmt)
 
 
+def load_static_library(path, func_names):
+    """Load the .o library at path which implements functions with func_names.
+    Unlike the generic load_module the result will remain as a static_library
+    and will not be relinked on-the-fly into a .so library."""
+    return _ffi_api.ModuleLoadStaticLibrary(path, func_names)
+
+
 def enabled(target):
     """Whether module runtime is enabled for target
 
diff --git a/src/printer/model_library_format_printer.cc 
b/src/printer/model_library_format_printer.cc
index 17ba84e68d..f6ac39ce79 100644
--- a/src/printer/model_library_format_printer.cc
+++ b/src/printer/model_library_format_printer.cc
@@ -35,7 +35,7 @@ class ModelLibraryFormatPrinter : public 
::tvm::runtime::ModuleNode {
                             bool show_warning)
       : text_printer_{show_meta_data, annotate, show_warning} {}
 
-  const char* type_key() const override { return 
"model_library_format_printer"; }
+  const char* type_key() const final { return "model_library_format_printer"; }
 
   std::string Print(const ObjectRef& node) {
     Doc doc;
diff --git a/src/relay/backend/contrib/ethosu/source_module.cc 
b/src/relay/backend/contrib/ethosu/source_module.cc
index c79785c869..eb4b779ecd 100644
--- a/src/relay/backend/contrib/ethosu/source_module.cc
+++ b/src/relay/backend/contrib/ethosu/source_module.cc
@@ -114,13 +114,22 @@ class EthosUModuleNode : public ModuleNode {
     return PackedFunc();
   }
 
-  const char* type_key() const override { return "c"; }
+  const char* type_key() const final { return "c"; }
 
   static Module Create(Array<CompilationArtifact> compilation_artifacts) {
     auto n = make_object<EthosUModuleNode>(compilation_artifacts);
     return Module(n);
   }
 
+  bool IsDSOExportable() const final { return true; }
+
+  bool ImplementsFunction(const String& name, bool query_imports) final {
+    return std::find_if(compilation_artifacts_.begin(), 
compilation_artifacts_.end(),
+                        [&name](const CompilationArtifact& artifact) {
+                          return artifact->function_name == name;
+                        }) != compilation_artifacts_.end();
+  }
+
  private:
   std::string c_source;
   Array<CompilationArtifact> compilation_artifacts_;
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 5a0502d175..76dbfef538 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -155,6 +155,7 @@ class TECompilerImpl : public TECompilerNode {
       }
     }
     for (const auto& global_var : to_be_deleted) {
+      VLOG(1) << "Removing definition for external codegened '" << 
global_var->name_hint << "'";
       module->Remove(global_var);
     }
     // HOWEVER we still need a Relay definition to go with those now external 
functions, so
@@ -203,27 +204,29 @@ class TECompilerImpl : public TECompilerNode {
 
         std::string ext_name = "relay.ext." + opt_compiler.value();
         auto pf = tvm::runtime::Registry::Get(ext_name);
-        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
+        ICHECK(pf) << "Failed to find the external codegen tool for " << 
ext_name;
         // No need to keep compiler attribute at this point, functions have 
been
         // extracted for specific codegen.
         src_func = WithAttr(std::move(src_func), attr::kCompiler, 
NullValue<ObjectRef>());
-        VLOG_CONTEXT << ext_name;
+        VLOG_CONTEXT << opt_compiler.value();
+        With<Target> with_target(it.first->target);
         runtime::Module ext_mod = (*pf)(src_func);
         if (ext_mod.defined()) {
-          if (ext_mod->GetFunction(opt_symbol_name.value(), 
/*query_imports=*/true) == nullptr) {
-            // It's possible the codegen yielded C or C++ tracked separately 
and thus the
-            // returned runtime module can be empty.
-            VLOG(1) << "Unable to find definition for the external function '"
-                    << opt_symbol_name.value()
-                    << "' in the runtime module generated by external codegen 
'"
-                    << opt_compiler.value() << "'";
+          // TODO(mbs): Can this be an ICHECKs?
+          if (!ext_mod->ImplementsFunction(opt_symbol_name.value())) {
+            VLOG(1) << "Note that the external codegen for '" << 
opt_compiler.value()
+                    << "' returned a runtime module which does not appear to 
implement '"
+                    << opt_symbol_name.value() << "'";
           }
           ret.push_back(ext_mod);
         } else {
-          // A warning only so that we can write unit tests which can return 
an empty runtime
-          // module.
-          LOG(WARNING) << "No external runtime module was generated by 
external codegen '"
-                       << opt_compiler.value() << "'";
+          // It is valid for the external codegen function to return null:
+          //  - Unit tests can use it.
+          //  - The true compilation may have already been handled by a 
RelayToTIR custom hook pass
+          //    on the Target's kind. The original Relay functions will be 
left in place so
+          //    that we can capture that their function names are now 
externally defined.
+          VLOG(1) << "Note that no external runtime module was generated by 
external codegen '"
+                  << opt_compiler.value() << "'";
         }
       }
     }
diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h
index b1c977e526..a65bdc5ab3 100644
--- a/src/relay/backend/vm/compiler.h
+++ b/src/relay/backend/vm/compiler.h
@@ -91,7 +91,7 @@ class VMCompiler : public runtime::ModuleNode {
 
   virtual PackedFunc GetFunction(const std::string& name, const 
ObjectPtr<Object>& sptr_to_self);
 
-  const char* type_key() const { return "VMCompiler"; }
+  const char* type_key() const final { return "VMCompiler"; }
 
   /*!
    * \brief Set the parameters
diff --git a/src/runtime/aot_executor/aot_executor_factory.h 
b/src/runtime/aot_executor/aot_executor_factory.h
index 1d6a0a6277..ada63f0ba8 100644
--- a/src/runtime/aot_executor/aot_executor_factory.h
+++ b/src/runtime/aot_executor/aot_executor_factory.h
@@ -63,7 +63,7 @@ class TVM_DLL AotExecutorFactory : public runtime::ModuleNode 
{
   /*!
    * \return The type key of the executor.
    */
-  const char* type_key() const override { return "AotExecutorFactory"; }
+  const char* type_key() const final { return "AotExecutorFactory"; }
 
   /*!
    * \brief Save the module to binary stream.
diff --git a/src/runtime/const_loader_module.cc 
b/src/runtime/const_loader_module.cc
index 5496e161e5..2e91d26d5f 100644
--- a/src/runtime/const_loader_module.cc
+++ b/src/runtime/const_loader_module.cc
@@ -79,7 +79,7 @@ class ConstLoaderModuleNode : public ModuleNode {
     return PackedFunc(nullptr);
   }
 
-  const char* type_key() const { return "const_loader"; }
+  const char* type_key() const final { return "const_loader"; }
 
   /*!
    * \brief Get the list of constants that is required by the given module.
diff --git a/src/runtime/contrib/json/json_runtime.h 
b/src/runtime/contrib/json/json_runtime.h
index 374a440e29..355390765d 100644
--- a/src/runtime/contrib/json/json_runtime.h
+++ b/src/runtime/contrib/json/json_runtime.h
@@ -54,7 +54,7 @@ class JSONRuntimeBase : public ModuleNode {
     LoadGraph(graph_json_);
   }
 
-  const char* type_key() const override { return "json"; }
+  const char* type_key() const override { return "json"; }  // May be 
overridden
 
   /*! \brief Initialize a specific json runtime. */
   virtual void Init(const Array<NDArray>& consts) = 0;
diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc 
b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
index b60074e66d..554515c456 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
@@ -95,7 +95,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
    *
    * \return module type key.
    */
-  const char* type_key() const override { return "tensorrt"; }
+  const char* type_key() const final { return "tensorrt"; }
 
   /*!
    * \brief Initialize runtime. Create TensorRT layer from JSON
diff --git a/src/runtime/graph_executor/graph_executor_factory.h 
b/src/runtime/graph_executor/graph_executor_factory.h
index 1ee74c3547..d8ebe44bb9 100644
--- a/src/runtime/graph_executor/graph_executor_factory.h
+++ b/src/runtime/graph_executor/graph_executor_factory.h
@@ -65,7 +65,7 @@ class TVM_DLL GraphExecutorFactory : public 
runtime::ModuleNode {
   /*!
    * \return The type key of the executor.
    */
-  const char* type_key() const override { return "GraphExecutorFactory"; }
+  const char* type_key() const final { return "GraphExecutorFactory"; }
 
   /*!
    * \brief Save the module to binary stream.
diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc
index c08f2872fe..8e034cc94d 100644
--- a/src/runtime/metadata.cc
+++ b/src/runtime/metadata.cc
@@ -75,7 +75,7 @@ class MetadataModuleNode : public ::tvm::runtime::ModuleNode {
   explicit MetadataModuleNode(runtime::metadata::Metadata metadata)
       : metadata_{::std::move(metadata)} {}
 
-  const char* type_key() const { return "metadata_module"; }
+  const char* type_key() const final { return "metadata_module"; }
 
   static Module LoadFromBinary() {
     return 
Module(make_object<MetadataModuleNode>(runtime::metadata::Metadata()));
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index 097d6a2f53..57fe575689 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -83,6 +83,7 @@ Module Module::LoadFromFile(const std::string& file_name, 
const std::string& for
     fmt = "so";
   }
   std::string load_f_name = "runtime.module.loadfile_" + fmt;
+  VLOG(1) << "Loading module from '" << file_name << "' of format '" << fmt << 
"'";
   const PackedFunc* f = Registry::Get(load_f_name);
   ICHECK(f != nullptr) << "Loader for `." << format << "` files is not 
registered,"
                        << " resolved to (" << load_f_name << ") in the global 
registry."
@@ -132,6 +133,12 @@ std::string ModuleNode::GetFormat() {
   return "";
 }
 
+bool ModuleNode::IsDSOExportable() const { return false; }
+
+bool ModuleNode::ImplementsFunction(const String& name, bool query_imports) {
+  return GetFunction(name, query_imports) != nullptr;
+}
+
 bool RuntimeEnabled(const std::string& target) {
   std::string f_name;
   if (target == "cpu") {
@@ -191,8 +198,15 @@ 
TVM_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) {
 
TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);
 
 TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
-    .set_body_typed([](Module mod, tvm::String name, tvm::String fmt) {
-      mod->SaveToFile(name, fmt);
+    .set_body_typed([](Module mod, String name, tvm::String fmt) { 
mod->SaveToFile(name, fmt); });
+
+TVM_REGISTER_GLOBAL("runtime.ModuleIsDSOExportable").set_body_typed([](Module 
mod) {
+  return mod->IsDSOExportable();
+});
+
+TVM_REGISTER_GLOBAL("runtime.ModuleImplementsFunction")
+    .set_body_typed([](Module mod, String name, bool query_imports) {
+      return mod->ImplementsFunction(std::move(name), query_imports);
     });
 
 TVM_REGISTER_OBJECT_TYPE(ModuleNode);
diff --git a/src/runtime/stackvm/stackvm_module.cc 
b/src/runtime/stackvm/stackvm_module.cc
index c784a9d048..bbcadd21b4 100644
--- a/src/runtime/stackvm/stackvm_module.cc
+++ b/src/runtime/stackvm/stackvm_module.cc
@@ -37,7 +37,7 @@ namespace runtime {
 
 class StackVMModuleNode : public runtime::ModuleNode {
  public:
-  const char* type_key() const { return "stackvm"; }
+  const char* type_key() const final { return "stackvm"; }
 
   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) final {
     if (name == runtime::symbol::tvm_module_main) {
diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc
new file mode 100644
index 0000000000..e845d0fac2
--- /dev/null
+++ b/src/runtime/static_library.cc
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file runtime/static_library.cc
+ * \brief Represents a generic '.o' static library which can be linked into 
the final output
+ * dynamic library by export_library.
+ */
+#include "./static_library.h"
+
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <iostream>
+
+#include "file_utils.h"
+
+namespace tvm {
+namespace runtime {
+
+namespace {
+
+/*!
+ * \brief A '.o' library which can be linked into the final output library by 
export_library.
+ * Can be used by external codegen tools which can produce a ready-to-link 
artifact.
+ */
+class StaticLibraryNode final : public runtime::ModuleNode {
+ public:
+  ~StaticLibraryNode() override = default;
+
+  const char* type_key() const final { return "static_library"; }
+
+  PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) final {
+    if (name == "get_func_names") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { 
*rv = func_names_; });
+    } else {
+      return {};
+    }
+  }
+
+  void SaveToFile(const std::string& file_name, const std::string& format) 
final {
+    VLOG(0) << "Saving static library of " << data_.size() << " bytes 
implementing " << FuncNames()
+            << " to '" << file_name << "'";
+    SaveBinaryToFile(file_name, data_);
+  }
+
+  bool IsDSOExportable() const final { return true; }
+
+  bool ImplementsFunction(const String& name, bool query_imports) final {
+    return std::find(func_names_.begin(), func_names_.end(), name) != 
func_names_.end();
+  }
+
+  std::string FuncNames() {
+    std::ostringstream os;
+    os << "[";
+    bool first = true;
+    for (const auto& func_name : func_names_) {
+      if (first) {
+        first = false;
+      } else {
+        os << ", ";
+      }
+      os << "'" << func_name << "'";
+    }
+    os << "]";
+    return os.str();
+  }
+
+  /*! \brief Contents of the object file. */
+  std::string data_;
+  /*! \brief Function names exported by the above. */
+  Array<String> func_names_;
+};
+
+}  // namespace
+
+Module LoadStaticLibrary(const std::string& filename, Array<String> 
func_names) {
+  auto node = make_object<StaticLibraryNode>();
+  LoadBinaryFromFile(filename, &node->data_);
+  node->func_names_ = std::move(func_names);
+  VLOG(0) << "Loaded static library from '" << filename << "' implementing " 
<< node->FuncNames();
+  return Module(node);
+}
+
+TVM_REGISTER_GLOBAL("runtime.ModuleLoadStaticLibrary").set_body_typed(LoadStaticLibrary);
+
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/static_library.h b/src/runtime/static_library.h
new file mode 100644
index 0000000000..352891f6fb
--- /dev/null
+++ b/src/runtime/static_library.h
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file runtime/static_library.h
+ * \brief Represents a generic '.o' static library which can be linked into 
the final output
+ * dynamic library by export_library.
+ */
+
+#ifndef TVM_RUNTIME_STATIC_LIBRARY_H_
+#define TVM_RUNTIME_STATIC_LIBRARY_H_
+
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/module.h>
+
+#include <array>
+#include <memory>
+#include <set>
+#include <string>
+#include <unordered_map>
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief Returns a static library with the contents loaded from filename 
which exports
+ * func_names with the usual packed-func calling convention.
+ */
+Module LoadStaticLibrary(const std::string& filename, Array<String> 
func_names);
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // TVM_RUNTIME_STATIC_LIBRARY_H_
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index c7cec21508..26d7a4b70f 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -121,7 +121,7 @@ 
TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRet
 
 class FrontendTestModuleNode : public runtime::ModuleNode {
  public:
-  virtual const char* type_key() const { return "frontend_test"; }
+  const char* type_key() const final { return "frontend_test"; }
 
   static constexpr const char* kAddFunctionName = "__add_function";
 
diff --git a/src/target/codegen.cc b/src/target/codegen.cc
index 41221ad8a3..3c4866be1b 100644
--- a/src/target/codegen.cc
+++ b/src/target/codegen.cc
@@ -68,7 +68,7 @@ class ModuleSerializer {
     // Only have one DSO module and it is in the root, then
     // we will not produce import_tree_.
     bool has_import_tree = true;
-    if (DSOExportable(mod_.operator->()) && mod_->imports().empty()) {
+    if (mod_->IsDSOExportable() && mod_->imports().empty()) {
       has_import_tree = false;
     }
     uint64_t sz = 0;
@@ -84,7 +84,7 @@ class ModuleSerializer {
 
     for (const auto& group : mod_group_vec_) {
       ICHECK_NE(group.size(), 0) << "Every allocated group must have at least 
one module";
-      if (!DSOExportable(group[0])) {
+      if (!group[0]->IsDSOExportable()) {
         ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged";
         std::string mod_type_key = group[0]->type_key();
         stream->Write(mod_type_key);
@@ -147,7 +147,7 @@ class ModuleSerializer {
     while (!stack.empty()) {
       runtime::ModuleNode* n = stack.back();
       stack.pop_back();
-      if (DSOExportable(n)) {
+      if (n->IsDSOExportable()) {
         // do not recursively expand dso modules
         // we will expand in phase 1
         dso_exportable_boundary.emplace_back(n);
@@ -174,7 +174,7 @@ class ModuleSerializer {
       runtime::ModuleNode* n = stack.back();
       stack.pop_back();
 
-      if (DSOExportable(n)) {
+      if (n->IsDSOExportable()) {
         mod_group_vec_[dso_module_index].emplace_back(n);
         mod2index_[n] = dso_module_index;
       } else {
@@ -219,10 +219,6 @@ class ModuleSerializer {
     }
   }
 
-  bool DSOExportable(const runtime::ModuleNode* mod) {
-    return !std::strcmp(mod->type_key(), "llvm") || 
!std::strcmp(mod->type_key(), "c");
-  }
-
   runtime::Module mod_;
   // construct module to index
   std::unordered_map<runtime::ModuleNode*, size_t> mod2index_;
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index b2dc4c81f9..c7aea3dc19 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -56,7 +56,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     }
   }
 
-  const char* type_key() const { return "llvm"; }
+  const char* type_key() const final { return "llvm"; }
 
   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) final {
     if (name == "__tvm_is_system_module") {
@@ -357,6 +357,12 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     Init(std::move(module), ctx);
   }
 
+  bool IsDSOExportable() const final { return true; }
+
+  bool ImplementsFunction(const String& name, bool query_imports) final {
+    return std::find(function_names_.begin(), function_names_.end(), name) != 
function_names_.end();
+  }
+
  private:
   void LazyInitJIT() {
     std::lock_guard<std::mutex> lock(mutex_);
diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc
index 840ba5cab2..97299c6375 100644
--- a/src/target/metadata_module.cc
+++ b/src/target/metadata_module.cc
@@ -190,10 +190,6 @@ runtime::Module CreateMetadataModule(
   Array<runtime::Module> crt_exportable_modules;
   Array<runtime::Module> non_crt_exportable_modules;
 
-  auto DSOExportable = [](tvm::runtime::Module& mod) {
-    return !std::strcmp(mod->type_key(), "llvm") || 
!std::strcmp(mod->type_key(), "c");
-  };
-
   bool is_targeting_crt = runtime->name == "crt";
 
   // Wrap all submodules in the initialization wrapper.
@@ -219,7 +215,7 @@ runtime::Module CreateMetadataModule(
 
     // TODO(@manupa-arm) : we should be able to use csource_metadata
     // if the variables are empty when all the runtime modules implement 
get_func_names
-    if (symbol_const_vars.empty() && is_targeting_crt && DSOExportable(mod) &&
+    if (symbol_const_vars.empty() && is_targeting_crt && 
mod->IsDSOExportable() &&
         (target->kind->name == "c" || target->kind->name == "llvm")) {
       crt_exportable_modules.push_back(mod);
     } else {
diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc
index 12d930d8f8..1bb567d148 100644
--- a/src/target/source/interface_c.cc
+++ b/src/target/source/interface_c.cc
@@ -52,7 +52,7 @@ class InterfaceCNode : public runtime::ModuleNode {
         pools_(FilterExternalPools(pools)),
         io_pool_allocations_(io_pool_allocations),
         workspace_size_(workspace_size) {}
-  const char* type_key() const { return "h"; }
+  const char* type_key() const final { return "h"; }
 
   std::string GetSource(const std::string& format) final {
     std::stringstream code;
diff --git a/src/target/source/source_module.cc 
b/src/target/source/source_module.cc
index 8f581f4cbb..2c4993419f 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -59,7 +59,7 @@ using runtime::SaveBinaryToFile;
 class SourceModuleNode : public runtime::ModuleNode {
  public:
   SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) 
{}
-  const char* type_key() const { return "source"; }
+  const char* type_key() const final { return "source"; }
 
   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) final {
     LOG(FATAL) << "Source module cannot execute, to get executable module"
@@ -87,7 +87,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
   CSourceModuleNode(const std::string& code, const std::string& fmt,
                     const Array<String>& func_names, const Array<String>& 
const_vars)
       : code_(code), fmt_(fmt), const_vars_(const_vars), 
func_names_(func_names) {}
-  const char* type_key() const { return "c"; }
+  const char* type_key() const final { return "c"; }
 
   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) final {
     // Currently c-source module is used as demonstration purposes with binary 
metadata module
@@ -123,6 +123,12 @@ class CSourceModuleNode : public runtime::ModuleNode {
     }
   }
 
+  bool IsDSOExportable() const final { return true; }
+
+  bool ImplementsFunction(const String& name, bool query_imports) final {
+    return std::find(func_names_.begin(), func_names_.end(), name) != 
func_names_.end();
+  }
+
  protected:
   std::string code_;
   std::string fmt_;
@@ -166,7 +172,7 @@ class CSourceCrtMetadataModuleNode : public 
runtime::ModuleNode {
         metadata_(metadata) {
     CreateSource();
   }
-  const char* type_key() const { return "c"; }
+  const char* type_key() const final { return "c"; }
 
   std::string GetSource(const std::string& format) final { return code_.str(); 
}
 
@@ -187,6 +193,12 @@ class CSourceCrtMetadataModuleNode : public 
runtime::ModuleNode {
     }
   }
 
+  bool IsDSOExportable() const final { return true; }
+
+  bool ImplementsFunction(const String& name, bool query_imports) final {
+    return std::find(func_names_.begin(), func_names_.end(), name) != 
func_names_.end();
+  }
+
  protected:
   std::stringstream code_;
   std::string fmt_;
@@ -908,7 +920,7 @@ class DeviceSourceModuleNode final : public 
runtime::ModuleNode {
     }
   }
 
-  const char* type_key() const { return type_key_.c_str(); }
+  const char* type_key() const final { return type_key_.c_str(); }
 
   void SaveToFile(const std::string& file_name, const std::string& format) 
final {
     std::string fmt = GetFileFormat(file_name, format);
diff --git a/tests/python/unittest/test_runtime_module_export.py 
b/tests/python/unittest/test_runtime_module_export.py
index 9ea1ff437f..57fcaea03d 100644
--- a/tests/python/unittest/test_runtime_module_export.py
+++ b/tests/python/unittest/test_runtime_module_export.py
@@ -18,7 +18,6 @@ from tvm import relay
 from tvm.relay import testing
 import tvm
 from tvm import te
-
 import tvm.testing
 
 from tvm.contrib import utils
@@ -80,8 +79,6 @@ def test_mod_export():
                 synthetic_llvm_mod, "llvm", params=synthetic_llvm_params, 
mod_name="llvmlib"
             )
 
-        from tvm.contrib import utils
-
         temp = utils.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -109,8 +106,6 @@ def test_mod_export():
         mod0 = tvm.build(s, [A, B], "llvm", name="myadd0")
         mod1 = tvm.build(s, [A, B], "llvm", name="myadd1")
 
-        from tvm.contrib import utils
-
         temp = utils.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -152,8 +147,6 @@ def test_mod_export():
             + "mul 6 inputs: 5 3 shape: 10 10"
         )
 
-        from tvm.contrib import utils
-
         temp = utils.tempdir()
         subgraph_path = temp.relpath("subgraph.examplejson")
         with open(subgraph_path, "w") as f:
@@ -203,7 +196,6 @@ def test_mod_export():
         s = te.create_schedule(B.op)
         f = tvm.build(s, [A, B], "c", name="myadd")
         engine_module = generate_engine_module()
-        from tvm.contrib import utils
 
         temp = utils.tempdir()
         file_name = "deploy_lib.so"
@@ -225,5 +217,43 @@ def test_mod_export():
     verify_multi_c_mod_export()
 
 
[email protected]_llvm
+def test_import_static_library():
+    # Generate two LLVM modules.
+    A = te.placeholder((1024,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
+    s = te.create_schedule(B.op)
+    mod0 = tvm.build(s, [A, B], "llvm", name="myadd0")
+    mod1 = tvm.build(s, [A, B], "llvm", name="myadd1")
+
+    assert mod0.implements_function("myadd0")
+    assert mod1.implements_function("myadd1")
+    assert mod1.is_dso_exportable
+
+    # mod1 is currently an 'llvm' module.
+    # Save and reload it as a vanilla 'static_library'.
+    temp = utils.tempdir()
+    mod1_o_path = temp.relpath("mod1.o")
+    mod1.save(mod1_o_path)
+    mod1_o = tvm.runtime.load_static_library(mod1_o_path, ["myadd1"])
+    assert mod1_o.implements_function("myadd1")
+    assert mod1_o.is_dso_exportable
+
+    # Import mod1 as a static library into mod0 and compile to its own DSO.
+    mod0.import_module(mod1_o)
+    mod0_dso_path = temp.relpath("mod0.so")
+    mod0.export_library(mod0_dso_path)
+
+    # The imported mod1 is statically linked into mod0.
+    loaded_lib = tvm.runtime.load_module(mod0_dso_path)
+    assert loaded_lib.type_key == "library"
+    assert len(loaded_lib.imported_modules) == 0
+    assert loaded_lib.implements_function("myadd0")
+    assert loaded_lib.get_function("myadd0")
+    assert loaded_lib.implements_function("myadd1")
+    assert loaded_lib.get_function("myadd1")
+    assert not loaded_lib.is_dso_exportable
+
+
 if __name__ == "__main__":
-    test_mod_export()
+    tvm.testing.main()

Reply via email to