tkonolige commented on a change in pull request #6331:
URL: https://github.com/apache/incubator-tvm/pull/6331#discussion_r476625453



##########
File path: python/tvm/testing.py
##########
@@ -285,4 +288,184 @@ def _check_forward(constraints1, constraints2, varmap, 
backvarmap):
                    constraints_trans.dst_to_src, constraints_trans.src_to_dst)
 
 
+def uses_gpu(f):
+    """Mark to differentiate tests that use the GPU is some capacity.
+
+    These tests will be run on CPU-only test nodes and on test nodes with GPUS.
+    To mark a test that must have a GPU present to run, use `@requires_gpu`.
+
+    Parameters
+    ----------
+    f : function
+        Function to mark
+    """
+    return pytest.mark.gpu(f)
+
+
+def requires_gpu(f):
+    """Mark a test as requiring a GPU to run.
+
+    Tests with this mark will not be run unless a gpu is present.
+
+    Parameters
+    ----------
+    f : function
+        Function to mark
+    """
+    return pytest.mark.skipif(not tvm.gpu().exist, reason="No GPU 
present")(uses_gpu(f))
+
+
+def requires_cuda(f):
+    """Mark a test as requiring the CUDA runtime.
+
+    This also marks the test as requiring a gpu.
+
+    Parameters
+    ----------
+    f : function
+        Function to mark
+    """
+    return pytest.mark.cuda(
+        pytest.mark.skipif(
+            not tvm.runtime.enabled("cuda"), reason="CUDA support not enabled"
+        )(requires_gpu(f))
+    )
+
+
+def requires_opencl(f):
+    """Mark a test as requiring the OpenCL runtime.
+
+    This also marks the test as requiring a gpu.
+
+    Parameters
+    ----------
+    f : function
+        Function to mark
+    """
+    return pytest.mark.opencl(
+        pytest.mark.skipif(
+            not tvm.runtime.enabled("opencl"), reason="OpenCL support not 
enabled"
+        )(requires_gpu(f))
+    )
+
+
+def requires_rocm(f):
+    """Mark a test as requiring the rocm runtime.
+
+    This also marks the test as requiring a gpu.
+
+    Parameters
+    ----------
+    f : function
+        Function to mark
+    """
+    return pytest.mark.rocm(
+        pytest.mark.skipif(
+            not tvm.runtime.enabled("rocm"), reason="rocm support not enabled"
+        )(requires_gpu(f))
+    )
+
+
+def requires_tensorcore(f):
+    """Mark a test as requiring a tensorcore to run.
+
+    Tests with this mark will not be run unless a tensorcore is present.
+
+    Parameters
+    ----------
+    f : function
+        Function to mark
+    """
+    return pytest.mark.tensorcore(
+        pytest.mark.skipif(
+            not tvm.gpu().exist or not 
nvcc.have_tensorcore(tvm.gpu(0).compute_version),
+            reason="No tensorcore present",
+        )(f)
+    )
+
+
+def parametrize_devices(f):
+    """Parametrize a test over all enabled devices.
+
+    Parameters
+    ----------
+    f : function
+        Function to parametrize. Must be of the form `def 
test_xxxxxxxxx(device, ctx)`:,
+        where `xxxxxxxxx` is any name.
+
+    Example
+    -------
+    >>> @tvm.testing.parametrize
+    >>> def test_mytest(device, ctx):
+    >>>     ...  # do something
+    """
+    return pytest.mark.parametrize("device,ctx", 
enabled_devices())(uses_gpu(f))
+
+
+def _get_backends():
+    backend_str = os.environ.get("TVM_TEST_DEVICES", "")
+    if len(backend_str) == 0:
+        backend_str = DEFAULT_TEST_DEVICES
+    backends = {
+        dev
+        for dev in backend_str.split(";")
+        if len(dev) > 0 and tvm.context(dev, 0).exist and 
tvm.runtime.enabled(dev)
+    }
+    if len(backends) == 0:
+        logging.warning(
+            "None of the following backends are supported by this build of 
TVM: %s."
+            "Try setting TVM_TEST_DEVICES to a supported backend. Defaulting 
to llvm.",
+            backend_str
+        )
+        return {"llvm"}
+    return backends
+
+
+DEFAULT_TEST_DEVICES = (
+    "llvm;cuda;opencl;metal;rocm;vulkan;nvptx;"
+    "llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu"
+)
+TEST_DEVICES = _get_backends()
+
+
+def device_enabled(device):

Review comment:
       It seems like I'm confusing device and target here. I'll rename these 
functions to `target_enabled` and `enabled_targets`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to