This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3a33771494 [TVMScript] Handle parsing of PrimFunc calls with non-void 
return (#15239)
3a33771494 is described below

commit 3a337714947a03be54c26b083e6a274c411c3815
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jul 6 20:15:57 2023 -0500

    [TVMScript] Handle parsing of PrimFunc calls with non-void return (#15239)
    
    * [TVMScript] Handle parsing of PrimFunc calls with non-void return
    
    Prior to this commit, the return type of all internal function calls
    was hard-coded as `"void"`.  After this commit, the `GlobalVar`
    representing the internal function has type annotation based on the
    callee's signature, which is then used as the return type of the
    internal call.
    
    * Update CallNode return type in MakeUnpackedAPI
---
 python/tvm/tir/op.py                              |  9 ++++++++-
 src/script/ir_builder/ir/ir.cc                    | 14 +++++++++++++-
 src/tir/transforms/make_unpacked_api.cc           | 11 +++++++----
 tests/python/unittest/test_tvmscript_roundtrip.py | 17 +++++++++++++++++
 4 files changed, 45 insertions(+), 6 deletions(-)

diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 32c98efa69..cdbdb4b542 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -445,7 +445,14 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args):
         The call expression.
     """
     assert isinstance(global_var, tvm.ir.GlobalVar)
-    return Call(dtype="void", op=global_var, args=args)
+
+    dtype = "void"
+    if global_var.checked_type is not None:
+        ret_type = global_var.checked_type.ret_type
+        if hasattr(ret_type, "dtype"):
+            dtype = ret_type.dtype
+
+    return Call(dtype=dtype, op=global_var, args=args)
 
 
 def start_profile_intrinsic(id):
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index 0c34f85246..02fb899f0d 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -19,6 +19,8 @@
 #include <tvm/ir/module.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/script/ir_builder/ir/ir.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
 
 #include "./utils.h"
 
@@ -38,7 +40,17 @@ GlobalVar DeclFunction(const String& func_name, const 
BaseFunc& func_signature)
   IRModuleFrame frame = FindModuleFrame("I.DeclFunction");
   CHECK(!frame->global_var_map.count(func_name))
       << "ValueError: function " << func_name << " already exists";
-  GlobalVar gv = GlobalVar(func_name);
+
+  auto gvar_type = [&]() -> Type {
+    if (auto prim_func = func_signature.as<tir::PrimFuncNode>()) {
+      Array<Type> arg_types = prim_func->params.Map([](const auto& var) { 
return GetType(var); });
+      return FuncType(arg_types, prim_func->ret_type, {}, {});
+    }
+
+    return {};
+  }();
+
+  GlobalVar gv = GlobalVar(func_name, gvar_type);
   CHECK(frame->functions.find(gv) == frame->functions.end())
       << "ValueError: function " << func_name << " has already been defined.";
   frame->global_var_map.Set(func_name, gv);
diff --git a/src/tir/transforms/make_unpacked_api.cc 
b/src/tir/transforms/make_unpacked_api.cc
index 2646b5baea..0cb072701c 100644
--- a/src/tir/transforms/make_unpacked_api.cc
+++ b/src/tir/transforms/make_unpacked_api.cc
@@ -64,18 +64,21 @@ class SubroutineCallRewriter : public StmtExprMutator {
 
     if (auto gvar = node->op.as<GlobalVarNode>()) {
       if (external_methods_.count(gvar)) {
-        Array<PrimExpr> args = node->args.Map([this](const PrimExpr& arg) -> 
PrimExpr {
+        Array<PrimExpr> args = node->args.Map([](const PrimExpr& arg) -> 
PrimExpr {
           if (auto* as_call = arg.as<CallNode>()) {
             if (as_call->op.same_as(builtin::tvm_stack_make_array())) {
               PrimExpr data_ptr = as_call->args[0];
-              made_change_ = true;
               return data_ptr;
             }
           }
           return arg;
         });
-        if (!args.same_as(node->args)) {
-          node.CopyOnWrite()->args = args;
+
+        if (!args.same_as(node->args) || node->dtype != DataType::Int(32)) {
+          auto write_ptr = node.CopyOnWrite();
+          write_ptr->dtype = DataType::Int(32);
+          write_ptr->args = args;
+          made_change_ = true;
         }
       }
     }
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index d36641dfc2..90d2599b58 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3817,6 +3817,22 @@ def subroutine_call():
     return mod
 
 
+def subroutine_call_returning_int():
+    """An internal function call may return non-void"""
+
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def main(A: T.Buffer(2, "float32")):
+            mod.subroutine(A[0]) + mod.subroutine(A[1])
+
+        @T.prim_func
+        def subroutine(x: T.float32) -> T.float32:
+            T.ret(x * x)
+
+    return mod
+
+
 def undefined_data_ptr_in_decl_buffer():
     """The T.decl_buffer syntax should not introduce an Allocate
 
@@ -4009,6 +4025,7 @@ ir_generator = tvm.testing.parameter(
     ir_module_with_attrs,
     nested_seqstmt,
     subroutine_call,
+    subroutine_call_returning_int,
     undefined_data_ptr_in_decl_buffer,
     undefined_shape_in_decl_buffer,
     undefined_stride_in_decl_buffer,

Reply via email to