This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ee052dd  Introduce Model Library Format export format (#7533)
ee052dd is described below

commit ee052dd6425ca889cb33948826f96fdcc37ff4e4
Author: Andrew Reusch <[email protected]>
AuthorDate: Wed Mar 10 07:22:20 2021 -0800

    Introduce Model Library Format export format (#7533)
    
    * Introduce Model Library Format export format.
    
     * This function produces a stable on-disk representation of TVM's
       compiler output.
     * It's intended just for use with the C runtime for microTVM right
       now. It could be expanded for other use cases.
     * This PR implements the Model Library Format RFC, which ultimately
       is intended to support the Project Generator API (RFC
       forthcoming).
     * There may be some changes to the format without revving the version
       number until downstream consumers are known. The Project Generator
       API is the first such known downstream consumer.
     * There are no plans currently to support generating old Model
       Library Format from TVM. The version number is intended as a
       compatibility check between the generator and downstream consumers.
---
 python/tvm/micro/__init__.py                       |   1 +
 python/tvm/micro/model_library_format.py           | 171 +++++++++++++++++++
 python/tvm/relay/backend/graph_runtime_factory.py  |  12 +-
 python/tvm/relay/build_module.py                   |  20 ++-
 python/tvm/runtime/module.py                       |  26 ++-
 src/runtime/graph/graph_runtime_factory.cc         |   3 +-
 .../unittest/test_micro_model_library_format.py    | 190 +++++++++++++++++++++
 7 files changed, 404 insertions(+), 19 deletions(-)

diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py
index 8e5807a..ade63f2 100644
--- a/python/tvm/micro/__init__.py
+++ b/python/tvm/micro/__init__.py
@@ -23,6 +23,7 @@ from .compiler import Compiler, DefaultCompiler, Flasher
 from .debugger import GdbRemoteDebugger
 from .micro_library import MicroLibrary
 from .micro_binary import MicroBinary
+from .model_library_format import export_model_library_format, 
UnsupportedInModelLibraryFormatError
 from .session import (
     create_local_graph_runtime,
     create_local_debug_runtime,
diff --git a/python/tvm/micro/model_library_format.py 
b/python/tvm/micro/model_library_format.py
new file mode 100644
index 0000000..4ce80be
--- /dev/null
+++ b/python/tvm/micro/model_library_format.py
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines functions for exporting to Model Library Format."""
+
+import datetime
+import json
+import os
+import re
+import tarfile
+
+from ..contrib import utils
+from ..relay.backend import graph_runtime_factory
+from ..relay import param_dict
+
+
+class UnsupportedInModelLibraryFormatError(Exception):
+    """Raised when export_model_library_format does not support the given 
Module tree."""
+
+
+def _populate_codegen_dir(mod, codegen_dir: str):
+    """Populate the codegen sub-directory as part of a Model Library Format 
export.
+
+    Parameters
+    ----------
+    mod : tvm.runtime.Module
+        Module which should be written to codegen_dir.
+    codegen_dir : str
+        Path to the codegen directory on disk.
+    """
+    dso_modules = mod._collect_dso_modules()
+    dso_module_handles = [m.handle.value for m in dso_modules]
+    non_dso_modules = mod._collect_from_import_tree(lambda m: m not in 
dso_modules)
+    if non_dso_modules:
+        raise UnsupportedInModelLibraryFormatError(
+            f"Don't know how to export non-c or non-llvm modules; found: 
{non_dso_modules!r}"
+        )
+
+    mod_indices = {"lib": 0, "src": 0}
+    host_codegen_dir = os.path.join(codegen_dir, "host")
+    for dso_mod in dso_modules:
+        if dso_mod.type_key == "c":
+            index = mod_indices["src"]
+            mod_indices["src"] += 1
+            parent_dir = os.path.join(host_codegen_dir, "src")
+            file_name = os.path.join(parent_dir, f"lib{index}.c")
+        elif dso_mod.type_key == "llvm":
+            index = mod_indices["lib"]
+            mod_indices["lib"] += 1
+            parent_dir = os.path.join(host_codegen_dir, "lib")
+            file_name = os.path.join(parent_dir, f"lib{index}.o")
+        else:
+            assert (
+                False
+            ), f"do not expect module with type_key={mod.type_key} from 
_collect_dso_modules"
+
+        if not os.path.exists(parent_dir):
+            os.makedirs(parent_dir)
+        dso_mod.save(file_name)
+
+
+def _build_memory_map(graph_json):
+    """Build a simpler memory map from graph JSON.
+
+    Parameters
+    ----------
+    graph_json : str
+        String representation of the graph_json created from tvm.relay.build().
+
+    Returns
+    -------
+    list :
+        A list with one entry per storage id describing that memory.
+    """
+    graph = json.loads(graph_json)
+
+    seen_storage_ids = set()
+    memory_map = []
+    for node_id, storage_id in enumerate(graph["attrs"]["storage_id"][1]):
+        if storage_id in seen_storage_ids:
+            continue
+
+        seen_storage_ids.add(storage_id)
+        num_elements = 1
+        for dim in graph["attrs"]["shape"][1][storage_id]:
+            num_elements *= dim
+
+        dltype = graph["attrs"]["dltype"][1][storage_id]
+        m = re.match(r"^[a-zA-Z]+([0-9]+)$", dltype)
+        assert m, f"Exported graph contains unknown dltype {dltype}"
+
+        elem_bits = int(m.group(1))
+
+        map_entry = {
+            "storage_id": storage_id,
+            "size_bytes": (num_elements * elem_bits + 7) // 8,
+        }
+        if node_id in graph["arg_nodes"]:
+            map_entry["input_binding"] = graph["nodes"][node_id]["name"]
+
+        memory_map.append(map_entry)
+
+    return memory_map
+
+
+def export_model_library_format(mod: 
graph_runtime_factory.GraphRuntimeFactoryModule, file_name):
+    """Export the build artifact in Model Library Format.
+
+    This function creates a .tar archive containing the build artifacts in a 
standardized
+    layout. It's intended to allow downstream automation to build TVM 
artifacts against the C
+    runtime.
+
+    Parameters
+    ----------
+    mod : tvm.relay.backend.graph_runtime_factory.GraphRuntimeFactoryModule
+        The return value of tvm.relay.build, which will be exported into Model 
Library Format.
+    file_name : str
+        Path to the .tar archive to generate.
+    """
+    tempdir = utils.tempdir()
+    metadata = {
+        "version": 1,
+        "model_name": mod.libmod_name,
+        "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d 
%H:%M:%SZ"),
+        "memory": _build_memory_map(mod.graph_json),
+        "target": {int(k): str(v) for k, v in mod.target.items()},
+        "runtimes": ["graph"],
+    }
+    with open(tempdir.relpath("metadata.json"), "w") as json_f:
+        json.dump(metadata, json_f, indent=2, sort_keys=True)
+
+    codegen_dir_path = tempdir.relpath("codegen")
+    os.mkdir(codegen_dir_path)
+    _populate_codegen_dir(mod.lib, codegen_dir_path)
+
+    parameters_dir_path = tempdir.relpath("parameters")
+    os.mkdir(parameters_dir_path)
+    param_filename = os.path.join(parameters_dir_path, 
f"{mod.libmod_name}.params")
+    with open(param_filename, "wb") as f:
+        f.write(param_dict.save_param_dict(mod.params))
+
+    with open(tempdir.relpath("relay.txt"), "w") as f:
+        f.write(str(mod.ir_mod))
+
+    graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", 
"graph"))
+    os.makedirs(graph_config_dir_path)
+    with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f:
+        f.write(mod.graph_json)
+
+    with tarfile.open(file_name, "w") as tar_f:
+
+        def reset(tarinfo):
+            tarinfo.uid = tarinfo.gid = 0
+            tarinfo.uname = tarinfo.gname = "root"
+            return tarinfo
+
+        tar_f.add(tempdir.temp_dir, arcname=".", filter=reset)
diff --git a/python/tvm/relay/backend/graph_runtime_factory.py 
b/python/tvm/relay/backend/graph_runtime_factory.py
index 3427a62..e92ae71 100644
--- a/python/tvm/relay/backend/graph_runtime_factory.py
+++ b/python/tvm/relay/backend/graph_runtime_factory.py
@@ -16,9 +16,9 @@
 # under the License.
 """Graph runtime factory."""
 import warnings
-from tvm._ffi.base import string_types
-from tvm._ffi.registry import get_global_func
-from tvm.runtime import ndarray
+from ..._ffi.base import string_types
+from ..._ffi.registry import get_global_func
+from ...runtime import ndarray
 
 
 class GraphRuntimeFactoryModule:
@@ -31,6 +31,8 @@ class GraphRuntimeFactoryModule:
         The graph to be deployed in json format output by graph compiler.
         The graph can contain operator(tvm_op) that points to the name of
         PackedFunc in the libmod.
+    target : tvm.Target
+        The Target used to build this module.
     libmod : tvm.Module
         The module of the corresponding function
     libmod_name: str
@@ -39,13 +41,15 @@ class GraphRuntimeFactoryModule:
         The parameters of module
     """
 
-    def __init__(self, graph_json_str, libmod, libmod_name, params):
+    def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name, 
params):
         assert isinstance(graph_json_str, string_types)
         fcreate = get_global_func("tvm.graph_runtime_factory.create")
         args = []
         for k, v in params.items():
             args.append(k)
             args.append(ndarray.array(v))
+        self.ir_mod = ir_mod
+        self.target = target
         self.module = fcreate(graph_json_str, libmod, libmod_name, *args)
         self.graph_json = graph_json_str
         self.lib = libmod
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 4c9a898..8e69d28 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -208,14 +208,14 @@ def _build_module_no_factory(mod, target=None, 
target_host=None, params=None, mo
     return build(mod, target, target_host, params, mod_name).module
 
 
-def build(mod, target=None, target_host=None, params=None, mod_name="default"):
+def build(ir_mod, target=None, target_host=None, params=None, 
mod_name="default"):
     # fmt: off
     # pylint: disable=line-too-long
     """Helper function that builds a Relay function to run on TVM graph 
runtime.
 
     Parameters
     ----------
-    mod : :py:class:`~tvm.IRModule`
+    ir_mod : :py:class:`~tvm.IRModule`
         The IR module to build. Using relay.Function is deprecated.
 
     target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context 
name) to str/tvm.target.Target, optional
@@ -251,13 +251,13 @@ def build(mod, target=None, target_host=None, 
params=None, mod_name="default"):
     """
     # pylint: enable=line-too-long
     # fmt: on
-    if not isinstance(mod, (IRModule, _function.Function)):
+    if not isinstance(ir_mod, (IRModule, _function.Function)):
         raise ValueError("Type of input parameter mod must be tvm.IRModule")
 
-    if isinstance(mod, _function.Function):
+    if isinstance(ir_mod, _function.Function):
         if params:
-            mod = bind_params_by_name(mod, params)
-        mod = IRModule.from_expr(mod)
+            ir_mod = bind_params_by_name(ir_mod, params)
+        ir_mod = IRModule.from_expr(ir_mod)
         warnings.warn(
             "Please use input parameter mod (tvm.IRModule) "
             "instead of deprecated parameter mod 
(tvm.relay.function.Function)",
@@ -280,9 +280,11 @@ def build(mod, target=None, target_host=None, params=None, 
mod_name="default"):
 
     with tophub_context:
         bld_mod = BuildModule()
-        graph_json, mod, params = bld_mod.build(mod, target, target_host, 
params)
-        mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json, 
mod, mod_name, params)
-        return mod
+        graph_json, runtime_mod, params = bld_mod.build(ir_mod, target, 
target_host, params)
+        runtime_mod = _graph_runtime_factory.GraphRuntimeFactoryModule(
+            ir_mod, target, graph_json, runtime_mod, mod_name, params
+        )
+        return runtime_mod
 
 
 def optimize(mod, target=None, params=None):
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 6326796..53576a6 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -105,6 +105,9 @@ class Module(object):
             raise ValueError("Can only take string as function name")
         return self.get_function(name)
 
+    def __eq__(self, other):
+        return self.handle.value == other.handle.value
+
     def __call__(self, *args):
         if self._entry:
             return self._entry(*args)
@@ -233,15 +236,27 @@ class Module(object):
         except NameError:
             raise NameError("time_evaluate is only supported when RPC is 
enabled")
 
-    def _collect_dso_modules(self):
-        """Helper function to collect dso modules, then return it."""
+    def _collect_from_import_tree(self, filter_func):
+        """Helper function to collect modules from the tree matching a 
filter_func, then return it.
+
+        Parameters
+        ----------
+        filter_func : Callable[[Module], bool]
+            A function which is invoked for each Module discovered in the 
import tree (including
+            self).
+
+        Returns
+        -------
+        list[Module] :
+            A list of matching Module.
+        """
         visited, stack, dso_modules = set(), [], []
         # append root module
         visited.add(self)
         stack.append(self)
         while stack:
             module = stack.pop()
-            if module._dso_exportable():
+            if filter_func(module):
                 dso_modules.append(module)
             for m in module.imported_modules:
                 if m not in visited:
@@ -249,8 +264,9 @@ class Module(object):
                     stack.append(m)
         return dso_modules
 
-    def _dso_exportable(self):
-        return self.type_key == "llvm" or self.type_key == "c"
+    def _collect_dso_modules(self):
+        is_dso_exportable = lambda m: (m.type_key == "llvm" or m.type_key == 
"c")
+        return self._collect_from_import_tree(is_dso_exportable)
 
     def export_library(self, file_name, fcompile=None, addons=None, 
workspace_dir=None, **kwargs):
         """Export the module and its imported device code one library.
diff --git a/src/runtime/graph/graph_runtime_factory.cc 
b/src/runtime/graph/graph_runtime_factory.cc
index 4d3993a..605d6b0 100644
--- a/src/runtime/graph/graph_runtime_factory.cc
+++ b/src/runtime/graph/graph_runtime_factory.cc
@@ -156,7 +156,8 @@ 
TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args
                                  "graph_runtime_factory.create needs at least 
3, "
                                  "but it has "
                               << args.num_args;
-  // The argument order is graph_json, module, module_name, params.
+  // The argument order is graph_json, module, module_name, param0_name, 
param0_tensor,
+  // [param1_name, param1_tensor], ...
   ICHECK_EQ((args.size() - 3) % 2, 0);
   std::unordered_map<std::string, tvm::runtime::NDArray> params;
   for (size_t i = 3; i < static_cast<size_t>(args.size()); i += 2) {
diff --git a/tests/python/unittest/test_micro_model_library_format.py 
b/tests/python/unittest/test_micro_model_library_format.py
new file mode 100644
index 0000000..c999091
--- /dev/null
+++ b/tests/python/unittest/test_micro_model_library_format.py
@@ -0,0 +1,190 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import json
+import os
+import sys
+import tarfile
+
+import numpy
+import pytest
+
+import tvm
+import tvm.relay
+from tvm.relay.backend import graph_runtime_factory
+import tvm.runtime.module
+import tvm.testing
+from tvm.contrib import utils
+
+
+def validate_graph_json(extract_dir, factory):
+    with open(os.path.join(extract_dir, "runtime-config", "graph", 
"graph.json")) as graph_f:
+        graph_json = graph_f.read()
+        assert graph_json == factory.graph_json
+
+        # Just check it parses and looks roughly right.
+        graph = json.loads(graph_json)
+        assert "nodes" in graph
+        assert len(graph["nodes"]) == 4
+        assert "attrs" in graph
+
+
[email protected]_micro
+def test_export_model_library_format_c():
+    with utils.TempDirectory.set_keep_for_debug(True):
+        target = tvm.target.target.micro("host")
+        with tvm.transform.PassContext(opt_level=3, 
config={"tir.disable_vectorize": True}):
+            relay_mod = tvm.parser.fromtext(
+                """
+            #[version = "0.0.5"]
+            def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), 
float32], %c : Tensor[(1, 2), float32]) {
+            %0 = cast(%a, dtype="float32") + %b * %c;
+            %0
+            }"""
+            )
+            factory = tvm.relay.build(
+                relay_mod,
+                target,
+                target_host=target,
+                mod_name="add",
+                params={"c": numpy.array([[2.0, 4.0]], dtype="float32")},
+            )
+
+        temp_dir = utils.tempdir()
+        mlf_tar_path = temp_dir.relpath("lib.tar")
+        import tvm.micro as micro
+
+        micro.export_model_library_format(factory, mlf_tar_path)
+        tf = tarfile.open(mlf_tar_path)
+
+        extract_dir = temp_dir.relpath("extract")
+        os.mkdir(extract_dir)
+        tf.extractall(extract_dir)
+
+        with open(os.path.join(extract_dir, "metadata.json")) as json_f:
+            metadata = json.load(json_f)
+            assert metadata["version"] == 1
+            assert metadata["model_name"] == "add"
+            export_datetime = datetime.datetime.strptime(
+                metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ"
+            )
+            assert (datetime.datetime.now() - export_datetime) < 
datetime.timedelta(seconds=60 * 5)
+            assert metadata["target"] == {"1": str(target)}
+            assert metadata["memory"] == [
+                {"storage_id": 0, "size_bytes": 2, "input_binding": "a"},
+                {"storage_id": 1, "size_bytes": 8, "input_binding": "b"},
+                {"storage_id": 2, "size_bytes": 8, "input_binding": "p0"},
+                {"storage_id": 3, "size_bytes": 8},
+            ]
+
+        assert os.path.exists(os.path.join(extract_dir, "codegen", "host", 
"src", "lib0.c"))
+        assert os.path.exists(os.path.join(extract_dir, "codegen", "host", 
"src", "lib1.c"))
+
+        validate_graph_json(extract_dir, factory)
+
+        with open(os.path.join(extract_dir, "relay.txt")) as relay_f:
+            assert relay_f.read() == str(relay_mod)
+
+        with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") 
as params_f:
+            params = tvm.relay.load_param_dict(params_f.read())
+            assert "p0" in params
+
+
[email protected]_micro
+def test_export_model_library_format_llvm():
+    with utils.TempDirectory.set_keep_for_debug(True):
+        target = tvm.target.target.micro("host")
+        assert str(target)[:2] == "c "
+        target = tvm.target.Target("llvm " + str(target)[2:])
+        with tvm.transform.PassContext(opt_level=3):
+            relay_mod = tvm.parser.fromtext(
+                """
+            #[version = "0.0.5"]
+            def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), 
float32], %c : Tensor[(1, 2), float32]) {
+            %0 = cast(%a, dtype="float32") + %b * %c;
+            %0
+            }"""
+            )
+            factory = tvm.relay.build(
+                relay_mod,
+                target,
+                target_host=target,
+                mod_name="add",
+                params={"c": numpy.array([[2.0, 4.0]], dtype="float32")},
+            )
+
+        temp_dir = utils.tempdir()
+        mlf_tar_path = temp_dir.relpath("lib.tar")
+        import tvm.micro as micro
+
+        micro.export_model_library_format(factory, mlf_tar_path)
+        tf = tarfile.open(mlf_tar_path)
+
+        extract_dir = temp_dir.relpath("extract")
+        os.mkdir(extract_dir)
+        tf.extractall(extract_dir)
+
+        with open(os.path.join(extract_dir, "metadata.json")) as json_f:
+            metadata = json.load(json_f)
+            assert metadata["version"] == 1
+            assert metadata["model_name"] == "add"
+            export_datetime = datetime.datetime.strptime(
+                metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ"
+            )
+            assert (datetime.datetime.now() - export_datetime) < 
datetime.timedelta(seconds=60 * 5)
+            assert metadata["target"] == {"1": str(target)}
+            assert metadata["memory"] == [
+                {"storage_id": 0, "size_bytes": 2, "input_binding": "a"},
+                {"storage_id": 1, "size_bytes": 8, "input_binding": "b"},
+                {"storage_id": 2, "size_bytes": 8, "input_binding": "p0"},
+                {"storage_id": 3, "size_bytes": 8},
+            ]
+
+        assert os.path.exists(os.path.join(extract_dir, "codegen", "host", 
"lib", "lib0.o"))
+
+        validate_graph_json(extract_dir, factory)
+
+        with open(os.path.join(extract_dir, "relay.txt")) as relay_f:
+            assert relay_f.read() == str(relay_mod)
+
+        with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") 
as params_f:
+            params = tvm.relay.load_param_dict(params_f.read())
+            assert "p0" in params
+
+
[email protected]_micro
+def test_export_model():
+    module = tvm.support.FrontendTestModule()
+    factory = graph_runtime_factory.GraphRuntimeFactoryModule(
+        None, tvm.target.target.micro("host"), '"graph_json"', module, 
"test_module", {}
+    )
+
+    temp_dir = utils.tempdir()
+    import tvm.micro as micro
+    import tvm.micro.model_library_format as model_library_format
+
+    with pytest.raises(micro.UnsupportedInModelLibraryFormatError) as exc:
+        model_library_format._populate_codegen_dir(module, 
temp_dir.relpath("codegen"))
+
+        assert str(exc.exception) == (
+            "Don't know how to export non-c or non-llvm modules; found: 
ffi_testing"
+        )
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to