This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new eacc2cb  [TIR] Bugfix for zero number arguments tir functions. (#8515)
eacc2cb is described below

commit eacc2cb0d3f6e7320ccf6537022ecb99e511e745
Author: ziheng <[email protected]>
AuthorDate: Wed Jul 21 09:20:59 2021 -0700

    [TIR] Bugfix for zero number arguments tir functions. (#8515)
    
    * [TIR] Bugfix for zero number arguments tir functions.
    
    
    Co-authored-by: Junru Shao <[email protected]>
---
 python/tvm/tir/transform/transform.py  |  5 +++--
 src/driver/driver_api.cc               |  2 +-
 src/tir/transforms/make_packed_api.cc  | 15 ++++++++++-----
 tests/python/unittest/test_tir_base.py | 11 +++++++++++
 4 files changed, 25 insertions(+), 8 deletions(-)

diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 437a06c..1e5c303 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -331,14 +331,15 @@ def LowerCustomDatatypes():
     return _ffi_api.LowerCustomDatatypes()  # type: ignore
 
 
-def MakePackedAPI(num_unpacked_params: int = 0):
+def MakePackedAPI(num_unpacked_params: int = -1):
     """Transform the PrimFuncs in the module to a packed func API.
 
     Parameters
     ----------
     num_unpacked_params : int
         Number of parameters that we hope to directly pass via normal arguments
-        following the PackedFunc input signature.
+        following the PackedFunc input signature. If it is specified as -1 or 
it
+        is less than the number of arguments, the pass will packed arguments 
still.
 
     Returns
     -------
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index b043404..1591e87 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -388,7 +388,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule 
mod_mixed, const Target
   if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
     mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
   } else {
-    mixed_pass_list.push_back(tir::transform::MakePackedAPI(0));
+    mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
   }
 
   mixed_pass_list.push_back(tir::transform::SplitHostDevice());
diff --git a/src/tir/transforms/make_packed_api.cc 
b/src/tir/transforms/make_packed_api.cc
index ee52a6f..393ce6c 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -119,7 +119,12 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int 
num_unpacked_args) {
   const Stmt nop = Evaluate(0);
   int num_args = static_cast<int>(func_ptr->params.size());
   ICHECK_LE(num_unpacked_args, num_args);
-
+  bool pack_args = (num_unpacked_args == -1) || (num_args > num_unpacked_args);
+  if (num_unpacked_args == -1) {
+    // reset to zero
+    num_unpacked_args = 0;
+  }
+  ICHECK_GE(num_unpacked_args, 0);
   int num_packed_args = num_args - num_unpacked_args;
   // Data field definitions
   // The packed fields
@@ -154,11 +159,10 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int 
num_unpacked_args) {
     }
     return res;
   };
-
   // ---------------------------
   // start of logics
   // add signiture for packed arguments.
-  if (num_packed_args != 0) {
+  if (pack_args) {
     args.push_back(v_packed_args);
     args.push_back(v_packed_arg_type_ids);
     args.push_back(v_num_packed_args);
@@ -214,13 +218,13 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int 
num_unpacked_args) {
   }
 
   // allow return value if the function is packed.
-  if (num_packed_args != 0) {
+  if (pack_args) {
     args.push_back(v_out_ret_value);
     args.push_back(v_out_ret_tcode);
     args.push_back(v_resource_handle);
   }
 
-  size_t expected_nargs = num_unpacked_args + (num_packed_args != 0 ? 6 : 0);
+  size_t expected_nargs = num_unpacked_args + (pack_args ? 6 : 0);
   ICHECK_EQ(args.size(), expected_nargs);
 
   // Arg definitions are defined before buffer binding to avoid the use before
@@ -282,6 +286,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int 
num_unpacked_args) {
 namespace transform {
 
 Pass MakePackedAPI(int num_unpacked_args) {
+  // packed arguments anyway while `num_unpacked_args` is -1
   auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) {
     IRModuleNode* mptr = m.CopyOnWrite();
     std::vector<std::pair<GlobalVar, PrimFunc> > updates;
diff --git a/tests/python/unittest/test_tir_base.py 
b/tests/python/unittest/test_tir_base.py
index 6e081a1..06eb221 100644
--- a/tests/python/unittest/test_tir_base.py
+++ b/tests/python/unittest/test_tir_base.py
@@ -41,6 +41,16 @@ def test_scalar_add():
     assert out == 3.0
 
 
+def test_ret_const():
+    a = tir.const(0)
+    b = tir.ret(a)
+    b = tir.Evaluate(b)
+    func = tir.PrimFunc([], b)
+    func = build_tir_func(func)
+    out = func()
+    assert out == 0
+
+
 def test_control_flow_jump():
     ib = tvm.tir.ir_builder.create()
     a = tir.Var("a", "float32")
@@ -57,4 +67,5 @@ def test_control_flow_jump():
 
 if __name__ == "__main__":
     test_scalar_add()
+    test_ret_const()
     test_control_flow_jump()

Reply via email to