This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 94c1b89abc [TVMScript][TIR] Parse subroutine calls with no arguments 
(#14919)
94c1b89abc is described below

commit 94c1b89abc1f561627ab5e2a152e5ee4c949c580
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun May 28 15:39:23 2023 -0500

    [TVMScript][TIR] Parse subroutine calls with no arguments (#14919)
    
    * [TVMScript][TIR] Parse subroutine calls with no arguments
    
    In most cases, the IR dialect in `GlobalVar.__call__` can be inferred
    from the argument types.  If there are no arguments, then the returned
    value is ambiguous.  This commit updates the TIR parser to identify
    and fix this case of erroneously producing a `relay.Call` instead of
    `tir.Call`.
    
    In addition, to prevent this from re-occuring, an unrecognized type
    resulting from `def visit_expr_stmt` now results in an error, rather
    than being silently ignored.
    
    * Ignore str for unknown parser result
    
    These may are used as docstrings in the TVMScript, even though they
    are not represented in the TIR.
    
    * Lint fixes
---
 python/tvm/script/parser/tir/parser.py            | 14 +++++++++++++-
 tests/python/unittest/test_tvmscript_roundtrip.py | 19 +++++++++++++++++++
 2 files changed, 32 insertions(+), 1 deletion(-)

diff --git a/python/tvm/script/parser/tir/parser.py 
b/python/tvm/script/parser/tir/parser.py
index dfecaacdf6..7d81fecedb 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -406,13 +406,25 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
         The doc AST Expr node.
     """
     res = self.eval_expr(node.value)
-    if isinstance(res, Frame):
+    if res is None:
+        pass
+    elif isinstance(res, Frame):
         res.add_callback(partial(res.__exit__, None, None, None))
         res.__enter__()
     elif isinstance(res, PrimExpr):
         T.evaluate(res)
     elif isinstance(res, (int, bool)):
         T.evaluate(tvm.tir.const(res))
+    elif isinstance(res, tvm.relay.Call) and not res.args:
+        # Using GlobalVar.__call__ with no arguments is ambiguous, as
+        # each IR has a different function Call representation.  If
+        # this occurs, convert to the TIR representation.
+        T.evaluate(tvm.tir.call_tir(res.op))
+    elif isinstance(res, str):
+        # Ignore docstrings
+        pass
+    else:
+        self.report_error(node, f"Parsing resulted in unexpected type 
{type(res)}")
 
 
 @dispatch.register(token="tir", type_name="If")
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index e3ec311cc0..58be4e14d0 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3864,6 +3864,24 @@ def undefined_elem_offset_in_decl_buffer():
     return func
 
 
+def subroutine_call_without_arguments():
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def main():
+            # Should be equivalent to the bare "mod.subroutine()", but
+            # that relies on `GlobalVar.__call__` returning the
+            # correct IR type.  Previously, this instead returned a
+            # `relay.Call` object.
+            tir.call_tir(mod.subroutine)
+
+        @T.prim_func
+        def subroutine():
+            T.evaluate(0)
+
+    return mod
+
+
 ir_generator = tvm.testing.parameter(
     launch_env_thread,
     opt_gemm_normalize,
@@ -3939,6 +3957,7 @@ ir_generator = tvm.testing.parameter(
     undefined_shape_in_decl_buffer,
     undefined_stride_in_decl_buffer,
     undefined_elem_offset_in_decl_buffer,
+    subroutine_call_without_arguments,
 )
 
 

Reply via email to