This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 18911705717210495e8a16914cbd9ea1d8d848a0 Author: tqchen <[email protected]> AuthorDate: Mon Apr 21 11:32:58 2025 -0400 Fix packed api testcases --- src/tir/transforms/make_packed_api.cc | 3 +- .../test_tir_transform_make_packed_api.py | 156 ++++----------------- 2 files changed, 26 insertions(+), 133 deletions(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 3d0ace67f7..13cff5276b 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -263,8 +263,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { }())); if (num_args > 0) { - seq_init.push_back( - MakeAssertNotNull(v_packed_args, name_hint + ": TVMValue* arg pointer was NULL")); + seq_init.push_back(MakeAssertNotNull(v_packed_args, name_hint + ": args pointer is NULL")); } // Need to delay binding of the buffers, in case some arguments also diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index c676330062..cbd5f0b3e5 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -60,92 +60,6 @@ def _find_compute_scope(func): return result - -def test_variable_passed_from_args(): - ib = tvm.tir.ir_builder.create() - - input_buffer = tvm.tir.decl_buffer(name="input_buffer", shape=[1]) - not_device_context = tvm.tir.Var("not_device_context", dtype="handle") - - ib.emit( - tvm.tir.call_extern("float32", "some_external_call", input_buffer.data, not_device_context), - ) - stmt = ib.get() - - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, not_device_context], stmt)) - mod = tvm.tir.transform.Apply( - lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm")) - )(mod) - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) - func = tvm.tir.transform.MakePackedAPI()(mod)["main"] - - num_args = func.params[2] - - # num_args assertion - assert func.body.condition.a == num_args - assert func.body.condition.b == 2 - - # Arguments unpacking - assignment = _find_assignment(func.body, "input_buffer") - assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")' - - assignment = _find_assignment(assignment.body, "input_buffer") - assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, "handle")' - unpacked_input_buffer = assignment.var - - assignment = _find_assignment(func.body, "not_device_context") - assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")' - unpacked_not_device_context = assignment.var - - seq_stmt = _find_next(assignment, tvm.tir.SeqStmt) - call = _find_next(seq_stmt[1], tvm.tir.Evaluate) - call_extern = call.value - - assert call_extern.args[1] == unpacked_input_buffer - assert call_extern.args[2] == unpacked_not_device_context - - -def test_device_api_context_implicit_resource_handle(): - ib = tvm.tir.ir_builder.create() - - input_buffer = tvm.tir.decl_buffer(name="input_buffer", shape=[1]) - device_context = tvm.tir.Var("device_api_context", dtype="handle") - - ib.emit( - tvm.tir.call_extern("float32", "some_external_call", input_buffer.data, device_context), - ) - stmt = ib.get() - - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, device_context], stmt)) - mod = tvm.tir.transform.Apply( - lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm")) - )(mod) - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) - func = tvm.tir.transform.MakePackedAPI()(mod)["main"] - - num_args = func.params[2] - device_context_in_resource_handle = func.params[5] - - # num_args assertion - assert func.body.condition.a == num_args - assert func.body.condition.b == 1 - - # Arguments unpacking - assignment = _find_assignment(func.body, "input_buffer") - assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")' - - assignment = _find_assignment(assignment.body, "input_buffer") - assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, "handle")' - unpacked_input_buffer = assignment.var - - seq_stmt = _find_next(assignment, tvm.tir.SeqStmt) - call = _find_next(seq_stmt[1], tvm.tir.Evaluate) - call_extern = call.value - - assert call_extern.args[1] == unpacked_input_buffer - assert call_extern.args[2] == device_context_in_resource_handle - - @pytest.mark.parametrize("use_global_symbol", [True, False]) def test_no_op_when_global_symbol_is_absent(use_global_symbol): func_attr = {"target": tvm.target.Target("llvm", host="llvm")} @@ -160,7 +74,7 @@ def test_no_op_when_global_symbol_is_absent(use_global_symbol): after = tvm.tir.transform.MakePackedAPI()(tvm.IRModule.from_expr(before))["main"] if use_global_symbol: - assert len(after.params) == 6 + assert len(after.params) == 4 else: tvm.ir.assert_structural_equal(before, after) @@ -341,12 +255,10 @@ def test_zero_arg_function(): class Expected: @T.prim_func def func_without_arg( + self: T.handle, args: T.handle, - arg_type_ids: T.handle("int32"), num_args: T.int32, - out_ret_value: T.handle("void"), - out_ret_tcode: T.handle("int32"), - resource_handle: T.handle, + result: T.handle("void"), ) -> T.int32: T.func_attr( { @@ -355,15 +267,11 @@ def test_zero_arg_function(): } ) assert num_args == 0, "func_without_arg: num_args should be 0" - arg_type_ids_1 = T.decl_buffer((0,), "int32", data=arg_type_ids) with T.attr(0, "compute_scope", "func_without_arg_compute_"): - out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) - out_ret_value_1[0] = T.Cast("int64", T.int64(42)) - out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) - out_ret_tcode_1[0] = 0 + T.tvm_struct_set(result, 0, 13, 1) + T.tvm_struct_set(result, 0, 14, T.Cast("int64", T.int64(42))) return 0 return 0 - After = tvm.tir.transform.MakePackedAPI()(Before) tvm.ir.assert_structural_equal(Expected, After) @@ -399,12 +307,10 @@ def test_int_parameter(): class Expected: @T.prim_func def main( + self: T.handle, args: T.handle, - arg_type_ids: T.handle("int32"), num_args: T.int32, - out_ret_value: T.handle("void"), - out_ret_tcode: T.handle("int32"), - resource_handle: T.handle, + result: T.handle("void"), ) -> T.int32: T.func_attr( { @@ -413,27 +319,22 @@ def test_int_parameter(): } ) assert num_args == 1, "main: num_args should be 1" - assert not T.isnullptr(args), "main: TVMValue* arg pointer was NULL" - assert not T.isnullptr(arg_type_ids), "main: int* type_codes was NULL" - arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) - arg_code: T.int32 = arg_type_ids_1[0] - assert arg_code == 0 or arg_code == 15, "main: Expect arg[0] to be int" - arg: T.int32 = T.Cast("int32", T.tvm_struct_get(args, 0, 12, "int64")) + assert not T.isnullptr(args), "main: args pointer is NULL" + arg_type_index: T.int32 = T.tvm_struct_get(args, 0, 13, "int32") + assert arg_type_index == 1 or arg_type_index == 2, "main: Expect arg[0] to be int" + arg: T.int32 = T.Cast("int32", T.tvm_struct_get(args, 0, 14, "int64")) with T.attr(0, "compute_scope", "main_compute_"): - out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) - out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) if arg > 0: - out_ret_value_1[0] = T.Cast("int64", 10) - out_ret_tcode_1[0] = 0 + T.tvm_struct_set(result, 0, 13, 1) + T.tvm_struct_set(result, 0, 14, T.Cast("int64", 10)) return 0 else: - out_ret_value_1[0] = T.Cast("int64", 20) - out_ret_tcode_1[0] = 0 + T.tvm_struct_set(result, 0, 13, 1) + T.tvm_struct_set(result, 0, 14, T.Cast("int64", 20)) return 0 return 0 After = tvm.tir.transform.MakePackedAPI()(Before) - tvm.ir.assert_structural_equal(Expected, After) @@ -461,12 +362,10 @@ def test_bool_parameter(): class Expected: @T.prim_func def main( + self: T.handle, args: T.handle, - arg_type_ids: T.handle("int32"), num_args: T.int32, - out_ret_value: T.handle("void"), - out_ret_tcode: T.handle("int32"), - resource_handle: T.handle, + result: T.handle("void"), ) -> T.int32: T.func_attr( { @@ -475,27 +374,22 @@ def test_bool_parameter(): } ) assert num_args == 1, "main: num_args should be 1" - assert not T.isnullptr(args), "main: TVMValue* arg pointer was NULL" - assert not T.isnullptr(arg_type_ids), "main: int* type_codes was NULL" - arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) - arg_code: T.int32 = arg_type_ids_1[0] - assert arg_code == 15 or arg_code == 0, "main: Expect arg[0] to be boolean" - arg: T.bool = T.Cast("bool", T.tvm_struct_get(args, 0, 12, "int64")) + assert not T.isnullptr(args), "main: args pointer is NULL" + arg_type_index: T.int32 = T.tvm_struct_get(args, 0, 13, "int32") + assert arg_type_index == 2 or arg_type_index == 1, "main: Expect arg[0] to be boolean" + arg: T.bool = T.Cast("bool", T.tvm_struct_get(args, 0, 14, "int64")) with T.attr(0, "compute_scope", "main_compute_"): - out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) - out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) if arg: - out_ret_value_1[0] = T.Cast("int64", 10) - out_ret_tcode_1[0] = 0 + T.tvm_struct_set(result, 0, 13, 1) + T.tvm_struct_set(result, 0, 14, T.Cast("int64", 10)) return 0 else: - out_ret_value_1[0] = T.Cast("int64", 20) - out_ret_tcode_1[0] = 0 + T.tvm_struct_set(result, 0, 13, 1) + T.tvm_struct_set(result, 0, 14, T.Cast("int64", 20)) return 0 return 0 After = tvm.tir.transform.MakePackedAPI()(Before) - tvm.ir.assert_structural_equal(Expected, After)
