junrushao commented on code in PR #16278:
URL: https://github.com/apache/tvm/pull/16278#discussion_r1436523719
##########
python/tvm/relax/frontend/nn/op.py:
##########
@@ -1461,13 +1461,87 @@ def _convert(arg):
OutType = TypeVar("OutType", bound=Union[Tensor, Sequence[Tensor]])
+def tensor_ir_op(
+ func: _tir.PrimFunc,
+ name_hint: str,
Review Comment:
There’s a bit of complication here: if the PrimFunc provided is a public
function (has “global_symbol” field in its attrs), Relax is not allowed to
rename it, and in this case, it’s not a name hint but a name instead.
Therefore, we will have to check symbol duplication and potentially throw an
error if it happens.
##########
python/tvm/relax/frontend/nn/op.py:
##########
@@ -1461,13 +1461,87 @@ def _convert(arg):
OutType = TypeVar("OutType", bound=Union[Tensor, Sequence[Tensor]])
+def tensor_ir_op(
+ func: _tir.PrimFunc,
+ name_hint: str,
Review Comment:
We could probably leave this logic to future work, but let’s rename
name_hint to name to better reflect this point
##########
tests/python/relax/test_frontend_nn_op.py:
##########
@@ -508,5 +509,134 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io:
R.Object) -> R.Tuple(R.Ten
tvm.ir.assert_structural_equal(irmodule, Expected)
+def test_tensor_ir_op():
+ num_q_heads, num_kv_heads, head_dim = 8, 8, 16
Review Comment:
This unittest is a bit more complicated than I expected :)) in the simplest
case, we could probably just supply a “B = A + 1”-style TIR
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]