This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c3862b573f [Unity] nn.Module Spec (#15416)
c3862b573f is described below
commit c3862b573f1598c08c51d6b24149db965b24ff93
Author: Junru Shao <[email protected]>
AuthorDate: Thu Jul 27 05:08:22 2023 -0700
[Unity] nn.Module Spec (#15416)
This PR introduces ModuleSpec and SpecBuilder, which are key classes to
export an nn.Module to TVM's IRModule.
Example:
```python
model.export_tvm(
spec={
"prefill": {
"inputs": spec.Tensor([batch_size, "seq_len"], "int32"),
"total_seq_len": int,
},
"decode": {
"inputs": spec.Tensor([batch_size, 1], "int32"),
"total_seq_len": int,
},
"softmax_with_temperature": {
"logits": spec.Tensor([1, 1, vocab_size], "float32"),
"temperature": spec.Tensor([], "float32"),
},
},
)
```
---
python/tvm/relax/frontend/nn/__init__.py | 3 +-
python/tvm/relax/frontend/nn/core.py | 82 +++++++-
python/tvm/relax/frontend/nn/modules.py | 51 +++++
python/tvm/relax/frontend/nn/spec.py | 336 +++++++++++++++++++++++++++++++
4 files changed, 470 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/__init__.py
b/python/tvm/relax/frontend/nn/__init__.py
index b7687eb924..8375cb6618 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -14,5 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name
"""A PyTorch-like API to build IRModules."""
+from . import spec
+from .core import Effect, Module, ModuleList, Parameter, Tensor
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index 8f258dd81c..1ac4a6fbdb 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -24,18 +24,36 @@
impure external function callings, inplace mutation, etc.
"""
from collections import OrderedDict
-from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence,
Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
import numpy as np
from tvm import tir
+from tvm.ir import IRModule
from tvm.runtime import Device, NDArray, ndarray
+from tvm.runtime.relax_vm import VirtualMachine
+from tvm.target import Target
from ... import expr as rx
from ...block_builder import BlockBuilder
from ...struct_info import ShapeStructInfo, TensorStructInfo
from ._tensor_op import _TensorOp
+if TYPE_CHECKING:
+ from . import spec as _spec
+
+
_DEFAULT_DTYPE = "float32"
@@ -327,6 +345,44 @@ class Module:
if hasattr(item, "to") and callable(item.to):
item.to(dtype=dtype)
+ def export_tvm(self, spec: "_spec.Module") -> Tuple[IRModule,
List[Tuple[str, Parameter]]]:
+ """Export the module to TVM IRModule and parameters"""
+ from . import spec as _spec # pylint: disable=import-outside-toplevel
+
+ spec = _spec.ModuleSpec.from_raw(spec, self)
+ mod, params = _spec.SpecBuilder().build(spec)
+ return mod, params
+
+ def jit(
+ self,
+ spec: "_spec.Module",
+ target: Union[str, Target] = "llvm",
+ device: str = "cpu",
+ pipeline: str = "zero",
+ out_format: str = "torch",
+ ) -> Callable:
+ """Just-in-time compilation of a nn.model to an executable"""
+ from tvm import relax # pylint: disable=import-outside-toplevel
+
+ from . import spec as _spec # pylint: disable=import-outside-toplevel
+
+ # Convert nn.Module to IRModule
+ spec = _spec.ModuleSpec.from_raw(spec, self)
+ mod, params = _spec.SpecBuilder().build(spec)
+
+ # Convert parameters
+ device = _str_to_device(device)
+ params = _param_to_ndarray(params, device)
+
+ # Compile mod and feed it to VM
+ mod = relax.pipeline.get_pipeline(pipeline)(mod) # pylint:
disable=no-value-for-parameter
+ mod = relax.build(mod, target=target)
+ VirtualMachine(mod, device)
+
+ if out_format == "torch":
+ raise NotImplementedError
+ raise ValueError(f"Unknown out_format: {out_format}")
+
class ModuleList(Module):
"""Holds submodules in a list."""
@@ -417,3 +473,27 @@ def _from_dlpack(tensor) -> NDArray:
device_id,
),
)
+
+
+def _str_to_device(device: str) -> Device:
+ split = device.split(":")
+ if len(split) > 2:
+ raise ValueError(f"Invalid device: {device}")
+ device_type = split[0]
+ device_id = 0 if len(split) == 1 else int(split[1])
+ if device_type not in Device.STR2MASK:
+ raise ValueError(f"Unsupported device type: {device_type}")
+ return Device(Device.STR2MASK[device_type], device_id)
+
+
+def _param_to_ndarray(params: List[Tuple[str, Parameter]], device: Device) ->
List[NDArray]:
+ results = []
+ missing = []
+ for name, param in params:
+ if param.data is None:
+ missing.append(name)
+ else:
+ results.append(param.data.copyto(target=device))
+ if missing:
+ raise ValueError(f"Parameters are not set to any concrete values: {',
'.join(missing)}")
+ return results
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
new file mode 100644
index 0000000000..20f5a1d3c3
--- /dev/null
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Builtin Modules."""
+from typing import List, Optional
+
+from tvm import relax as rx
+
+from .core import Effect, Tensor
+
+
+class IOEffect(Effect):
+ """
+ Modeling IO side effect, for example, printing the content of NDArrays on
screen, inserting
+ debug breakpoints, etc.
+ """
+
+ effect: Optional[rx.Var]
+
+ def __init__(self):
+ self.effect = None
+
+ def emit_init(self, name_hint, builder: rx.BlockBuilder) ->
List[rx.DataflowVar]:
+ return [builder.emit(rx.op.null_value(), f"{name_hint}.io")]
+
+ def create(self, name_hint: str) -> List[rx.Var]:
+ assert self.effect is None
+ self.effect = rx.Var(f"{name_hint}.io",
struct_info=rx.ObjectStructInfo())
+ return [self.effect]
+
+ def finalize(self) -> List[rx.Var]:
+ result = self.effect
+ self.effect = None
+ return [result]
+
+ def print_(self, tensor: Tensor) -> None:
+ """Encloses the side effect of NDArray printing"""
+ raise NotImplementedError
diff --git a/python/tvm/relax/frontend/nn/spec.py
b/python/tvm/relax/frontend/nn/spec.py
new file mode 100644
index 0000000000..3a9be83a51
--- /dev/null
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -0,0 +1,336 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Compilation specifications, for example, dynamic shape inputs."""
+import inspect
+import threading
+from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
+
+from tvm import tir
+from tvm.ir import IRModule
+
+from ... import expr as rx
+from ...block_builder import BlockBuilder
+from ...struct_info import ShapeStructInfo
+from . import core
+
+ArgSpecType = Union["Int", "Tensor"]
+MethodSpecType = Union["MethodSpec", Dict[str, ArgSpecType]]
+ModuleSpecType = Union["ModuleSpec", Dict[str, MethodSpecType]]
+
+
+class Int: # pylint: disable=too-few-public-methods
+ """An integer input"""
+
+ def __init__(self) -> None:
+ pass
+
+ def __repr__(self) -> str:
+ return "int"
+
+
+class Tensor: # pylint: disable=too-few-public-methods
+ """A tensor input with static ndim and dtype, but can have symbolic
shapes."""
+
+ shape: List[Union[int, str]]
+ dtype: str
+
+ def __init__(self, shape: Sequence[Union[int, str]], dtype: str) -> None:
+ self.shape = list(shape)
+ self.dtype = dtype
+
+ def __repr__(self) -> str:
+ shape = ", ".join(str(i) for i in self.shape)
+ return f"Tensor([{shape}], '{self.dtype}')"
+
+
+class MethodSpec:
+ """A spec for a compiled method"""
+
+ method: Callable
+ arg_names: List[str]
+ arg_specs: List[ArgSpecType]
+
+ def __init__(self, method: Callable, arg_names: List[str], arg_specs:
List[ArgSpecType]):
+ self.method = method
+ self.arg_names = arg_names
+ self.arg_specs = arg_specs
+
+ def _repr(self, name: str) -> str:
+ args = ", ".join(
+ f"{name}: {spec}"
+ for name, spec in zip(
+ self.arg_names,
+ self.arg_specs,
+ )
+ )
+ return f"{name}({args})"
+
+ def __repr__(self) -> str:
+ return self._repr(name="MethodSpec")
+
+ @staticmethod
+ def from_raw(spec: MethodSpecType, method: Callable) -> "MethodSpec":
+ """Create MethodSpec from raw python dictionaries.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ MethodSpec.from_raw(
+ spec={
+ "inputs": spec.Tensor([batch_size, "seq_len"], "int32"),
+ "total_seq_len": "int",
+ },
+ method=module.prefill,
+ )
+ """
+ if isinstance(spec, MethodSpec):
+ return spec
+ method_signature = inspect.signature(method)
+ arg_names = list(method_signature.parameters.keys())
+ arg_specs = []
+ for arg_name in arg_names:
+ arg_spec = spec[arg_name]
+ if arg_spec is Int or arg_spec is int:
+ arg_spec = arg_spec()
+ elif isinstance(arg_spec, str) and arg_spec == "int":
+ arg_spec = Int()
+ elif isinstance(arg_spec, (Int, Tensor)):
+ pass
+ else:
+ raise TypeError(f"Invalid spec for argument {arg_name}:
{arg_spec}")
+ arg_specs.append(arg_spec)
+ return MethodSpec(method, arg_names, arg_specs)
+
+ @staticmethod
+ def from_torch(torch_args: List[Any], method: Callable) -> "MethodSpec":
+ """Converts a list of torch tensors to MethodSpec."""
+ raise NotImplementedError
+
+ def as_inputs(self) -> List[Union[tir.Var, core.Tensor]]:
+ """Convert the MethodSpec to a list of inputs to Module's method."""
+ str2var: Dict[str, tir.Var] = {}
+
+ def _get_var(name: str) -> tir.Var:
+ if name in str2var:
+ return str2var[name]
+ var = tir.Var(name, "int64")
+ str2var[name] = var
+ return var
+
+ args = []
+ for arg_name, arg_spec in zip(self.arg_names, self.arg_specs):
+ if isinstance(arg_spec, Int):
+ arg = _get_var(arg_name)
+ elif isinstance(arg_spec, Tensor):
+ arg = core._tensor_placeholder( # pylint:
disable=protected-access
+ name=arg_name,
+ shape=[_get_var(x) if isinstance(x, str) else x for x in
arg_spec.shape],
+ dtype=arg_spec.dtype,
+ )
+ else:
+ raise TypeError(f"Invalid spec for argument {arg_name}:
{arg_spec}")
+ args.append(arg)
+ return args
+
+
+class ModuleSpec:
+ """A spec for a compiled nn.Module"""
+
+ module: core.Module
+ method_names: List[str]
+ method_specs: List[MethodSpecType]
+
+ def __init__(
+ self,
+ module: core.Module,
+ method_names: List[str],
+ method_specs: List[MethodSpecType],
+ ) -> None:
+ self.module = module
+ self.method_names = method_names
+ self.method_specs = method_specs
+
+ @staticmethod
+ def from_raw(spec: ModuleSpecType, module: core.Module) -> "ModuleSpec":
+ """Create ModuleSpec from raw python dictionaries.
+
+
+ Examples
+ --------
+ .. code-block:: python
+
+ ModuleSpec.from_raw(
+ spec={
+ "prefill": {
+ "inputs": spec.Tensor([batch_size, "seq_len"],
"int32"),
+ "total_seq_len": int,
+ },
+ "decode": {
+ "inputs": spec.Tensor([batch_size, 1], "int32"),
+ "total_seq_len": int,
+ },
+ "softmax_with_temperature": {
+ "logits": spec.Tensor([1, 1, config.vocab_size],
"float32"),
+ "temperature": spec.Tensor([], "float32"),
+ },
+ },
+ module=module,
+ )
+ """
+ if isinstance(spec, ModuleSpec):
+ return spec
+ method_names = list(spec.keys())
+ method_specs = []
+ for method_name in method_names:
+ method_spec = spec[method_name]
+ if isinstance(method_spec, MethodSpec):
+ pass
+ else:
+ method_spec = MethodSpec.from_raw(method_spec, getattr(module,
method_name))
+ method_specs.append(method_spec)
+ return ModuleSpec(module, method_names, method_specs)
+
+ def __repr__(self) -> str:
+ return "ModuleSpec:\n" + "\n".join(
+ " " + spec._repr(name) # pylint: disable=protected-access
+ for name, spec in zip(
+ self.method_names,
+ self.method_specs,
+ )
+ )
+
+
+class SpecBuilder:
+ """Builder of ModuleSpec, which exports an nn.Module to TVM IRModule."""
+
+ _tls = threading.local()
+
+ builder: BlockBuilder
+ io_effect: core.Effect
+
+ def __init__(self) -> None:
+ from .modules import IOEffect # pylint:
disable=import-outside-toplevel
+
+ self.builder = BlockBuilder()
+ self.io_effect = IOEffect()
+
+ @staticmethod
+ def current() -> "SpecBuilder":
+ """Get the current SpecBuilder under the with scope."""
+ assert hasattr(SpecBuilder._tls, "current")
+ return SpecBuilder._tls.current
+
+ def __enter__(self) -> "SpecBuilder":
+ assert not hasattr(SpecBuilder._tls, "current")
+ SpecBuilder._tls.current = self
+ return self
+
+ def __exit__(self, exc_type, exc, traceback) -> None:
+ assert hasattr(SpecBuilder._tls, "current")
+ delattr(SpecBuilder._tls, "current")
+
+ def build(self, spec: ModuleSpec) -> Tuple[IRModule, List[Tuple[str,
core.Parameter]]]:
+ """Build the ModuleSpec to TVM IRModule. Returns the IRModule and the
parameters."""
+
+ # pylint: disable=protected-access
+ def _params() -> List[Tuple[str, core.Parameter]]:
+ params = []
+ for name, param in core._attribute_finder(
+ spec.module, prefix="", condition_yield=lambda x:
isinstance(x, core.Parameter)
+ ):
+ params.append((name, param))
+ return params
+
+ def _effects() -> List[Tuple[str, core.Effect]]:
+ result = [("", self.io_effect)]
+ for name, effect in core._attribute_finder(
+ spec.module, "", condition_yield=lambda x: isinstance(x,
core.Effect)
+ ):
+ result.append((name, effect))
+ return result
+
+ # pylint: enable=protected-access
+
+ params = _params()
+ effects = _effects()
+ with self:
+ with self.builder.function("_initialize_effect"):
+ with self.builder.dataflow():
+ outputs = _emit_effect_init(self.builder, effects)
+ self.builder.emit_func_output(outputs, params=[])
+ for method_name, method_spec in zip(spec.method_names,
spec.method_specs):
+ with self.builder.function(method_name):
+ with self.builder.dataflow():
+ outputs, inputs = _emit_method(self.builder,
method_spec, params, effects)
+ self.builder.emit_func_output(outputs, inputs)
+ return self.builder.get(), params
+
+
+def _emit_effect_init(
+ builder: BlockBuilder,
+ effects: List[Tuple[str, core.Effect]],
+):
+ outputs = []
+ for prefix, effect in effects:
+ inits = effect.emit_init(prefix, builder)
+ assert isinstance(inits, list)
+ outputs.extend(inits)
+ outputs = builder.emit_output(builder.emit(rx.Tuple(outputs)))
+ return outputs
+
+
+def _emit_method(
+ builder: BlockBuilder,
+ spec: MethodSpec,
+ params: List[Tuple[str, core.Parameter]],
+ effects: List[Tuple[str, core.Effect]],
+):
+ # pylint: disable=protected-access
+ def _unwrap_ret(expr: Any) -> Any:
+ if isinstance(expr, core.Tensor):
+ return expr._expr # pylint: disable=protected-access
+ if isinstance(expr, tuple):
+ return rx.Tuple([_unwrap_ret(x) for x in expr])
+ if isinstance(expr, list):
+ return rx.Tuple([_unwrap_ret(x) for x in expr])
+ raise TypeError(f"Unsupported return type: {type(expr)}")
+
+ def _convert_input(arg):
+ if isinstance(arg, tir.Var):
+ return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg]))
+ if isinstance(arg, core.Tensor):
+ return arg._expr # pylint: disable=protected-access
+ raise TypeError(f"Unsupported input type: {type(arg)}")
+
+ explicit_inputs = spec.as_inputs()
+ inputs = []
+ for arg in explicit_inputs:
+ inputs.append(_convert_input(arg))
+ for name, param in params:
+ param._expr = core._tensor_placeholder(name, param.shape,
param.dtype)._expr
+ inputs.append(param._expr)
+ for name, effect in effects:
+ inputs.extend(effect.create(name))
+ # pylint: enable=protected-access
+
+ outputs = spec.method(*explicit_inputs)
+ effect_outputs = []
+ for _, effect in effects:
+ effect_outputs.extend(effect.finalize())
+ outputs = builder.emit_output(rx.Tuple([_unwrap_ret(outputs),
rx.Tuple(effect_outputs)]))
+ return outputs, inputs