This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new eacc2cb [TIR] Bugfix for zero number arguments tir functions. (#8515)
eacc2cb is described below
commit eacc2cb0d3f6e7320ccf6537022ecb99e511e745
Author: ziheng <[email protected]>
AuthorDate: Wed Jul 21 09:20:59 2021 -0700
[TIR] Bugfix for zero number arguments tir functions. (#8515)
* [TIR] Bugfix for zero number arguments tir functions.
Co-authored-by: Junru Shao <[email protected]>
---
python/tvm/tir/transform/transform.py | 5 +++--
src/driver/driver_api.cc | 2 +-
src/tir/transforms/make_packed_api.cc | 15 ++++++++++-----
tests/python/unittest/test_tir_base.py | 11 +++++++++++
4 files changed, 25 insertions(+), 8 deletions(-)
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index 437a06c..1e5c303 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -331,14 +331,15 @@ def LowerCustomDatatypes():
return _ffi_api.LowerCustomDatatypes() # type: ignore
-def MakePackedAPI(num_unpacked_params: int = 0):
+def MakePackedAPI(num_unpacked_params: int = -1):
"""Transform the PrimFuncs in the module to a packed func API.
Parameters
----------
num_unpacked_params : int
Number of parameters that we hope to directly pass via normal arguments
- following the PackedFunc input signature.
+ following the PackedFunc input signature. If it is specified as -1 or
it
+ is less than the number of arguments, the pass will packed arguments
still.
Returns
-------
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index b043404..1591e87 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -388,7 +388,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule
mod_mixed, const Target
if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
} else {
- mixed_pass_list.push_back(tir::transform::MakePackedAPI(0));
+ mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
}
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
diff --git a/src/tir/transforms/make_packed_api.cc
b/src/tir/transforms/make_packed_api.cc
index ee52a6f..393ce6c 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -119,7 +119,12 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int
num_unpacked_args) {
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
ICHECK_LE(num_unpacked_args, num_args);
-
+ bool pack_args = (num_unpacked_args == -1) || (num_args > num_unpacked_args);
+ if (num_unpacked_args == -1) {
+ // reset to zero
+ num_unpacked_args = 0;
+ }
+ ICHECK_GE(num_unpacked_args, 0);
int num_packed_args = num_args - num_unpacked_args;
// Data field definitions
// The packed fields
@@ -154,11 +159,10 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int
num_unpacked_args) {
}
return res;
};
-
// ---------------------------
// start of logics
// add signiture for packed arguments.
- if (num_packed_args != 0) {
+ if (pack_args) {
args.push_back(v_packed_args);
args.push_back(v_packed_arg_type_ids);
args.push_back(v_num_packed_args);
@@ -214,13 +218,13 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int
num_unpacked_args) {
}
// allow return value if the function is packed.
- if (num_packed_args != 0) {
+ if (pack_args) {
args.push_back(v_out_ret_value);
args.push_back(v_out_ret_tcode);
args.push_back(v_resource_handle);
}
- size_t expected_nargs = num_unpacked_args + (num_packed_args != 0 ? 6 : 0);
+ size_t expected_nargs = num_unpacked_args + (pack_args ? 6 : 0);
ICHECK_EQ(args.size(), expected_nargs);
// Arg definitions are defined before buffer binding to avoid the use before
@@ -282,6 +286,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int
num_unpacked_args) {
namespace transform {
Pass MakePackedAPI(int num_unpacked_args) {
+ // packed arguments anyway while `num_unpacked_args` is -1
auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc> > updates;
diff --git a/tests/python/unittest/test_tir_base.py
b/tests/python/unittest/test_tir_base.py
index 6e081a1..06eb221 100644
--- a/tests/python/unittest/test_tir_base.py
+++ b/tests/python/unittest/test_tir_base.py
@@ -41,6 +41,16 @@ def test_scalar_add():
assert out == 3.0
+def test_ret_const():
+ a = tir.const(0)
+ b = tir.ret(a)
+ b = tir.Evaluate(b)
+ func = tir.PrimFunc([], b)
+ func = build_tir_func(func)
+ out = func()
+ assert out == 0
+
+
def test_control_flow_jump():
ib = tvm.tir.ir_builder.create()
a = tir.Var("a", "float32")
@@ -57,4 +67,5 @@ def test_control_flow_jump():
if __name__ == "__main__":
test_scalar_add()
+ test_ret_const()
test_control_flow_jump()