This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 889d2f6cef [Unity][Frontend] NNModule `tensor_ir_op` support (#16278)
889d2f6cef is described below
commit 889d2f6cef5a0a533f48e763b626367d9c36ccff
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Dec 27 05:05:47 2023 +0800
[Unity][Frontend] NNModule `tensor_ir_op` support (#16278)
This PR adds support for `tensor_ir_op` in NNModule, which enables us to
call TensorIR function in NNModule.
Also this PR adds a test case for extern op.
---
python/tvm/relax/frontend/nn/_tensor_op.py | 12 +++
python/tvm/relax/frontend/nn/modules.py | 4 +-
python/tvm/relax/frontend/nn/op.py | 76 ++++++++++++++++-
tests/python/relax/test_frontend_nn_op.py | 130 +++++++++++++++++++++++++++++
4 files changed, 219 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/_tensor_op.py
b/python/tvm/relax/frontend/nn/_tensor_op.py
index a653c9fa29..627b8b626c 100644
--- a/python/tvm/relax/frontend/nn/_tensor_op.py
+++ b/python/tvm/relax/frontend/nn/_tensor_op.py
@@ -47,10 +47,22 @@ class _TensorOp:
other = _convert_scalar(other, self)
return _op().add(self, other)
+ def __sub__(self, other):
+ other = _convert_scalar(other, self)
+ return _op().subtract(self, other)
+
+ def __rsub__(self, other):
+ other = _convert_scalar(other, self)
+ return _op().subtract(other, self)
+
def __mul__(self, other):
other = _convert_scalar(other, self)
return _op().multiply(self, other)
+ def __rmul__(self, other):
+ other = _convert_scalar(other, self)
+ return _op().multiply(self, other)
+
def __truediv__(self, other):
other = _convert_scalar(other, self)
return _op().divide(self, other)
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
index b2c97a567a..03d6a06994 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -311,7 +311,7 @@ class ConvTranspose1D(Module):
def forward(self, x: Tensor) -> Tensor:
"""
- Forward method for convtranspose1d layer.
+ Forward method for conv transpose 1d layer.
Parameters
----------
@@ -321,7 +321,7 @@ class ConvTranspose1D(Module):
Returns
-------
ret : Tensor
- The output tensor for the convtranspose1d layer.
+ The output tensor for the conv transpose 1d layer.
"""
return op.conv1d_transpose(
x,
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 2369451ac9..3197145289 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1461,13 +1461,87 @@ def tensor_expr_op(
OutType = TypeVar("OutType", bound=Union[Tensor, Sequence[Tensor]])
+def tensor_ir_op(
+ func: _tir.PrimFunc,
+ name_hint: str,
+ args: Union[Tensor, Sequence[Union[Tensor, _tir.Var]]],
+ out: OutType,
+) -> OutType:
+ """Create a `call_tir` binding with given PrimFunc
+
+ Parameters
+ ----------
+ func : _tir.PrimFunc
+ The PrimFunc to call.
+
+ name_hint : str
+ Name hint.
+
+ args : Union[Tensor, Sequence[Union[Tensor, _tir.Var]]]
+ The arguments to pass to the PrimFunc.
+
+ out : Union[Tensor, List[Tensor]]
+ The output tensors.
+
+ Returns
+ -------
+ result : Tensor
+ The result tensor
+ """
+ from tvm import relax as rx # pylint: disable=import-outside-toplevel
+
+ call_tir_args, tir_vars = [], []
+ if not isinstance(args, (tuple, list)):
+ args = [args]
+
+ for arg in args:
+ if isinstance(arg, Tensor):
+ call_tir_args.append(arg._expr)
+ elif isinstance(arg, _tir.Var):
+ tir_vars.append(arg)
+ else:
+ raise TypeError(
+ f"Unsupported type: tensor_ir_op args expect Tensor or
tir.Var, but got {type(arg)}"
+ )
+
+ if isinstance(out, Tensor):
+ out_sinfo = [out._expr.struct_info]
+ else:
+ out_sinfo = [x._expr.struct_info for x in out]
+
+ bb = BlockBuilder.current()
+ global_var = bb.add_func(func, name_hint)
+
+ return wrap_nested(
+ bb.emit(rx.call_tir(global_var, call_tir_args, out_sinfo,
tir_vars=tir_vars)),
+ name=name_hint,
+ )
+
+
def extern(
name: str,
args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]],
out: OutType,
) -> OutType:
"""Invoke an extern function during runtime. The extern function must be
registered with the "
- TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func`
(Python)."""
+ TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func`
(Python).
+
+ Parameters
+ ----------
+ name : str
+ The name of the extern function to call.
+
+ args : Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]]
+ The arguments to pass to the extern function.
+
+ out : Union[Tensor, List[Tensor]]
+ The output tensors, only
+
+ Returns
+ -------
+ result : Tensor
+ The result
+ """
from tvm import relax as rx # pylint: disable=import-outside-toplevel
def _convert(arg, name: str):
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index ddaec7234b..55870426e4 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# pylint: disable=missing-docstring, invalid-name
import tvm
import tvm.testing
from tvm import tir
@@ -508,5 +509,134 @@ def test_tensor_expr_op():
tvm.ir.assert_structural_equal(irmodule, Expected)
+def test_tensor_ir_op():
+ num_q_heads, num_kv_heads, head_dim = 8, 8, 16
+ fused_heads = num_q_heads + num_kv_heads * 2
+ dtype = "float16"
+
+ @T.prim_func(private=True)
+ def fused_rope( # pylint: disable=too-many-locals
+ var_qkv: T.handle,
+ offset: T.int64,
+ var_q: T.handle,
+ var_k: T.handle,
+ var_v: T.handle,
+ ):
+ batch_size = T.int64()
+ seq_len = T.int64()
+ qkv = T.match_buffer(var_qkv, (batch_size, seq_len, fused_heads,
head_dim), dtype)
+ q = T.match_buffer(var_q, (batch_size, seq_len, num_q_heads,
head_dim), dtype)
+ k = T.match_buffer(var_k, (batch_size, seq_len, num_kv_heads,
head_dim), dtype)
+ v = T.match_buffer(var_v, (batch_size, seq_len, num_kv_heads,
head_dim), dtype)
+ T.evaluate(offset)
+
+ class Model(Module):
+ def test(self, qkv: Tensor, offset: tir.Var):
+ tensor_expr_op_out = op.tensor_ir_op(
+ fused_rope,
+ "llama_fused_rope",
+ args=[qkv, offset],
+ out=[
+ Tensor.placeholder((1, 1, num_q_heads, head_dim), dtype),
+ Tensor.placeholder((1, 1, num_kv_heads, head_dim), dtype),
+ Tensor.placeholder((1, 1, num_kv_heads, head_dim), dtype),
+ ],
+ )
+ return tensor_expr_op_out
+
+ # fmt: off
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q:
T.handle, var_k: T.handle, var_v: T.handle):
+ batch_size, seq_len = T.int64(), T.int64()
+ qkv = T.match_buffer(var_qkv, (batch_size, seq_len, 24, 16),
"float16")
+ q = T.match_buffer(var_q, (batch_size, seq_len, 8, 16), "float16")
+ k = T.match_buffer(var_k, (batch_size, seq_len, 8, 16), "float16")
+ v = T.match_buffer(var_v, (batch_size, seq_len, 8, 16), "float16")
+ T.evaluate(offset)
+
+ @R.function
+ def _initialize_effect() -> R.Tuple(R.Object):
+ with R.dataflow():
+ _io: R.Object = R.null_value()
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
+ R.output(gv)
+ return gv
+
+ @R.function
+ def test(qkv: R.Tensor((1, 1, 24, 16), dtype="float16"), offset:
R.Shape(["offset_1"]), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((1, 1, 8,
16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1,
1, 8, 16), dtype="float16")), R.Tuple(R.Object)):
+ offset_1 = T.int64()
+ R.func_attr({"num_input": 3})
+ cls = Expected
+ with R.dataflow():
+ lv1 = R.call_tir(cls.llama_fused_rope, (qkv,),
out_sinfo=[R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16),
dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")],
tir_vars=R.shape([offset_1]))
+ llama_fused_rope_0: R.Tensor((1, 1, 8, 16), dtype="float16") =
lv1[0]
+ llama_fused_rope_1: R.Tensor((1, 1, 8, 16), dtype="float16") =
lv1[1]
+ llama_fused_rope_2: R.Tensor((1, 1, 8, 16), dtype="float16") =
lv1[2]
+ gv1: R.Tuple(R.Tuple(R.Tensor((1, 1, 8, 16), dtype="float16"),
R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16),
dtype="float16")), R.Tuple(R.Object)) = (llama_fused_rope_0,
llama_fused_rope_1, llama_fused_rope_2), (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, _ = m.export_tvm(
+ spec={
+ "test": {"qkv": spec.Tensor([1, 1, fused_heads, head_dim],
"float16"), "offset": int}
+ },
+ debug=True,
+ )
+ tvm.ir.assert_structural_equal(irmodule, Expected)
+
+
+def test_extern():
+ class Model(Module):
+ def test(self, q: Tensor, k: Tensor, v: Tensor):
+ b, s, h_q, d = q.shape
+ tensor_expr_op_out = op.extern(
+ name="flashinfer.single_decode",
+ args=[q, k, v, 0, 0, 1.0, 10000.0],
+ out=Tensor.placeholder((b, s, h_q * d), dtype="float16"),
+ )
+ return tensor_expr_op_out
+
+ # fmt: off
+ @I.ir_module
+ class Expected:
+ @R.function
+ def _initialize_effect() -> R.Tuple(R.Object):
+ with R.dataflow():
+ _io: R.Object = R.null_value()
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
+ R.output(gv)
+ return gv
+
+ @R.function
+ def test(q: R.Tensor((1, 1, 16, 8), dtype="float32"), k: R.Tensor((64,
16, 8), dtype="float32"), v: R.Tensor((64, 16, 8), dtype="float32"), _io:
R.Object) -> R.Tuple(R.Tensor((1, 1, 128), dtype="float16"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 4})
+ with R.dataflow():
+ flashinfer_single_decode =
R.call_dps_packed("flashinfer.single_decode", (q, k, v, R.prim_value(0),
R.prim_value(0), R.prim_value(T.float64(1)), R.prim_value(T.float64(10000))),
out_sinfo=R.Tensor((1, 1, 128), dtype="float16"))
+ gv1: R.Tuple(R.Tensor((1, 1, 128), dtype="float16"),
R.Tuple(R.Object)) = flashinfer_single_decode, (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ batch, seq, t, d, h_q, h_kv = 1, 1, 64, 8, 16, 16
+ m = Model()
+ irmodule, _ = m.export_tvm(
+ spec={
+ "test": {
+ "q": spec.Tensor([batch, seq, h_q, d], "float32"),
+ "k": spec.Tensor([t, h_kv, d], "float32"),
+ "v": spec.Tensor([t, h_kv, d], "float32"),
+ }
+ },
+ debug=True,
+ )
+ tvm.ir.assert_structural_equal(irmodule, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()