This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d9e01475af [UnitTest][TIR] Support IRModule comparisons in
CompareBeforeAfter (#12920)
d9e01475af is described below
commit d9e01475af1253ca3fb52d7ad91165407ca8e740
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Oct 7 10:43:16 2022 -0500
[UnitTest][TIR] Support IRModule comparisons in CompareBeforeAfter (#12920)
A follow-up commit from https://github.com/apache/tvm/pull/12264.
This allows the before/expected fixtures generated by
`tvm.testing.CompareBeforeAfter` to be `IRModule` instances as well as
`PrimFunc`. This is intended to allow testing that requires comparing
more than one function (e.g. hoisting/fusing a PrimFunc).
* Prevent circular fixture references
---
python/tvm/testing/utils.py | 105 ++++++++++++++-------
.../unittest/test_tvm_testing_before_after.py | 49 +++++++++-
2 files changed, 117 insertions(+), 37 deletions(-)
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 1c4dcba29d..f89d5e6369 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -1914,10 +1914,7 @@ class CompareBeforeAfter:
cls.transform = cls._normalize_transform(cls.transform)
@classmethod
- def _normalize_before(cls, func):
- if hasattr(func, "_pytestfixturefunction"):
- return func
-
+ def _normalize_ir_module(cls, func):
if isinstance(func, tvm.tir.PrimFunc):
def inner(self):
@@ -1930,6 +1927,22 @@ class CompareBeforeAfter:
# pylint: disable=unused-argument
return func(self)
+ elif inspect.isclass(func):
+
+ def inner(self):
+ # pylint: disable=unused-argument
+ func_dict = {}
+ for name, method in func.__dict__.items():
+ if name.startswith("_"):
+ pass
+ elif isinstance(method, tvm.ir.function.BaseFunc):
+ func_dict[name] = method
+ else:
+ source_code = "@T.prim_func\n" +
textwrap.dedent(inspect.getsource(method))
+ prim_func = tvm.script.from_source(source_code)
+ func_dict[name] = prim_func
+ return tvm.IRModule(func_dict)
+
else:
def inner(self):
@@ -1939,50 +1952,64 @@ class CompareBeforeAfter:
return pytest.fixture(inner)
+ @classmethod
+ def _normalize_before(cls, func):
+ if hasattr(func, "_pytestfixturefunction"):
+ return func
+ else:
+ return cls._normalize_ir_module(func)
+
@classmethod
def _normalize_expected(cls, func):
if hasattr(func, "_pytestfixturefunction"):
return func
- if isinstance(func, tvm.tir.PrimFunc) or (
- inspect.isclass(func) and issubclass(func, Exception)
- ):
+ elif inspect.isclass(func) and issubclass(func, Exception):
def inner(self):
# pylint: disable=unused-argument
return func
- elif cls._is_method(func):
-
- def inner(self):
- # pylint: disable=unused-argument
- return func(self)
+ return pytest.fixture(inner)
else:
-
- def inner(self):
- # pylint: disable=unused-argument
- source_code = "@T.prim_func\n" +
textwrap.dedent(inspect.getsource(func))
- return tvm.script.from_source(source_code)
-
- return pytest.fixture(inner)
+ return cls._normalize_ir_module(func)
@classmethod
def _normalize_transform(cls, transform):
+ def apply(module_transform):
+ def inner(obj):
+ if isinstance(obj, tvm.IRModule):
+ return module_transform(obj)
+ elif isinstance(obj, tvm.tir.PrimFunc):
+ mod = tvm.IRModule({"main": obj})
+ mod = module_transform(mod)
+ return mod["main"]
+ else:
+ raise TypeError(f"Expected IRModule or PrimFunc, but
received {type(obj)}")
+
+ return inner
+
if hasattr(transform, "_pytestfixturefunction"):
- return transform
- if isinstance(transform, tvm.ir.transform.Pass):
+ if not hasattr(cls, "_transform_orig"):
+ cls._transform_orig = transform
+
+ def inner(self, _transform_orig):
+ # pylint: disable=unused-argument
+ return apply(_transform_orig)
+
+ elif isinstance(transform, tvm.ir.transform.Pass):
def inner(self):
# pylint: disable=unused-argument
- return transform
+ return apply(transform)
elif cls._is_method(transform):
def inner(self):
# pylint: disable=unused-argument
- return transform(self)
+ return apply(transform(self))
else:
@@ -2000,36 +2027,42 @@ class CompareBeforeAfter:
def test_compare(self, before, expected, transform):
"""Unit test to compare the expected TIR PrimFunc to actual"""
- before_mod = tvm.IRModule.from_expr(before)
+ def pprint(name, obj):
+ script = obj.script()
+ if isinstance(obj, tvm.IRModule):
+ return script.replace("class Module", f"class {name}")
+ else:
+ return script.replace("def func", f"def {name}")
if inspect.isclass(expected) and issubclass(expected, Exception):
with pytest.raises(expected):
- after_mod = transform(before_mod)
+ after = transform(before)
# This portion through pytest.fail isn't strictly
# necessary, but gives a better error message that
# includes the before/after.
- after = after_mod["main"]
- script = tvm.IRModule({"after": after, "before":
before}).script()
+ before_str = pprint("before", before)
+ after_str = pprint("after", after)
+
pytest.fail(
msg=(
f"Expected {expected.__name__} to be raised from
transformation, "
- f"instead received TIR\n:{script}"
+ f"instead received TIR\n:{before_str}\n{after_str}"
)
)
- elif isinstance(expected, tvm.tir.PrimFunc):
- after_mod = transform(before_mod)
- after = after_mod["main"]
+ elif isinstance(expected, (tvm.tir.PrimFunc, tvm.ir.IRModule)):
+ after = transform(before)
try:
tvm.ir.assert_structural_equal(after, expected)
except ValueError as err:
- script = tvm.IRModule(
- {"expected": expected, "after": after, "before": before}
- ).script()
+ before_str = pprint("before", before)
+ after_str = pprint("after", after)
+ expected_str = pprint("expected", expected)
raise ValueError(
- f"TIR after transformation did not match
expected:\n{script}"
+ f"TIR after transformation did not match expected:\n"
+ f"{before_str}\n{after_str}\n{expected_str}"
) from err
else:
@@ -2037,5 +2070,5 @@ class CompareBeforeAfter:
f"tvm.testing.CompareBeforeAfter requires the `expected`
fixture "
f"to return either `Exception`, an `Exception` subclass, "
f"or an instance of `tvm.tir.PrimFunc`. "
- f"Instead, received {type(exception)}."
+ f"Instead, received {type(expected)}."
)
diff --git a/tests/python/unittest/test_tvm_testing_before_after.py
b/tests/python/unittest/test_tvm_testing_before_after.py
index 613d66ccdb..946493922e 100644
--- a/tests/python/unittest/test_tvm_testing_before_after.py
+++ b/tests/python/unittest/test_tvm_testing_before_after.py
@@ -18,7 +18,7 @@
import tvm
import tvm.testing
-from tvm.script import tir as T
+from tvm.script import tir as T, ir_module
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
@@ -79,5 +79,52 @@ class TestBeforeAfterParametrizedFixture(BaseBeforeAfter):
expected = before
+class TestBeforeAfterIRModule(BaseBeforeAfter):
+ """The preferred form for writing TIR unit tests
+
+ All evaluation is done at test-time, with the minimal amount of
+ additional lines. The `@tvm.testing.fixture`, `@ir_module`, and
+ `@T.prim_func` annotations are handled by
+ `tvm.testing.CompareBeforeAfter`.
+ """
+
+ class before:
+ def func_A(A: T.Buffer[16, "float32"]):
+ for i in T.serial(16):
+ A[i] = 0.0
+
+ def func_B(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 42
+
+ expected = before
+
+
+class TestBeforeAfterIRModuleExplicitFixture(BaseBeforeAfter):
+ """Like TestBeforeAfterIRModule, but with an explicit fixture
+
+ If the IRModule depends on additional fixtures, this form can be
+ used.
+ """
+
+ @tvm.testing.fixture
+ def before(self):
+ @ir_module
+ class mod:
+ @T.prim_func
+ def func_A(A: T.Buffer[16, "float32"]):
+ for i in T.serial(16):
+ A[i] = 0.0
+
+ @T.prim_func
+ def func_B(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 42
+
+ return mod
+
+ expected = before
+
+
if __name__ == "__main__":
tvm.testing.main()