This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 18911705717210495e8a16914cbd9ea1d8d848a0
Author: tqchen <[email protected]>
AuthorDate: Mon Apr 21 11:32:58 2025 -0400

    Fix packed api testcases
---
 src/tir/transforms/make_packed_api.cc              |   3 +-
 .../test_tir_transform_make_packed_api.py          | 156 ++++-----------------
 2 files changed, 26 insertions(+), 133 deletions(-)

diff --git a/src/tir/transforms/make_packed_api.cc 
b/src/tir/transforms/make_packed_api.cc
index 3d0ace67f7..13cff5276b 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -263,8 +263,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
   }()));
 
   if (num_args > 0) {
-    seq_init.push_back(
-        MakeAssertNotNull(v_packed_args, name_hint + ": TVMValue* arg pointer 
was NULL"));
+    seq_init.push_back(MakeAssertNotNull(v_packed_args, name_hint + ": args 
pointer is NULL"));
   }
 
   // Need to delay binding of the buffers, in case some arguments also
diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py 
b/tests/python/tir-transform/test_tir_transform_make_packed_api.py
index c676330062..cbd5f0b3e5 100644
--- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py
+++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py
@@ -60,92 +60,6 @@ def _find_compute_scope(func):
 
     return result
 
-
-def test_variable_passed_from_args():
-    ib = tvm.tir.ir_builder.create()
-
-    input_buffer = tvm.tir.decl_buffer(name="input_buffer", shape=[1])
-    not_device_context = tvm.tir.Var("not_device_context", dtype="handle")
-
-    ib.emit(
-        tvm.tir.call_extern("float32", "some_external_call", 
input_buffer.data, not_device_context),
-    )
-    stmt = ib.get()
-
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, 
not_device_context], stmt))
-    mod = tvm.tir.transform.Apply(
-        lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
-    )(mod)
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", 
"main"))(mod)
-    func = tvm.tir.transform.MakePackedAPI()(mod)["main"]
-
-    num_args = func.params[2]
-
-    # num_args assertion
-    assert func.body.condition.a == num_args
-    assert func.body.condition.b == 2
-
-    # Arguments unpacking
-    assignment = _find_assignment(func.body, "input_buffer")
-    assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")'
-
-    assignment = _find_assignment(assignment.body, "input_buffer")
-    assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, 
"handle")'
-    unpacked_input_buffer = assignment.var
-
-    assignment = _find_assignment(func.body, "not_device_context")
-    assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")'
-    unpacked_not_device_context = assignment.var
-
-    seq_stmt = _find_next(assignment, tvm.tir.SeqStmt)
-    call = _find_next(seq_stmt[1], tvm.tir.Evaluate)
-    call_extern = call.value
-
-    assert call_extern.args[1] == unpacked_input_buffer
-    assert call_extern.args[2] == unpacked_not_device_context
-
-
-def test_device_api_context_implicit_resource_handle():
-    ib = tvm.tir.ir_builder.create()
-
-    input_buffer = tvm.tir.decl_buffer(name="input_buffer", shape=[1])
-    device_context = tvm.tir.Var("device_api_context", dtype="handle")
-
-    ib.emit(
-        tvm.tir.call_extern("float32", "some_external_call", 
input_buffer.data, device_context),
-    )
-    stmt = ib.get()
-
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, 
device_context], stmt))
-    mod = tvm.tir.transform.Apply(
-        lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
-    )(mod)
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", 
"main"))(mod)
-    func = tvm.tir.transform.MakePackedAPI()(mod)["main"]
-
-    num_args = func.params[2]
-    device_context_in_resource_handle = func.params[5]
-
-    # num_args assertion
-    assert func.body.condition.a == num_args
-    assert func.body.condition.b == 1
-
-    # Arguments unpacking
-    assignment = _find_assignment(func.body, "input_buffer")
-    assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")'
-
-    assignment = _find_assignment(assignment.body, "input_buffer")
-    assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, 
"handle")'
-    unpacked_input_buffer = assignment.var
-
-    seq_stmt = _find_next(assignment, tvm.tir.SeqStmt)
-    call = _find_next(seq_stmt[1], tvm.tir.Evaluate)
-    call_extern = call.value
-
-    assert call_extern.args[1] == unpacked_input_buffer
-    assert call_extern.args[2] == device_context_in_resource_handle
-
-
 @pytest.mark.parametrize("use_global_symbol", [True, False])
 def test_no_op_when_global_symbol_is_absent(use_global_symbol):
     func_attr = {"target": tvm.target.Target("llvm", host="llvm")}
@@ -160,7 +74,7 @@ def 
test_no_op_when_global_symbol_is_absent(use_global_symbol):
 
     after = 
tvm.tir.transform.MakePackedAPI()(tvm.IRModule.from_expr(before))["main"]
     if use_global_symbol:
-        assert len(after.params) == 6
+        assert len(after.params) == 4
     else:
         tvm.ir.assert_structural_equal(before, after)
 
@@ -341,12 +255,10 @@ def test_zero_arg_function():
     class Expected:
         @T.prim_func
         def func_without_arg(
+            self: T.handle,
             args: T.handle,
-            arg_type_ids: T.handle("int32"),
             num_args: T.int32,
-            out_ret_value: T.handle("void"),
-            out_ret_tcode: T.handle("int32"),
-            resource_handle: T.handle,
+            result: T.handle("void"),
         ) -> T.int32:
             T.func_attr(
                 {
@@ -355,15 +267,11 @@ def test_zero_arg_function():
                 }
             )
             assert num_args == 0, "func_without_arg: num_args should be 0"
-            arg_type_ids_1 = T.decl_buffer((0,), "int32", data=arg_type_ids)
             with T.attr(0, "compute_scope", "func_without_arg_compute_"):
-                out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, 
strides=(1,))
-                out_ret_value_1[0] = T.Cast("int64", T.int64(42))
-                out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, 
strides=(1,))
-                out_ret_tcode_1[0] = 0
+                T.tvm_struct_set(result, 0, 13, 1)
+                T.tvm_struct_set(result, 0, 14, T.Cast("int64", T.int64(42)))
                 return 0
             return 0
-
     After = tvm.tir.transform.MakePackedAPI()(Before)
     tvm.ir.assert_structural_equal(Expected, After)
 
@@ -399,12 +307,10 @@ def test_int_parameter():
     class Expected:
         @T.prim_func
         def main(
+            self: T.handle,
             args: T.handle,
-            arg_type_ids: T.handle("int32"),
             num_args: T.int32,
-            out_ret_value: T.handle("void"),
-            out_ret_tcode: T.handle("int32"),
-            resource_handle: T.handle,
+            result: T.handle("void"),
         ) -> T.int32:
             T.func_attr(
                 {
@@ -413,27 +319,22 @@ def test_int_parameter():
                 }
             )
             assert num_args == 1, "main: num_args should be 1"
-            assert not T.isnullptr(args), "main: TVMValue* arg pointer was 
NULL"
-            assert not T.isnullptr(arg_type_ids), "main: int* type_codes was 
NULL"
-            arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids)
-            arg_code: T.int32 = arg_type_ids_1[0]
-            assert arg_code == 0 or arg_code == 15, "main: Expect arg[0] to be 
int"
-            arg: T.int32 = T.Cast("int32", T.tvm_struct_get(args, 0, 12, 
"int64"))
+            assert not T.isnullptr(args), "main: args pointer is NULL"
+            arg_type_index: T.int32 = T.tvm_struct_get(args, 0, 13, "int32")
+            assert arg_type_index == 1 or arg_type_index == 2, "main: Expect 
arg[0] to be int"
+            arg: T.int32 = T.Cast("int32", T.tvm_struct_get(args, 0, 14, 
"int64"))
             with T.attr(0, "compute_scope", "main_compute_"):
-                out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, 
strides=(1,))
-                out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, 
strides=(1,))
                 if arg > 0:
-                    out_ret_value_1[0] = T.Cast("int64", 10)
-                    out_ret_tcode_1[0] = 0
+                    T.tvm_struct_set(result, 0, 13, 1)
+                    T.tvm_struct_set(result, 0, 14, T.Cast("int64", 10))
                     return 0
                 else:
-                    out_ret_value_1[0] = T.Cast("int64", 20)
-                    out_ret_tcode_1[0] = 0
+                    T.tvm_struct_set(result, 0, 13, 1)
+                    T.tvm_struct_set(result, 0, 14, T.Cast("int64", 20))
                     return 0
             return 0
 
     After = tvm.tir.transform.MakePackedAPI()(Before)
-
     tvm.ir.assert_structural_equal(Expected, After)
 
 
@@ -461,12 +362,10 @@ def test_bool_parameter():
     class Expected:
         @T.prim_func
         def main(
+            self: T.handle,
             args: T.handle,
-            arg_type_ids: T.handle("int32"),
             num_args: T.int32,
-            out_ret_value: T.handle("void"),
-            out_ret_tcode: T.handle("int32"),
-            resource_handle: T.handle,
+            result: T.handle("void"),
         ) -> T.int32:
             T.func_attr(
                 {
@@ -475,27 +374,22 @@ def test_bool_parameter():
                 }
             )
             assert num_args == 1, "main: num_args should be 1"
-            assert not T.isnullptr(args), "main: TVMValue* arg pointer was 
NULL"
-            assert not T.isnullptr(arg_type_ids), "main: int* type_codes was 
NULL"
-            arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids)
-            arg_code: T.int32 = arg_type_ids_1[0]
-            assert arg_code == 15 or arg_code == 0, "main: Expect arg[0] to be 
boolean"
-            arg: T.bool = T.Cast("bool", T.tvm_struct_get(args, 0, 12, 
"int64"))
+            assert not T.isnullptr(args), "main: args pointer is NULL"
+            arg_type_index: T.int32 = T.tvm_struct_get(args, 0, 13, "int32")
+            assert arg_type_index == 2 or arg_type_index == 1, "main: Expect 
arg[0] to be boolean"
+            arg: T.bool = T.Cast("bool", T.tvm_struct_get(args, 0, 14, 
"int64"))
             with T.attr(0, "compute_scope", "main_compute_"):
-                out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, 
strides=(1,))
-                out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, 
strides=(1,))
                 if arg:
-                    out_ret_value_1[0] = T.Cast("int64", 10)
-                    out_ret_tcode_1[0] = 0
+                    T.tvm_struct_set(result, 0, 13, 1)
+                    T.tvm_struct_set(result, 0, 14, T.Cast("int64", 10))
                     return 0
                 else:
-                    out_ret_value_1[0] = T.Cast("int64", 20)
-                    out_ret_tcode_1[0] = 0
+                    T.tvm_struct_set(result, 0, 13, 1)
+                    T.tvm_struct_set(result, 0, 14, T.Cast("int64", 20))
                     return 0
             return 0
 
     After = tvm.tir.transform.MakePackedAPI()(Before)
-
     tvm.ir.assert_structural_equal(Expected, After)
 
 

Reply via email to