This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch ir-builder-v2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 964f195f8ab8e9b53d4fdca7b9b5a6c6bc44eea1 Author: Junru Shao <[email protected]> AuthorDate: Sat Sep 17 15:53:15 2022 -0700 Fix unittests --- .../unittest/test_aot_legalize_packed_call.py | 12 ++-- .../unittest/test_meta_schedule_space_cuda.py | 5 +- .../python/unittest/test_tir_lower_match_buffer.py | 43 +++++++------ .../test_tir_transform_inject_software_pipeline.py | 14 ++--- .../test_tir_transform_inject_virtual_thread.py | 17 +++-- .../python/unittest/test_tvmscript_error_report.py | 45 ++----------- tests/python/unittest/test_tvmscript_spans.py | 73 ---------------------- .../python/unittest/test_tvmscript_syntax_sugar.py | 53 +++++++++------- 8 files changed, 84 insertions(+), 178 deletions(-) diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index 9c597a55e5..cd0114d464 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import pytest import tvm -from tvm.script import tir as T -from tvm import tir import tvm.testing -import pytest +from tvm import tir +from tvm.script import tir as T @tvm.script.ir_module @@ -85,7 +85,7 @@ class Expected: T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), - T.cast(0, dtype="float32"), + T.Cast("float32", 0), 0, dtype="handle", ), @@ -94,7 +94,7 @@ class Expected: T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), - T.cast(0, dtype="float32"), + T.Cast("float32", 0), 0, dtype="handle", ), @@ -103,7 +103,7 @@ class Expected: T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), - T.cast(0, dtype="float32"), + T.Cast("float32", 0), 0, dtype="handle", ), diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index ffa2b57ba8..666170e819 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. """Tests for MetaSchedule search space on CUDA""" -from tvm import te, topi, autotvm +from tvm import autotvm from tvm import meta_schedule as ms +from tvm import te, topi from tvm.meta_schedule.testing.space_generation import check_sketches, print_sketches from tvm.meta_schedule.testing.te_workload import create_te_workload from tvm.script import tir as T @@ -910,7 +911,7 @@ def test_cuda_nrm(): for i0_1 in T.thread_binding(128, thread="threadIdx.x"): with T.block("D"): b = T.axis.spatial(1, i0_1) - T.where(0 * 128 + i0_1 < 1) + T.where(T.int32(0) * 128 + i0_1 < 1) T.reads(C_shared[b]) T.writes(D[b]) D[b] = T.sqrt(C_shared[b], dtype="float32") diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 93b7caf9cd..51fb28b7da 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -16,7 +16,6 @@ # under the License. import pytest - import tvm from tvm.script import tir as T @@ -62,9 +61,17 @@ def transformed_buffer_load_store(a: T.handle, c: T.handle) -> None: A[i * 4 + ii, j, k * 2 + kk] += C[i * 4 + ii, k * 2 + kk] [email protected]_op_attr("tir.intrin_test", "") -def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1): - return 0 +def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1, dtype): + return tvm.tir.call_intrin( + dtype, + "tir.intrin_test", + data, + elem_offset, + stride_0, + stride_1, + shape_0, + shape_1, + ) @T.prim_func @@ -82,7 +89,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: offset_factor=1, ) T.evaluate( - T.intrin_test( + intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], @@ -105,7 +112,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: offset_factor=1, ) T.evaluate( - T.intrin_test( + intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], @@ -126,7 +133,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: T.reads([]) T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) T.evaluate( - T.intrin_test( + intrin_test( A.data, i * 131072 + j * 128 + k * 16, 8192, @@ -141,7 +148,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: T.reads([]) T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) T.evaluate( - T.intrin_test( + intrin_test( B.data, i * 4096 + j * 2048 + k * 8, 64, @@ -169,7 +176,7 @@ def high_dim_opaque_access(a: T.handle) -> None: offset_factor=1, ) T.evaluate( - T.intrin_test( + intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], @@ -189,7 +196,7 @@ def transformed_high_dim_opaque_access(a: T.handle) -> None: T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( - T.intrin_test( + intrin_test( A.data, i * 2048 + j * 1024 + k * 16, 64, @@ -217,7 +224,7 @@ def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: offset_factor=1, ) T.evaluate( - T.intrin_test( + intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], @@ -237,7 +244,7 @@ def transformed_high_dim_opaque_access_with_source_strides(a: T.handle) -> None: T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( - T.intrin_test( + intrin_test( A.data, i * 2576 + j * 1280 + k * 16, 80, @@ -298,7 +305,7 @@ def recursive_match(a: T.handle, b: T.handle) -> None: offset_factor=1, ) T.evaluate( - T.intrin_test( + intrin_test( sub_sub_A.data, sub_sub_A.elem_offset, sub_sub_A.strides[0], @@ -343,7 +350,7 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None: ] ) T.evaluate( - T.intrin_test( + intrin_test( A.data, i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4, 64, @@ -375,7 +382,7 @@ def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None: sub_A[ii, jj] = 1 for j in range(0, 4): T.evaluate( - T.intrin_test( + intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], @@ -399,7 +406,7 @@ def transformed_symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) A[i * m + ii, jj] = 1 for j in range(0, 4): T.evaluate( - T.intrin_test( + intrin_test( B.data, i * n * (m * 4), m * 4, @@ -423,7 +430,7 @@ def rank0_buffer(a: T.handle, b: T.handle) -> None: sub_B = T.match_buffer(B[i, j], (), offset_factor=1) sub_A[()] = 1 T.evaluate( - T.intrin_test( + intrin_test( sub_B.data, sub_B.elem_offset, 0, @@ -445,7 +452,7 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: T.writes([A[i, j], B[i, j]]) A[i, j] = 1 T.evaluate( - T.intrin_test( + intrin_test( B.data, i * 8 + j, 0, diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index edaeb7c9b6..34f988c77c 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -14,16 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest import sys -import numpy as np +import numpy as np +import pytest import tvm import tvm.testing import tvm.tir.tensor_intrin.cuda -from tvm import tir, te, TVMError -from tvm.script import tir as T +from tvm import TVMError, te, tir from tvm.meta_schedule.testing import te_workload +from tvm.script import tir as T from tvm.testing.tir import mma_schedule from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x16_A_DYN_INTRIN, @@ -1060,7 +1060,7 @@ def test_simple_compute_async(): T.writes(B[0, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): - B[0 % 2, tx, 0] = A[tx, 0] * T.float32(2) + B[T.int32(0) % 2, tx, 0] = A[tx, 0] * T.float32(2) with T.block(): T.reads(A[tx, 1:16], B[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[tx, 0:15]) @@ -1080,11 +1080,11 @@ def test_simple_compute_async(): with T.attr(0, "async_wait_inflight_count", 1): C[tx, i - 1 + 1] = B[(i - 1 + 1) % 2, tx, 0] + T.float32(1) with T.block(): - T.reads(B[15 % 2, tx, 0]) + T.reads(B[T.int32(15) % 2, tx, 0]) T.writes(C[tx, 15]) with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 0): - C[tx, 15] = B[15 % 2, tx, 0] + T.float32(1) + C[tx, 15] = B[T.int32(15) % 2, tx, 0] + T.float32(1) tvm.ir.assert_structural_equal(mod["main"], ref, True) diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 548f3bc8d1..6000ea339b 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -16,7 +16,6 @@ # under the License. import tvm from tvm import te - from tvm.script import tir as T vthread_name = tvm.testing.parameter("vthread", "cthread") @@ -155,10 +154,10 @@ def test_vthread_simplified(): B = T.buffer_decl([16], "int32", data=B_data, scope="shared") # The indices for B should each be a single Ramp node, and # should not be the sum of a Ramp and Broadcast node. - B[0 * 4 : 0 * 4 + 4] = T.broadcast(0, 4) - B[1 * 4 : 1 * 4 + 4] = T.broadcast(1, 4) - B[2 * 4 : 2 * 4 + 4] = T.broadcast(2, 4) - B[3 * 4 : 3 * 4 + 4] = T.broadcast(3, 4) + B[T.int32(0) * 4 : T.int32(0) * 4 + 4] = T.broadcast(0, 4) + B[T.int32(1) * 4 : T.int32(1) * 4 + 4] = T.broadcast(1, 4) + B[T.int32(2) * 4 : T.int32(2) * 4 + 4] = T.broadcast(2, 4) + B[T.int32(3) * 4 : T.int32(3) * 4 + 4] = T.broadcast(3, 4) before_mod = tvm.IRModule.from_expr(before_func) after_mod = tvm.tir.transform.InjectVirtualThread()(before_mod) @@ -182,10 +181,10 @@ def test_vthread_vectorized(): def expected_func(): B_data = T.allocate([4], "int32x4", "shared") B = T.buffer_decl([4], "int32x4", data=B_data, scope="shared") - B[0 * 4 / 4] = T.broadcast(0, 4) - B[1 * 4 / 4] = T.broadcast(1, 4) - B[2 * 4 / 4] = T.broadcast(2, 4) - B[3 * 4 / 4] = T.broadcast(3, 4) + B[T.int32(0) * 4 / 4] = T.broadcast(0, 4) + B[T.int32(1) * 4 / 4] = T.broadcast(1, 4) + B[T.int32(2) * 4 / 4] = T.broadcast(2, 4) + B[T.int32(3) * 4 / 4] = T.broadcast(3, 4) before_mod = tvm.IRModule.from_expr(before_func) intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index acc68af065..7d27eac4e9 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. -import pytest import sys + +import pytest import tvm from tvm import tir -from tvm.testing import check_error from tvm.script import tir as T +from tvm.testing import check_error def buffer_bind_missing_args(a: T.handle) -> None: @@ -76,14 +77,6 @@ def test_missing_type_annotation(): check_error(missing_type_annotation, 1) -def invalid_expr_stmt() -> None: - T.max(1, 2) # error - - -def test_invalid_expr_stmt(): - check_error(invalid_expr_stmt, 2) - - def invalid_for_function(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") @@ -115,14 +108,6 @@ def test_return_not_allowed(): check_error(return_not_allowed, 2) -def tir_assert(a: T.handle) -> None: - T.Assert(0, "") # error - - -def test_tir_assert(): - check_error(tir_assert, 2) - - def no_body(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") T.realize(A, "") # error @@ -250,19 +235,6 @@ def test_invalid_match_buffer_region(): check_error(invalid_match_buffer_region, 5) -def duplicate_buffer() -> None: - A = T.alloc_buffer((128, 128), "float32") - for i, j in T.grid(128, 128): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - A = T.alloc_buffer((128, 128), "float32") # error - T.evaluate(1.0) - - -def test_duplicate_buffer(): - check_error(duplicate_buffer, 6) - - def duplicate_reads() -> None: A = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): @@ -334,7 +306,7 @@ def opaque_access_during_complete(a: T.handle) -> None: # error def test_opaque_access_during_complete(): - check_error(opaque_access_during_complete, 1) + check_error(opaque_access_during_complete, 0) def convert_slice_to_bufferload() -> None: @@ -608,15 +580,6 @@ def test_binop_bad_type(): check_error(binop_bad_type, 3) -def floor_dtype(h: T.handle): - h_ = T.match_buffer(h, [1]) - h_[0] = T.floor(2) # error floor requires a dtype - - -def test_floor_dtype(): - check_error(floor_dtype, 3) - - def non_integer_typed_block_iter(): with T.block(): i = T.axis.S(0.1, 0.1) # error IterVar requires an integer dtype diff --git a/tests/python/unittest/test_tvmscript_spans.py b/tests/python/unittest/test_tvmscript_spans.py deleted file mode 100644 index f863a4dd98..0000000000 --- a/tests/python/unittest/test_tvmscript_spans.py +++ /dev/null @@ -1,73 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from tvm.script import tir as T - - [email protected]_func -def loops() -> None: - for i in T.parallel(0, 2): - for j in T.serial(0, 1): - for z in T.vectorized(3, 4): - T.evaluate(0) - - -def test_loops(): - start_line = 23 - parsed = loops - - assert parsed.span.line == start_line - - assert parsed.body.span.line == start_line + 1 - assert parsed.body.min.span.column == 25 - assert parsed.body.extent.span.column == 28 - assert parsed.body.extent.span.line == start_line + 1 - - assert parsed.body.body.span.line == start_line + 2 - assert parsed.body.body.loop_var.span.line == start_line + 2 - assert parsed.body.body.loop_var.span.column == 13 - - assert parsed.body.body.body.span.line == start_line + 3 - assert parsed.body.body.body.span.column == 22 - - assert parsed.body.body.body.body.span.line == start_line + 4 - assert parsed.body.body.body.body.span.column == 17 - - [email protected]_func -def statements() -> None: - T.evaluate(1) - T.evaluate("test") - - -def test_statements(): - start_line = 53 - parsed = statements - - assert parsed.body.span.line == start_line + 1 - - assert parsed.body[0].span.line == start_line + 1 - assert parsed.body[0].span.column == 5 - - assert parsed.body[0].span.line == start_line + 1 - assert parsed.body[0].span.column == 5 - - -if __name__ == "__main__": - test_loops() - test_statements() diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index d955ec0a8c..d09a0d143a 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -20,8 +20,8 @@ import sys import pytest import tvm.testing from tvm.ir import assert_structural_equal +from tvm.script import from_source from tvm.script import tir as T -from tvm.script.parser import from_source from tvm.testing import check_error @@ -164,15 +164,24 @@ def test_match_buffer_1d(): # match buffer failed case -def test_match_buffer_no_kwargs_failed(): - with pytest.raises(ValueError) as e: - - @T.prim_func - def elementwise_buffer_no_kwargs_failed( - a: T.Buffer[(128, 128, 128, 128)], - b: T.Buffer[(128, 128, 128, 128)], - ) -> None: - pass +def test_match_buffer_without_dtype(): + @T.prim_func + def no_dtype( + a: T.Buffer[(128, 128, 128, 128)], + b: T.Buffer[(128, 128, 128, 128)], + ) -> None: + pass + + a0, a1, a2, a3 = no_dtype.buffer_map[no_dtype.params[0]].shape + b0, b1, b2, b3 = no_dtype.buffer_map[no_dtype.params[1]].shape + assert a0 == 128 + assert a1 == 128 + assert a2 == 128 + assert a3 == 128 + assert b0 == 128 + assert b1 == 128 + assert b2 == 128 + assert b3 == 128 # dynamic shape gemm @@ -274,8 +283,8 @@ def test_letstmt_bind_with_constant(): @T.prim_func def constant_binds_wrapped(): - x = T.int32(1) - y = T.float32(42.0) + x = T.inline(T.int32(1)) + y = T.inline(T.float32(42.0)) T.evaluate(T.cast(x, "float32") + y) assert_structural_equal(constant_binds, constant_binds_wrapped) @@ -288,9 +297,9 @@ def test_func_call(): @T.prim_func def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 8), "float16", align=64, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 8), "float16", align=64, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "float16", align=64, offset_factor=16, scope="warp") + A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") + B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") + C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") with T.block("root"): T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) @@ -298,9 +307,9 @@ def test_func_call(): for i, j, k in T.grid(16, 16, 16): with T.block("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) - thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j) - thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k) - thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j) + thread_id_C, local_id_C = T.inline(shared_16x16_to_ldmatrix_32x8_layout(i, j)) + thread_id_A, local_id_A = T.inline(shared_16x16_to_ldmatrix_32x8_layout(i, k)) + thread_id_B, local_id_B = T.inline(shared_16x16_to_ldmatrix_32x8_layout(k, j)) T.reads( C[thread_id_C, local_id_C], @@ -315,9 +324,9 @@ def test_func_call(): @T.prim_func def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 8), "float16", align=64, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 8), "float16", align=64, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "float16", align=64, offset_factor=16, scope="warp") + A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") + B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") + C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") with T.block("root"): T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) @@ -341,7 +350,7 @@ def test_func_call(): # The following is an example of an error message from calling an invalid function - # error: Error occurred when invoking the function sqrt: + # error: Error occured when invoking the function sqrt: # loop of ufunc does not support argument 0 of type Var which has no callable sqrt method # --> test_tvmscript_syntax_sugar.py:334:19 # |
