This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9030522960 [Unity][Frontend] Introducing Object (#16316)
9030522960 is described below
commit 90305229604b0ca4cce34bd6de5b6b21925b55d4
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Dec 31 18:33:10 2023 -0500
[Unity][Frontend] Introducing Object (#16316)
This PR supports `Object` as a new spec/frontend type in
nn.Module, so that non-tensor opaque objects (such as
PagedKVCache) can be effectively represented.
---
python/tvm/relax/frontend/nn/__init__.py | 2 +-
python/tvm/relax/frontend/nn/core.py | 23 ++++++++++++++++++++++-
python/tvm/relax/frontend/nn/exporter.py | 8 +++++---
python/tvm/relax/frontend/nn/spec.py | 16 ++++++++++++++--
4 files changed, 42 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/__init__.py
b/python/tvm/relax/frontend/nn/__init__.py
index 5723e3d9ff..61d1001ea8 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -17,7 +17,7 @@
"""A PyTorch-like API to build IRModules."""
# pylint: disable=redefined-builtin
from . import op, spec
-from .core import Effect, Module, ModuleList, Parameter, Tensor
+from .core import Effect, Module, ModuleList, Object, Parameter, Tensor
from .exporter import add_extern
from .extern import ExternModule, ObjectModule, SourceModule
from .modules import (
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index 8ed0efe2cd..9c99ba6177 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -50,7 +50,12 @@ from tvm.target import Target
from ... import expr as rx
from ...block_builder import BlockBuilder
-from ...struct_info import ShapeStructInfo, TensorStructInfo, TupleStructInfo
+from ...struct_info import (
+ ObjectStructInfo,
+ ShapeStructInfo,
+ TensorStructInfo,
+ TupleStructInfo,
+)
from ._tensor_op import _TensorOp
from .subroutine import SubroutineMixin
@@ -274,6 +279,22 @@ class Parameter(Tensor):
)._expr
+class Object:
+ """A wrapper on top of relax.Expr whose struct_info is the base
+ ObjectStructInfo (rather than any its subclass). Object effectively
+ represents non-tensor frontend components such as KV caches.
+ """
+
+ _expr: rx.Var
+
+ def __init__(self, *, _expr: rx.Expr, _name: str) -> None:
+ """Private constructor. Object is never supposed to be constructed
directly by users."""
+ if not isinstance(_expr, rx.Var):
+ _expr = BlockBuilder.current().emit(_expr, _name)
+ self._expr = _expr
+ assert isinstance(self._expr.struct_info, ObjectStructInfo)
+
+
class Effect:
"""Effect is a special non-user facing type that is used to represent
operations with side
effects, for example, print. It is used to represent the output of a
computation.
diff --git a/python/tvm/relax/frontend/nn/exporter.py
b/python/tvm/relax/frontend/nn/exporter.py
index 416913def4..99591c8a3e 100644
--- a/python/tvm/relax/frontend/nn/exporter.py
+++ b/python/tvm/relax/frontend/nn/exporter.py
@@ -23,7 +23,7 @@ from tvm.ir import IRModule
from ... import expr as rx
from ...block_builder import BlockBuilder
-from ...struct_info import ShapeStructInfo, TupleStructInfo
+from ...struct_info import ObjectStructInfo, ShapeStructInfo, TupleStructInfo
from . import core, extern
from . import spec as _spec
from .modules import IOEffect
@@ -160,7 +160,7 @@ def _emit_method( # pylint:
disable=too-many-locals,too-many-branches,too-many-
):
# pylint: disable=protected-access
def _unwrap_ret(expr: typing.Any) -> typing.Any:
- if isinstance(expr, core.Tensor):
+ if isinstance(expr, (core.Tensor, core.Object)):
return expr._expr
if isinstance(expr, tuple):
return rx.Tuple([_unwrap_ret(x) for x in expr])
@@ -171,7 +171,7 @@ def _emit_method( # pylint:
disable=too-many-locals,too-many-branches,too-many-
def _convert_input(arg):
if isinstance(arg, tir.Var):
return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg]))
- if isinstance(arg, core.Tensor):
+ if isinstance(arg, (core.Tensor, core.Object)):
return arg._expr # pylint: disable=protected-access
if isinstance(arg, _spec.Tuple):
return rx.Var(
@@ -292,6 +292,8 @@ def _method_spec_to_inputs(
dtype=arg_spec.dtype,
name=arg_name,
)
+ elif isinstance(arg_spec, _spec.Object):
+ arg = arg_spec.object_type(_expr=rx.Var(arg_name,
ObjectStructInfo()), _name=arg_name)
elif isinstance(arg_spec, _spec.Tuple):
elements = type(arg_spec.elements)(
[
diff --git a/python/tvm/relax/frontend/nn/spec.py
b/python/tvm/relax/frontend/nn/spec.py
index 210b16ce01..54928ce07b 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -24,7 +24,7 @@ if typing.TYPE_CHECKING:
ArgSpecType = typing.Union["Int", "Tensor"]
MethodSpecType = typing.Union["MethodSpec", typing.Dict[str, ArgSpecType]]
ModuleSpecType = typing.Union["ModuleSpec", typing.Dict[str, MethodSpecType]]
-SpecAny = typing.Union["Int", "Tensor", "Tuple"]
+SpecAny = typing.Union["Object", "Int", "Tensor", "Tuple"]
class Int: # pylint: disable=too-few-public-methods
@@ -52,6 +52,18 @@ class Tensor: # pylint: disable=too-few-public-methods
return f"Tensor([{shape}], '{self.dtype}')"
+class Object: # pylint: disable=too-few-public-methods
+ """An non-tensor opaque frontend object."""
+
+ object_type: typing.Type
+
+ def __init__(self, object_type: typing.Type) -> None:
+ self.object_type = object_type
+
+ def __repr__(self) -> str:
+ return "object"
+
+
class Tuple: # pylint: disable=too-few-public-methods
"""A tuple input or a list input"""
@@ -141,7 +153,7 @@ class MethodSpec:
return Int()
if isinstance(arg_spec, str) and arg_spec == "int":
return Int()
- if isinstance(arg_spec, (Int, Tensor)):
+ if isinstance(arg_spec, (Int, Tensor, Object)):
return arg_spec
if isinstance(arg_spec, (tuple, list, Tuple)):
return Tuple(