This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 244499a4cd [Tests] Inline thin gating helpers in the pytest plugin and
tvm.testing.env (#19819)
244499a4cd is described below
commit 244499a4cd8130fb062184754f81a19862a9247d
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Jun 17 20:26:28 2026 -0400
[Tests] Inline thin gating helpers in the pytest plugin and tvm.testing.env
(#19819)
tvm.testing's test-gating layer had a number of one-line helper
functions that add a name but no behavior. Inline the thin ones so call
sites name the underlying flag/feature/condition directly.
Pytest plugin (plugin.py): _target_to_requirement built its skip / gpu
marks through two one-line wrappers (_gpu_mark_and_skip / _skip_only)
plus a per-kind if ladder. Replace them with two frozensets (GPU- vs
CPU-family kinds) and resolve the skip probe by name:
marks.append(pytest.mark.skipif(not getattr(env, f"has_{kind}")(),
reason=f"need {kind}"))
The cuda+cudnn / cuda+cublas accelerator-library cases are remapped
inline (cudnn before cublas). Adds two direct unit tests for the
cudnn/cublas special-case and the unknown-kind ([]) fallback.
tvm.testing.env (env.py): inline the pure probe wrappers that just
forwarded to a primitive --
* build-flag (5): has_cutlass/rpc/nnapi/openclml/mrvl ->
env.build_flag_enabled ("USE_X"). The private _build_flag_enabled is
promoted to the public build_flag_enabled; the composed probes
(has_cudnn/cublas/nccl/hipblas) and the hexagon/adreno probes call it
too.
* cpu-feature (5 pure):
has_arm_dot/arm_fp16/aarch64_sve/aarch64_sme/x86_amx ->
env.has_cpu_feature("..."). The composed has_x86_vnni (avx512vnni OR
avxvnni) and has_x86_avx512 (a five-feature set) are kept -- not thin
wrappers.
Also drops the obsolete test_build_flag_probe_matches_libinfo self-test
and the matching _BOOL_PROBES entries.
The runtime device probes (has_cuda/has_rocm/...) are intentionally left
as-is: the pytest plugin resolves env.has_<kind>() from each target
kind, so those names are load-bearing rather than thin wrappers.
---
python/tvm/testing/env.py | 79 +++--------------------
python/tvm/testing/plugin.py | 55 +++++++---------
tests/python/contrib/test_cutlass_gemm.py | 8 +--
tests/python/nightly/test_nnapi/test_network.py | 2 +-
tests/python/relax/test_codegen_cutlass.py | 2 +-
tests/python/relax/texture/test_texture_nd.py | 2 +-
tests/python/runtime/test_runtime_rpc.py | 34 +++++-----
tests/python/target/test_arm_target.py | 8 +--
tests/python/testing/test_env.py | 26 --------
tests/python/testing/test_tvm_testing_features.py | 21 ++++++
10 files changed, 84 insertions(+), 153 deletions(-)
diff --git a/python/tvm/testing/env.py b/python/tvm/testing/env.py
index 0c9b48e5c1..793f541b48 100644
--- a/python/tvm/testing/env.py
+++ b/python/tvm/testing/env.py
@@ -39,7 +39,7 @@ Three kinds of probe live here:
* **runtime device** probes (``has_cuda``, ``has_gpu`` …) ask whether a
usable device of a given kind is present;
-* **build-support** probes (``has_cutlass``, ``has_cudnn`` …) ask whether
+* **build-support** probes (``has_cudnn`` …, ``build_flag_enabled`` …) ask
whether
an optional library was compiled into the runtime;
* **version / capability** probes (``has_cuda_compute``,
``has_tensorcore`` …) ask about a finer capability of a present device
@@ -53,12 +53,9 @@ import platform
import tvm
__all__ = [
- "has_aarch64_sme",
- "has_aarch64_sve",
+ "build_flag_enabled",
"has_adreno_opencl",
"has_aprofile_aem_fvp",
- "has_arm_dot",
- "has_arm_fp16",
# cpu features
"has_cpu_feature",
"has_cublas",
@@ -69,7 +66,6 @@ __all__ = [
"has_cudagraph",
# build support
"has_cudnn",
- "has_cutlass",
"has_gpu",
# toolchain / environment
"has_hexagon",
@@ -79,20 +75,15 @@ __all__ = [
"has_llvm_min_version",
"has_matrixcore",
"has_metal",
- "has_mrvl",
"has_multi_gpu",
"has_nccl",
- "has_nnapi",
"has_nvcc_version",
"has_nvptx",
"has_nvshmem",
"has_opencl",
- "has_openclml",
"has_rocm",
- "has_rpc",
"has_tensorcore",
"has_vulkan",
- "has_x86_amx",
"has_x86_avx512",
"has_x86_vnni",
"is_aarch64",
@@ -112,12 +103,12 @@ def _device_exists(kind: str, index: int = 0) -> bool:
@functools.cache
-def _build_flag_enabled(flag: str) -> bool:
+def build_flag_enabled(flag: str) -> bool:
"""Return whether an optional build flag (e.g. ``USE_CUTLASS``) is on.
A flag counts as enabled unless it is explicitly disabled, so library
flags carrying a path (rather than a boolean) still register as present.
- Callers gate on this via ``@pytest.mark.skipif(not
tvm.testing.env.has_cutlass(), ...)``.
+ Callers gate via ``@pytest.mark.skipif(not
env.build_flag_enabled("USE_X"), ...)``.
"""
try:
value = tvm.support.libinfo().get(flag, "OFF")
@@ -239,47 +230,22 @@ def has_multi_gpu(count: int = 2) -> bool:
def has_cudnn() -> bool:
"""True if cuDNN was built in and a CUDA device is present."""
- return has_cuda() and _build_flag_enabled("USE_CUDNN")
+ return has_cuda() and build_flag_enabled("USE_CUDNN")
def has_cublas() -> bool:
"""True if cuBLAS was built in and a CUDA device is present."""
- return has_cuda() and _build_flag_enabled("USE_CUBLAS")
+ return has_cuda() and build_flag_enabled("USE_CUBLAS")
def has_nccl() -> bool:
"""True if NCCL was built in and a CUDA device is present."""
- return has_cuda() and _build_flag_enabled("USE_NCCL")
+ return has_cuda() and build_flag_enabled("USE_NCCL")
def has_hipblas() -> bool:
"""True if hipBLAS was built in and a ROCm device is present."""
- return has_rocm() and _build_flag_enabled("USE_HIPBLAS")
-
-
-def has_cutlass() -> bool:
- """True if CUTLASS support was built into the runtime."""
- return _build_flag_enabled("USE_CUTLASS")
-
-
-def has_rpc() -> bool:
- """True if RPC support was built into the runtime."""
- return _build_flag_enabled("USE_RPC")
-
-
-def has_nnapi() -> bool:
- """True if NNAPI codegen support was built into the runtime."""
- return _build_flag_enabled("USE_NNAPI_CODEGEN")
-
-
-def has_openclml() -> bool:
- """True if OpenCLML (CLML) support was built into the runtime."""
- return _build_flag_enabled("USE_CLML")
-
-
-def has_mrvl() -> bool:
- """True if the Marvell (MRVL) backend was built into the runtime."""
- return _build_flag_enabled("USE_MRVL")
+ return has_rocm() and build_flag_enabled("USE_HIPBLAS")
@functools.cache
@@ -414,7 +380,7 @@ def has_hexagon_toolchain() -> bool:
_ci_env_check,
)
- return _build_flag_enabled("USE_HEXAGON") and
_ci_env_check._compile_time_check() is True
+ return build_flag_enabled("USE_HEXAGON") and
_ci_env_check._compile_time_check() is True
except Exception: # pylint: disable=broad-except
return False
@@ -435,7 +401,7 @@ def has_hexagon() -> bool:
@functools.cache
def has_adreno_opencl() -> bool:
"""True if remote Adreno OpenCL testing is configured (RPC_TARGET set)."""
- return _build_flag_enabled("USE_OPENCL") and os.environ.get("RPC_TARGET")
is not None
+ return build_flag_enabled("USE_OPENCL") and os.environ.get("RPC_TARGET")
is not None
@functools.cache
@@ -472,26 +438,6 @@ def has_cpu_feature(features) -> bool:
return _has_cpu_feature(features)
-def has_arm_dot() -> bool:
- """True if the host CPU supports the ARM dot-product instructions."""
- return has_cpu_feature("dotprod")
-
-
-def has_arm_fp16() -> bool:
- """True if the host CPU supports ARM Neon FP16 instructions."""
- return has_cpu_feature("fullfp16")
-
-
-def has_aarch64_sve() -> bool:
- """True if the host CPU supports AArch64 SVE."""
- return has_cpu_feature("sve")
-
-
-def has_aarch64_sme() -> bool:
- """True if the host CPU supports AArch64 SME."""
- return has_cpu_feature("sme")
-
-
def has_x86_vnni() -> bool:
"""True if the host CPU supports x86 VNNI (AVX512-VNNI or AVX-VNNI)."""
return has_cpu_feature("avx512vnni") or has_cpu_feature("avxvnni")
@@ -502,11 +448,6 @@ def has_x86_avx512() -> bool:
return has_cpu_feature(["avx512bw", "avx512cd", "avx512dq", "avx512vl",
"avx512f"])
-def has_x86_amx() -> bool:
- """True if the host CPU supports the x86 AMX (int8) extensions."""
- return has_cpu_feature("amx-int8")
-
-
# --- host architecture probes ----------------------------------------------
diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py
index 8f7cd0c516..cb46dedc67 100644
--- a/python/tvm/testing/plugin.py
+++ b/python/tvm/testing/plugin.py
@@ -279,44 +279,39 @@ def _sort_tests(items):
items.sort(key=sort_key)
-def _gpu_mark_and_skip(has_fn, reason):
- """A GPU-family target: the ``gpu`` selection marker plus an env skip."""
- return [pytest.mark.gpu, pytest.mark.skipif(not has_fn(), reason=reason)]
-
-
-def _skip_only(has_fn, reason):
- """A non-GPU target: an env skip with no selection marker."""
- return [pytest.mark.skipif(not has_fn(), reason=reason)]
+# GPU-family target kinds carry the ``gpu`` selection marker; CPU-family kinds
+# (llvm, hexagon) only skip. The skip condition is the matching tvm.testing.env
+# probe, resolved by name, so there is no per-kind ladder of has_* calls.
+_GPU_TARGET_KINDS = frozenset(
+ {"cuda", "cudnn", "cublas", "rocm", "vulkan", "nvptx", "metal", "opencl"}
+)
+_CPU_TARGET_KINDS = frozenset({"llvm", "hexagon"})
def _target_to_requirement(target):
if isinstance(target, str | dict):
target = tvm.target.Target(target)
- # GPU-family kinds get the `gpu` selection marker; CPU-family kinds only
skip.
+ # A cuda target carrying an accelerator library gates on that library's
probe
+ # (cudnn before cublas) instead of plain cuda.
kind = target.kind.name
- if kind == "cuda" and "cudnn" in target.attrs.get("libs", []):
- return _gpu_mark_and_skip(env.has_cudnn, "need cudnn")
- if kind == "cuda" and "cublas" in target.attrs.get("libs", []):
- return _gpu_mark_and_skip(env.has_cublas, "need cublas")
if kind == "cuda":
- return _gpu_mark_and_skip(env.has_cuda, "need cuda")
- if kind == "rocm":
- return _gpu_mark_and_skip(env.has_rocm, "need rocm")
- if kind == "vulkan":
- return _gpu_mark_and_skip(env.has_vulkan, "need vulkan")
- if kind == "nvptx":
- return _gpu_mark_and_skip(env.has_nvptx, "need nvptx")
- if kind == "metal":
- return _gpu_mark_and_skip(env.has_metal, "need metal")
- if kind == "opencl":
- return _gpu_mark_and_skip(env.has_opencl, "need opencl")
- if kind == "llvm":
- return _skip_only(env.has_llvm, "need llvm")
- if kind == "hexagon":
- return _skip_only(env.has_hexagon, "need hexagon")
-
- return []
+ libs = target.attrs.get("libs", [])
+ if "cudnn" in libs:
+ kind = "cudnn"
+ elif "cublas" in libs:
+ kind = "cublas"
+
+ if kind in _GPU_TARGET_KINDS:
+ is_gpu = True
+ elif kind in _CPU_TARGET_KINDS:
+ is_gpu = False
+ else:
+ return []
+
+ marks = [pytest.mark.gpu] if is_gpu else []
+ marks.append(pytest.mark.skipif(not getattr(env, f"has_{kind}")(),
reason=f"need {kind}"))
+ return marks
# pytest-xdist isn't required but is used in CI, so guard on its presence
diff --git a/tests/python/contrib/test_cutlass_gemm.py
b/tests/python/contrib/test_cutlass_gemm.py
index 19e53c48b0..f5785c77e6 100644
--- a/tests/python/contrib/test_cutlass_gemm.py
+++ b/tests/python/contrib/test_cutlass_gemm.py
@@ -73,7 +73,7 @@ def verify_group_gemm(
tvm.testing.assert_allclose(c_nd.numpy(), c_np, rtol=rtol, atol=atol)
[email protected](not env.has_cutlass(), reason="need cutlass")
[email protected](not env.build_flag_enabled("USE_CUTLASS"), reason="need
cutlass")
@pytest.mark.gpu
@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >=
9.0")
def test_group_gemm_sm90():
@@ -118,7 +118,7 @@ def test_group_gemm_sm90():
)
[email protected](not env.has_cutlass(), reason="need cutlass")
[email protected](not env.build_flag_enabled("USE_CUTLASS"), reason="need
cutlass")
@pytest.mark.gpu
@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
def test_group_gemm_sm100():
@@ -302,7 +302,7 @@ def blockwise_bmm(
return o_np
[email protected](not env.has_cutlass(), reason="need cutlass")
[email protected](not env.build_flag_enabled("USE_CUTLASS"), reason="need
cutlass")
@pytest.mark.gpu
@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >=
9.0")
def test_fp8_e4m3_groupwise_scaled_gemm():
@@ -336,7 +336,7 @@ def test_fp8_e4m3_groupwise_scaled_gemm():
tvm.testing.assert_allclose(o_tvm, o_np, rtol=1e-4, atol=0.5)
[email protected](not env.has_cutlass(), reason="need cutlass")
[email protected](not env.build_flag_enabled("USE_CUTLASS"), reason="need
cutlass")
@pytest.mark.gpu
@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >=
9.0")
def test_fp8_e4m3_groupwise_scaled_bmm():
diff --git a/tests/python/nightly/test_nnapi/test_network.py
b/tests/python/nightly/test_nnapi/test_network.py
index 85b8b31456..88d8854801 100644
--- a/tests/python/nightly/test_nnapi/test_network.py
+++ b/tests/python/nightly/test_nnapi/test_network.py
@@ -116,7 +116,7 @@ def get_network(name, dtype, input_shape=(1, 3, 224, 224)):
"float32",
],
)
[email protected](not env.has_nnapi(), reason="need nnapi")
[email protected](not env.build_flag_enabled("USE_NNAPI_CODEGEN"),
reason="need nnapi")
def test_network(name, dtype):
remote_obj, tracker = remote()
print(f"Network evaluating {name} with dtype {dtype}")
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index ff662f4faa..cc6fd499fa 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -85,7 +85,7 @@ class Conv2dx2:
pytestmark = [
- pytest.mark.skipif(not env.has_cutlass(), reason="need cutlass"),
+ pytest.mark.skipif(not env.build_flag_enabled("USE_CUTLASS"), reason="need
cutlass"),
]
diff --git a/tests/python/relax/texture/test_texture_nd.py
b/tests/python/relax/texture/test_texture_nd.py
index 3c3447749d..56025f1298 100644
--- a/tests/python/relax/texture/test_texture_nd.py
+++ b/tests/python/relax/texture/test_texture_nd.py
@@ -106,7 +106,7 @@ def postprocess_pipeline(mod: IRModule) -> IRModule:
return mod
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
@pytest.mark.gpu
@pytest.mark.skipif(not env.has_adreno_opencl(), reason="need adreno opencl")
@pytest.mark.parametrize("backend", ["opencl"])
diff --git a/tests/python/runtime/test_runtime_rpc.py
b/tests/python/runtime/test_runtime_rpc.py
index b48d9631dc..2f9890da90 100644
--- a/tests/python/runtime/test_runtime_rpc.py
+++ b/tests/python/runtime/test_runtime_rpc.py
@@ -66,7 +66,7 @@ pytestmark = pytest.mark.skipif(
# to ensure all the remote resources destructs before the server terminates
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_bigendian_rpc():
"""Test big endian rpc when there is a PowerPC RPC server available"""
host = os.environ.get("TVM_POWERPC_TEST_HOST", None)
@@ -97,7 +97,7 @@ def test_bigendian_rpc():
verify_rpc(remote, target, (10,), dtype)
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_rpc_simple():
server = rpc.Server(key="x1")
client = rpc.connect("127.0.0.1", server.port, key="x1")
@@ -116,7 +116,7 @@ def test_rpc_simple():
check_remote()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_rpc_runtime_string():
server = rpc.Server(key="x1")
client = rpc.connect("127.0.0.1", server.port, key="x1")
@@ -130,7 +130,7 @@ def test_rpc_runtime_string():
check_remote()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_rpc_array():
server = rpc.Server()
remote = rpc.connect("127.0.0.1", server.port)
@@ -146,7 +146,7 @@ def test_rpc_array():
check_remote()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_rpc_large_array():
# testcase of large array creation
server = rpc.Server()
@@ -165,7 +165,7 @@ def test_rpc_large_array():
@tvm.testing.skip_if_32bit(reason="skipping test for i386.")
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_rpc_echo():
def check(remote, local_session):
fecho = remote.get_function("testing.echo")
@@ -214,7 +214,7 @@ def test_rpc_echo():
# check_minrpc()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_rpc_file_exchange():
server = rpc.Server()
remote = rpc.connect("127.0.0.1", server.port)
@@ -228,7 +228,7 @@ def test_rpc_file_exchange():
check_remote()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
@pytest.mark.skipif(not env.has_llvm(), reason="need llvm")
def test_rpc_remote_module():
# graph
@@ -339,7 +339,7 @@ def test_rpc_remote_module():
check_minrpc()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_rpc_return_func():
server = rpc.Server(key="x1")
client = rpc.connect("127.0.0.1", server.port, key="x1")
@@ -352,7 +352,7 @@ def test_rpc_return_func():
check_remote()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_rpc_session_constructor_args():
# start server
server0 = rpc.Server(key="x0")
@@ -389,7 +389,7 @@ def test_rpc_session_constructor_args():
check_error_handling()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_rpc_return_tensor():
def run_arr_test():
server = rpc.Server(key="x1")
@@ -410,7 +410,7 @@ def test_rpc_return_tensor():
run_arr_test()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_rpc_return_remote_object():
def check(client, is_local):
make_shape = client.get_function("ffi.Shape")
@@ -456,7 +456,7 @@ def test_rpc_return_remote_object():
check_minrpc()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
def test_local_func():
client = rpc.LocalSession()
@@ -473,7 +473,7 @@ def test_local_func():
check_remote()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
@pytest.mark.parametrize("device_key", ["test_device", "127.0.0.1:5555"])
def test_rpc_tracker_register(device_key):
# test registration
@@ -546,7 +546,7 @@ def _target(host, port, device_key, timeout):
remote.cpu()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
@pytest.mark.parametrize("device_key", ["test_device", "127.0.0.1:5555"])
def test_rpc_tracker_request(device_key):
# test concurrent request
@@ -587,7 +587,7 @@ def test_rpc_tracker_request(device_key):
tracker.terminate()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
@pytest.mark.parametrize("device_key", ["test_device", "127.0.0.1:5555"])
def test_rpc_tracker_via_proxy(device_key):
"""
@@ -629,7 +629,7 @@ def test_rpc_tracker_via_proxy(device_key):
tracker_server.terminate()
[email protected](not env.has_rpc(), reason="need rpc")
[email protected](not env.build_flag_enabled("USE_RPC"), reason="need rpc")
@pytest.mark.parametrize("with_proxy", (True, False))
def test_rpc_session_timeout_error(with_proxy):
port = 9000
diff --git a/tests/python/target/test_arm_target.py
b/tests/python/target/test_arm_target.py
index 862f41e146..12fe17fe23 100644
--- a/tests/python/target/test_arm_target.py
+++ b/tests/python/target/test_arm_target.py
@@ -51,7 +51,7 @@ def sve_device_vector_length():
return int(out)
[email protected](not env.has_aarch64_sve(), reason="need aarch64 sve")
[email protected](not env.has_cpu_feature("sve"), reason="need aarch64 sve")
def test_scalable_div(sve_device_vector_length):
np.random.seed(0)
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr":
["+sve"]}
@@ -72,7 +72,7 @@ def test_scalable_div(sve_device_vector_length):
tvm.testing.assert_allclose(A_nd.numpy()[0], ref)
[email protected](not env.has_aarch64_sve(), reason="need aarch64 sve")
[email protected](not env.has_cpu_feature("sve"), reason="need aarch64 sve")
def test_scalable_buffer_load_store(sve_device_vector_length):
np.random.seed(0)
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr":
["+sve"]}
@@ -97,7 +97,7 @@ def test_scalable_buffer_load_store(sve_device_vector_length):
tvm.testing.assert_allclose(B_nd.numpy(), A_np)
[email protected](not env.has_aarch64_sve(), reason="need aarch64 sve")
[email protected](not env.has_cpu_feature("sve"), reason="need aarch64 sve")
def test_scalable_loop_bound(sve_device_vector_length):
np.random.seed(0)
@@ -125,7 +125,7 @@ def test_scalable_loop_bound(sve_device_vector_length):
tvm.testing.assert_allclose(B_nd.numpy(), A_np)
[email protected](not env.has_aarch64_sve(), reason="need aarch64 sve")
[email protected](not env.has_cpu_feature("sve"), reason="need aarch64 sve")
def test_scalable_broadcast(sve_device_vector_length):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr":
["+sve"]}
num_elements = sve_device_vector_length // 32
diff --git a/tests/python/testing/test_env.py b/tests/python/testing/test_env.py
index 24277a2c9e..c90c450381 100644
--- a/tests/python/testing/test_env.py
+++ b/tests/python/testing/test_env.py
@@ -38,11 +38,6 @@ _BOOL_PROBES = [
env.has_cublas,
env.has_nccl,
env.has_hipblas,
- env.has_cutlass,
- env.has_rpc,
- env.has_nnapi,
- env.has_openclml,
- env.has_mrvl,
env.has_nvshmem,
# version / capability
env.has_tensorcore,
@@ -54,13 +49,8 @@ _BOOL_PROBES = [
env.has_adreno_opencl,
env.has_aprofile_aem_fvp,
# cpu features
- env.has_arm_dot,
- env.has_arm_fp16,
- env.has_aarch64_sve,
- env.has_aarch64_sme,
env.has_x86_vnni,
env.has_x86_avx512,
- env.has_x86_amx,
# host architecture
env.is_x86,
env.is_aarch64,
@@ -123,22 +113,6 @@ def test_has_multi_gpu_is_bool():
assert env.has_multi_gpu(1) or not env.has_multi_gpu(2)
[email protected](
- "probe,flag",
- [
- (env.has_cutlass, "USE_CUTLASS"),
- (env.has_rpc, "USE_RPC"),
- (env.has_nnapi, "USE_NNAPI_CODEGEN"),
- (env.has_openclml, "USE_CLML"),
- (env.has_mrvl, "USE_MRVL"),
- ],
- ids=lambda v: getattr(v, "__name__", v),
-)
-def test_build_flag_probe_matches_libinfo(probe, flag):
- """Pure build-flag probes agree with the build-info flag they wrap."""
- assert probe() == env._build_flag_enabled(flag) # pylint:
disable=protected-access
-
-
def test_llvm_min_version_is_monotone():
if not env.has_llvm():
assert not env.has_llvm_min_version(1)
diff --git a/tests/python/testing/test_tvm_testing_features.py
b/tests/python/testing/test_tvm_testing_features.py
index 071da7b9bc..63d5c1ac7b 100644
--- a/tests/python/testing/test_tvm_testing_features.py
+++ b/tests/python/testing/test_tvm_testing_features.py
@@ -238,6 +238,27 @@ class TestAutomaticMarks:
self.check_marks(request, target)
+def test_target_to_requirement_cuda_libs():
+ """cuda+cudnn / cuda+cublas select their own probe; cudnn wins when both
are present."""
+ ttr = tvm.testing.plugin._target_to_requirement
+
+ def skip_reasons(target):
+ return [d.mark.kwargs["reason"] for d in ttr(target) if d.mark.name ==
"skipif"]
+
+ assert skip_reasons({"kind": "cuda", "libs": ["cudnn"]}) == ["need cudnn"]
+ assert skip_reasons({"kind": "cuda", "libs": ["cublas"]}) == ["need
cublas"]
+ # cudnn is checked before cublas, so it wins when both are present.
+ assert skip_reasons({"kind": "cuda", "libs": ["cudnn", "cublas"]}) ==
["need cudnn"]
+ assert skip_reasons("cuda") == ["need cuda"]
+ # every cuda variant is GPU-family and carries the `gpu` selection marker.
+ assert any(d.mark.name == "gpu" for d in ttr({"kind": "cuda", "libs":
["cudnn"]}))
+
+
+def test_target_to_requirement_unknown_kind_has_no_marks():
+ """A target kind with no requirement entry produces no marks (no gpu, no
skip)."""
+ assert tvm.testing.plugin._target_to_requirement("c") == []
+
+
@pytest.mark.skipif(
bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))),
reason="Cannot test cache behavior while caching is disabled",