This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a86e41bcd2 [Unity][TVMScript] Update `call_packed` semantics to
support empty sinfo_args (#16379)
a86e41bcd2 is described below
commit a86e41bcd27085514d911f1e03d8d8a1db1eef24
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Jan 24 23:13:26 2024 +0800
[Unity][TVMScript] Update `call_packed` semantics to support empty
sinfo_args (#16379)
In low-level Relax (after pass `CallTIRewrite`), the `call_packed` nodes
do not always have explicit `sinfo_args`. This PR extents the parser to
support this case.
---
python/tvm/script/ir_builder/relax/ir.py | 6 +++---
tests/python/relax/test_tvmscript_parser.py | 22 ++++++++++++++++++++++
2 files changed, 25 insertions(+), 3 deletions(-)
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 9105fce00f..6447178909 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -330,7 +330,7 @@ def output(*vars: Tuple[Var]) -> None:
def call_packed(
func: py_str,
*args: Expr,
- sinfo_args: Union[StructInfo, List[StructInfo]],
+ sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] = None,
**kwargs: Any,
) -> Call:
"""Create a relax Call, which calls a packed function.
@@ -340,7 +340,7 @@ def call_packed(
The name of extern function.
*args : Expr
The arguments.
- sinfo_args: Union[StructInfo, List[StructInfo]]
+ sinfo_args: Optional[Union[StructInfo, List[StructInfo]]]
The list of structure info arguments.
kwargs: Expr
The keyword arguments.
@@ -352,7 +352,7 @@ def call_packed(
"""
op = ExternFunc(func)
if sinfo_args is None:
- raise ValueError("R.call_packed is required to have type_args")
+ sinfo_args = []
if isinstance(sinfo_args, py_tuple): # type: ignore
sinfo_args = list(sinfo_args)
elif not isinstance(sinfo_args, list):
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index 3ef75b4b49..71970ad965 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -842,6 +842,28 @@ def test_call_packed():
_check(foo, bb.get()["foo"])
+def test_call_packed_without_sinfo_args():
+ @R.function
+ def foo(x: R.Object) -> R.Object:
+ z = R.call_packed("test", x)
+ return z
+
+ x = relax.Var("x", R.Object())
+ bb = relax.BlockBuilder()
+ with bb.function("foo", (x)):
+ z = bb.emit(
+ relax.Call(
+ relax.ExternFunc("test"),
+ (x,),
+ None,
+ sinfo_args=[],
+ )
+ )
+ bb.emit_func_output(z)
+
+ _check(foo, bb.get()["foo"])
+
+
def test_annotation():
@R.function
def foo(