This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch fix-tirx-cuda-sm-guards
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 7b1c7124aee60dfc3d5a482d4399dbcbd0ca4f59
Author: spectrometerHBH <[email protected]>
AuthorDate: Tue Jun 9 14:07:04 2026 -0400

    [Tests] Guard TIRX CUDA tests by compute capability
---
 tests/python/tirx/codegen/test_codegen_cuda.py                 |  5 +++++
 tests/python/tirx/codegen/test_codegen_dsmem.py                |  2 ++
 tests/python/tirx/codegen/test_codegen_hopper.py               |  1 +
 .../tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py  |  3 +++
 .../operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py |  7 +++++++
 .../operator/tile_primitive/cuda/gemm_async/test_gemm_async.py | 10 ++++++++++
 .../operator/tile_primitive/cuda/reduction/test_reduction.py   |  1 +
 7 files changed, 29 insertions(+)

diff --git a/tests/python/tirx/codegen/test_codegen_cuda.py 
b/tests/python/tirx/codegen/test_codegen_cuda.py
index f253d6d375..cd5baaac33 100644
--- a/tests/python/tirx/codegen/test_codegen_cuda.py
+++ b/tests/python/tirx/codegen/test_codegen_cuda.py
@@ -87,6 +87,7 @@ def test_serial_pragma_unroll_codegen():
     assert "break;" in src
 
 
[email protected]_cuda_compute_version(9)
 def test_cluster_cta_id_codegen_uses_coordinate_sregs():
     @T.prim_func
     def main(A: T.Buffer((1,), "int32")):
@@ -160,6 +161,7 @@ def test_ptx_ld_acquire_and_volatile_codegen():
     assert "ld.volatile.global.u64" in src
 
 
[email protected]_cuda_compute_version(10)
 def test_megamoe_extracted_intrinsics_codegen():
     @T.prim_func
     def main(
@@ -265,6 +267,7 @@ def test_megamoe_extracted_intrinsics_codegen():
         assert snippet in src
 
 
[email protected]_cuda_compute_version(9)
 def test_ptx_cp_async_bulk_non_tma_form_codegen():
     @T.prim_func
     def main(
@@ -304,6 +307,7 @@ def test_tensor_map_param_codegen():
     assert "((unsigned long long)(&(A_map)))" in src
 
 
[email protected]_cuda_compute_version(9)
 def test_tma_cache_policy_operand_codegen():
     @T.prim_func
     def main(Cache: T.Buffer((1,), "uint64")):
@@ -537,6 +541,7 @@ def test_warp_shuffle_xor_sync():
 @pytest.mark.parametrize("prefetch_size", [-1, 64, 128, 256])
 @pytest.mark.parametrize("predicate", [-1, T.int32(0), T.int32(1)])
 @pytest.mark.parametrize("fill_mode", ["", "zero"])
[email protected]_cuda_compute_version(9)
 def test_ptx_cp_async(cp_size, cache_hint, prefetch_size, predicate, 
fill_mode):
     if fill_mode != "" and predicate == -1:
         return
diff --git a/tests/python/tirx/codegen/test_codegen_dsmem.py 
b/tests/python/tirx/codegen/test_codegen_dsmem.py
index d538be571f..ed4f1e7e18 100644
--- a/tests/python/tirx/codegen/test_codegen_dsmem.py
+++ b/tests/python/tirx/codegen/test_codegen_dsmem.py
@@ -30,6 +30,7 @@ def _get_source(func: tvm.tirx.PrimFunc) -> str:
     return src
 
 
[email protected]_cuda_compute_version(9)
 def test_ptx_cp_async_bulk_s2c_codegen():
     """Test that T.ptx.cp_async.bulk.s2c emits the correct PTX instruction."""
 
@@ -58,6 +59,7 @@ def test_ptx_cp_async_bulk_s2c_codegen():
     assert 
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes" in src
 
 
[email protected]_cuda_compute_version(9)
 def test_ptx_cp_async_bulk_s2c_codegen_address_conversion():
     """Test that the codegen correctly converts addresses to shared space."""
 
diff --git a/tests/python/tirx/codegen/test_codegen_hopper.py 
b/tests/python/tirx/codegen/test_codegen_hopper.py
index 8f14dfc3c2..90b1921503 100644
--- a/tests/python/tirx/codegen/test_codegen_hopper.py
+++ b/tests/python/tirx/codegen/test_codegen_hopper.py
@@ -139,6 +139,7 @@ def test_stmatrix_sync_aligned(trans):
 
 @pytest.mark.parametrize("trans", [False, True])
 @pytest.mark.parametrize("num", [1, 2, 4])
[email protected]_cuda_compute_version(9)
 def test_ptx_stmatrix(trans, num):
     # fmt: off
     @T.prim_func
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py 
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py
index 0f910a4376..af180e15cc 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py
@@ -30,6 +30,7 @@ from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg
 
 @pytest.mark.parametrize("dtype", ["float16", "float32"])
 @pytest.mark.parametrize("width_32b", [4, 8, 16, 32])
[email protected]_cuda_compute_version(10)
 def test_copy_tmem2reg_async(dtype, width_32b):
     """Test async tmem<->local copy using copy_async instead of copy.
 
@@ -135,6 +136,7 @@ def test_copy_tmem2reg_async(dtype, width_32b):
 @pytest.mark.parametrize("dtype", ["uint8", "float16", "float32"])
 @pytest.mark.parametrize("width_32b", [2, 4, 8, 16, 32, 64, 128])
 @pytest.mark.parametrize("offset_32b", [0, 3, 10])
[email protected]_cuda_compute_version(10)
 def test_copy_tmem2reg(dtype, width_32b, offset_32b):
     def next_power_of_2(x):
         if x <= 1:
@@ -227,6 +229,7 @@ def test_copy_tmem2reg(dtype, width_32b, offset_32b):
 @pytest.mark.parametrize("dtype", ["float16", "float32"])
 @pytest.mark.parametrize("width_32b", [4, 8, 16, 32])
 @pytest.mark.parametrize("local_offset_32b", [0, 2, 4])
[email protected]_cuda_compute_version(10)
 def test_copy_tmem2reg_sliced_local(dtype, width_32b, local_offset_32b):
     """tmem<->local copy with a sliced local buffer region."""
 
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py 
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
index 4209359460..eab1b83d89 100644
--- 
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
+++ 
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
@@ -155,6 +155,7 @@ def _expected_reg_value_16b(
 @pytest.mark.parametrize("shape", list(_SHAPE_REPS))
 @pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32])  # subset; full reps 
below
 @pytest.mark.parametrize("dtype", ["float32"])
[email protected]_cuda_compute_version(10)
 def test_tcgen05_ld_16xnb_load_fp32(shape, rep, dtype):
     """Bit-exact verification of ``tcgen05.<shape>.x<rep>.b32`` load."""
     if rep not in _SHAPE_REPS[shape]:
@@ -170,6 +171,7 @@ def test_tcgen05_ld_16xnb_load_fp32(shape, rep, dtype):
         ("16x128b", 64),
     ],
 )
[email protected]_cuda_compute_version(10)
 def test_tcgen05_ld_16xnb_load_fp32_large_rep(shape, rep):
     """High-rep entries that aren't in the parametrize-cross above."""
     _run_load_test(shape, rep, "float32")
@@ -178,6 +180,7 @@ def test_tcgen05_ld_16xnb_load_fp32_large_rep(shape, rep):
 @pytest.mark.parametrize("shape", list(_SHAPE_REPS))
 @pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32])
 @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
[email protected]_cuda_compute_version(10)
 def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype):
     """Self-consistent round-trip for 16-bit pack::16b path.
 
@@ -204,6 +207,7 @@ def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype):
 @pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"])
 @pytest.mark.parametrize("rep", [1, 2, 4])
 @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
[email protected]_cuda_compute_version(10)
 def test_tcgen05_16xnb_roundtrip_16b_M128(shape, rep, dtype):
     if rep not in _SHAPE_REPS[shape]:
         pytest.skip(f"rep {rep} not valid for {shape}")
@@ -217,6 +221,7 @@ def test_tcgen05_16xnb_roundtrip_16b_M128(shape, rep, 
dtype):
 @pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"])
 @pytest.mark.parametrize("rep", [1, 2, 4])
 @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
[email protected]_cuda_compute_version(10)
 def test_tcgen05_16xnb_roundtrip_16b_layout_F(shape, rep, dtype):
     if rep not in _SHAPE_REPS[shape]:
         pytest.skip(f"rep {rep} not valid for {shape}")
@@ -642,6 +647,7 @@ def _run_load_test(shape: str, rep: int, dtype: str):
 @pytest.mark.parametrize("shape", list(_SHAPE_REPS))
 @pytest.mark.parametrize("rep", [1, 4, 16])
 @pytest.mark.parametrize("dtype", ["float32"])
[email protected]_cuda_compute_version(10)
 def test_tcgen05_st_16xnb_store(shape, rep, dtype):
     """Round-trip test: write the M=64 fragment via .<shape>.x<rep>.st then 
read
     via the standard .32x32b path; verify the host-known fragment data ends up
@@ -807,6 +813,7 @@ def test_tcgen05_st_16xnb_store(shape, rep, dtype):
         ("16x256b", 64, 64),  # .16x256b.x8 fp32
     ],
 )
[email protected]_cuda_compute_version(10)
 def test_alloc_tcgen05_frag_wrapper_compiles(shape, frag_rows, K_cols):
     """Ensure T.alloc_tcgen05_ldst_frag yields a buffer that ``T.copy_async`` 
accepts
     and lowers to the correct tcgen05 atom for each supported instr_shape."""
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py 
b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
index 8c32bbe048..359bbbe171 100644
--- 
a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
+++ 
b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
@@ -179,6 +179,7 @@ def pack_sf_fp8_uint32(sf_uint8, n_total=128):
         )
     ],
 )
[email protected]_cuda_compute_version(10)
 def test_gemm_tcgen05_cta_group_1(task):
     (
         (C_shape, C_dtype, C_region),
@@ -293,6 +294,7 @@ def test_gemm_tcgen05_cta_group_1(task):
         np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3)
 
 
[email protected]_cuda_compute_version(10)
 def test_gemm_tcgen05_cta_group_1_layout_f_m64():
     """M=64 MMA with C operand allocated as Layout F (datapath="F").
 
@@ -417,6 +419,7 @@ def test_gemm_tcgen05_cta_group_1_layout_f_m64():
         )
     ],
 )
[email protected]_cuda_compute_version(10)
 def test_gemm_tcgen05_cta_group_2(task):
     (
         (C_shape, C_dtype, C_region),
@@ -545,6 +548,7 @@ def test_gemm_tcgen05_cta_group_2(task):
         np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3)
 
 
[email protected]_cuda_compute_version(10)
 def test_gemm_tcgen05_cta_group_2_layout_b():
     """Test cta_group=2 with Layout B (2x2 datapath, M=128 total, 64 per CTA).
 
@@ -689,6 +693,7 @@ def test_gemm_tcgen05_cta_group_2_layout_b():
         )
     ],
 )
[email protected]_cuda_compute_version(10)
 def test_gemm_block_scaled_fp8_cta_group_1(task):
     """Test block-scaled fp8 GEMM with cta_group=1 using gemm_async op.
 
@@ -882,6 +887,7 @@ def test_gemm_block_scaled_fp8_cta_group_1(task):
         )
     ],
 )
[email protected]_cuda_compute_version(10)
 def test_gemm_block_scaled_fp8_cta_group_2(task):
     """Test block-scaled fp8 GEMM with cta_group=2 using gemm_async op.
 
@@ -1090,6 +1096,7 @@ def test_gemm_block_scaled_fp8_cta_group_2(task):
 
 
 @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
[email protected]_cuda_compute_version(10)
 def test_gemm_block_scaled_nvfp4_cta_group_1():
     """Test block-scaled nvfp4 GEMM with cta_group=1.
 
@@ -1259,6 +1266,7 @@ def test_gemm_block_scaled_nvfp4_cta_group_1():
 
 
 @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
[email protected]_cuda_compute_version(10)
 def test_gemm_block_scaled_nvfp4_cta_group_2():
     """Test block-scaled nvfp4 GEMM with cta_group=2.
 
@@ -1463,6 +1471,7 @@ def test_gemm_block_scaled_nvfp4_cta_group_2():
 
 
 @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
[email protected]_cuda_compute_version(10)
 def test_gemm_block_scaled_fp8_sf_id():
     """Test sf_id auto-derivation from layout for fp8 block-scaled MMA.
 
@@ -1809,6 +1818,7 @@ def test_gemm_block_scaled_fp8_sf_id():
         "transA_kmajor_smem",
     ],
 )
[email protected]_cuda_compute_version(10)
 def test_gemm_tcgen05_arbitrary_tiles(task):
     """Test arbitrary tile decomposition for tcgen05 gemm_async.
 
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py 
b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py
index 0474ad2dc4..92077fa449 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py
@@ -687,6 +687,7 @@ def 
test_reduction_local_optimized_3input_maxmin(reduction_len, op_type, accum):
 
 @pytest.mark.parametrize("reduction_len", [8, 16, 64, 128, 256, 9, 17, 63, 65, 
100])
 @pytest.mark.parametrize("accum", [False, True])
[email protected]_cuda_compute_version(10)
 def test_reduction_local_optimized_packed_add_sum(reduction_len, accum):
     """Test thread-level sum reduction using packed add with add.f32x2 PTX 
instruction."""
     dev = tvm.cuda(0)

Reply via email to