This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b959645454 [Unity] nn.Module external modules (#15487)
b959645454 is described below

commit b959645454bdd47641c91887cfd9fe23d920d74a
Author: Yaxing Cai <[email protected]>
AuthorDate: Sun Aug 20 00:21:17 2023 -0700

    [Unity] nn.Module external modules (#15487)
    
    * [Unity] nn.Module external modules
    
    This PR introduces the feature of importing external `*.o` modules into our 
`nn.Module` frontend.
    
    * apply code review suggestions
    
    * add ASF header
    
    ---------
    
    Co-authored-by: Ubuntu <[email protected]>
---
 python/tvm/relax/frontend/nn/__init__.py           |   2 +-
 python/tvm/relax/frontend/nn/core.py               |  66 +++++++++++
 python/tvm/relax/frontend/nn/spec.py               |  62 ++++++++++-
 .../python/relax/test_frontend_nn_extern_module.py | 121 +++++++++++++++++++++
 4 files changed, 249 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/__init__.py 
b/python/tvm/relax/frontend/nn/__init__.py
index 358e1211a9..2c5736c4fc 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -16,7 +16,7 @@
 # under the License.
 """A PyTorch-like API to build IRModules."""
 from . import op, spec
-from .core import Effect, Module, ModuleList, Parameter, Tensor
+from .core import Effect, Module, ModuleList, Parameter, Tensor, ExternModule
 from .modules import Embedding, IOEffect, KVCache, Linear, RMSNorm
 from .op import *
 from .subroutine import SubroutineMixin
diff --git a/python/tvm/relax/frontend/nn/core.py 
b/python/tvm/relax/frontend/nn/core.py
index 1d1869fb6c..c282815160 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -402,6 +402,72 @@ class Module(SubroutineMixin):
         raise ValueError(f"Unknown out_format: {out_format}")
 
 
+class ExternModule(Module):
+    """Base class for external module. Subclass it to import your external 
models.
+    Modules can nest within each other in a tree structure using regular 
attribute assignment."""
+
+    module_spec: "_spec.ExternModuleSpec"
+
+    def __init__(self, module_spec: "_spec.ExternModuleSpec") -> None:
+        super().__init__()
+        self.module_spec = module_spec
+
+    def get_extern_func(self, func_name: str) -> Callable:
+        """This method helps get the external funciton in external module by 
function name.
+        It will wrap the functions as other prebuilt operators.
+
+        Parameters
+        ----------
+        func_name : str
+            The name of the function to get.
+
+        Returns
+        ------
+        ret_func: Callable
+            The callable function to call.
+        """
+        for function_spec in self.module_spec.functions:
+            if function_spec.symbol == func_name:
+                # pylint: disable=cell-var-from-loop, import-outside-toplevel, 
protected-access
+                from .op import _wrap_nested
+                from tvm.relax import call_dps_packed, Tuple as RxTuple
+
+                def extern_func(*args: Tensor) -> Tensor:
+                    spec2var = {}
+                    for arg, arg_spec in zip(args, function_spec.args):
+                        for value, value_spec in zip(arg.shape, 
arg_spec.shape):
+                            if isinstance(value_spec, str):
+                                if not value_spec in spec2var:
+                                    spec2var[value_spec] = value
+                                else:
+                                    if not spec2var[value_spec] == value:
+                                        raise ValueError(
+                                            f"Confilict vars 
{spec2var[value_spec]} and {value} "
+                                            f"for {value_spec} in 
{function_spec}"
+                                        )
+                    out_shape = []
+                    for value_spec in function_spec.ret.shape:
+                        if isinstance(value_spec, int):
+                            out_shape.append(value_spec)
+                        elif isinstance(value_spec, str):
+                            if not value_spec in spec2var:
+                                raise ValueError(f"Undefined var {value_spec} 
in {function_spec}")
+                            out_shape.append(spec2var[value_spec])
+                    out_sinfo = TensorStructInfo(out_shape, 
function_spec.ret.dtype)
+                    return _wrap_nested(
+                        call_dps_packed(
+                            func_name,
+                            args=RxTuple([tensor._expr for tensor in args]),
+                            out_sinfo=out_sinfo,
+                        ),
+                        func_name,
+                    )
+
+                return extern_func
+
+        raise ValueError(f"Unknown function {func_name} in 
{self.module_spec.filename}")
+
+
 class ModuleList(Module):
     """Holds submodules in a list."""
 
diff --git a/python/tvm/relax/frontend/nn/spec.py 
b/python/tvm/relax/frontend/nn/spec.py
index a279616f31..95772f2f94 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -15,12 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 """Compilation specifications, for example, dynamic shape inputs."""
+from collections import defaultdict
 import inspect
 import threading
 from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
 
 from tvm import tir
 from tvm.ir import IRModule
+from tvm.runtime import load_static_library
 
 from ... import expr as rx
 from ...block_builder import BlockBuilder
@@ -219,6 +221,43 @@ class ModuleSpec:
         )
 
 
+class ExternFunctionSpec:
+    """A spec for a compiled external function."""
+
+    symbol: str
+    args: List[Tensor]
+    ret: Union[Tensor, List[Tensor]]
+
+    def __init__(self, symbol: str, args: List[Tensor], ret: Union[Tensor, 
List[Tensor]]) -> None:
+        self.symbol = symbol
+        self.args = args
+        self.ret = ret
+
+    def __repr__(self) -> str:
+        arg_repr = ", ".join(arg.__repr__() for arg in self.args)
+        if isinstance(self.ret, list):
+            ret_repr = "(" + ", ".join(ret.__repr__() for ret in self.ret) + 
")"
+        else:
+            ret_repr = self.ret.__repr__()
+        return f"ExternFunctionSpec: {self.symbol}({arg_repr}) -> {ret_repr}"
+
+
+class ExternModuleSpec:
+    """A spec for a compiled external Module."""
+
+    filename: str
+    functions: List[ExternFunctionSpec]
+
+    def __init__(self, filename: str, functions: List[ExternFunctionSpec]) -> 
None:
+        self.filename = filename
+        self.functions = functions
+
+    def __repr__(self) -> str:
+        return f"ExternModuleSpec(path={self.filename}):\n" + "\n".join(
+            f"  {func.__repr__()}" for func in self.functions
+        )
+
+
 class SpecBuilder:
     """Builder of ModuleSpec, which exports an nn.Module to TVM IRModule."""
 
@@ -268,10 +307,22 @@ class SpecBuilder:
                 result.append((name, effect))
             return result
 
+        def _extern_modules() -> List[Tuple[str, List[str]]]:
+            mod2func = defaultdict(set)
+            for _, extern_module in core._attribute_finder(
+                spec.module, "", condition_yield=lambda x: isinstance(x, 
core.ExternModule)
+            ):
+                module_spec = extern_module.module_spec
+                mod2func[module_spec.filename].update(
+                    [function_spec.symbol for function_spec in 
module_spec.functions]
+                )
+            return [(mod, list(funcs)) for mod, funcs in mod2func.items()]
+
         # pylint: enable=protected-access
 
         params = _params()
         effects = _effects()
+        extern_modules = _extern_modules()
         with self:
             with self.builder.function("_initialize_effect"):
                 with self.builder.dataflow():
@@ -282,7 +333,16 @@ class SpecBuilder:
                     with self.builder.dataflow():
                         outputs, inputs = _emit_method(self.builder, 
method_spec, params, effects)
                     self.builder.emit_func_output(outputs, inputs)
-        return self.builder.get(), params
+        external_mods = []
+        for lib_path, func_names in extern_modules:
+            external_mods.append(load_static_library(path=lib_path, 
func_names=func_names))
+        mod = self.builder.get()
+        if extern_modules:
+            original_external_mods = mod.get_attr("external_mods")
+            if original_external_mods is not None:
+                external_mods = original_external_mods + extern_modules
+            mod = mod.with_attr("external_mods", external_mods)
+        return mod, params
 
 
 def _emit_effect_init(
diff --git a/tests/python/relax/test_frontend_nn_extern_module.py 
b/tests/python/relax/test_frontend_nn_extern_module.py
new file mode 100644
index 0000000000..a5753c19ab
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_extern_module.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+import tempfile
+
+import pytest
+
+import tvm
+from tvm.script import ir as I, tir as T, relax as R
+from tvm.relax.frontend import nn
+from tvm.relax.frontend.nn import spec
+
+
+def _gen_extern_module(mod_dir, file):
+    src = """#include <dlpack/dlpack.h>
+    #include <tvm/runtime/packed_func.h>
+
+    int f_matmul(DLTensor* a, DLTensor* b, DLTensor* c) { return 0; }
+
+    TVM_DLL_EXPORT_TYPED_FUNC(matmul, f_matmul)"""
+    with open(f"{mod_dir}/{file}.cc", "w") as cc_file:
+        cc_file.write(src)
+    cur_dir = os.path.dirname(os.path.abspath(__file__))
+    os.system(
+        f"gcc -c {mod_dir}/{file}.cc "
+        f"-o {mod_dir}/{file}.o "
+        f"-I{cur_dir}/../../../include "
+        f"-I{cur_dir}/../../../3rdparty/dlpack/include "
+        f"-I{cur_dir}/../../../3rdparty/dmlc-core/include"
+    )
+    return f"{mod_dir}/{file}.o"
+
+
+def test_extern_module():
+    shape_a = ("a", "b", "c", "d", 1, 2, 3, 4)
+    shape_b = ("c", "d", "e", "f", 5, 6, 7, 8)
+    shape_c = ("a", "b", "c", "d", "e", "f", 9, 10)
+    dtype = "float32"
+    tmp_dir = tempfile.mkdtemp()
+    obj_file = _gen_extern_module(tmp_dir, "test")
+    func_name = "matmul"
+    os.system(f"ls {tmp_dir}")
+
+    ext_mod = nn.ExternModule(
+        module_spec=spec.ExternModuleSpec(
+            filename=obj_file,
+            functions=[
+                spec.ExternFunctionSpec(
+                    symbol=func_name,
+                    args=[
+                        spec.Tensor(shape_a, dtype),
+                        spec.Tensor(shape_b, dtype),
+                    ],
+                    ret=spec.Tensor(shape_c, dtype),
+                )
+            ],
+        )
+    )
+
+    class MatmulModule(nn.Module):
+        def __init__(self) -> None:
+            self.Matmul = ext_mod
+
+        def forward(self, a: nn.Tensor, b: nn.Tensor):
+            return self.Matmul.get_extern_func(func_name)(a, b)
+
+    matmul_mod = MatmulModule()
+    ir_module, _ = matmul_mod.export_tvm(
+        spec={
+            "forward": {
+                "a": spec.Tensor(shape_a, dtype),
+                "b": spec.Tensor(shape_b, dtype),
+            }
+        }
+    )
+
+    @R.function
+    def forward(
+        a_1: R.Tensor(("a", "b", "c", "d", 1, 2, 3, 4), dtype="float32"),
+        b_1: R.Tensor(("c", "d", "e", "f", 5, 6, 7, 8), dtype="float32"),
+        _io: R.Object,
+    ) -> R.Tuple(
+        R.Tensor(("a", "b", "c", "d", "e", "f", 9, 10), dtype="float32"), 
R.Tuple(R.Object)
+    ):
+        a = T.int64()
+        b = T.int64()
+        c = T.int64()
+        d = T.int64()
+        e = T.int64()
+        f = T.int64()
+        with R.dataflow():
+            matmul = R.call_dps_packed(
+                "matmul",
+                (a_1, b_1),
+                out_sinfo=R.Tensor((a, b, c, d, e, f, 9, 10), dtype="float32"),
+            )
+            gv1: R.Tuple(
+                R.Tensor((a, b, c, d, e, f, 9, 10), dtype="float32"), 
R.Tuple(R.Object)
+            ) = matmul, (_io,)
+            R.output(gv1)
+        return gv1
+
+    tvm.ir.assert_structural_equal(ir_module["forward"], forward)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to