This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ee052dd Introduce Model Library Format export format (#7533)
ee052dd is described below
commit ee052dd6425ca889cb33948826f96fdcc37ff4e4
Author: Andrew Reusch <[email protected]>
AuthorDate: Wed Mar 10 07:22:20 2021 -0800
Introduce Model Library Format export format (#7533)
* Introduce Model Library Format export format.
* This function produces a stable on-disk representation of TVM's
compiler output.
* It's intended just for use with the C runtime for microTVM right
now. It could be expanded for other use cases.
* This PR implements the Model Library Format RFC, which ultimately
is intended to support the Project Generator API (RFC
forthcoming).
* There may be some changes to the format without revving the version
number until downstream consumers are known. The Project Generator
API is the first such known downstream consumer.
* There are no plans currently to support generating old Model
Library Format from TVM. The version number is intended as a
compatibility check between the generator and downstream consumers.
---
python/tvm/micro/__init__.py | 1 +
python/tvm/micro/model_library_format.py | 171 +++++++++++++++++++
python/tvm/relay/backend/graph_runtime_factory.py | 12 +-
python/tvm/relay/build_module.py | 20 ++-
python/tvm/runtime/module.py | 26 ++-
src/runtime/graph/graph_runtime_factory.cc | 3 +-
.../unittest/test_micro_model_library_format.py | 190 +++++++++++++++++++++
7 files changed, 404 insertions(+), 19 deletions(-)
diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py
index 8e5807a..ade63f2 100644
--- a/python/tvm/micro/__init__.py
+++ b/python/tvm/micro/__init__.py
@@ -23,6 +23,7 @@ from .compiler import Compiler, DefaultCompiler, Flasher
from .debugger import GdbRemoteDebugger
from .micro_library import MicroLibrary
from .micro_binary import MicroBinary
+from .model_library_format import export_model_library_format,
UnsupportedInModelLibraryFormatError
from .session import (
create_local_graph_runtime,
create_local_debug_runtime,
diff --git a/python/tvm/micro/model_library_format.py
b/python/tvm/micro/model_library_format.py
new file mode 100644
index 0000000..4ce80be
--- /dev/null
+++ b/python/tvm/micro/model_library_format.py
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines functions for exporting to Model Library Format."""
+
+import datetime
+import json
+import os
+import re
+import tarfile
+
+from ..contrib import utils
+from ..relay.backend import graph_runtime_factory
+from ..relay import param_dict
+
+
+class UnsupportedInModelLibraryFormatError(Exception):
+ """Raised when export_model_library_format does not support the given
Module tree."""
+
+
+def _populate_codegen_dir(mod, codegen_dir: str):
+ """Populate the codegen sub-directory as part of a Model Library Format
export.
+
+ Parameters
+ ----------
+ mod : tvm.runtime.Module
+ Module which should be written to codegen_dir.
+ codegen_dir : str
+ Path to the codegen directory on disk.
+ """
+ dso_modules = mod._collect_dso_modules()
+ dso_module_handles = [m.handle.value for m in dso_modules]
+ non_dso_modules = mod._collect_from_import_tree(lambda m: m not in
dso_modules)
+ if non_dso_modules:
+ raise UnsupportedInModelLibraryFormatError(
+ f"Don't know how to export non-c or non-llvm modules; found:
{non_dso_modules!r}"
+ )
+
+ mod_indices = {"lib": 0, "src": 0}
+ host_codegen_dir = os.path.join(codegen_dir, "host")
+ for dso_mod in dso_modules:
+ if dso_mod.type_key == "c":
+ index = mod_indices["src"]
+ mod_indices["src"] += 1
+ parent_dir = os.path.join(host_codegen_dir, "src")
+ file_name = os.path.join(parent_dir, f"lib{index}.c")
+ elif dso_mod.type_key == "llvm":
+ index = mod_indices["lib"]
+ mod_indices["lib"] += 1
+ parent_dir = os.path.join(host_codegen_dir, "lib")
+ file_name = os.path.join(parent_dir, f"lib{index}.o")
+ else:
+ assert (
+ False
+ ), f"do not expect module with type_key={mod.type_key} from
_collect_dso_modules"
+
+ if not os.path.exists(parent_dir):
+ os.makedirs(parent_dir)
+ dso_mod.save(file_name)
+
+
+def _build_memory_map(graph_json):
+ """Build a simpler memory map from graph JSON.
+
+ Parameters
+ ----------
+ graph_json : str
+ String representation of the graph_json created from tvm.relay.build().
+
+ Returns
+ -------
+ list :
+ A list with one entry per storage id describing that memory.
+ """
+ graph = json.loads(graph_json)
+
+ seen_storage_ids = set()
+ memory_map = []
+ for node_id, storage_id in enumerate(graph["attrs"]["storage_id"][1]):
+ if storage_id in seen_storage_ids:
+ continue
+
+ seen_storage_ids.add(storage_id)
+ num_elements = 1
+ for dim in graph["attrs"]["shape"][1][storage_id]:
+ num_elements *= dim
+
+ dltype = graph["attrs"]["dltype"][1][storage_id]
+ m = re.match(r"^[a-zA-Z]+([0-9]+)$", dltype)
+ assert m, f"Exported graph contains unknown dltype {dltype}"
+
+ elem_bits = int(m.group(1))
+
+ map_entry = {
+ "storage_id": storage_id,
+ "size_bytes": (num_elements * elem_bits + 7) // 8,
+ }
+ if node_id in graph["arg_nodes"]:
+ map_entry["input_binding"] = graph["nodes"][node_id]["name"]
+
+ memory_map.append(map_entry)
+
+ return memory_map
+
+
+def export_model_library_format(mod:
graph_runtime_factory.GraphRuntimeFactoryModule, file_name):
+ """Export the build artifact in Model Library Format.
+
+ This function creates a .tar archive containing the build artifacts in a
standardized
+ layout. It's intended to allow downstream automation to build TVM
artifacts against the C
+ runtime.
+
+ Parameters
+ ----------
+ mod : tvm.relay.backend.graph_runtime_factory.GraphRuntimeFactoryModule
+ The return value of tvm.relay.build, which will be exported into Model
Library Format.
+ file_name : str
+ Path to the .tar archive to generate.
+ """
+ tempdir = utils.tempdir()
+ metadata = {
+ "version": 1,
+ "model_name": mod.libmod_name,
+ "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d
%H:%M:%SZ"),
+ "memory": _build_memory_map(mod.graph_json),
+ "target": {int(k): str(v) for k, v in mod.target.items()},
+ "runtimes": ["graph"],
+ }
+ with open(tempdir.relpath("metadata.json"), "w") as json_f:
+ json.dump(metadata, json_f, indent=2, sort_keys=True)
+
+ codegen_dir_path = tempdir.relpath("codegen")
+ os.mkdir(codegen_dir_path)
+ _populate_codegen_dir(mod.lib, codegen_dir_path)
+
+ parameters_dir_path = tempdir.relpath("parameters")
+ os.mkdir(parameters_dir_path)
+ param_filename = os.path.join(parameters_dir_path,
f"{mod.libmod_name}.params")
+ with open(param_filename, "wb") as f:
+ f.write(param_dict.save_param_dict(mod.params))
+
+ with open(tempdir.relpath("relay.txt"), "w") as f:
+ f.write(str(mod.ir_mod))
+
+ graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config",
"graph"))
+ os.makedirs(graph_config_dir_path)
+ with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f:
+ f.write(mod.graph_json)
+
+ with tarfile.open(file_name, "w") as tar_f:
+
+ def reset(tarinfo):
+ tarinfo.uid = tarinfo.gid = 0
+ tarinfo.uname = tarinfo.gname = "root"
+ return tarinfo
+
+ tar_f.add(tempdir.temp_dir, arcname=".", filter=reset)
diff --git a/python/tvm/relay/backend/graph_runtime_factory.py
b/python/tvm/relay/backend/graph_runtime_factory.py
index 3427a62..e92ae71 100644
--- a/python/tvm/relay/backend/graph_runtime_factory.py
+++ b/python/tvm/relay/backend/graph_runtime_factory.py
@@ -16,9 +16,9 @@
# under the License.
"""Graph runtime factory."""
import warnings
-from tvm._ffi.base import string_types
-from tvm._ffi.registry import get_global_func
-from tvm.runtime import ndarray
+from ..._ffi.base import string_types
+from ..._ffi.registry import get_global_func
+from ...runtime import ndarray
class GraphRuntimeFactoryModule:
@@ -31,6 +31,8 @@ class GraphRuntimeFactoryModule:
The graph to be deployed in json format output by graph compiler.
The graph can contain operator(tvm_op) that points to the name of
PackedFunc in the libmod.
+ target : tvm.Target
+ The Target used to build this module.
libmod : tvm.Module
The module of the corresponding function
libmod_name: str
@@ -39,13 +41,15 @@ class GraphRuntimeFactoryModule:
The parameters of module
"""
- def __init__(self, graph_json_str, libmod, libmod_name, params):
+ def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name,
params):
assert isinstance(graph_json_str, string_types)
fcreate = get_global_func("tvm.graph_runtime_factory.create")
args = []
for k, v in params.items():
args.append(k)
args.append(ndarray.array(v))
+ self.ir_mod = ir_mod
+ self.target = target
self.module = fcreate(graph_json_str, libmod, libmod_name, *args)
self.graph_json = graph_json_str
self.lib = libmod
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 4c9a898..8e69d28 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -208,14 +208,14 @@ def _build_module_no_factory(mod, target=None,
target_host=None, params=None, mo
return build(mod, target, target_host, params, mod_name).module
-def build(mod, target=None, target_host=None, params=None, mod_name="default"):
+def build(ir_mod, target=None, target_host=None, params=None,
mod_name="default"):
# fmt: off
# pylint: disable=line-too-long
"""Helper function that builds a Relay function to run on TVM graph
runtime.
Parameters
----------
- mod : :py:class:`~tvm.IRModule`
+ ir_mod : :py:class:`~tvm.IRModule`
The IR module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
name) to str/tvm.target.Target, optional
@@ -251,13 +251,13 @@ def build(mod, target=None, target_host=None,
params=None, mod_name="default"):
"""
# pylint: enable=line-too-long
# fmt: on
- if not isinstance(mod, (IRModule, _function.Function)):
+ if not isinstance(ir_mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")
- if isinstance(mod, _function.Function):
+ if isinstance(ir_mod, _function.Function):
if params:
- mod = bind_params_by_name(mod, params)
- mod = IRModule.from_expr(mod)
+ ir_mod = bind_params_by_name(ir_mod, params)
+ ir_mod = IRModule.from_expr(ir_mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter mod
(tvm.relay.function.Function)",
@@ -280,9 +280,11 @@ def build(mod, target=None, target_host=None, params=None,
mod_name="default"):
with tophub_context:
bld_mod = BuildModule()
- graph_json, mod, params = bld_mod.build(mod, target, target_host,
params)
- mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json,
mod, mod_name, params)
- return mod
+ graph_json, runtime_mod, params = bld_mod.build(ir_mod, target,
target_host, params)
+ runtime_mod = _graph_runtime_factory.GraphRuntimeFactoryModule(
+ ir_mod, target, graph_json, runtime_mod, mod_name, params
+ )
+ return runtime_mod
def optimize(mod, target=None, params=None):
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 6326796..53576a6 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -105,6 +105,9 @@ class Module(object):
raise ValueError("Can only take string as function name")
return self.get_function(name)
+ def __eq__(self, other):
+ return self.handle.value == other.handle.value
+
def __call__(self, *args):
if self._entry:
return self._entry(*args)
@@ -233,15 +236,27 @@ class Module(object):
except NameError:
raise NameError("time_evaluate is only supported when RPC is
enabled")
- def _collect_dso_modules(self):
- """Helper function to collect dso modules, then return it."""
+ def _collect_from_import_tree(self, filter_func):
+ """Helper function to collect modules from the tree matching a
filter_func, then return it.
+
+ Parameters
+ ----------
+ filter_func : Callable[[Module], bool]
+ A function which is invoked for each Module discovered in the
import tree (including
+ self).
+
+ Returns
+ -------
+ list[Module] :
+ A list of matching Module.
+ """
visited, stack, dso_modules = set(), [], []
# append root module
visited.add(self)
stack.append(self)
while stack:
module = stack.pop()
- if module._dso_exportable():
+ if filter_func(module):
dso_modules.append(module)
for m in module.imported_modules:
if m not in visited:
@@ -249,8 +264,9 @@ class Module(object):
stack.append(m)
return dso_modules
- def _dso_exportable(self):
- return self.type_key == "llvm" or self.type_key == "c"
+ def _collect_dso_modules(self):
+ is_dso_exportable = lambda m: (m.type_key == "llvm" or m.type_key ==
"c")
+ return self._collect_from_import_tree(is_dso_exportable)
def export_library(self, file_name, fcompile=None, addons=None,
workspace_dir=None, **kwargs):
"""Export the module and its imported device code one library.
diff --git a/src/runtime/graph/graph_runtime_factory.cc
b/src/runtime/graph/graph_runtime_factory.cc
index 4d3993a..605d6b0 100644
--- a/src/runtime/graph/graph_runtime_factory.cc
+++ b/src/runtime/graph/graph_runtime_factory.cc
@@ -156,7 +156,8 @@
TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args
"graph_runtime_factory.create needs at least
3, "
"but it has "
<< args.num_args;
- // The argument order is graph_json, module, module_name, params.
+ // The argument order is graph_json, module, module_name, param0_name,
param0_tensor,
+ // [param1_name, param1_tensor], ...
ICHECK_EQ((args.size() - 3) % 2, 0);
std::unordered_map<std::string, tvm::runtime::NDArray> params;
for (size_t i = 3; i < static_cast<size_t>(args.size()); i += 2) {
diff --git a/tests/python/unittest/test_micro_model_library_format.py
b/tests/python/unittest/test_micro_model_library_format.py
new file mode 100644
index 0000000..c999091
--- /dev/null
+++ b/tests/python/unittest/test_micro_model_library_format.py
@@ -0,0 +1,190 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import json
+import os
+import sys
+import tarfile
+
+import numpy
+import pytest
+
+import tvm
+import tvm.relay
+from tvm.relay.backend import graph_runtime_factory
+import tvm.runtime.module
+import tvm.testing
+from tvm.contrib import utils
+
+
+def validate_graph_json(extract_dir, factory):
+ with open(os.path.join(extract_dir, "runtime-config", "graph",
"graph.json")) as graph_f:
+ graph_json = graph_f.read()
+ assert graph_json == factory.graph_json
+
+ # Just check it parses and looks roughly right.
+ graph = json.loads(graph_json)
+ assert "nodes" in graph
+ assert len(graph["nodes"]) == 4
+ assert "attrs" in graph
+
+
[email protected]_micro
+def test_export_model_library_format_c():
+ with utils.TempDirectory.set_keep_for_debug(True):
+ target = tvm.target.target.micro("host")
+ with tvm.transform.PassContext(opt_level=3,
config={"tir.disable_vectorize": True}):
+ relay_mod = tvm.parser.fromtext(
+ """
+ #[version = "0.0.5"]
+ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2),
float32], %c : Tensor[(1, 2), float32]) {
+ %0 = cast(%a, dtype="float32") + %b * %c;
+ %0
+ }"""
+ )
+ factory = tvm.relay.build(
+ relay_mod,
+ target,
+ target_host=target,
+ mod_name="add",
+ params={"c": numpy.array([[2.0, 4.0]], dtype="float32")},
+ )
+
+ temp_dir = utils.tempdir()
+ mlf_tar_path = temp_dir.relpath("lib.tar")
+ import tvm.micro as micro
+
+ micro.export_model_library_format(factory, mlf_tar_path)
+ tf = tarfile.open(mlf_tar_path)
+
+ extract_dir = temp_dir.relpath("extract")
+ os.mkdir(extract_dir)
+ tf.extractall(extract_dir)
+
+ with open(os.path.join(extract_dir, "metadata.json")) as json_f:
+ metadata = json.load(json_f)
+ assert metadata["version"] == 1
+ assert metadata["model_name"] == "add"
+ export_datetime = datetime.datetime.strptime(
+ metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ"
+ )
+ assert (datetime.datetime.now() - export_datetime) <
datetime.timedelta(seconds=60 * 5)
+ assert metadata["target"] == {"1": str(target)}
+ assert metadata["memory"] == [
+ {"storage_id": 0, "size_bytes": 2, "input_binding": "a"},
+ {"storage_id": 1, "size_bytes": 8, "input_binding": "b"},
+ {"storage_id": 2, "size_bytes": 8, "input_binding": "p0"},
+ {"storage_id": 3, "size_bytes": 8},
+ ]
+
+ assert os.path.exists(os.path.join(extract_dir, "codegen", "host",
"src", "lib0.c"))
+ assert os.path.exists(os.path.join(extract_dir, "codegen", "host",
"src", "lib1.c"))
+
+ validate_graph_json(extract_dir, factory)
+
+ with open(os.path.join(extract_dir, "relay.txt")) as relay_f:
+ assert relay_f.read() == str(relay_mod)
+
+ with open(os.path.join(extract_dir, "parameters", "add.params"), "rb")
as params_f:
+ params = tvm.relay.load_param_dict(params_f.read())
+ assert "p0" in params
+
+
[email protected]_micro
+def test_export_model_library_format_llvm():
+ with utils.TempDirectory.set_keep_for_debug(True):
+ target = tvm.target.target.micro("host")
+ assert str(target)[:2] == "c "
+ target = tvm.target.Target("llvm " + str(target)[2:])
+ with tvm.transform.PassContext(opt_level=3):
+ relay_mod = tvm.parser.fromtext(
+ """
+ #[version = "0.0.5"]
+ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2),
float32], %c : Tensor[(1, 2), float32]) {
+ %0 = cast(%a, dtype="float32") + %b * %c;
+ %0
+ }"""
+ )
+ factory = tvm.relay.build(
+ relay_mod,
+ target,
+ target_host=target,
+ mod_name="add",
+ params={"c": numpy.array([[2.0, 4.0]], dtype="float32")},
+ )
+
+ temp_dir = utils.tempdir()
+ mlf_tar_path = temp_dir.relpath("lib.tar")
+ import tvm.micro as micro
+
+ micro.export_model_library_format(factory, mlf_tar_path)
+ tf = tarfile.open(mlf_tar_path)
+
+ extract_dir = temp_dir.relpath("extract")
+ os.mkdir(extract_dir)
+ tf.extractall(extract_dir)
+
+ with open(os.path.join(extract_dir, "metadata.json")) as json_f:
+ metadata = json.load(json_f)
+ assert metadata["version"] == 1
+ assert metadata["model_name"] == "add"
+ export_datetime = datetime.datetime.strptime(
+ metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ"
+ )
+ assert (datetime.datetime.now() - export_datetime) <
datetime.timedelta(seconds=60 * 5)
+ assert metadata["target"] == {"1": str(target)}
+ assert metadata["memory"] == [
+ {"storage_id": 0, "size_bytes": 2, "input_binding": "a"},
+ {"storage_id": 1, "size_bytes": 8, "input_binding": "b"},
+ {"storage_id": 2, "size_bytes": 8, "input_binding": "p0"},
+ {"storage_id": 3, "size_bytes": 8},
+ ]
+
+ assert os.path.exists(os.path.join(extract_dir, "codegen", "host",
"lib", "lib0.o"))
+
+ validate_graph_json(extract_dir, factory)
+
+ with open(os.path.join(extract_dir, "relay.txt")) as relay_f:
+ assert relay_f.read() == str(relay_mod)
+
+ with open(os.path.join(extract_dir, "parameters", "add.params"), "rb")
as params_f:
+ params = tvm.relay.load_param_dict(params_f.read())
+ assert "p0" in params
+
+
[email protected]_micro
+def test_export_model():
+ module = tvm.support.FrontendTestModule()
+ factory = graph_runtime_factory.GraphRuntimeFactoryModule(
+ None, tvm.target.target.micro("host"), '"graph_json"', module,
"test_module", {}
+ )
+
+ temp_dir = utils.tempdir()
+ import tvm.micro as micro
+ import tvm.micro.model_library_format as model_library_format
+
+ with pytest.raises(micro.UnsupportedInModelLibraryFormatError) as exc:
+ model_library_format._populate_codegen_dir(module,
temp_dir.relpath("codegen"))
+
+ assert str(exc.exception) == (
+ "Don't know how to export non-c or non-llvm modules; found:
ffi_testing"
+ )
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))