This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new b959645454 [Unity] nn.Module external modules (#15487)
b959645454 is described below
commit b959645454bdd47641c91887cfd9fe23d920d74a
Author: Yaxing Cai <[email protected]>
AuthorDate: Sun Aug 20 00:21:17 2023 -0700
[Unity] nn.Module external modules (#15487)
* [Unity] nn.Module external modules
This PR introduces the feature of importing external `*.o` modules into our
`nn.Module` frontend.
* apply code review suggestions
* add ASF header
---------
Co-authored-by: Ubuntu <[email protected]>
---
python/tvm/relax/frontend/nn/__init__.py | 2 +-
python/tvm/relax/frontend/nn/core.py | 66 +++++++++++
python/tvm/relax/frontend/nn/spec.py | 62 ++++++++++-
.../python/relax/test_frontend_nn_extern_module.py | 121 +++++++++++++++++++++
4 files changed, 249 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/__init__.py
b/python/tvm/relax/frontend/nn/__init__.py
index 358e1211a9..2c5736c4fc 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -16,7 +16,7 @@
# under the License.
"""A PyTorch-like API to build IRModules."""
from . import op, spec
-from .core import Effect, Module, ModuleList, Parameter, Tensor
+from .core import Effect, Module, ModuleList, Parameter, Tensor, ExternModule
from .modules import Embedding, IOEffect, KVCache, Linear, RMSNorm
from .op import *
from .subroutine import SubroutineMixin
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index 1d1869fb6c..c282815160 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -402,6 +402,72 @@ class Module(SubroutineMixin):
raise ValueError(f"Unknown out_format: {out_format}")
+class ExternModule(Module):
+ """Base class for external module. Subclass it to import your external
models.
+ Modules can nest within each other in a tree structure using regular
attribute assignment."""
+
+ module_spec: "_spec.ExternModuleSpec"
+
+ def __init__(self, module_spec: "_spec.ExternModuleSpec") -> None:
+ super().__init__()
+ self.module_spec = module_spec
+
+ def get_extern_func(self, func_name: str) -> Callable:
+ """This method helps get the external funciton in external module by
function name.
+ It will wrap the functions as other prebuilt operators.
+
+ Parameters
+ ----------
+ func_name : str
+ The name of the function to get.
+
+ Returns
+ ------
+ ret_func: Callable
+ The callable function to call.
+ """
+ for function_spec in self.module_spec.functions:
+ if function_spec.symbol == func_name:
+ # pylint: disable=cell-var-from-loop, import-outside-toplevel,
protected-access
+ from .op import _wrap_nested
+ from tvm.relax import call_dps_packed, Tuple as RxTuple
+
+ def extern_func(*args: Tensor) -> Tensor:
+ spec2var = {}
+ for arg, arg_spec in zip(args, function_spec.args):
+ for value, value_spec in zip(arg.shape,
arg_spec.shape):
+ if isinstance(value_spec, str):
+ if not value_spec in spec2var:
+ spec2var[value_spec] = value
+ else:
+ if not spec2var[value_spec] == value:
+ raise ValueError(
+ f"Confilict vars
{spec2var[value_spec]} and {value} "
+ f"for {value_spec} in
{function_spec}"
+ )
+ out_shape = []
+ for value_spec in function_spec.ret.shape:
+ if isinstance(value_spec, int):
+ out_shape.append(value_spec)
+ elif isinstance(value_spec, str):
+ if not value_spec in spec2var:
+ raise ValueError(f"Undefined var {value_spec}
in {function_spec}")
+ out_shape.append(spec2var[value_spec])
+ out_sinfo = TensorStructInfo(out_shape,
function_spec.ret.dtype)
+ return _wrap_nested(
+ call_dps_packed(
+ func_name,
+ args=RxTuple([tensor._expr for tensor in args]),
+ out_sinfo=out_sinfo,
+ ),
+ func_name,
+ )
+
+ return extern_func
+
+ raise ValueError(f"Unknown function {func_name} in
{self.module_spec.filename}")
+
+
class ModuleList(Module):
"""Holds submodules in a list."""
diff --git a/python/tvm/relax/frontend/nn/spec.py
b/python/tvm/relax/frontend/nn/spec.py
index a279616f31..95772f2f94 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -15,12 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""Compilation specifications, for example, dynamic shape inputs."""
+from collections import defaultdict
import inspect
import threading
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
from tvm import tir
from tvm.ir import IRModule
+from tvm.runtime import load_static_library
from ... import expr as rx
from ...block_builder import BlockBuilder
@@ -219,6 +221,43 @@ class ModuleSpec:
)
+class ExternFunctionSpec:
+ """A spec for a compiled external function."""
+
+ symbol: str
+ args: List[Tensor]
+ ret: Union[Tensor, List[Tensor]]
+
+ def __init__(self, symbol: str, args: List[Tensor], ret: Union[Tensor,
List[Tensor]]) -> None:
+ self.symbol = symbol
+ self.args = args
+ self.ret = ret
+
+ def __repr__(self) -> str:
+ arg_repr = ", ".join(arg.__repr__() for arg in self.args)
+ if isinstance(self.ret, list):
+ ret_repr = "(" + ", ".join(ret.__repr__() for ret in self.ret) +
")"
+ else:
+ ret_repr = self.ret.__repr__()
+ return f"ExternFunctionSpec: {self.symbol}({arg_repr}) -> {ret_repr}"
+
+
+class ExternModuleSpec:
+ """A spec for a compiled external Module."""
+
+ filename: str
+ functions: List[ExternFunctionSpec]
+
+ def __init__(self, filename: str, functions: List[ExternFunctionSpec]) ->
None:
+ self.filename = filename
+ self.functions = functions
+
+ def __repr__(self) -> str:
+ return f"ExternModuleSpec(path={self.filename}):\n" + "\n".join(
+ f" {func.__repr__()}" for func in self.functions
+ )
+
+
class SpecBuilder:
"""Builder of ModuleSpec, which exports an nn.Module to TVM IRModule."""
@@ -268,10 +307,22 @@ class SpecBuilder:
result.append((name, effect))
return result
+ def _extern_modules() -> List[Tuple[str, List[str]]]:
+ mod2func = defaultdict(set)
+ for _, extern_module in core._attribute_finder(
+ spec.module, "", condition_yield=lambda x: isinstance(x,
core.ExternModule)
+ ):
+ module_spec = extern_module.module_spec
+ mod2func[module_spec.filename].update(
+ [function_spec.symbol for function_spec in
module_spec.functions]
+ )
+ return [(mod, list(funcs)) for mod, funcs in mod2func.items()]
+
# pylint: enable=protected-access
params = _params()
effects = _effects()
+ extern_modules = _extern_modules()
with self:
with self.builder.function("_initialize_effect"):
with self.builder.dataflow():
@@ -282,7 +333,16 @@ class SpecBuilder:
with self.builder.dataflow():
outputs, inputs = _emit_method(self.builder,
method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)
- return self.builder.get(), params
+ external_mods = []
+ for lib_path, func_names in extern_modules:
+ external_mods.append(load_static_library(path=lib_path,
func_names=func_names))
+ mod = self.builder.get()
+ if extern_modules:
+ original_external_mods = mod.get_attr("external_mods")
+ if original_external_mods is not None:
+ external_mods = original_external_mods + extern_modules
+ mod = mod.with_attr("external_mods", external_mods)
+ return mod, params
def _emit_effect_init(
diff --git a/tests/python/relax/test_frontend_nn_extern_module.py
b/tests/python/relax/test_frontend_nn_extern_module.py
new file mode 100644
index 0000000000..a5753c19ab
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_extern_module.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+import tempfile
+
+import pytest
+
+import tvm
+from tvm.script import ir as I, tir as T, relax as R
+from tvm.relax.frontend import nn
+from tvm.relax.frontend.nn import spec
+
+
+def _gen_extern_module(mod_dir, file):
+ src = """#include <dlpack/dlpack.h>
+ #include <tvm/runtime/packed_func.h>
+
+ int f_matmul(DLTensor* a, DLTensor* b, DLTensor* c) { return 0; }
+
+ TVM_DLL_EXPORT_TYPED_FUNC(matmul, f_matmul)"""
+ with open(f"{mod_dir}/{file}.cc", "w") as cc_file:
+ cc_file.write(src)
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
+ os.system(
+ f"gcc -c {mod_dir}/{file}.cc "
+ f"-o {mod_dir}/{file}.o "
+ f"-I{cur_dir}/../../../include "
+ f"-I{cur_dir}/../../../3rdparty/dlpack/include "
+ f"-I{cur_dir}/../../../3rdparty/dmlc-core/include"
+ )
+ return f"{mod_dir}/{file}.o"
+
+
+def test_extern_module():
+ shape_a = ("a", "b", "c", "d", 1, 2, 3, 4)
+ shape_b = ("c", "d", "e", "f", 5, 6, 7, 8)
+ shape_c = ("a", "b", "c", "d", "e", "f", 9, 10)
+ dtype = "float32"
+ tmp_dir = tempfile.mkdtemp()
+ obj_file = _gen_extern_module(tmp_dir, "test")
+ func_name = "matmul"
+ os.system(f"ls {tmp_dir}")
+
+ ext_mod = nn.ExternModule(
+ module_spec=spec.ExternModuleSpec(
+ filename=obj_file,
+ functions=[
+ spec.ExternFunctionSpec(
+ symbol=func_name,
+ args=[
+ spec.Tensor(shape_a, dtype),
+ spec.Tensor(shape_b, dtype),
+ ],
+ ret=spec.Tensor(shape_c, dtype),
+ )
+ ],
+ )
+ )
+
+ class MatmulModule(nn.Module):
+ def __init__(self) -> None:
+ self.Matmul = ext_mod
+
+ def forward(self, a: nn.Tensor, b: nn.Tensor):
+ return self.Matmul.get_extern_func(func_name)(a, b)
+
+ matmul_mod = MatmulModule()
+ ir_module, _ = matmul_mod.export_tvm(
+ spec={
+ "forward": {
+ "a": spec.Tensor(shape_a, dtype),
+ "b": spec.Tensor(shape_b, dtype),
+ }
+ }
+ )
+
+ @R.function
+ def forward(
+ a_1: R.Tensor(("a", "b", "c", "d", 1, 2, 3, 4), dtype="float32"),
+ b_1: R.Tensor(("c", "d", "e", "f", 5, 6, 7, 8), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(
+ R.Tensor(("a", "b", "c", "d", "e", "f", 9, 10), dtype="float32"),
R.Tuple(R.Object)
+ ):
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ e = T.int64()
+ f = T.int64()
+ with R.dataflow():
+ matmul = R.call_dps_packed(
+ "matmul",
+ (a_1, b_1),
+ out_sinfo=R.Tensor((a, b, c, d, e, f, 9, 10), dtype="float32"),
+ )
+ gv1: R.Tuple(
+ R.Tensor((a, b, c, d, e, f, 9, 10), dtype="float32"),
R.Tuple(R.Object)
+ ) = matmul, (_io,)
+ R.output(gv1)
+ return gv1
+
+ tvm.ir.assert_structural_equal(ir_module["forward"], forward)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()