This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d9e01475af [UnitTest][TIR] Support IRModule comparisons in 
CompareBeforeAfter (#12920)
d9e01475af is described below

commit d9e01475af1253ca3fb52d7ad91165407ca8e740
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Oct 7 10:43:16 2022 -0500

    [UnitTest][TIR] Support IRModule comparisons in CompareBeforeAfter (#12920)
    
    A follow-up commit from https://github.com/apache/tvm/pull/12264.
    This allows the before/expected fixtures generated by
    `tvm.testing.CompareBeforeAfter` to be `IRModule` instances as well as
    `PrimFunc`.  This is intended to allow testing that requires comparing
    more than one function (e.g. hoisting/fusing a PrimFunc).
    
    * Prevent circular fixture references
---
 python/tvm/testing/utils.py                        | 105 ++++++++++++++-------
 .../unittest/test_tvm_testing_before_after.py      |  49 +++++++++-
 2 files changed, 117 insertions(+), 37 deletions(-)

diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 1c4dcba29d..f89d5e6369 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -1914,10 +1914,7 @@ class CompareBeforeAfter:
             cls.transform = cls._normalize_transform(cls.transform)
 
     @classmethod
-    def _normalize_before(cls, func):
-        if hasattr(func, "_pytestfixturefunction"):
-            return func
-
+    def _normalize_ir_module(cls, func):
         if isinstance(func, tvm.tir.PrimFunc):
 
             def inner(self):
@@ -1930,6 +1927,22 @@ class CompareBeforeAfter:
                 # pylint: disable=unused-argument
                 return func(self)
 
+        elif inspect.isclass(func):
+
+            def inner(self):
+                # pylint: disable=unused-argument
+                func_dict = {}
+                for name, method in func.__dict__.items():
+                    if name.startswith("_"):
+                        pass
+                    elif isinstance(method, tvm.ir.function.BaseFunc):
+                        func_dict[name] = method
+                    else:
+                        source_code = "@T.prim_func\n" + 
textwrap.dedent(inspect.getsource(method))
+                        prim_func = tvm.script.from_source(source_code)
+                        func_dict[name] = prim_func
+                return tvm.IRModule(func_dict)
+
         else:
 
             def inner(self):
@@ -1939,50 +1952,64 @@ class CompareBeforeAfter:
 
         return pytest.fixture(inner)
 
+    @classmethod
+    def _normalize_before(cls, func):
+        if hasattr(func, "_pytestfixturefunction"):
+            return func
+        else:
+            return cls._normalize_ir_module(func)
+
     @classmethod
     def _normalize_expected(cls, func):
         if hasattr(func, "_pytestfixturefunction"):
             return func
 
-        if isinstance(func, tvm.tir.PrimFunc) or (
-            inspect.isclass(func) and issubclass(func, Exception)
-        ):
+        elif inspect.isclass(func) and issubclass(func, Exception):
 
             def inner(self):
                 # pylint: disable=unused-argument
                 return func
 
-        elif cls._is_method(func):
-
-            def inner(self):
-                # pylint: disable=unused-argument
-                return func(self)
+            return pytest.fixture(inner)
 
         else:
-
-            def inner(self):
-                # pylint: disable=unused-argument
-                source_code = "@T.prim_func\n" + 
textwrap.dedent(inspect.getsource(func))
-                return tvm.script.from_source(source_code)
-
-        return pytest.fixture(inner)
+            return cls._normalize_ir_module(func)
 
     @classmethod
     def _normalize_transform(cls, transform):
+        def apply(module_transform):
+            def inner(obj):
+                if isinstance(obj, tvm.IRModule):
+                    return module_transform(obj)
+                elif isinstance(obj, tvm.tir.PrimFunc):
+                    mod = tvm.IRModule({"main": obj})
+                    mod = module_transform(mod)
+                    return mod["main"]
+                else:
+                    raise TypeError(f"Expected IRModule or PrimFunc, but 
received {type(obj)}")
+
+            return inner
+
         if hasattr(transform, "_pytestfixturefunction"):
-            return transform
 
-        if isinstance(transform, tvm.ir.transform.Pass):
+            if not hasattr(cls, "_transform_orig"):
+                cls._transform_orig = transform
+
+            def inner(self, _transform_orig):
+                # pylint: disable=unused-argument
+                return apply(_transform_orig)
+
+        elif isinstance(transform, tvm.ir.transform.Pass):
 
             def inner(self):
                 # pylint: disable=unused-argument
-                return transform
+                return apply(transform)
 
         elif cls._is_method(transform):
 
             def inner(self):
                 # pylint: disable=unused-argument
-                return transform(self)
+                return apply(transform(self))
 
         else:
 
@@ -2000,36 +2027,42 @@ class CompareBeforeAfter:
     def test_compare(self, before, expected, transform):
         """Unit test to compare the expected TIR PrimFunc to actual"""
 
-        before_mod = tvm.IRModule.from_expr(before)
+        def pprint(name, obj):
+            script = obj.script()
+            if isinstance(obj, tvm.IRModule):
+                return script.replace("class Module", f"class {name}")
+            else:
+                return script.replace("def func", f"def {name}")
 
         if inspect.isclass(expected) and issubclass(expected, Exception):
             with pytest.raises(expected):
-                after_mod = transform(before_mod)
+                after = transform(before)
 
                 # This portion through pytest.fail isn't strictly
                 # necessary, but gives a better error message that
                 # includes the before/after.
-                after = after_mod["main"]
-                script = tvm.IRModule({"after": after, "before": 
before}).script()
+                before_str = pprint("before", before)
+                after_str = pprint("after", after)
+
                 pytest.fail(
                     msg=(
                         f"Expected {expected.__name__} to be raised from 
transformation, "
-                        f"instead received TIR\n:{script}"
+                        f"instead received TIR\n:{before_str}\n{after_str}"
                     )
                 )
 
-        elif isinstance(expected, tvm.tir.PrimFunc):
-            after_mod = transform(before_mod)
-            after = after_mod["main"]
+        elif isinstance(expected, (tvm.tir.PrimFunc, tvm.ir.IRModule)):
+            after = transform(before)
 
             try:
                 tvm.ir.assert_structural_equal(after, expected)
             except ValueError as err:
-                script = tvm.IRModule(
-                    {"expected": expected, "after": after, "before": before}
-                ).script()
+                before_str = pprint("before", before)
+                after_str = pprint("after", after)
+                expected_str = pprint("expected", expected)
                 raise ValueError(
-                    f"TIR after transformation did not match 
expected:\n{script}"
+                    f"TIR after transformation did not match expected:\n"
+                    f"{before_str}\n{after_str}\n{expected_str}"
                 ) from err
 
         else:
@@ -2037,5 +2070,5 @@ class CompareBeforeAfter:
                 f"tvm.testing.CompareBeforeAfter requires the `expected` 
fixture "
                 f"to return either `Exception`, an `Exception` subclass, "
                 f"or an instance of `tvm.tir.PrimFunc`.  "
-                f"Instead, received {type(exception)}."
+                f"Instead, received {type(expected)}."
             )
diff --git a/tests/python/unittest/test_tvm_testing_before_after.py 
b/tests/python/unittest/test_tvm_testing_before_after.py
index 613d66ccdb..946493922e 100644
--- a/tests/python/unittest/test_tvm_testing_before_after.py
+++ b/tests/python/unittest/test_tvm_testing_before_after.py
@@ -18,7 +18,7 @@
 
 import tvm
 import tvm.testing
-from tvm.script import tir as T
+from tvm.script import tir as T, ir_module
 
 
 class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
@@ -79,5 +79,52 @@ class TestBeforeAfterParametrizedFixture(BaseBeforeAfter):
     expected = before
 
 
+class TestBeforeAfterIRModule(BaseBeforeAfter):
+    """The preferred form for writing TIR unit tests
+
+    All evaluation is done at test-time, with the minimal amount of
+    additional lines.  The `@tvm.testing.fixture`, `@ir_module`, and
+    `@T.prim_func` annotations are handled by
+    `tvm.testing.CompareBeforeAfter`.
+    """
+
+    class before:
+        def func_A(A: T.Buffer[16, "float32"]):
+            for i in T.serial(16):
+                A[i] = 0.0
+
+        def func_B(A: T.Buffer[16, "int32"]):
+            for i in T.serial(16):
+                A[i] = 42
+
+    expected = before
+
+
+class TestBeforeAfterIRModuleExplicitFixture(BaseBeforeAfter):
+    """Like TestBeforeAfterIRModule, but with an explicit fixture
+
+    If the IRModule depends on additional fixtures, this form can be
+    used.
+    """
+
+    @tvm.testing.fixture
+    def before(self):
+        @ir_module
+        class mod:
+            @T.prim_func
+            def func_A(A: T.Buffer[16, "float32"]):
+                for i in T.serial(16):
+                    A[i] = 0.0
+
+            @T.prim_func
+            def func_B(A: T.Buffer[16, "int32"]):
+                for i in T.serial(16):
+                    A[i] = 42
+
+        return mod
+
+    expected = before
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to