Lunderberg commented on a change in pull request #8010:
URL: https://github.com/apache/tvm/pull/8010#discussion_r658963278



##########
File path: python/tvm/testing.py
##########
@@ -718,31 +844,421 @@ def parametrize_targets(*args):
 
     Example
     -------
-    >>> @tvm.testing.parametrize
+    >>> @tvm.testing.parametrize_targets("llvm", "cuda")
+    >>> def test_mytest(target, dev):
+    >>>     ...  # do something
+    """
+
+    def wrap(targets):
+        def func(f):
+            return pytest.mark.parametrize(
+                "target", _pytest_target_params(targets), scope="session"
+            )(f)
+
+        return func
+
+    if len(args) == 1 and callable(args[0]):
+        return wrap(None)(args[0])
+    return wrap(args)
+
+
+def exclude_targets(*args):
+    """Exclude a test from running on a particular target.
+
+    Use this decorator when you want your test to be run over a
+    variety of targets and devices (including cpu and gpu devices),
+    but want to exclude some particular target or targets.  For
+    example, a test may wish to be run against all targets in
+    tvm.testing.enabled_targets(), except for a particular target that
+    does not support the capabilities.
+
+    Applies pytest.mark.skipif to the targets given.
+
+    Parameters
+    ----------
+    f : function
+        Function to parametrize. Must be of the form `def 
test_xxxxxxxxx(target, dev)`:,
+        where `xxxxxxxxx` is any name.
+    targets : list[str]
+        Set of targets to exclude.
+
+    Example
+    -------
+    >>> @tvm.testing.exclude_targets("cuda")
     >>> def test_mytest(target, dev):
     >>>     ...  # do something
 
     Or
 
-    >>> @tvm.testing.parametrize("llvm", "cuda")
+    >>> @tvm.testing.exclude_targets("llvm", "cuda")
     >>> def test_mytest(target, dev):
     >>>     ...  # do something
+
     """
 
-    def wrap(targets):
-        def func(f):
-            params = [
-                pytest.param(target, tvm.device(target, 0), 
marks=_target_to_requirement(target))
-                for target in targets
-            ]
-            return pytest.mark.parametrize("target,dev", params)(f)
+    def wraps(func):
+        func.tvm_excluded_targets = args
+        return func
+
+    return wraps
+
+
+def known_failing_targets(*args):
+    """Skip a test that is known to fail on a particular target.
+
+    Use this decorator when you want your test to be run over a
+    variety of targets and devices (including cpu and gpu devices),
+    but know that it fails for some targets.  For example, a newly
+    implemented runtime may not support all features being tested, and
+    should be excluded.
+
+    Applies pytest.mark.xfail to the targets given.
 
+    Parameters
+    ----------
+    f : function
+        Function to parametrize. Must be of the form `def 
test_xxxxxxxxx(target, dev)`:,
+        where `xxxxxxxxx` is any name.
+    targets : list[str]
+        Set of targets to skip.
+
+    Example
+    -------
+    >>> @tvm.testing.known_failing_targets("cuda")
+    >>> def test_mytest(target, dev):
+    >>>     ...  # do something
+
+    Or
+
+    >>> @tvm.testing.known_failing_targets("llvm", "cuda")
+    >>> def test_mytest(target, dev):
+    >>>     ...  # do something
+
+    """
+
+    def wraps(func):
+        func.tvm_known_failing_targets = args
         return func
 
-    if len(args) == 1 and callable(args[0]):
-        targets = [t for t, _ in enabled_targets()]
-        return wrap(targets)(args[0])
-    return wrap(args)
+    return wraps
+
+
+def parameter(*values, ids=None):
+    """Convenience function to define pytest parametrized fixtures.
+
+    Declaring a variable using ``tvm.testing.parameter`` will define a
+    parametrized pytest fixture that can be used by test
+    functions. This is intended for cases that have no setup cost,
+    such as strings, integers, tuples, etc.  For cases that have a
+    significant setup cost, please use :py:func:`tvm.testing.fixture`
+    instead.
+
+    If a test function accepts multiple parameters defined using
+    ``tvm.testing.parameter``, then the test will be run using every
+    combination of those parameters.
+
+    The parameter definition applies to all tests in a module.  If a
+    specific test should have different values for the parameter, that
+    test should be marked with ``@pytest.mark.parametrize``.
+
+    Parameters
+    ----------
+    values
+       A list of parameter values.  A unit test that accepts this
+       parameter as an argument will be run once for each parameter
+       given.
+
+    ids : List[str], optional
+       A list of names for the parameters.  If None, pytest will
+       generate a name from the value.  These generated names may not
+       be readable/useful for composite types such as tuples.
+
+    Returns
+    -------
+    function
+       A function output from pytest.fixture.
+
+    Example
+    -------
+    >>> size = tvm.testing.parameter(1, 10, 100)
+    >>> def test_using_size(size):
+    >>>     ... # Test code here
+
+    Or
+
+    >>> shape = tvm.testing.parameter((5,10), (512,1024), 
ids=['small','large'])
+    >>> def test_using_size(shape):
+    >>>     ... # Test code here
+
+    """
+
+    # Optional cls parameter in case a parameter is defined inside a
+    # class scope.
+    @pytest.fixture(params=values, ids=ids)
+    def as_fixture(*_cls, request):
+        return request.param
+
+    return as_fixture
+
+
+_parametrize_group = 0
+
+
+def parameters(*value_sets):
+    """Convenience function to define pytest parametrized fixtures.
+
+    Declaring a variable using tvm.testing.parameters will define a
+    parametrized pytest fixture that can be used by test
+    functions. Like :py:func:`tvm.testing.parameter`, this is intended
+    for cases that have no setup cost, such as strings, integers,
+    tuples, etc.  For cases that have a significant setup cost, please
+    use :py:func:`tvm.testing.fixture` instead.
+
+    Unlike :py:func:`tvm.testing.parameter`, if a test function
+    accepts multiple parameters defined using a single call to
+    ``tvm.testing.parameters``, then the test will only be run once
+    for each set of parameters, not for all combinations of
+    parameters.
+
+    These parameter definitions apply to all tests in a module.  If a
+    specific test should have different values for some parameters,
+    that test should be marked with ``@pytest.mark.parametrize``.
+
+    Parameters
+    ----------
+    values : List[tuple]
+       A list of parameter value sets.  Each set of values represents
+       a single combination of values to be tested.  A unit test that
+       accepts parameters defined will be run once for every set of
+       parameters in the list.
+
+    Returns
+    -------
+    List[function]
+       Function outputs from pytest.fixture.  These should be unpacked
+       into individual named parameters.
+
+    Example
+    -------
+    >>> size, dtype = tvm.testing.parameters( (16,'float32'), (512,'float16') )
+    >>> def test_feature_x(size, dtype):
+    >>>     # Test code here
+    >>>     assert( (size,dtype) in [(16,'float32'), (512,'float16')])
+
+    """
+    global _parametrize_group
+    parametrize_group = _parametrize_group
+    _parametrize_group += 1
+
+    outputs = []
+    for param_values in zip(*value_sets):
+
+        # Optional cls parameter in case a parameter is defined inside a
+        # class scope.
+        def fixture_func(*_cls, request):
+            return request.param
+
+        fixture_func.parametrize_group = parametrize_group
+        fixture_func.parametrize_values = param_values
+        outputs.append(pytest.fixture(fixture_func))
+
+    return outputs
+
+
+def _parametrize_correlated_parameters(metafunc):
+    parametrize_needed = collections.defaultdict(list)
+
+    for name, fixturedefs in 
metafunc.definition._fixtureinfo.name2fixturedefs.items():
+        fixturedef = fixturedefs[-1]
+        if hasattr(fixturedef.func, "parametrize_group") and hasattr(
+            fixturedef.func, "parametrize_values"
+        ):
+            group = fixturedef.func.parametrize_group
+            values = fixturedef.func.parametrize_values
+            parametrize_needed[group].append((name, values))
+
+    for parametrize_group in parametrize_needed.values():
+        if len(parametrize_group) == 1:
+            name, values = parametrize_group[0]
+            metafunc.parametrize(name, values, indirect=True)
+        else:
+            names = ",".join(name for name, values in parametrize_group)
+            value_sets = zip(*[values for name, values in parametrize_group])
+            metafunc.parametrize(names, value_sets, indirect=True)
+
+
+def fixture(func=None, *, cache_return_value=False):
+    """Convenience function to define pytest fixtures.
+
+    This should be used as a decorator to mark functions that set up
+    state before a function.  The return value of that fixture
+    function is then accessible by test functions as that accept it as
+    a parameter.
+
+    Fixture functions can accept parameters defined with
+    :py:func:`tvm.testing.parameter`.
+
+    By default, the setup will be performed once for each unit test
+    that uses a fixture, to ensure that unit tests are independent.
+    If the setup is expensive to perform, then the
+    cache_return_value=True argument can be passed to cache the setup.
+    The fixture function will be run only once (or once per parameter,
+    if used with tvm.testing.parameter), and the same return value
+    will be passed to all tests that use it.  If the environment
+    variable TVM_TEST_DISABLE_CACHE is set to a non-zero value, it
+    will disable this feature and no caching will be performed.
+
+    Example
+    -------
+    >>> @tvm.testing.fixture
+    >>> def cheap_setup():
+    >>>     return 5 # Setup code here.
+    >>>
+    >>> def test_feature_x(target, dev, cheap_setup)
+    >>>     assert(cheap_setup == 5) # Run test here
+
+    Or
+
+    >>> size = tvm.testing.parameter(1, 10, 100)
+    >>>
+    >>> @tvm.testing.fixture
+    >>> def cheap_setup(size):
+    >>>     return 5*size # Setup code here, based on size.
+    >>>
+    >>> def test_feature_x(cheap_setup):
+    >>>     assert(cheap_setup in [5, 50, 500])
+
+    Or
+
+    >>> @tvm.testing.fixture(cache_return_value=True)
+    >>> def expensive_setup():
+    >>>     time.sleep(10) # Setup code here
+    >>>     return 5
+    >>>
+    >>> def test_feature_x(target, dev, expensive_setup):
+    >>>     assert(expensive_setup == 5)
+
+    """
+
+    force_disable_cache = bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", 
"0")))
+    cache_return_value = cache_return_value and not force_disable_cache
+
+    # Deliberately at function scope, so that caching can track how
+    # many times the fixture has been used.  If used, the cache gets
+    # cleared after the fixture is no longer needed.
+    scope = "function"
+
+    def wraps(func):
+        if cache_return_value:
+            func = _fixture_cache(func)
+        func = pytest.fixture(func, scope=scope)
+        return func
+
+    if func is None:
+        return wraps
+
+    return wraps(func)
+
+
+def _fixture_cache(func):
+    cache = {}
+
+    # Can't use += on a bound method's property.  Therefore, this is a
+    # list rather than a variable so that it can be accessed from the
+    # pytest_collection_modifyitems().
+    num_uses_remaining = [0]
+
+    # Using functools.lru_cache would require the function arguments
+    # to be hashable, which wouldn't allow caching fixtures that
+    # depend on numpy arrays.  For example, a fixture that takes a
+    # numpy array as input, then calculates uses a slow method to
+    # compute a known correct output for that input.  Therefore,
+    # including a fallback for serializable types.
+    def get_cache_key(*args, **kwargs):
+        try:
+            hash((args, kwargs))
+            return (args, kwargs)
+        except TypeError as e:
+            pass
+
+        try:
+            return pickle.dumps((args, kwargs))
+        except TypeError as e:
+            raise TypeError(
+                "TVM caching of fixtures requires arguments to the fixture "
+                "to be either hashable or serializable"
+            ) from e
+
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        try:
+            cache_key = get_cache_key(*args, **kwargs)
+
+            try:
+                cached_value = cache[cache_key]
+            except KeyError:
+                cached_value = cache[cache_key] = func(*args, **kwargs)

Review comment:
       If the fixture definition `func` raises an exception, then the exception 
gets passed on to pytest, and it gets treated as a failure to generate the 
fixture.  These still result in the test failing, but are recorded as a failed 
setup.  The test itself is never run in that case.  This behavior is pytest's 
default, and is the same in both the cached and uncached versions of 
`tvm.testing.fixture`.
   
   I don't have a unit test yet to verify this behavior, but I'll add one.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to