This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9030522960 [Unity][Frontend] Introducing Object (#16316)
9030522960 is described below

commit 90305229604b0ca4cce34bd6de5b6b21925b55d4
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Dec 31 18:33:10 2023 -0500

    [Unity][Frontend] Introducing Object (#16316)
    
    This PR supports `Object` as a new spec/frontend type in
    nn.Module, so that non-tensor opaque objects (such as
    PagedKVCache) can be effectively represented.
---
 python/tvm/relax/frontend/nn/__init__.py |  2 +-
 python/tvm/relax/frontend/nn/core.py     | 23 ++++++++++++++++++++++-
 python/tvm/relax/frontend/nn/exporter.py |  8 +++++---
 python/tvm/relax/frontend/nn/spec.py     | 16 ++++++++++++++--
 4 files changed, 42 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/__init__.py 
b/python/tvm/relax/frontend/nn/__init__.py
index 5723e3d9ff..61d1001ea8 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -17,7 +17,7 @@
 """A PyTorch-like API to build IRModules."""
 # pylint: disable=redefined-builtin
 from . import op, spec
-from .core import Effect, Module, ModuleList, Parameter, Tensor
+from .core import Effect, Module, ModuleList, Object, Parameter, Tensor
 from .exporter import add_extern
 from .extern import ExternModule, ObjectModule, SourceModule
 from .modules import (
diff --git a/python/tvm/relax/frontend/nn/core.py 
b/python/tvm/relax/frontend/nn/core.py
index 8ed0efe2cd..9c99ba6177 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -50,7 +50,12 @@ from tvm.target import Target
 
 from ... import expr as rx
 from ...block_builder import BlockBuilder
-from ...struct_info import ShapeStructInfo, TensorStructInfo, TupleStructInfo
+from ...struct_info import (
+    ObjectStructInfo,
+    ShapeStructInfo,
+    TensorStructInfo,
+    TupleStructInfo,
+)
 from ._tensor_op import _TensorOp
 from .subroutine import SubroutineMixin
 
@@ -274,6 +279,22 @@ class Parameter(Tensor):
             )._expr
 
 
+class Object:
+    """A wrapper on top of relax.Expr whose struct_info is the base
+    ObjectStructInfo (rather than any its subclass). Object effectively
+    represents non-tensor frontend components such as KV caches.
+    """
+
+    _expr: rx.Var
+
+    def __init__(self, *, _expr: rx.Expr, _name: str) -> None:
+        """Private constructor. Object is never supposed to be constructed 
directly by users."""
+        if not isinstance(_expr, rx.Var):
+            _expr = BlockBuilder.current().emit(_expr, _name)
+        self._expr = _expr
+        assert isinstance(self._expr.struct_info, ObjectStructInfo)
+
+
 class Effect:
     """Effect is a special non-user facing type that is used to represent 
operations with side
     effects, for example, print. It is used to represent the output of a 
computation.
diff --git a/python/tvm/relax/frontend/nn/exporter.py 
b/python/tvm/relax/frontend/nn/exporter.py
index 416913def4..99591c8a3e 100644
--- a/python/tvm/relax/frontend/nn/exporter.py
+++ b/python/tvm/relax/frontend/nn/exporter.py
@@ -23,7 +23,7 @@ from tvm.ir import IRModule
 
 from ... import expr as rx
 from ...block_builder import BlockBuilder
-from ...struct_info import ShapeStructInfo, TupleStructInfo
+from ...struct_info import ObjectStructInfo, ShapeStructInfo, TupleStructInfo
 from . import core, extern
 from . import spec as _spec
 from .modules import IOEffect
@@ -160,7 +160,7 @@ def _emit_method(  # pylint: 
disable=too-many-locals,too-many-branches,too-many-
 ):
     # pylint: disable=protected-access
     def _unwrap_ret(expr: typing.Any) -> typing.Any:
-        if isinstance(expr, core.Tensor):
+        if isinstance(expr, (core.Tensor, core.Object)):
             return expr._expr
         if isinstance(expr, tuple):
             return rx.Tuple([_unwrap_ret(x) for x in expr])
@@ -171,7 +171,7 @@ def _emit_method(  # pylint: 
disable=too-many-locals,too-many-branches,too-many-
     def _convert_input(arg):
         if isinstance(arg, tir.Var):
             return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg]))
-        if isinstance(arg, core.Tensor):
+        if isinstance(arg, (core.Tensor, core.Object)):
             return arg._expr  # pylint: disable=protected-access
         if isinstance(arg, _spec.Tuple):
             return rx.Var(
@@ -292,6 +292,8 @@ def _method_spec_to_inputs(
                 dtype=arg_spec.dtype,
                 name=arg_name,
             )
+        elif isinstance(arg_spec, _spec.Object):
+            arg = arg_spec.object_type(_expr=rx.Var(arg_name, 
ObjectStructInfo()), _name=arg_name)
         elif isinstance(arg_spec, _spec.Tuple):
             elements = type(arg_spec.elements)(
                 [
diff --git a/python/tvm/relax/frontend/nn/spec.py 
b/python/tvm/relax/frontend/nn/spec.py
index 210b16ce01..54928ce07b 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -24,7 +24,7 @@ if typing.TYPE_CHECKING:
 ArgSpecType = typing.Union["Int", "Tensor"]
 MethodSpecType = typing.Union["MethodSpec", typing.Dict[str, ArgSpecType]]
 ModuleSpecType = typing.Union["ModuleSpec", typing.Dict[str, MethodSpecType]]
-SpecAny = typing.Union["Int", "Tensor", "Tuple"]
+SpecAny = typing.Union["Object", "Int", "Tensor", "Tuple"]
 
 
 class Int:  # pylint: disable=too-few-public-methods
@@ -52,6 +52,18 @@ class Tensor:  # pylint: disable=too-few-public-methods
         return f"Tensor([{shape}], '{self.dtype}')"
 
 
+class Object:  # pylint: disable=too-few-public-methods
+    """An non-tensor opaque frontend object."""
+
+    object_type: typing.Type
+
+    def __init__(self, object_type: typing.Type) -> None:
+        self.object_type = object_type
+
+    def __repr__(self) -> str:
+        return "object"
+
+
 class Tuple:  # pylint: disable=too-few-public-methods
     """A tuple input or a list input"""
 
@@ -141,7 +153,7 @@ class MethodSpec:
                 return Int()
             if isinstance(arg_spec, str) and arg_spec == "int":
                 return Int()
-            if isinstance(arg_spec, (Int, Tensor)):
+            if isinstance(arg_spec, (Int, Tensor, Object)):
                 return arg_spec
             if isinstance(arg_spec, (tuple, list, Tuple)):
                 return Tuple(

Reply via email to