This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1228104726 [Minor][Testing] Consolidate IRs into corresponding 
functions (#13339)
1228104726 is described below

commit 1228104726b3cfb63c0c13da7a584ca6d7b5e584
Author: Junru Shao <[email protected]>
AuthorDate: Wed Nov 9 22:13:36 2022 -0800

    [Minor][Testing] Consolidate IRs into corresponding functions (#13339)
    
    We moved most of the IR definition into the testing methods correspondingly.
    
    Co-authored-by: Yaxing Cai <[email protected]>
---
 python/tvm/testing/__init__.py                     |   2 -
 python/tvm/testing/tir.py                          |  45 +-
 .../python/unittest/test_tvmscript_error_report.py | 710 ++++++++++-----------
 .../python/unittest/test_tvmscript_syntax_sugar.py |  13 +-
 4 files changed, 330 insertions(+), 440 deletions(-)

diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py
index 9a18f16891..d84846725e 100644
--- a/python/tvm/testing/__init__.py
+++ b/python/tvm/testing/__init__.py
@@ -28,7 +28,5 @@ from .popen_pool import initializer, after_initializer, 
register_ffi, call_cpp_f
 from .popen_pool import call_py_ffi, call_cpp_py_ffi, fast_summation, 
slow_summation
 from .popen_pool import timeout_job
 
-from .tir import check_error
-
 from . import auto_scheduler
 from . import autotvm
diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py
index 8dd4826738..57c1a85c5b 100644
--- a/python/tvm/testing/tir.py
+++ b/python/tvm/testing/tir.py
@@ -16,49 +16,6 @@
 # under the License.
 # pylint: disable=invalid-name, import-outside-toplevel, unused-variable
 """Common utility functions in TVM tir"""
-import inspect
-import re
-import tvm
-from tvm.ir.diagnostics import override_renderer
-
-
-CHECK_ERROR_RE = re.compile(r"^.*# check_error: (.+)$")
-
-
-def check_error(func, rel_lineno):
-    """check if TIR script throws error"""
-    # Override the default renderer to accumulate errors
-    errors = []
-
-    def render(e):
-        for d in e.diagnostics:
-            errors.append(d)
-
-    override_renderer(render)
-    # The diagnostic context throws an exception when it gets an error
-    try:
-        source_code = inspect.getsource(func)
-        source_code = "@T.prim_func\n" + source_code
-        from tvm.script import from_source
-
-        # to avoid cyclic import
-        from_source(source_code)
-    except tvm.error.DiagnosticError as e:
-        pass
-    assert len(errors) == 1, errors
-    for d in errors:
-        assert (
-            d.span.line - 1 == rel_lineno
-        ), f"Expected error to be on line {rel_lineno}, but it was on 
{d.span.line - 1}"
-
-    error_line = source_code.split("\n")[rel_lineno]
-    m = CHECK_ERROR_RE.match(error_line)
-    if m:
-        expected_error_text = m.group(1)
-        errors = [e.message for e in errors]
-        assert (
-            expected_error_text in errors
-        ), f'check_error expects "{expected_error_text} in str(errors): 
{errors}'
 
 
 def mma_schedule(
@@ -80,6 +37,8 @@ def mma_schedule(
     shared_scope="shared",
 ):
     """Create a tensorized schedule for GEMM with MMA intrinsics."""
+    import tvm  # pylint: disable=import-outside-toplevel
+
     ir_module = tvm.IRModule({"main": workload})
     sch = tvm.tir.Schedule(ir_module)
 
diff --git a/tests/python/unittest/test_tvmscript_error_report.py 
b/tests/python/unittest/test_tvmscript_error_report.py
index acc68af065..36de35fa92 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -14,310 +14,304 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import inspect
+import re
 
 import pytest
-import sys
 import tvm
+import tvm.testing
 from tvm import tir
-from tvm.testing import check_error
+from tvm.ir.diagnostics import override_renderer
+from tvm.script import from_source
 from tvm.script import tir as T
 
 
-def buffer_bind_missing_args(a: T.handle) -> None:
-    A = T.match_buffer((16, 16), "float32")  # error
+def check_error(func, rel_lineno):
+    check_error_re = re.compile(r"^.*# check_error: (.+)$")
+    """check if TIR script throws error"""
+    # Override the default renderer to accumulate errors
+    errors = []
+
+    def render(e):
+        for d in e.diagnostics:
+            errors.append(d)
+
+    override_renderer(render)
+    # The diagnostic context throws an exception when it gets an error
+    try:
+        source_code = inspect.getsource(func)
+        indent = len(re.match(r"^\s*", source_code).group(0))
+        source_code = "@T.prim_func\n" + "\n".join(
+            line[indent:] for line in source_code.splitlines()
+        )
+        from_source(source_code)
+    except tvm.error.DiagnosticError as e:
+        pass
+    assert len(errors) == 1, errors
+    if rel_lineno is None:
+        return
+    error = errors[0]
+    assert (
+        error.span.line - 1 == rel_lineno
+    ), f"Expected error to be on line {rel_lineno}, but it was on 
{error.span.line - 1}"
+
+    error_line = source_code.split("\n")[rel_lineno]
+    m = check_error_re.match(error_line)
+    if m:
+        expected_error_text = m.group(1)
+        error = error.message
+        assert (
+            expected_error_text == error
+        ), f'check_error expects "{expected_error_text} in str(errors): 
{error}'
 
 
 def test_buffer_bind():
-    check_error(buffer_bind_missing_args, 2)
-
+    def buffer_bind_missing_args(a: T.handle) -> None:
+        A = T.match_buffer((16, 16), "float32")  # error
 
-def undefined_buffer(a: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-
-    T.attr(A, "realize_scope", "")
-    T.realize(C[0:16, 0:16], "")  # error
-    for i in T.serial(16):
-        for j in T.serial(0, 16):
-            A[i, j] = 0.0
+    check_error(buffer_bind_missing_args, 2)
 
 
 def test_undefined_buffer():
-    check_error(undefined_buffer, 5)
+    def undefined_buffer(a: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
 
+        T.attr(A, "realize_scope", "")
+        T.realize(C[0:16, 0:16], "")  # error
+        for i in T.serial(16):
+            for j in T.serial(0, 16):
+                A[i, j] = 0.0
 
-def unsupported_stmt(a: T.int32) -> None:
-    if a > 0:
-        print("I love tvm")  # error
+    check_error(undefined_buffer, 5)
 
 
 def test_unsupported_stmt():
-    check_error(unsupported_stmt, 3)
-
+    def unsupported_stmt(a: T.int32) -> None:
+        if a > 0:
+            print("I love tvm")  # error
 
-def unsupported_function_call(a: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-
-    T.attr(A, "realize_scope", "")
-    T.realize(A[0:16, 0:16], "")
-    for i in T.const_range(16):  # error
-        for j in T.serial(0, 16):
-            A[i, j] = 0.0
+    check_error(unsupported_stmt, 3)
 
 
 def test_unsupported_function_call():
-    check_error(unsupported_function_call, 6)
+    def unsupported_function_call(a: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
 
+        T.attr(A, "realize_scope", "")
+        T.realize(A[0:16, 0:16], "")
+        for i in T.const_range(16):  # error
+            for j in T.serial(0, 16):
+                A[i, j] = 0.0
 
-def missing_type_annotation(a) -> None:  # error
-    T.evaluate(0.0)
+    check_error(unsupported_function_call, 6)
 
 
 def test_missing_type_annotation():
-    check_error(missing_type_annotation, 1)
-
-
-def invalid_expr_stmt() -> None:
-    T.max(1, 2)  # error
-
-
-def test_invalid_expr_stmt():
-    check_error(invalid_expr_stmt, 2)
-
-
-def invalid_for_function(a: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
+    def missing_type_annotation(a) -> None:  # error
+        T.evaluate(0.0)
 
-    for i in T.evaluate(0.0):  # error
-        for j in T.serial(0, 16):
-            A[i, j] = 0.0
+    check_error(missing_type_annotation, 1)
 
 
 def test_invalid_for_function():
-    check_error(invalid_for_function, 4)
+    def invalid_for_function(a: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
 
+        for i in T.evaluate(0.0):  # error
+            for j in T.serial(0, 16):
+                A[i, j] = 0.0
 
-def invalid_block_function(a: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-
-    with T.evaluate(0.0):  # error
-        T.evaluate(1.0)
+    check_error(invalid_for_function, 4)
 
 
 def test_invalid_block_function():
-    check_error(invalid_block_function, 4)
+    def invalid_block_function(a: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
 
+        with T.evaluate(0.0):  # error
+            T.evaluate(1.0)
 
-def return_not_allowed(a: T.handle) -> None:
-    return T.evaluate(0)  # error
+    check_error(invalid_block_function, 4)
 
 
 def test_return_not_allowed():
-    check_error(return_not_allowed, 2)
+    def return_not_allowed(a: T.handle) -> None:
+        return T.evaluate(0)  # error
 
-
-def tir_assert(a: T.handle) -> None:
-    T.Assert(0, "")  # error
-
-
-def test_tir_assert():
-    check_error(tir_assert, 2)
-
-
-def no_body(a: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    T.realize(A, "")  # error
+    check_error(return_not_allowed, 2)
 
 
 def test_no_body():
-    check_error(no_body, 3)
+    def no_body(a: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        T.realize(A, "")  # error
 
-
-def allocate_with_buffers() -> None:
-    with T.allocate([1], "float32", "") as [A, B]:  # error
-        T.evaluate(1.0)
+    check_error(no_body, 3)
 
 
 def test_allocate_with_buffers():
-    check_error(allocate_with_buffers, 2)
-
+    def allocate_with_buffers() -> None:
+        with T.allocate([1], "float32", "") as [A, B]:  # error
+            T.evaluate(1.0)
 
-def inconsistent_binding_value() -> None:
-    for i, j in T.grid(16, 16):
-        vi, vj = T.axis.remap("SS", [i])  # error
-        T.evaluate(1.0)
+    check_error(allocate_with_buffers, 2)
 
 
-def inconsistent_binding_type() -> None:
-    for i, j in T.grid(16, 16):
-        vi, vj = T.axis.remap("S", [i, j])  # error
-        T.evaluate(1.0)
+def test_inconsistent_binding():
+    def inconsistent_binding_value() -> None:
+        for i, j in T.grid(16, 16):
+            vi, vj = T.axis.remap("SS", [i])  # error
+            T.evaluate(1.0)
 
+    def inconsistent_binding_type() -> None:
+        for i, j in T.grid(16, 16):
+            vi, vj = T.axis.remap("S", [i, j])  # error
+            T.evaluate(1.0)
 
-def test_inconsistent_binding():
     check_error(inconsistent_binding_value, 3)
     check_error(inconsistent_binding_type, 3)
 
 
-def error_remap_type() -> None:
-    for i, j in T.grid(16, 16):
-        with T.block():
-            vi, vj = T.axis.remap("TT", [i, j])  # error
-            T.evaluate(1.0)
-
-
-def error_remap_value() -> None:
-    for i, j in T.grid(16, 16):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i + j, j])  # error
-            T.evaluate(1.0)
+def test_error_remap_args():
+    def error_remap_type() -> None:
+        for i, j in T.grid(16, 16):
+            with T.block():
+                vi, vj = T.axis.remap("TT", [i, j])  # error
+                T.evaluate(1.0)
 
+    def error_remap_value() -> None:
+        for i, j in T.grid(16, 16):
+            with T.block():
+                vi, vj = T.axis.remap("SS", [i + j, j])  # error
+                T.evaluate(1.0)
 
-def test_error_remap_args():
     check_error(error_remap_type, 4)
     check_error(error_remap_value, 4)
 
 
-def invalid_block_axes(a: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    for i, j in T.grid(16, 16):
-        with T.block():
-            vi = T.axis.S(i, A)  # error
-            T.evaluate(1.0)
-
-
 def test_invalid_block_axes():
-    check_error(invalid_block_axes, 5)
-
+    def invalid_block_axes(a: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        for i, j in T.grid(16, 16):
+            with T.block():
+                vi = T.axis.S(i, A)  # error
+                T.evaluate(1.0)
 
-def duplicate_block_axes() -> None:
-    for i, j in T.grid(16, 16):
-        with T.block():
-            vi = T.axis.S(16, i)
-            vi = T.axis.S(16, j)  # error
-            T.evaluate(1.0)
+    check_error(invalid_block_axes, 5)
 
 
-def duplicate_block_axes_remap() -> None:
-    for i, j in T.grid(16, 16):
-        with T.block():
-            vi, vi = T.axis.remap("SS", [i, j])  # error
-            T.evaluate(1.0)
+def test_duplicate_block_axes():
+    def duplicate_block_axes() -> None:
+        for i, j in T.grid(16, 16):
+            with T.block():
+                vi = T.axis.S(16, i)
+                vi = T.axis.S(16, j)  # error
+                T.evaluate(1.0)
 
+    def duplicate_block_axes_remap() -> None:
+        for i, j in T.grid(16, 16):
+            with T.block():
+                vi, vi = T.axis.remap("SS", [i, j])  # error
+                T.evaluate(1.0)
 
-def test_duplicate_block_axes():
     check_error(duplicate_block_axes, 5)
     check_error(duplicate_block_axes_remap, 4)
 
 
-def miss_block_bind_value() -> None:
-    for i, j in T.grid(128, 128):
-        with T.block():
-            vi = T.axis.S(i)  # error
-            T.evaluate(1.0)
-
-
 def test_miss_block_bind():
-    check_error(miss_block_bind_value, 4)
-
+    def miss_block_bind_value() -> None:
+        for i, j in T.grid(128, 128):
+            with T.block():
+                vi = T.axis.S(i)  # error
+                T.evaluate(1.0)
 
-def invalid_loop_var() -> None:
-    for i, j in range(0, 16):  # error
-        T.evaluate(1.0)
+    check_error(miss_block_bind_value, 4)
 
 
 def test_invalid_loop_var():
-    check_error(invalid_loop_var, 2)
-
+    def invalid_loop_var() -> None:
+        for i, j in range(0, 16):  # error
+            T.evaluate(1.0)
 
-def inconsistent_grid() -> None:
-    for i in T.grid(16, 16):  # error
-        T.evaluate(1.0)
+    check_error(invalid_loop_var, 2)
 
 
 def test_inconsistent_grid():
-    check_error(inconsistent_grid, 2)
-
-
-def invalid_match_buffer_region() -> None:
-    for i, j in T.grid(128, 128):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            A = T.match_buffer(vi)  # error
+    def inconsistent_grid() -> None:
+        for i in T.grid(16, 16):  # error
             T.evaluate(1.0)
 
+    check_error(inconsistent_grid, 2)
 
-def test_invalid_match_buffer_region():
-    check_error(invalid_match_buffer_region, 5)
 
+def test_invalid_match_buffer_region():
+    def invalid_match_buffer_region() -> None:
+        for i, j in T.grid(128, 128):
+            with T.block():
+                vi, vj = T.axis.remap("SS", [i, j])
+                A = T.match_buffer(vi)  # error
+                T.evaluate(1.0)
 
-def duplicate_buffer() -> None:
-    A = T.alloc_buffer((128, 128), "float32")
-    for i, j in T.grid(128, 128):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            A = T.alloc_buffer((128, 128), "float32")  # error
-            T.evaluate(1.0)
+    check_error(invalid_match_buffer_region, 5)
 
 
 def test_duplicate_buffer():
-    check_error(duplicate_buffer, 6)
+    def duplicate_buffer() -> None:
+        A = T.alloc_buffer((128, 128), "float32")
+        A = T.alloc_buffer((128, 128), "float32")  # error
 
-
-def duplicate_reads() -> None:
-    A = T.alloc_buffer((128, 128), "float32")
-    for i, j in T.grid(128, 128):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.reads(A[0:8, 0:8])
-            T.reads(A[0:16, 0:16])  # error
-            T.evaluate(1.0)
-
-
-def duplicate_writes() -> None:
-    A = T.alloc_buffer((128, 128), "float32")
-    for i, j in T.grid(128, 128):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.writes(A[0:8, 0:8])
-            T.writes(A[0:16, 0:16])  # error
-            T.evaluate(1.0)
+    check_error(duplicate_buffer, 3)
 
 
-def duplicate_predicate() -> None:
-    for i, j in T.grid(16, 16):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.where(1)
-            T.where(0)  # error
-
-
-def duplicate_annotations() -> None:
-    for i, j in T.grid(16, 16):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.block_attr({})
-            T.block_attr({})  # error
-
-
-def duplicate_init() -> None:
-    for i, j in T.grid(16, 16):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            with T.init():
+def test_duplicate_block_signature():
+    def duplicate_reads() -> None:
+        A = T.alloc_buffer((128, 128), "float32")
+        for i, j in T.grid(128, 128):
+            with T.block():
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.reads(A[0:8, 0:8])
+                T.reads(A[0:16, 0:16])  # error
                 T.evaluate(1.0)
-            with T.init():  # error
+
+    def duplicate_writes() -> None:
+        A = T.alloc_buffer((128, 128), "float32")
+        for i, j in T.grid(128, 128):
+            with T.block():
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.writes(A[0:8, 0:8])
+                T.writes(A[0:16, 0:16])  # error
                 T.evaluate(1.0)
 
+    def duplicate_predicate() -> None:
+        for i, j in T.grid(16, 16):
+            with T.block():
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.where(1)
+                T.where(0)  # error
 
-def duplicate_axes() -> None:
-    for i, j in T.grid(16, 16):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            vi = T.axis.S(i, 16)  # error
-            T.evaluate(1.0)
+    def duplicate_annotations() -> None:
+        for i, j in T.grid(16, 16):
+            with T.block():
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.block_attr({})
+                T.block_attr({})  # error
 
+    def duplicate_init() -> None:
+        for i, j in T.grid(16, 16):
+            with T.block():
+                vi, vj = T.axis.remap("SS", [i, j])
+                with T.init():
+                    T.evaluate(1.0)
+                with T.init():  # error
+                    T.evaluate(1.0)
+
+    def duplicate_axes() -> None:
+        for i, j in T.grid(16, 16):
+            with T.block():
+                vi, vj = T.axis.remap("SS", [i, j])
+                vi = T.axis.S(i, 16)  # error
+                T.evaluate(1.0)
 
-def test_duplicate_block_signature():
     check_error(duplicate_reads, 7)
     check_error(duplicate_writes, 7)
     check_error(duplicate_predicate, 6)
@@ -326,143 +320,105 @@ def test_duplicate_block_signature():
     check_error(duplicate_axes, 5)
 
 
-def opaque_access_during_complete(a: T.handle) -> None:  # error
-    A = T.match_buffer(a, (16, 16), "float32")
-    for i, j in T.grid(16, 16):
-        with T.block():
-            T.evaluate(T.call_extern("dummy_extern_function", A.data, 
dtype="int32"))
-
-
 def test_opaque_access_during_complete():
-    check_error(opaque_access_during_complete, 1)
-
+    def opaque_access_during_complete(a: T.handle) -> None:  # error
+        A = T.match_buffer(a, (16, 16), "float32")
+        for i, j in T.grid(16, 16):
+            with T.block():
+                T.evaluate(T.call_extern("dummy_extern_function", A.data, 
dtype="int32"))
 
-def convert_slice_to_bufferload() -> None:
-    A = T.alloc_buffer((128, 128), "float32")
-    for i, j in T.grid(128, 128):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            A[vi, vj] = A[vi : vi + 2, vj] + 1  # error
+    check_error(opaque_access_during_complete, None)
 
 
 def test_convert_slice_to_bufferload():
-    check_error(convert_slice_to_bufferload, 6)
-
-
-def error_index_type() -> None:
-    A = T.alloc_buffer((128, 128), "float32")
-    for i, j in T.grid(16, 16):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            A[vi, vj] = A[vi, 0.0] + 1  # error
-
-
-def error_bufferslice_index_type() -> None:
-    A = T.alloc_buffer((1,), "float32")
-    B = T.alloc_buffer((16, 16), "float32")
-    C = T.alloc_buffer((16, 16), "float32")
-    for i, j in T.grid(16, 16):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            C[vi, vj] = B[vi, A[0]]  # error
-
-
-def test_error_index_type():
-    check_error(error_index_type, 6)
-    check_error(error_bufferslice_index_type, 8)
-
-
-def special_stmt_except() -> None:
-    A = T.alloc_buffer("(128, 128)", "float32")  # error
-    T.evaluate(1.0)
+    def convert_slice_to_bufferload() -> None:
+        A = T.alloc_buffer((128, 128), "float32")
+        for i, j in T.grid(128, 128):
+            with T.block():
+                vi, vj = T.axis.remap("SS", [i, j])
+                A[vi, vj] = A[vi : vi + 2, vj] + 1  # error
 
-
-def scope_handler_except() -> None:
-    for i in T.serial("1", "1"):  # error
-        T.evaluate(1)
+    check_error(convert_slice_to_bufferload, 6)
 
 
-def intrin_except_unassign(a: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    T.evaluate(A)  # error
+def test_tvm_exception_catch():
+    def special_stmt_except() -> None:
+        A = T.alloc_buffer("(128, 128)", "float32")  # error
+        T.evaluate(1.0)
 
+    def scope_handler_except() -> None:
+        for i in T.serial("1", "1"):  # error
+            T.evaluate(1)
 
-def intrin_except_assign(a: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    A[0, 0] = A[A]  # error
+    def intrin_except_unassign(a: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        T.evaluate(A)  # error
 
+    def intrin_except_assign(a: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        A[0, 0] = A[A]  # error
 
-def test_tvm_exception_catch():
-    # test catching c++ side exception
     check_error(special_stmt_except, 2)
     check_error(scope_handler_except, 2)
     check_error(intrin_except_unassign, 3)
     check_error(intrin_except_assign, 3)
 
 
-def buffer_shape_mismatch(a: T.handle) -> None:
-    A = T.match_buffer(a, (8, 8))
-    for i, j in T.grid(8, 2):
-        with T.block():
-            T.reads([])
-            T.writes([A[i, j * 4 : j * 4 + 4]])
-            sub_A = T.match_buffer(
-                A[i, j * 4 : j * 4 + 4], (5)
-            )  # error: shape mismatched between 4 and 5
-            for jj in range(0, 4):
-                sub_A[i, j * 4 + jj] = 1
-
-
 def test_match_buffer_shape_mismatch():
-    check_error(buffer_shape_mismatch, 7)
-
+    def buffer_shape_mismatch(a: T.handle) -> None:
+        A = T.match_buffer(a, (8, 8))
+        for i, j in T.grid(8, 2):
+            with T.block():
+                T.reads([])
+                T.writes([A[i, j * 4 : j * 4 + 4]])
+                sub_A = T.match_buffer(
+                    A[i, j * 4 : j * 4 + 4], (5)
+                )  # error: shape mismatched between 4 and 5
+                for jj in range(0, 4):
+                    sub_A[i, j * 4 + jj] = 1
 
-def high_dim_store() -> None:
-    with T.block("root"):
-        B = T.allocate([256], "float32", "global")
-        for i, j in T.grid(16, 16):
-            B[i, j] = 1.0  # error: Store is only allowed with one index
+    check_error(buffer_shape_mismatch, 7)
 
 
 def test_high_dim_store():
-    check_error(high_dim_store, 5)
+    def high_dim_store() -> None:
+        with T.block("root"):
+            B = T.allocate([256], "float32", "global")
+            for i, j in T.grid(16, 16):
+                B[i, j] = 1.0  # error: Store is only allowed with one index
 
-
-def block_has_option_vars() -> None:
-    with T.block("root") as x:  # error: block does not support option_vars
-        T.evaluate(0.0)
+    check_error(high_dim_store, 5)
 
 
 def test_block_has_option_vars():
-    check_error(block_has_option_vars, 2)
-
-
-def implicit_root_has_read():
-    T.reads([])  # error: implicit root does not support reads
-    T.evaluate(0.0)
-
-
-def implicit_root_has_write():
-    T.writes([])  # error: implicit root does not support writes
-    T.evaluate(0.0)
+    def block_has_option_vars() -> None:
+        with T.block("root") as x:  # error: block does not support option_vars
+            T.evaluate(0.0)
 
+    check_error(block_has_option_vars, 2)
 
-def implicit_root_has_attrs():
-    T.block_attr({})  # error: implicit root does not support block_attr
-    T.evaluate(0.0)
 
+def test_implicit_root_has_attrs():
+    def implicit_root_has_read():
+        T.reads([])  # error: implicit root does not support reads
+        T.evaluate(0.0)
 
-def implicit_root_has_predicate():
-    T.where(True)  # error: implicit root does not support predicate
-    T.evaluate(0.0)
+    def implicit_root_has_write():
+        T.writes([])  # error: implicit root does not support writes
+        T.evaluate(0.0)
 
+    def implicit_root_has_attrs():
+        T.block_attr({})  # error: implicit root does not support block_attr
+        T.evaluate(0.0)
 
-def implicit_root_has_axes():
-    v = T.axis.S(0, 0)  # error: implicit root does not support axis define
-    T.evaluate(0.0)
+    def implicit_root_has_predicate():
+        T.where(True)  # error: implicit root does not support predicate
+        T.evaluate(0.0)
 
+    def implicit_root_has_axes():
+        v = T.axis.S(0, 0)  # error: implicit root does not support axis define
+        T.evaluate(0.0)
 
-def test_implicit_root_has_attrs():
     check_error(implicit_root_has_read, 2)
     check_error(implicit_root_has_write, 2)
     check_error(implicit_root_has_attrs, 2)
@@ -554,127 +510,115 @@ def test_report_error_root_block():
     assert expected_sub_error_message in str(execinfo.value)
 
 
-def load_var_multiple() -> None:
-    d = T.var("float32")
-    d[2] = d[2, 1]  # error cannot provide two indices to load
-
-
 def test_load_var():
-    check_error(load_var_multiple, 3)
-
+    def load_var_multiple() -> None:
+        d = T.var("float32")
+        d[2] = d[2, 1]  # error cannot provide two indices to load
 
-def store_var_multiple() -> None:
-    d = T.var("float32")
-    d[2, 1] = d[1]  # error cannot provide two indices to store
+    check_error(load_var_multiple, 3)
 
 
 def test_store_var():
-    check_error(store_var_multiple, 3)
-
+    def store_var_multiple() -> None:
+        d = T.var("float32")
+        d[2, 1] = d[1]  # error cannot provide two indices to store
 
-def load_handle(h: T.handle) -> None:
-    h_ = T.match_buffer(h, [1])
-    h_[0] = h[0]  # error cannot load from handle
+    check_error(store_var_multiple, 3)
 
 
 def test_load_handle():
-    check_error(load_var_multiple, 3)
+    def load_handle(h: T.handle) -> None:
+        h_ = T.match_buffer(h, [1])
+        h_[0] = h[0]  # error cannot load from handle
 
-
-def store_handle(h: T.handle) -> None:
-    h_ = T.match_buffer(h, [1])
-    h[0] = h_[0]  # error cannot store to handle
+    check_error(load_handle, 3)
 
 
 def test_store_handle():
-    check_error(store_var_multiple, 3)
-
+    def store_handle(h: T.handle) -> None:
+        h_ = T.match_buffer(h, [1])
+        h[0] = h_[0]  # error cannot store to handle
 
-def binop_bad_ast_type(h: T.handle):
-    h_ = T.match_buffer(h, [1])
-    h_[0] = h + [2]  # error rhs should be a primexpr
+    check_error(store_handle, 3)
 
 
 def test_binop_bad_ast_type():
-    check_error(binop_bad_ast_type, 3)
-
+    def binop_bad_ast_type(h: T.handle):
+        h_ = T.match_buffer(h, [1])
+        h_[0] = h + [2]  # error rhs should be a primexpr
 
-def binop_bad_type(h: T.handle):
-    h_ = T.match_buffer(h, [1])
-    h_[0] = h + 2  # error lhs and rhs should be the same type
+    check_error(binop_bad_ast_type, 3)
 
 
 def test_binop_bad_type():
-    check_error(binop_bad_type, 3)
-
-
-def floor_dtype(h: T.handle):
-    h_ = T.match_buffer(h, [1])
-    h_[0] = T.floor(2)  # error floor requires a dtype
-
+    def binop_bad_type(h: T.handle):
+        h_ = T.match_buffer(h, [1])
+        h_[0] = h + 2  # error lhs and rhs should be the same type
 
-def test_floor_dtype():
-    check_error(floor_dtype, 3)
-
-
-def non_integer_typed_block_iter():
-    with T.block():
-        i = T.axis.S(0.1, 0.1)  # error IterVar requires an integer dtype
+    check_error(binop_bad_type, 3)
 
 
 def test_non_integer_typed_block_iter():
-    check_error(non_integer_typed_block_iter, 3)
-
+    def non_integer_typed_block_iter():
+        with T.block():
+            i = T.axis.S(0.1, 0.1)  # error IterVar requires an integer dtype
 
-def preflattened_buffer_map_align_nonint(foo: T.handle):
-    foo_1 = T.match_buffer(foo, [1])
-    T.preflattened_buffer(
-        foo_1, [1], align="bar"
-    )  # check_error: align: want int or IntImm, got 'bar'
+    check_error(non_integer_typed_block_iter, 3)
 
 
 def test_preflattened_buffer_map_align():
-    check_error(preflattened_buffer_map_align_nonint, 3)
-
+    def preflattened_buffer_map_align_nonint(foo: T.handle):
+        foo_1 = T.match_buffer(foo, [1])
+        T.preflattened_buffer(
+            foo_1, [1], align="bar"
+        )  # check_error: align: want int or IntImm, got 'bar'
 
-def preflattened_buffer_map_offset_factor_nonint(foo: T.handle):
-    foo_1 = T.match_buffer(foo, [1])
-    T.preflattened_buffer(
-        foo_1, [1], offset_factor="bar"
-    )  # check_error: offset_factor: want int or IntImm, got 'bar'
+    check_error(preflattened_buffer_map_align_nonint, 3)
 
 
 def test_preflattened_buffer_map_offset_factor():
-    check_error(preflattened_buffer_map_offset_factor_nonint, 3)
-
-
-def strided_buffer_region(A: T.handle):
-    # do not allow stride in buffer region
-    A = T.match_buffer((128, 128), "int32")
-    with T.block():
-        T.reads([])
-        T.writes([A[0:128:2, 0:128:3]])  # error
-        T.evaluate(T.call_extern("strided_compute", dtype=""))
+    def preflattened_buffer_map_offset_factor_nonint(foo: T.handle):
+        foo_1 = T.match_buffer(foo, [1])
+        T.preflattened_buffer(
+            foo_1, [1], offset_factor="bar"
+        )  # check_error: offset_factor: want int or IntImm, got 'bar'
 
+    check_error(preflattened_buffer_map_offset_factor_nonint, 3)
 
-def access_reversed_slice(A: T.handle):
-    # do not allow reversed slice step
-    A = T.match_buffer((128,), "int32")
-    A[0:128:-1] = T.broadcast(1, 128)  # error
 
+def test_illegal_buffer_slice():
+    def strided_buffer_region(A: T.handle):
+        # do not allow stride in buffer region
+        A = T.match_buffer((128, 128), "int32")
+        with T.block():
+            T.reads([])
+            T.writes([A[0:128:2, 0:128:3]])  # error
+            T.evaluate(T.call_extern("strided_compute", dtype=""))
 
-def access_non_const_slice_length(A: T.handle):
-    # do not allow non-constant slice length
-    A = T.match_buffer((128,), "int32")
-    for i in range(4):
-        T.evaluate(A[0:i:1])  # error
+    def access_reversed_slice(A: T.handle):
+        # do not allow reversed slice step
+        A = T.match_buffer((128,), "int32")
+        A[0:128:-1] = T.broadcast(1, 128)  # error
 
+    def access_non_const_slice_length(A: T.handle):
+        # do not allow non-constant slice length
+        A = T.match_buffer((128,), "int32")
+        for i in range(4):
+            T.evaluate(A[0:i:1])  # error
 
-def test_illegal_buffer_slice():
     check_error(strided_buffer_region, 3)
     check_error(access_reversed_slice, 3)
     check_error(access_non_const_slice_length, 3)
 
 
+def test_syntax_sugar_fail():
+    def loop_syntax_sugar_fail(a: T.handle) -> None:
+        A = T.match_buffer(a, (128,))
+        for i in T.thread_binding(128, 128):
+            A[i] = A[i] * 2.0
+
+    check_error(loop_syntax_sugar_fail, 3)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py 
b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index 849b0fc03d..32572d392c 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -20,9 +20,8 @@ import sys
 import pytest
 import tvm.testing
 from tvm.ir import assert_structural_equal
+from tvm.script import from_source
 from tvm.script import tir as T
-from tvm.script.parser import from_source
-from tvm.testing import check_error
 
 
 @T.prim_func
@@ -89,20 +88,10 @@ def loop_syntax_sugar(a: T.handle) -> None:
                             A[i, j, k, x] = A[i, j, k, x] * 2.0
 
 
-def loop_syntax_sugar_fail(a: T.handle) -> None:
-    A = T.match_buffer(a, (128,))
-    for i in T.thread_binding(128, 128):
-        A[i] = A[i] * 2.0
-
-
 def test_loop_syntax_sugar():
     assert_structural_equal(loop_no_syntax_sugar, loop_syntax_sugar)
 
 
-def test_syntax_sugar_fail():
-    check_error(loop_syntax_sugar_fail, 3)
-
-
 # match buffer - use kwargs
 @T.prim_func
 def elementwise_handle(

Reply via email to