This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 4a7e4fec37 [Unity] Fix nn.op.tensor_ir_op signature (#16333)
4a7e4fec37 is described below

commit 4a7e4fec376d9dbc3127dde19def2eb356228d8e
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed Jan 3 02:00:18 2024 -0500

    [Unity] Fix nn.op.tensor_ir_op signature (#16333)
    
    * fix nn tensorir
    
    * fix lint
---
 python/tvm/relax/frontend/nn/op.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 66f023ef9d..1d3454fc88 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1523,7 +1523,7 @@ OutType = TypeVar("OutType", bound=Union[Tensor, 
Sequence[Tensor]])
 def tensor_ir_op(
     func: _tir.PrimFunc,
     name_hint: str,
-    args: Union[Tensor, Sequence[Union[Tensor, _tir.Var]]],
+    args: Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]],
     out: OutType,
 ) -> OutType:
     """Create a `call_tir` binding with given PrimFunc
@@ -1536,7 +1536,7 @@ def tensor_ir_op(
     name_hint : str
         Name hint.
 
-    args : Union[Tensor, Sequence[Union[Tensor, _tir.Var]]]
+    args : Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]]
         The arguments to pass to the PrimFunc.
 
     out : Union[Tensor, List[Tensor]]
@@ -1556,11 +1556,12 @@ def tensor_ir_op(
     for arg in args:
         if isinstance(arg, Tensor):
             call_tir_args.append(arg._expr)
-        elif isinstance(arg, _tir.Var):
+        elif isinstance(arg, (rx.ShapeExpr, _tir.PrimExpr)):
             tir_vars.append(arg)
         else:
             raise TypeError(
-                f"Unsupported type: tensor_ir_op args expect Tensor or 
tir.Var, but got {type(arg)}"
+                "Unsupported type: tensor_ir_op args expect Tensor or 
ShapeExpr or PrimExpr,"
+                f"but got {type(arg)}"
             )
 
     if isinstance(out, Tensor):

Reply via email to