This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 4a7e4fec37 [Unity] Fix nn.op.tensor_ir_op signature (#16333)
4a7e4fec37 is described below
commit 4a7e4fec376d9dbc3127dde19def2eb356228d8e
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed Jan 3 02:00:18 2024 -0500
[Unity] Fix nn.op.tensor_ir_op signature (#16333)
* fix nn tensorir
* fix lint
---
python/tvm/relax/frontend/nn/op.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 66f023ef9d..1d3454fc88 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1523,7 +1523,7 @@ OutType = TypeVar("OutType", bound=Union[Tensor,
Sequence[Tensor]])
def tensor_ir_op(
func: _tir.PrimFunc,
name_hint: str,
- args: Union[Tensor, Sequence[Union[Tensor, _tir.Var]]],
+ args: Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]],
out: OutType,
) -> OutType:
"""Create a `call_tir` binding with given PrimFunc
@@ -1536,7 +1536,7 @@ def tensor_ir_op(
name_hint : str
Name hint.
- args : Union[Tensor, Sequence[Union[Tensor, _tir.Var]]]
+ args : Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]]
The arguments to pass to the PrimFunc.
out : Union[Tensor, List[Tensor]]
@@ -1556,11 +1556,12 @@ def tensor_ir_op(
for arg in args:
if isinstance(arg, Tensor):
call_tir_args.append(arg._expr)
- elif isinstance(arg, _tir.Var):
+ elif isinstance(arg, (rx.ShapeExpr, _tir.PrimExpr)):
tir_vars.append(arg)
else:
raise TypeError(
- f"Unsupported type: tensor_ir_op args expect Tensor or
tir.Var, but got {type(arg)}"
+ "Unsupported type: tensor_ir_op args expect Tensor or
ShapeExpr or PrimExpr,"
+ f"but got {type(arg)}"
)
if isinstance(out, Tensor):