This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 889d2f6cef [Unity][Frontend] NNModule `tensor_ir_op` support (#16278)
889d2f6cef is described below

commit 889d2f6cef5a0a533f48e763b626367d9c36ccff
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Dec 27 05:05:47 2023 +0800

    [Unity][Frontend] NNModule `tensor_ir_op` support (#16278)
    
    This PR adds support for `tensor_ir_op` in NNModule, which enables us to
    call TensorIR function in NNModule.
    
    Also this PR adds a test case for extern op.
---
 python/tvm/relax/frontend/nn/_tensor_op.py |  12 +++
 python/tvm/relax/frontend/nn/modules.py    |   4 +-
 python/tvm/relax/frontend/nn/op.py         |  76 ++++++++++++++++-
 tests/python/relax/test_frontend_nn_op.py  | 130 +++++++++++++++++++++++++++++
 4 files changed, 219 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/_tensor_op.py 
b/python/tvm/relax/frontend/nn/_tensor_op.py
index a653c9fa29..627b8b626c 100644
--- a/python/tvm/relax/frontend/nn/_tensor_op.py
+++ b/python/tvm/relax/frontend/nn/_tensor_op.py
@@ -47,10 +47,22 @@ class _TensorOp:
         other = _convert_scalar(other, self)
         return _op().add(self, other)
 
+    def __sub__(self, other):
+        other = _convert_scalar(other, self)
+        return _op().subtract(self, other)
+
+    def __rsub__(self, other):
+        other = _convert_scalar(other, self)
+        return _op().subtract(other, self)
+
     def __mul__(self, other):
         other = _convert_scalar(other, self)
         return _op().multiply(self, other)
 
+    def __rmul__(self, other):
+        other = _convert_scalar(other, self)
+        return _op().multiply(self, other)
+
     def __truediv__(self, other):
         other = _convert_scalar(other, self)
         return _op().divide(self, other)
diff --git a/python/tvm/relax/frontend/nn/modules.py 
b/python/tvm/relax/frontend/nn/modules.py
index b2c97a567a..03d6a06994 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -311,7 +311,7 @@ class ConvTranspose1D(Module):
 
     def forward(self, x: Tensor) -> Tensor:
         """
-        Forward method for convtranspose1d layer.
+        Forward method for conv transpose 1d layer.
 
         Parameters
         ----------
@@ -321,7 +321,7 @@ class ConvTranspose1D(Module):
         Returns
         -------
         ret : Tensor
-            The output tensor for the convtranspose1d layer.
+            The output tensor for the conv transpose 1d layer.
         """
         return op.conv1d_transpose(
             x,
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 2369451ac9..3197145289 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1461,13 +1461,87 @@ def tensor_expr_op(
 OutType = TypeVar("OutType", bound=Union[Tensor, Sequence[Tensor]])
 
 
+def tensor_ir_op(
+    func: _tir.PrimFunc,
+    name_hint: str,
+    args: Union[Tensor, Sequence[Union[Tensor, _tir.Var]]],
+    out: OutType,
+) -> OutType:
+    """Create a `call_tir` binding with given PrimFunc
+
+    Parameters
+    ----------
+    func : _tir.PrimFunc
+        The PrimFunc to call.
+
+    name_hint : str
+        Name hint.
+
+    args : Union[Tensor, Sequence[Union[Tensor, _tir.Var]]]
+        The arguments to pass to the PrimFunc.
+
+    out : Union[Tensor, List[Tensor]]
+        The output tensors.
+
+    Returns
+    -------
+    result : Tensor
+        The result tensor
+    """
+    from tvm import relax as rx  # pylint: disable=import-outside-toplevel
+
+    call_tir_args, tir_vars = [], []
+    if not isinstance(args, (tuple, list)):
+        args = [args]
+
+    for arg in args:
+        if isinstance(arg, Tensor):
+            call_tir_args.append(arg._expr)
+        elif isinstance(arg, _tir.Var):
+            tir_vars.append(arg)
+        else:
+            raise TypeError(
+                f"Unsupported type: tensor_ir_op args expect Tensor or 
tir.Var, but got {type(arg)}"
+            )
+
+    if isinstance(out, Tensor):
+        out_sinfo = [out._expr.struct_info]
+    else:
+        out_sinfo = [x._expr.struct_info for x in out]
+
+    bb = BlockBuilder.current()
+    global_var = bb.add_func(func, name_hint)
+
+    return wrap_nested(
+        bb.emit(rx.call_tir(global_var, call_tir_args, out_sinfo, 
tir_vars=tir_vars)),
+        name=name_hint,
+    )
+
+
 def extern(
     name: str,
     args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]],
     out: OutType,
 ) -> OutType:
     """Invoke an extern function during runtime. The extern function must be 
registered with the "
-    TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func` 
(Python)."""
+    TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func` 
(Python).
+
+    Parameters
+    ----------
+    name : str
+        The name of the extern function to call.
+
+    args : Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]]
+        The arguments to pass to the extern function.
+
+    out : Union[Tensor, List[Tensor]]
+        The output tensors, only
+
+    Returns
+    -------
+    result : Tensor
+        The result
+    """
     from tvm import relax as rx  # pylint: disable=import-outside-toplevel
 
     def _convert(arg, name: str):
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index ddaec7234b..55870426e4 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=missing-docstring, invalid-name
 import tvm
 import tvm.testing
 from tvm import tir
@@ -508,5 +509,134 @@ def test_tensor_expr_op():
     tvm.ir.assert_structural_equal(irmodule, Expected)
 
 
+def test_tensor_ir_op():
+    num_q_heads, num_kv_heads, head_dim = 8, 8, 16
+    fused_heads = num_q_heads + num_kv_heads * 2
+    dtype = "float16"
+
+    @T.prim_func(private=True)
+    def fused_rope(  # pylint: disable=too-many-locals
+        var_qkv: T.handle,
+        offset: T.int64,
+        var_q: T.handle,
+        var_k: T.handle,
+        var_v: T.handle,
+    ):
+        batch_size = T.int64()
+        seq_len = T.int64()
+        qkv = T.match_buffer(var_qkv, (batch_size, seq_len, fused_heads, 
head_dim), dtype)
+        q = T.match_buffer(var_q, (batch_size, seq_len, num_q_heads, 
head_dim), dtype)
+        k = T.match_buffer(var_k, (batch_size, seq_len, num_kv_heads, 
head_dim), dtype)
+        v = T.match_buffer(var_v, (batch_size, seq_len, num_kv_heads, 
head_dim), dtype)
+        T.evaluate(offset)
+
+    class Model(Module):
+        def test(self, qkv: Tensor, offset: tir.Var):
+            tensor_expr_op_out = op.tensor_ir_op(
+                fused_rope,
+                "llama_fused_rope",
+                args=[qkv, offset],
+                out=[
+                    Tensor.placeholder((1, 1, num_q_heads, head_dim), dtype),
+                    Tensor.placeholder((1, 1, num_kv_heads, head_dim), dtype),
+                    Tensor.placeholder((1, 1, num_kv_heads, head_dim), dtype),
+                ],
+            )
+            return tensor_expr_op_out
+
+    # fmt: off
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q: 
T.handle, var_k: T.handle, var_v: T.handle):
+            batch_size, seq_len = T.int64(), T.int64()
+            qkv = T.match_buffer(var_qkv, (batch_size, seq_len, 24, 16), 
"float16")
+            q = T.match_buffer(var_q, (batch_size, seq_len, 8, 16), "float16")
+            k = T.match_buffer(var_k, (batch_size, seq_len, 8, 16), "float16")
+            v = T.match_buffer(var_v, (batch_size, seq_len, 8, 16), "float16")
+            T.evaluate(offset)
+
+        @R.function
+        def _initialize_effect() -> R.Tuple(R.Object):
+            with R.dataflow():
+                _io: R.Object = R.null_value()
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
+                R.output(gv)
+            return gv
+
+        @R.function
+        def test(qkv: R.Tensor((1, 1, 24, 16), dtype="float16"), offset: 
R.Shape(["offset_1"]), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((1, 1, 8, 
16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 
1, 8, 16), dtype="float16")), R.Tuple(R.Object)):
+            offset_1 = T.int64()
+            R.func_attr({"num_input": 3})
+            cls = Expected
+            with R.dataflow():
+                lv1 = R.call_tir(cls.llama_fused_rope, (qkv,), 
out_sinfo=[R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), 
dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")], 
tir_vars=R.shape([offset_1]))
+                llama_fused_rope_0: R.Tensor((1, 1, 8, 16), dtype="float16") = 
lv1[0]
+                llama_fused_rope_1: R.Tensor((1, 1, 8, 16), dtype="float16") = 
lv1[1]
+                llama_fused_rope_2: R.Tensor((1, 1, 8, 16), dtype="float16") = 
lv1[2]
+                gv1: R.Tuple(R.Tuple(R.Tensor((1, 1, 8, 16), dtype="float16"), 
R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), 
dtype="float16")), R.Tuple(R.Object)) = (llama_fused_rope_0, 
llama_fused_rope_1, llama_fused_rope_2), (_io,)
+                R.output(gv1)
+            return gv1
+    # fmt: on
+
+    m = Model()
+    irmodule, _ = m.export_tvm(
+        spec={
+            "test": {"qkv": spec.Tensor([1, 1, fused_heads, head_dim], 
"float16"), "offset": int}
+        },
+        debug=True,
+    )
+    tvm.ir.assert_structural_equal(irmodule, Expected)
+
+
+def test_extern():
+    class Model(Module):
+        def test(self, q: Tensor, k: Tensor, v: Tensor):
+            b, s, h_q, d = q.shape
+            tensor_expr_op_out = op.extern(
+                name="flashinfer.single_decode",
+                args=[q, k, v, 0, 0, 1.0, 10000.0],
+                out=Tensor.placeholder((b, s, h_q * d), dtype="float16"),
+            )
+            return tensor_expr_op_out
+
+    # fmt: off
+    @I.ir_module
+    class Expected:
+        @R.function
+        def _initialize_effect() -> R.Tuple(R.Object):
+            with R.dataflow():
+                _io: R.Object = R.null_value()
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
+                R.output(gv)
+            return gv
+
+        @R.function
+        def test(q: R.Tensor((1, 1, 16, 8), dtype="float32"), k: R.Tensor((64, 
16, 8), dtype="float32"), v: R.Tensor((64, 16, 8), dtype="float32"), _io: 
R.Object) -> R.Tuple(R.Tensor((1, 1, 128), dtype="float16"), R.Tuple(R.Object)):
+            R.func_attr({"num_input": 4})
+            with R.dataflow():
+                flashinfer_single_decode = 
R.call_dps_packed("flashinfer.single_decode", (q, k, v, R.prim_value(0), 
R.prim_value(0), R.prim_value(T.float64(1)), R.prim_value(T.float64(10000))), 
out_sinfo=R.Tensor((1, 1, 128), dtype="float16"))
+                gv1: R.Tuple(R.Tensor((1, 1, 128), dtype="float16"), 
R.Tuple(R.Object)) = flashinfer_single_decode, (_io,)
+                R.output(gv1)
+            return gv1
+    # fmt: on
+
+    batch, seq, t, d, h_q, h_kv = 1, 1, 64, 8, 16, 16
+    m = Model()
+    irmodule, _ = m.export_tvm(
+        spec={
+            "test": {
+                "q": spec.Tensor([batch, seq, h_q, d], "float32"),
+                "k": spec.Tensor([t, h_kv, d], "float32"),
+                "v": spec.Tensor([t, h_kv, d], "float32"),
+            }
+        },
+        debug=True,
+    )
+    tvm.ir.assert_structural_equal(irmodule, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to