This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3a33771494 [TVMScript] Handle parsing of PrimFunc calls with non-void
return (#15239)
3a33771494 is described below
commit 3a337714947a03be54c26b083e6a274c411c3815
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jul 6 20:15:57 2023 -0500
[TVMScript] Handle parsing of PrimFunc calls with non-void return (#15239)
* [TVMScript] Handle parsing of PrimFunc calls with non-void return
Prior to this commit, the return type of all internal function calls
was hard-coded as `"void"`. After this commit, the `GlobalVar`
representing the internal function has type annotation based on the
callee's signature, which is then used as the return type of the
internal call.
* Update CallNode return type in MakeUnpackedAPI
---
python/tvm/tir/op.py | 9 ++++++++-
src/script/ir_builder/ir/ir.cc | 14 +++++++++++++-
src/tir/transforms/make_unpacked_api.cc | 11 +++++++----
tests/python/unittest/test_tvmscript_roundtrip.py | 17 +++++++++++++++++
4 files changed, 45 insertions(+), 6 deletions(-)
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 32c98efa69..cdbdb4b542 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -445,7 +445,14 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args):
The call expression.
"""
assert isinstance(global_var, tvm.ir.GlobalVar)
- return Call(dtype="void", op=global_var, args=args)
+
+ dtype = "void"
+ if global_var.checked_type is not None:
+ ret_type = global_var.checked_type.ret_type
+ if hasattr(ret_type, "dtype"):
+ dtype = ret_type.dtype
+
+ return Call(dtype=dtype, op=global_var, args=args)
def start_profile_intrinsic(id):
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index 0c34f85246..02fb899f0d 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -19,6 +19,8 @@
#include <tvm/ir/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/ir_builder/ir/ir.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
#include "./utils.h"
@@ -38,7 +40,17 @@ GlobalVar DeclFunction(const String& func_name, const
BaseFunc& func_signature)
IRModuleFrame frame = FindModuleFrame("I.DeclFunction");
CHECK(!frame->global_var_map.count(func_name))
<< "ValueError: function " << func_name << " already exists";
- GlobalVar gv = GlobalVar(func_name);
+
+ auto gvar_type = [&]() -> Type {
+ if (auto prim_func = func_signature.as<tir::PrimFuncNode>()) {
+ Array<Type> arg_types = prim_func->params.Map([](const auto& var) {
return GetType(var); });
+ return FuncType(arg_types, prim_func->ret_type, {}, {});
+ }
+
+ return {};
+ }();
+
+ GlobalVar gv = GlobalVar(func_name, gvar_type);
CHECK(frame->functions.find(gv) == frame->functions.end())
<< "ValueError: function " << func_name << " has already been defined.";
frame->global_var_map.Set(func_name, gv);
diff --git a/src/tir/transforms/make_unpacked_api.cc
b/src/tir/transforms/make_unpacked_api.cc
index 2646b5baea..0cb072701c 100644
--- a/src/tir/transforms/make_unpacked_api.cc
+++ b/src/tir/transforms/make_unpacked_api.cc
@@ -64,18 +64,21 @@ class SubroutineCallRewriter : public StmtExprMutator {
if (auto gvar = node->op.as<GlobalVarNode>()) {
if (external_methods_.count(gvar)) {
- Array<PrimExpr> args = node->args.Map([this](const PrimExpr& arg) ->
PrimExpr {
+ Array<PrimExpr> args = node->args.Map([](const PrimExpr& arg) ->
PrimExpr {
if (auto* as_call = arg.as<CallNode>()) {
if (as_call->op.same_as(builtin::tvm_stack_make_array())) {
PrimExpr data_ptr = as_call->args[0];
- made_change_ = true;
return data_ptr;
}
}
return arg;
});
- if (!args.same_as(node->args)) {
- node.CopyOnWrite()->args = args;
+
+ if (!args.same_as(node->args) || node->dtype != DataType::Int(32)) {
+ auto write_ptr = node.CopyOnWrite();
+ write_ptr->dtype = DataType::Int(32);
+ write_ptr->args = args;
+ made_change_ = true;
}
}
}
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index d36641dfc2..90d2599b58 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3817,6 +3817,22 @@ def subroutine_call():
return mod
+def subroutine_call_returning_int():
+ """An internal function call may return non-void"""
+
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(2, "float32")):
+ mod.subroutine(A[0]) + mod.subroutine(A[1])
+
+ @T.prim_func
+ def subroutine(x: T.float32) -> T.float32:
+ T.ret(x * x)
+
+ return mod
+
+
def undefined_data_ptr_in_decl_buffer():
"""The T.decl_buffer syntax should not introduce an Allocate
@@ -4009,6 +4025,7 @@ ir_generator = tvm.testing.parameter(
ir_module_with_attrs,
nested_seqstmt,
subroutine_call,
+ subroutine_call_returning_int,
undefined_data_ptr_in_decl_buffer,
undefined_shape_in_decl_buffer,
undefined_stride_in_decl_buffer,