This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 35a35b8434 [Tests][Refactor] Remove unused testing helpers (#19800)
35a35b8434 is described below
commit 35a35b8434c36cbaa8ff2b3854a4528ed83d2856
Author: Shushi Hong <[email protected]>
AuthorDate: Tue Jun 16 17:51:23 2026 -0400
[Tests][Refactor] Remove unused testing helpers (#19800)
CompareBeforeAfter, skip_parameterizations, and xfail_parameterizations
have no remaining users anywhere in the repo. CompareBeforeAfter (a base
class for TIR before/after transform tests) has been superseded by the
inline assert_structural_equal(transform(Before), Expected) pattern, and
the {skip,xfail}_parameterizations helpers (which marked specific
parametrizations at runtime) are unused -- native pytest.param(...,
marks=...) covers that need.
Also drop the private _mark_parameterizations helper they relied on and
the now-unused 'import textwrap'.
---
python/tvm/testing/utils.py | 280 --------------------------------------------
1 file changed, 280 deletions(-)
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 9adeba689b..c686bc1184 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -74,7 +74,6 @@ import os
import pickle
import platform
import sys
-import textwrap
import time
from pathlib import Path
@@ -1170,40 +1169,6 @@ def install_request_hook(depth: int) -> None:
request_hook.init()
-def _mark_parameterizations(*params, marker_fn, reason):
- """
- Mark tests with a nodeid parameters that exactly matches one in params.
- Useful for quickly marking tests as xfail when they have a large
- combination of parameters.
- """
- params = set(params)
-
- def decorator(func):
- @functools.wraps(func)
- def wrapper(request, *args, **kwargs):
- if "[" in request.node.name and "]" in request.node.name:
- # Strip out the test name and the [ and ] brackets
- params_from_name =
request.node.name[len(request.node.originalname) + 1 : -1]
- if params_from_name in params:
- marker_fn(
- reason=f"{marker_fn.__name__} on nodeid
{request.node.nodeid}: " + reason
- )
-
- return func(request, *args, **kwargs)
-
- return wrapper
-
- return decorator
-
-
-def xfail_parameterizations(*xfail_params, reason):
- return _mark_parameterizations(*xfail_params, marker_fn=pytest.xfail,
reason=reason)
-
-
-def skip_parameterizations(*skip_params, reason):
- return _mark_parameterizations(*skip_params, marker_fn=pytest.skip,
reason=reason)
-
-
def strtobool(val):
"""Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
@@ -1224,251 +1189,6 @@ def main():
sys.exit(pytest.main([test_file, *sys.argv[1:]]))
-class CompareBeforeAfter:
- """Utility for comparing before/after of TIR transforms
-
- A standard framework for writing tests that take a TIR PrimFunc as
- input, apply a transformation, then either compare against an
- expected output or assert that the transformation raised an error.
- A test should subclass CompareBeforeAfter, defining class members
- `before` / `Before`, `transform`, and `expected` / `Expected`.
CompareBeforeAfter will
- then use these members to define a test method and test fixture.
-
- `transform` may be one of the following.
-
- - An instance of `tvm.ir.transform.Pass`
-
- - A method that takes no arguments and returns a `tvm.ir.transform.Pass`
-
- - A pytest fixture that returns a `tvm.ir.transform.Pass`
-
- `before` / `Before` may be any one of the following.
-
- - An instance of `tvm.tirx.PrimFunc`. This is allowed, but is not
- the preferred method, as any errors in constructing the
- `PrimFunc` occur while collecting the test, preventing any other
- tests in the same file from being run.
-
- - An TVMScript function, without the ``@T.prim_func`` decoration.
- The ``@T.prim_func`` decoration will be applied when running the
- test, rather than at module import.
-
- - A method that takes no arguments and returns a `tvm.tirx.PrimFunc`
-
- - A pytest fixture that returns a `tvm.tirx.PrimFunc`
-
- `expected` / `Expected` may be any one of the following. The type of
- `expected` / `Expected` defines the test being performed. If `expected`
- provides a `tvm.tirx.PrimFunc`, the result of the transformation
- must match `expected`. If `expected` is an exception, then the
- transformation must raise that exception type.
-
- - Any option supported for `before` / `Before`.
-
- - The `Exception` class object, or a class object that inherits
- from `Exception`.
-
- - A method that takes no arguments and returns `Exception` or a
- class object that inherits from `Exception`.
-
- - A pytest fixture that returns `Exception` or an class object
- that inherits from `Exception`.
-
- Examples
- --------
-
- .. code-block:: python
-
- class TestRemoveIf(tvm.testing.CompareBeforeAfter):
- transform = tvm.tirx.transform.StmtSimplify()
-
- def before(A: T.Buffer(1, "int32")):
- if True:
- A[0] = 42
- else:
- A[0] = 5
-
- def expected(A: T.Buffer(1, "int32")):
- A[0] = 42
-
- """
-
- check_well_formed: bool = True
-
- def __init_subclass__(cls):
- assert len([getattr(cls, name) for name in ["before", "Before"] if
hasattr(cls, name)]) <= 1
- assert (
- len([getattr(cls, name) for name in ["expected", "Expected"] if
hasattr(cls, name)])
- <= 1
- )
- for name in ["before", "Before"]:
- if hasattr(cls, name):
- cls.before = cls._normalize_before(getattr(cls, name))
- break
- for name in ["expected", "Expected"]:
- if hasattr(cls, name):
- cls.expected = cls._normalize_expected(getattr(cls, name))
- break
- if hasattr(cls, "transform"):
- cls.transform = cls._normalize_transform(cls.transform)
-
- @classmethod
- def _normalize_ir_module(cls, func):
- if isinstance(func, tvm.tirx.PrimFunc | tvm.IRModule):
-
- def inner(self):
- # pylint: disable=unused-argument
- return func
-
- elif cls._is_method(func):
-
- def inner(self):
- # pylint: disable=unused-argument
- return func(self)
-
- elif inspect.isclass(func):
-
- def inner(self):
- # pylint: disable=unused-argument
- func_dict = {}
- for name, method in func.__dict__.items():
- if name.startswith("_"):
- pass
- elif isinstance(method, tvm.ir.function.BaseFunc):
- func_dict[name] = method.with_attr("global_symbol",
name)
- else:
- source_code = "@T.prim_func\n" +
textwrap.dedent(inspect.getsource(method))
- prim_func = tvm.script.from_source(
- source_code,
check_well_formed=self.check_well_formed
- )
- func_dict[name] = prim_func.with_attr("global_symbol",
name)
- return tvm.IRModule(func_dict)
-
- else:
-
- def inner(self):
- # pylint: disable=unused-argument
- source_code = "@T.prim_func\n" +
textwrap.dedent(inspect.getsource(func))
- return tvm.script.from_source(source_code,
check_well_formed=self.check_well_formed)
-
- return pytest.fixture(inner)
-
- @classmethod
- def _normalize_before(cls, func):
- if hasattr(func, "_pytestfixturefunction"):
- return func
- else:
- return cls._normalize_ir_module(func)
-
- @classmethod
- def _normalize_expected(cls, func):
- if hasattr(func, "_pytestfixturefunction"):
- return func
-
- elif inspect.isclass(func) and issubclass(func, Exception):
-
- def inner(self):
- # pylint: disable=unused-argument
- return func
-
- return pytest.fixture(inner)
-
- else:
- return cls._normalize_ir_module(func)
-
- @classmethod
- def _normalize_transform(cls, transform):
- def apply(module_transform):
- def inner(obj):
- if isinstance(obj, tvm.IRModule):
- return module_transform(obj)
- elif isinstance(obj, tvm.tirx.PrimFunc):
- mod = tvm.IRModule({"main": obj})
- mod = module_transform(mod)
- return mod["main"]
- else:
- raise TypeError(f"Expected IRModule or PrimFunc, but
received {type(obj)}")
-
- return inner
-
- if hasattr(transform, "_pytestfixturefunction"):
- if not hasattr(cls, "_transform_orig"):
- cls._transform_orig = transform
-
- def inner(self, _transform_orig):
- # pylint: disable=unused-argument
- return apply(_transform_orig)
-
- elif isinstance(transform, tvm.ir.transform.Pass):
-
- def inner(self):
- # pylint: disable=unused-argument
- return apply(transform)
-
- elif cls._is_method(transform):
-
- def inner(self):
- # pylint: disable=unused-argument
- return apply(transform(self))
-
- else:
- raise TypeError(
- "Expected transform to be a tvm.ir.transform.Pass, or a method
returning a Pass"
- )
-
- return pytest.fixture(inner)
-
- @staticmethod
- def _is_method(func):
- return callable(func) and "self" in inspect.signature(func).parameters
-
- def test_compare(self, before, expected, transform):
- """Unit test to compare the expected TIR PrimFunc to actual"""
-
- if inspect.isclass(expected) and issubclass(expected, Exception):
- with pytest.raises(expected):
- after = transform(before)
-
- # This portion through pytest.fail isn't strictly
- # necessary, but gives a better error message that
- # includes the before/after.
- before_str = before.script(name="before")
- after_str = after.script(name="after")
-
- pytest.fail(
- msg=(
- f"Expected {expected.__name__} to be raised from
transformation, "
- f"instead received TIR\n:{before_str}\n{after_str}"
- )
- )
-
- elif isinstance(expected, tvm.tirx.PrimFunc | tvm.ir.IRModule):
- after = transform(before)
-
- try:
- # overwrite global symbol so it doesn't come up in the
comparison
- if isinstance(after, tvm.tirx.PrimFunc):
- after = after.with_attr("global_symbol", "main")
- expected = expected.with_attr("global_symbol", "main")
- tvm.ir.assert_structural_equal(after, expected)
- except ValueError as err:
- before_str = before.script(name="before")
- after_str = after.script(name="after")
- expected_str = expected.script(name="expected")
- raise ValueError(
- f"TIR after transformation did not match expected:\n"
- f"{before_str}\n{after_str}\n{expected_str}"
- ) from err
-
- else:
- raise TypeError(
- f"tvm.testing.CompareBeforeAfter requires the `expected`
fixture "
- f"to return either `Exception`, an `Exception` subclass, "
- f"or an instance of `tvm.tirx.PrimFunc`. "
- f"Instead, received {type(expected)}."
- )
-
-
ml_dtypes_dict = {
"float8_e4m3fn": ml_dtypes.float8_e4m3fn,
"float8_e5m2": ml_dtypes.float8_e5m2,