This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 71466bb737 [Tests] Migrate tvm.testing.parameters() to
pytest.mark.parametrize (#19803)
71466bb737 is described below
commit 71466bb7379903b76a16b50c7e5923f3e043568b
Author: Shushi Hong <[email protected]>
AuthorDate: Tue Jun 16 21:53:40 2026 -0400
[Tests] Migrate tvm.testing.parameters() to pytest.mark.parametrize (#19803)
This pr phases out the custom `tvm.testing.parameters()` helper in favor
of native `pytest.mark.parametrize`. `parameters()` itself is left in
place for now and removed in a follow-up, together with updating the
framework self-test
(`tests/python/testing/test_tvm_testing_features.py`) that exercises it.
Migration rules
- A group consumed only by test functions becomes
`pytest.mark.parametrize`.
- Single-name groups are unwrapped from 1-tuples to bare values.
- A group shared by multiple tests uses a module-level named list; a
test that uses only a subset of a group's names is parametrized only on
the names in its signature.
- `pytest.mark.parametrize` is stacked above the existing, unrelated
`tvm.testing.parametrize_targets(...)`, which is kept as-is.
Per-file pytest collection case counts are unchanged, except the two
intentional changes below.
Behavior changes (intentional)
- tests/python/relax/test_training_optimizer_numeric.py: the names `lr`
and `weight_decay` were rebound across three `parameters()` groups, so
`test_sgd` and `test_momentum_sgd` silently used the *adam* group's
`lr`/`weight_decay` (and `test_momentum_sgd` cross-producted with it:
2/6/2 = 10 cases). Native parametrize gives each test its own co-located
group: 2/3/2 = 7 cases. This fixes that latent rebinding bug; the case
count drops 10 -> 7 and `test_momentum_sgd` now exercises its own
`weight_decay` values.
- tests/python/target/test_arm_target.py: its `parameters()` group was
orphaned (no test consumed those names) — removed the dead definition.
Note: for tests that also use `tvm.testing.parametrize_targets`, the
generated test ids reorder the target (e.g. `test_unary[abs-True-llvm]`
-> `test_unary[llvm-abs-True]`); values and case counts are unchanged.
---
.../python/codegen/test_target_codegen_cuda_fp8.py | 33 +--
tests/python/ir/test_datatype_nv_fp4.py | 5 +-
tests/python/ir/test_datatype_nv_fp8.py | 8 +-
tests/python/relax/test_op_binary.py | 24 ++-
tests/python/relax/test_op_gradient_numeric.py | 187 ++++++++--------
tests/python/relax/test_op_search.py | 8 +-
tests/python/relax/test_op_statistical.py | 13 +-
tests/python/relax/test_op_unary.py | 11 +-
.../relax/test_training_optimizer_numeric.py | 36 ++--
.../relax/test_tvmscript_parser_op_arith_cmp.py | 112 +++++-----
.../s_tir/base/test_tir_te_extern_primfunc.py | 17 +-
tests/python/target/test_arm_target.py | 57 -----
tests/python/target/test_llvm_features_info.py | 34 +--
tests/python/target/test_riscv_features.py | 8 +-
tests/python/target/test_x86_features.py | 240 ++++++++++++---------
15 files changed, 406 insertions(+), 387 deletions(-)
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py
b/tests/python/codegen/test_target_codegen_cuda_fp8.py
index 331c96b1d3..cd762569af 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fp8.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py
@@ -142,23 +142,24 @@ def test_fp8_packing(dtype):
tvm.testing.assert_allclose(a.numpy().astype("float16"),
b.numpy().astype("float16"))
-native_dtype, promoted_dtype, numpytype = tvm.testing.parameters(
- ("float8_e4m3fn", "float32", "float8_e4m3fn"),
- ("float8_e4m3fn", "float16", "float8_e4m3fn"),
- ("float8_e4m3fnx2", "float32x2", "float8_e4m3fn"),
- ("float8_e4m3fnx2", "float16x2", "float8_e4m3fn"),
- ("float8_e4m3fnx4", "float32x4", "float8_e4m3fn"),
- # Supported via half4 vector type extension in codegen
- ("float8_e4m3fnx4", "float16x4", "float8_e4m3fn"),
- ("float8_e5m2", "float32", "float8_e5m2"),
- ("float8_e5m2", "float16", "float8_e5m2"),
- ("float8_e5m2x2", "float32x2", "float8_e5m2"),
- ("float8_e5m2x2", "float16x2", "float8_e5m2"),
- ("float8_e5m2x4", "float32x4", "float8_e5m2"),
- ("float8_e5m2x4", "float16x4", "float8_e5m2"),
[email protected](
+ "native_dtype,promoted_dtype,numpytype",
+ [
+ ("float8_e4m3fn", "float32", "float8_e4m3fn"),
+ ("float8_e4m3fn", "float16", "float8_e4m3fn"),
+ ("float8_e4m3fnx2", "float32x2", "float8_e4m3fn"),
+ ("float8_e4m3fnx2", "float16x2", "float8_e4m3fn"),
+ ("float8_e4m3fnx4", "float32x4", "float8_e4m3fn"),
+ # Supported via half4 vector type extension in codegen
+ ("float8_e4m3fnx4", "float16x4", "float8_e4m3fn"),
+ ("float8_e5m2", "float32", "float8_e5m2"),
+ ("float8_e5m2", "float16", "float8_e5m2"),
+ ("float8_e5m2x2", "float32x2", "float8_e5m2"),
+ ("float8_e5m2x2", "float16x2", "float8_e5m2"),
+ ("float8_e5m2x4", "float32x4", "float8_e5m2"),
+ ("float8_e5m2x4", "float16x4", "float8_e5m2"),
+ ],
)
-
-
@pytest.mark.gpu
@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
def test_fp8_vector_conversions(native_dtype, promoted_dtype, numpytype):
diff --git a/tests/python/ir/test_datatype_nv_fp4.py
b/tests/python/ir/test_datatype_nv_fp4.py
index 653ce22050..768a4b6973 100644
--- a/tests/python/ir/test_datatype_nv_fp4.py
+++ b/tests/python/ir/test_datatype_nv_fp4.py
@@ -16,6 +16,7 @@
# under the License.
# ruff: noqa: F401
import numpy as np
+import pytest
import tvm
import tvm.testing
@@ -29,9 +30,10 @@ except ImportError:
float4_e2m1fn = None
-np_dtype, dtype_str = tvm.testing.parameters((float4_e2m1fn, "float4_e2m1fn"))
+nv_fp4_dtypes = [(float4_e2m1fn, "float4_e2m1fn")]
[email protected]("np_dtype,dtype_str", nv_fp4_dtypes)
def test_create_nv_fp4_nd_array(np_dtype, dtype_str):
if np_dtype is None:
"""Skip test if ml_dtypes is not installed"""
@@ -42,6 +44,7 @@ def test_create_nv_fp4_nd_array(np_dtype, dtype_str):
np.testing.assert_equal(x_nd.numpy(), x)
[email protected]("np_dtype,dtype_str", nv_fp4_dtypes)
def test_nv_fp4_buffer(np_dtype, dtype_str):
m = te.size_var("m")
n = te.size_var("n")
diff --git a/tests/python/ir/test_datatype_nv_fp8.py
b/tests/python/ir/test_datatype_nv_fp8.py
index 6a077d28d5..1ac9720a13 100644
--- a/tests/python/ir/test_datatype_nv_fp8.py
+++ b/tests/python/ir/test_datatype_nv_fp8.py
@@ -16,6 +16,7 @@
# under the License.
# ruff: noqa: E501, F401
import numpy as np
+import pytest
import tvm
import tvm.testing
@@ -69,7 +70,7 @@ def fp8_unary(dtype: str):
return func
-np_dtype, dtype_str = tvm.testing.parameters(
+nv_fp8_dtypes = [
(float8_e3m4, "float8_e3m4"),
(float8_e4m3, "float8_e4m3"),
(float8_e4m3b11fnuz, "float8_e4m3b11fnuz"),
@@ -78,9 +79,10 @@ np_dtype, dtype_str = tvm.testing.parameters(
(float8_e5m2, "float8_e5m2"),
(float8_e5m2fnuz, "float8_e5m2fnuz"),
(float8_e8m0fnu, "float8_e8m0fnu"),
-)
+]
[email protected]("np_dtype,dtype_str", nv_fp8_dtypes)
def test_create_nv_fp8_nd_array(np_dtype, dtype_str):
if np_dtype is None:
"""Skip test if ml_dtypes is not installed"""
@@ -91,6 +93,7 @@ def test_create_nv_fp8_nd_array(np_dtype, dtype_str):
np.testing.assert_equal(x_nd.numpy(), x)
[email protected]("np_dtype,dtype_str", nv_fp8_dtypes)
def test_fp8_unary_op(np_dtype, dtype_str):
func = fp8_unary(dtype_str)
if not tvm.testing.device_enabled("llvm"):
@@ -119,6 +122,7 @@ def test_fp8_unary_op(np_dtype, dtype_str):
np.testing.assert_equal(args[6].numpy(), expected_a_roundtrip)
[email protected]("np_dtype,dtype_str", nv_fp8_dtypes)
def test_nv_fp8_buffer(np_dtype, dtype_str):
m = te.size_var("m")
n = te.size_var("n")
diff --git a/tests/python/relax/test_op_binary.py
b/tests/python/relax/test_op_binary.py
index 94a4231e78..5823044d85 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -66,7 +66,7 @@ def _check_inference(bb: relax.BlockBuilder, call:
relax.Call, expected_sinfo: r
tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
-(binary_arith_op, tir_arith_op) = tvm.testing.parameters(
+binary_arith_ops = [
(relax.op.add, tirx.Add),
(relax.op.divide, tirx.Div),
(relax.op.floor_divide, tirx.FloorDiv),
@@ -78,9 +78,10 @@ def _check_inference(bb: relax.BlockBuilder, call:
relax.Call, expected_sinfo: r
(relax.op.minimum, tirx.Min),
(relax.op.mod, tirx.Mod),
(relax.op.floor_mod, tirx.FloorMod),
-)
+]
[email protected]("binary_arith_op", [row[0] for row in
binary_arith_ops])
def test_binary_arith_infer_struct_info(binary_arith_op: Callable):
bb = relax.BlockBuilder()
vdevice0 = VDevice("llvm")
@@ -125,6 +126,7 @@ def test_binary_arith_infer_struct_info(binary_arith_op:
Callable):
)
[email protected]("binary_arith_op", [row[0] for row in
binary_arith_ops])
def
test_infer_struct_info_binary_arith_prim_value_with_tensor(binary_arith_op:
Callable):
bb = relax.BlockBuilder()
@@ -134,6 +136,7 @@ def
test_infer_struct_info_binary_arith_prim_value_with_tensor(binary_arith_op:
_check_inference(bb, binary_arith_op(x, y), relax.TensorStructInfo((2, 3),
"float32"))
[email protected]("binary_arith_op", [row[0] for row in
binary_arith_ops])
def
test_infer_struct_info_binary_arith_prim_value_with_prim_value(binary_arith_op:
Callable):
bb = relax.BlockBuilder()
@@ -143,6 +146,7 @@ def
test_infer_struct_info_binary_arith_prim_value_with_prim_value(binary_arith_
_check_inference(bb, binary_arith_op(x, y),
relax.PrimStructInfo("float32"))
[email protected]("binary_arith_op,tir_arith_op", binary_arith_ops)
@pytest.mark.xfail(reason="Not yet implemented")
def test_infer_struct_info_binary_arith_known_prim_value_with_prim_value(
binary_arith_op: Callable, tir_arith_op
@@ -159,16 +163,17 @@ def
test_infer_struct_info_binary_arith_known_prim_value_with_prim_value(
_check_inference(bb, binary_arith_op(y, x),
relax.PrimStructInfo(value=tir_y + tir_x))
-(binary_cmp_op, tir_cmp_op) = tvm.testing.parameters(
+binary_cmp_ops = [
(relax.op.equal, tirx.EQ),
(relax.op.greater, tirx.GT),
(relax.op.greater_equal, tirx.GE),
(relax.op.less, tirx.LT),
(relax.op.less_equal, tirx.LE),
(relax.op.not_equal, tirx.NE),
-)
+]
[email protected]("binary_cmp_op", [row[0] for row in binary_cmp_ops])
def test_binary_cmp_infer_struct_info(binary_cmp_op: Callable):
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
@@ -185,6 +190,7 @@ def test_binary_cmp_infer_struct_info(binary_cmp_op:
Callable):
_check_inference(bb, binary_cmp_op(x, y2), relax.TensorStructInfo((2, 3),
"bool", vdev0))
[email protected]("binary_cmp_op", [row[0] for row in binary_cmp_ops])
def test_infer_struct_info_binary_cmp_prim_value_to_tensor(binary_cmp_op:
Callable):
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((2, 3), "float32"))
@@ -193,6 +199,7 @@ def
test_infer_struct_info_binary_cmp_prim_value_to_tensor(binary_cmp_op: Callab
_check_inference(bb, binary_cmp_op(y, x), relax.TensorStructInfo((2, 3),
"bool"))
[email protected]("binary_cmp_op", [row[0] for row in binary_cmp_ops])
def test_infer_struct_info_binary_cmp_prim_value_to_prim_value(binary_cmp_op:
Callable):
bb = relax.BlockBuilder()
x = relax.Var("x", R.Prim("float32"))
@@ -201,6 +208,7 @@ def
test_infer_struct_info_binary_cmp_prim_value_to_prim_value(binary_cmp_op: Ca
_check_inference(bb, binary_cmp_op(y, x), relax.PrimStructInfo("bool"))
[email protected]("binary_cmp_op,tir_cmp_op", binary_cmp_ops)
@pytest.mark.xfail(reason="Not yet implemented")
def test_infer_struct_info_binary_cmp_known_prim_value_to_prim_value(
binary_cmp_op: Callable, tir_cmp_op
@@ -217,6 +225,7 @@ def
test_infer_struct_info_binary_cmp_known_prim_value_to_prim_value(
_check_inference(bb, binary_cmp_op(y, x),
relax.PrimStructInfo(value=tir_cmp_op(tir_y, tir_x)))
[email protected]("binary_arith_op", [row[0] for row in
binary_arith_ops])
def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable):
bb = relax.BlockBuilder()
m = tirx.Var("m", "int64")
@@ -245,6 +254,7 @@ def
test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable):
_check_inference(bb, binary_arith_op(x4, y4),
relax.TensorStructInfo(dtype="float32", ndim=-1))
[email protected]("binary_arith_op", [row[0] for row in
binary_arith_ops])
def test_binary_infer_struct_info_shape_var(binary_arith_op: Callable):
bb = relax.BlockBuilder()
s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=2))
@@ -266,6 +276,7 @@ def
test_binary_infer_struct_info_shape_var(binary_arith_op: Callable):
_check_inference(bb, binary_arith_op(x, y4),
relax.TensorStructInfo(dtype="float32"))
[email protected]("binary_arith_op", [row[0] for row in
binary_arith_ops])
def test_binary_arith_infer_struct_info_more_input_dtype(binary_arith_op:
Callable):
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3), "float64"))
@@ -280,6 +291,7 @@ def
test_binary_arith_infer_struct_info_more_input_dtype(binary_arith_op: Callab
_check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo((2,
3), "int64"))
[email protected]("binary_arith_op", [row[0] for row in
binary_arith_ops])
def test_binary_infer_struct_info_shape_unequal_const_int(binary_arith_op:
Callable):
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
@@ -288,6 +300,7 @@ def
test_binary_infer_struct_info_shape_unequal_const_int(binary_arith_op: Calla
bb.normalize(binary_arith_op(x0, y0))
[email protected]("binary_arith_op", [row[0] for row in
binary_arith_ops])
def test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op:
Callable):
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((2, 3), "float32"))
@@ -296,6 +309,7 @@ def
test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op: Callable
bb.normalize(binary_arith_op(x, y))
[email protected]("binary_arith_op", [row[0] for row in
binary_arith_ops])
def test_binary_arith_infer_struct_info_vdevice_mismatch(binary_arith_op:
Callable):
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((2, 3), "float32", VDevice("llvm")))
@@ -304,6 +318,7 @@ def
test_binary_arith_infer_struct_info_vdevice_mismatch(binary_arith_op: Callab
bb.normalize(binary_arith_op(x, y))
[email protected]("binary_arith_op", [row[0] for row in
binary_arith_ops])
def test_binary_wrong_input_number(binary_arith_op: Callable):
x = relax.Var("x", R.Tensor((2, 3), "float32"))
@@ -315,6 +330,7 @@ def test_binary_wrong_input_number(binary_arith_op:
Callable):
binary_arith_op(x, x, x, x)
[email protected]("binary_arith_op", [row[0] for row in
binary_arith_ops])
def test_binary_infer_struct_info_wrong_input_type(binary_arith_op: Callable):
bb = relax.BlockBuilder()
x0 = relax.Var("x", relax.ShapeStructInfo((2, 3)))
diff --git a/tests/python/relax/test_op_gradient_numeric.py
b/tests/python/relax/test_op_gradient_numeric.py
index 3eb77f9412..3ed2bbf279 100644
--- a/tests/python/relax/test_op_gradient_numeric.py
+++ b/tests/python/relax/test_op_gradient_numeric.py
@@ -207,19 +207,20 @@ def relax_check_gradients(
##################### Unary #####################
-unary_op_func, can_be_neg = tvm.testing.parameters(
- (relax.op.abs, True),
- (relax.op.cos, True),
- (relax.op.exp, True),
- (relax.op.log, False),
- (relax.op.negative, True),
- (relax.op.sigmoid, True),
- (relax.op.sin, True),
- (relax.op.sqrt, False),
- (relax.op.tanh, True),
[email protected](
+ "unary_op_func,can_be_neg",
+ [
+ (relax.op.abs, True),
+ (relax.op.cos, True),
+ (relax.op.exp, True),
+ (relax.op.log, False),
+ (relax.op.negative, True),
+ (relax.op.sigmoid, True),
+ (relax.op.sin, True),
+ (relax.op.sqrt, False),
+ (relax.op.tanh, True),
+ ],
)
-
-
@tvm.testing.parametrize_targets("llvm")
def test_unary(target, dev, unary_op_func, can_be_neg):
(low, high) = (-1, 1) if can_be_neg else (0.1, 1)
@@ -230,15 +231,16 @@ def test_unary(target, dev, unary_op_func, can_be_neg):
##################### Binary #####################
-(binary_arith_op_func,) = tvm.testing.parameters(
- (relax.op.add,),
- (relax.op.subtract,),
- (relax.op.multiply,),
- (relax.op.divide,),
- (relax.op.power,),
[email protected](
+ "binary_arith_op_func",
+ [
+ relax.op.add,
+ relax.op.subtract,
+ relax.op.multiply,
+ relax.op.divide,
+ relax.op.power,
+ ],
)
-
-
@tvm.testing.parametrize_targets("llvm")
def test_binary_arith(target, dev, binary_arith_op_func):
data1_numpy = np.random.uniform(1, 2, (3, 3)).astype(np.float32)
@@ -246,12 +248,7 @@ def test_binary_arith(target, dev, binary_arith_op_func):
relax_check_gradients(binary_arith_op_func, [data1_numpy, data2_numpy],
target, dev)
-(binary_minmax_op_func,) = tvm.testing.parameters(
- (relax.op.maximum,),
- (relax.op.minimum,),
-)
-
-
[email protected]("binary_minmax_op_func", [relax.op.maximum,
relax.op.minimum])
@tvm.testing.parametrize_targets("llvm")
def test_binary_minmax(target, dev, binary_minmax_op_func):
# Checking numerical gradient of min and max requires data1_numpy[i] !=
data2_numpy[i]
@@ -264,16 +261,17 @@ def test_binary_minmax(target, dev,
binary_minmax_op_func):
relax_check_gradients(binary_minmax_op_func, [data1_numpy, data2_numpy],
target, dev)
-(binary_cmp_op_func,) = tvm.testing.parameters(
- (relax.op.equal,),
- (relax.op.greater,),
- (relax.op.greater_equal,),
- (relax.op.less,),
- (relax.op.less_equal,),
- (relax.op.not_equal,),
[email protected](
+ "binary_cmp_op_func",
+ [
+ relax.op.equal,
+ relax.op.greater,
+ relax.op.greater_equal,
+ relax.op.less,
+ relax.op.less_equal,
+ relax.op.not_equal,
+ ],
)
-
-
@tvm.testing.parametrize_targets("llvm")
def test_binary_cmp(target, dev, binary_cmp_op_func):
data1_numpy = np.random.uniform(1, 2, (3, 3)).astype(np.float32)
@@ -286,12 +284,7 @@ def test_binary_cmp(target, dev, binary_cmp_op_func):
##################### Create #####################
-(like_op_func,) = tvm.testing.parameters(
- (relax.op.zeros_like,),
- (relax.op.ones_like,),
-)
-
-
[email protected]("like_op_func", [relax.op.zeros_like,
relax.op.ones_like])
@tvm.testing.parametrize_targets("llvm")
def test_ones_zeros_like(target, dev, like_op_func):
data_numpy = np.random.uniform(-1, 1, (3, 3)).astype(np.float32)
@@ -307,12 +300,7 @@ def test_full_like(target, dev):
)
-(create_op_func,) = tvm.testing.parameters(
- (relax.op.zeros,),
- (relax.op.ones,),
-)
-
-
[email protected]("create_op_func", [relax.op.zeros, relax.op.ones])
@tvm.testing.parametrize_targets("llvm")
def test_ones_zeros(target, dev, create_op_func):
relax_check_gradients(
@@ -688,16 +676,17 @@ def test_cross_entropy_with_logits_batch(target, dev):
)
-(nll_reduction, nll_weighted, nll_ignore_index) = tvm.testing.parameters(
- ("mean", True, -1),
- ("sum", True, -1),
- ("none", True, -1),
- ("mean", True, 1),
- ("mean", True, 1),
- ("mean", False, 1),
[email protected](
+ "nll_reduction,nll_weighted,nll_ignore_index",
+ [
+ ("mean", True, -1),
+ ("sum", True, -1),
+ ("none", True, -1),
+ ("mean", True, 1),
+ ("mean", True, 1),
+ ("mean", False, 1),
+ ],
)
-
-
@tvm.testing.parametrize_targets("llvm")
def test_nll_loss(target, dev, nll_reduction, nll_weighted, nll_ignore_index):
data1_numpy = np.random.uniform(0, 16, (2, 3, 4)).astype(np.float32)
@@ -721,13 +710,14 @@ def test_nll_loss(target, dev, nll_reduction,
nll_weighted, nll_ignore_index):
)
-(nll_reduction1, nll_weighted1, nll_ignore_index1) = tvm.testing.parameters(
- ("mean", True, -1),
- ("sum", True, -1),
- ("none", True, -1),
[email protected](
+ "nll_reduction1,nll_weighted1,nll_ignore_index1",
+ [
+ ("mean", True, -1),
+ ("sum", True, -1),
+ ("none", True, -1),
+ ],
)
-
-
@tvm.testing.parametrize_targets("llvm")
def test_nll_loss_no_batch(target, dev, nll_reduction1, nll_weighted1,
nll_ignore_index1):
data1_numpy = np.random.uniform(0, 16, (3,)).astype(np.float32)
@@ -749,40 +739,41 @@ def test_nll_loss_no_batch(target, dev, nll_reduction1,
nll_weighted1, nll_ignor
)
-(c2d_shape1, c2d_shape2, c2d_kwargs) = tvm.testing.parameters(
- (
- (3, 2, 10, 10),
- (3, 2, 3, 3),
- {},
- ),
- (
- (3, 2, 10, 10),
- (3, 2, 1, 2),
- {},
- ),
- (
- (3, 2, 10, 10),
- (3, 2, 3, 3),
- {"strides": (2, 2), "padding": (3, 2), "dilation": (1, 1)},
- ),
- (
- (3, 2, 10, 10),
- (3, 2, 3, 3),
- {"strides": (2, 1), "padding": (2, 2), "dilation": (1, 1)},
- ),
- (
- (3, 6, 10, 10),
- (4, 3, 3, 3),
- {"groups": 2},
- ),
- (
- (3, 2, 10, 10),
- (4, 1, 3, 3),
- {"groups": 2, "strides": (2, 2), "padding": (2, 2), "dilation": (1,
1)},
- ),
[email protected](
+ "c2d_shape1,c2d_shape2,c2d_kwargs",
+ [
+ (
+ (3, 2, 10, 10),
+ (3, 2, 3, 3),
+ {},
+ ),
+ (
+ (3, 2, 10, 10),
+ (3, 2, 1, 2),
+ {},
+ ),
+ (
+ (3, 2, 10, 10),
+ (3, 2, 3, 3),
+ {"strides": (2, 2), "padding": (3, 2), "dilation": (1, 1)},
+ ),
+ (
+ (3, 2, 10, 10),
+ (3, 2, 3, 3),
+ {"strides": (2, 1), "padding": (2, 2), "dilation": (1, 1)},
+ ),
+ (
+ (3, 6, 10, 10),
+ (4, 3, 3, 3),
+ {"groups": 2},
+ ),
+ (
+ (3, 2, 10, 10),
+ (4, 1, 3, 3),
+ {"groups": 2, "strides": (2, 2), "padding": (2, 2), "dilation":
(1, 1)},
+ ),
+ ],
)
-
-
@tvm.testing.parametrize_targets("llvm")
def test_conv2d(target, dev, c2d_shape1, c2d_shape2, c2d_kwargs):
import pytest
@@ -799,7 +790,7 @@ def test_conv2d(target, dev, c2d_shape1, c2d_shape2,
c2d_kwargs):
)
-(pool_size, pool_kwargs) = tvm.testing.parameters(
+pool_params = [
(
(3, 3),
{},
@@ -818,9 +809,10 @@ def test_conv2d(target, dev, c2d_shape1, c2d_shape2,
c2d_kwargs):
"count_include_pad": True,
},
),
-)
+]
[email protected]("pool_size,pool_kwargs", pool_params)
@tvm.testing.parametrize_targets("llvm")
def test_max_pool2d(target, dev, pool_size, pool_kwargs):
data_numpy = np.random.uniform(0, 3, size=(3, 2, 10,
10)).astype(np.float32)
@@ -834,6 +826,7 @@ def test_max_pool2d(target, dev, pool_size, pool_kwargs):
)
[email protected]("pool_size,pool_kwargs", pool_params)
@tvm.testing.parametrize_targets("llvm")
def test_avg_pool2d(target, dev, pool_size, pool_kwargs):
data_numpy = np.random.uniform(0, 3, size=(3, 2, 10,
10)).astype(np.float32)
diff --git a/tests/python/relax/test_op_search.py
b/tests/python/relax/test_op_search.py
index 11733da2f4..252f5db241 100644
--- a/tests/python/relax/test_op_search.py
+++ b/tests/python/relax/test_op_search.py
@@ -285,9 +285,10 @@ def test_where_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.where(cond1, x1, y0))
-(argmax_argmin_op,) = tvm.testing.parameters((relax.op.argmax,),
(relax.op.argmin,))
+argmax_argmin_ops = [relax.op.argmax, relax.op.argmin]
[email protected]("argmax_argmin_op", argmax_argmin_ops)
def test_argmax_argmin_infer_struct_info(argmax_argmin_op: Callable):
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
@@ -360,6 +361,7 @@ def test_argmax_argmin_infer_struct_info(argmax_argmin_op:
Callable):
_check_inference(bb, argmax_argmin_op(x0, axis=-1),
relax.TensorStructInfo((2, 3, 4), "int64"))
[email protected]("argmax_argmin_op", argmax_argmin_ops)
def test_argmax_argmin_infer_struct_info_shape_symbolic(argmax_argmin_op:
Callable):
bb = relax.BlockBuilder()
a = tirx.Var("a", "int64")
@@ -382,6 +384,7 @@ def
test_argmax_argmin_infer_struct_info_shape_symbolic(argmax_argmin_op: Callab
)
[email protected]("argmax_argmin_op", argmax_argmin_ops)
def test_argmax_argmin_infer_struct_info_shape_var(argmax_argmin_op: Callable):
bb = relax.BlockBuilder()
s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
@@ -409,6 +412,7 @@ def
test_argmax_argmin_infer_struct_info_shape_var(argmax_argmin_op: Callable):
)
[email protected]("argmax_argmin_op", argmax_argmin_ops)
def test_argmax_argmin_infer_struct_info_more_input_dtype(argmax_argmin_op:
Callable):
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16"))
@@ -418,6 +422,7 @@ def
test_argmax_argmin_infer_struct_info_more_input_dtype(argmax_argmin_op: Call
_check_inference(bb, argmax_argmin_op(x1), relax.TensorStructInfo((),
"int64"))
[email protected]("argmax_argmin_op", argmax_argmin_ops)
def test_argmax_argmin_infer_struct_info_axis_out_of_range(argmax_argmin_op:
Callable):
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int64"))
@@ -433,6 +438,7 @@ def
test_argmax_argmin_infer_struct_info_axis_out_of_range(argmax_argmin_op: Cal
bb.normalize(argmax_argmin_op(x1, axis=-5))
[email protected]("argmax_argmin_op", argmax_argmin_ops)
def test_argmax_argmin_infer_struct_info_wrong_input_type(argmax_argmin_op:
Callable):
bb = relax.BlockBuilder()
x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
diff --git a/tests/python/relax/test_op_statistical.py
b/tests/python/relax/test_op_statistical.py
index 3f2d5bb1ee..5905843267 100644
--- a/tests/python/relax/test_op_statistical.py
+++ b/tests/python/relax/test_op_statistical.py
@@ -208,12 +208,13 @@ def test_statistical_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.variance(x1))
-(scan_op,) = tvm.testing.parameters(
- (relax.op.cumprod,),
- (relax.op.cumsum,),
-)
+scan_ops = [
+ relax.op.cumprod,
+ relax.op.cumsum,
+]
[email protected]("scan_op", scan_ops)
def test_scan_op_infer_struct_info(scan_op: Callable):
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
@@ -238,6 +239,7 @@ def test_scan_op_infer_struct_info(scan_op: Callable):
)
[email protected]("scan_op", scan_ops)
def test_scan_op_infer_struct_info_shape_symbolic(scan_op: Callable):
bb = relax.BlockBuilder()
a = tirx.Var("a", "int64")
@@ -249,6 +251,7 @@ def test_scan_op_infer_struct_info_shape_symbolic(scan_op:
Callable):
_check_inference(bb, scan_op(x), relax.TensorStructInfo((a * b * c,),
"float32"))
[email protected]("scan_op", scan_ops)
def test_scan_op_infer_struct_info_more_input_dtype(scan_op: Callable):
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
@@ -258,6 +261,7 @@ def
test_scan_op_infer_struct_info_more_input_dtype(scan_op: Callable):
_check_inference(bb, scan_op(x1, axis=1), relax.TensorStructInfo((2, 3,
4), "int8"))
[email protected]("scan_op", scan_ops)
def test_scan_op_wrong_input_number(scan_op: Callable):
x = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
y = relax.Var("y", R.Tensor((2, 3, 4), "float32"))
@@ -266,6 +270,7 @@ def test_scan_op_wrong_input_number(scan_op: Callable):
scan_op(x, y)
[email protected]("scan_op", scan_ops)
def test_scan_opinfer_struct_info_wrong_input_type(scan_op: Callable):
bb = relax.BlockBuilder()
x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
diff --git a/tests/python/relax/test_op_unary.py
b/tests/python/relax/test_op_unary.py
index 018f042db1..0527c2e9de 100644
--- a/tests/python/relax/test_op_unary.py
+++ b/tests/python/relax/test_op_unary.py
@@ -68,7 +68,7 @@ def _check_inference(bb: relax.BlockBuilder, call:
relax.Call, expected_sinfo: r
tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
-unary_arith_op, require_float_dtype = tvm.testing.parameters(
+unary_arith_ops = [
(relax.op.abs, False),
(relax.op.acos, True),
(relax.op.acosh, True),
@@ -93,9 +93,10 @@ unary_arith_op, require_float_dtype = tvm.testing.parameters(
(relax.op.sqrt, True),
(relax.op.tan, True),
(relax.op.tanh, True),
-)
+]
[email protected]("unary_arith_op", [row[0] for row in unary_arith_ops])
def test_unary_arith_infer_struct_info(unary_arith_op: Callable):
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
@@ -114,6 +115,7 @@ def test_unary_arith_infer_struct_info(unary_arith_op:
Callable):
_check_inference(bb, unary_arith_op(x4), relax.TensorStructInfo(dtype=""))
[email protected]("unary_arith_op", [row[0] for row in unary_arith_ops])
def test_unary_arith_infer_struct_info_shape_symbolic(unary_arith_op:
Callable):
bb = relax.BlockBuilder()
m = tirx.Var("m", "int64")
@@ -125,6 +127,7 @@ def
test_unary_arith_infer_struct_info_shape_symbolic(unary_arith_op: Callable):
_check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo((4, n),
"float32"))
[email protected]("unary_arith_op", [row[0] for row in unary_arith_ops])
def test_unary_arith_infer_struct_info_shape_var(unary_arith_op: Callable):
bb = relax.BlockBuilder()
s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
@@ -136,6 +139,7 @@ def
test_unary_arith_infer_struct_info_shape_var(unary_arith_op: Callable):
_check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo(s1,
"float32"))
[email protected]("unary_arith_op,require_float_dtype", unary_arith_ops)
def test_unary_arith_infer_struct_info_more_input_dtype(
unary_arith_op: Callable, require_float_dtype: bool
):
@@ -152,6 +156,7 @@ def test_unary_arith_infer_struct_info_more_input_dtype(
_check_inference(bb, unary_arith_op(x2), relax.TensorStructInfo((2, 3),
"int64"))
[email protected]("unary_arith_op,require_float_dtype", unary_arith_ops)
def test_unary_arith_infer_struct_info_invalid_input_dtype(
unary_arith_op: Callable, require_float_dtype: bool
):
@@ -168,6 +173,7 @@ def test_unary_arith_infer_struct_info_invalid_input_dtype(
bb.normalize(unary_arith_op(x1))
[email protected]("unary_arith_op", [row[0] for row in unary_arith_ops])
def test_unary_arith_wrong_input_number(unary_arith_op: Callable):
x = relax.Var("x", R.Tensor((2, 3), "float32"))
@@ -177,6 +183,7 @@ def test_unary_arith_wrong_input_number(unary_arith_op:
Callable):
unary_arith_op(x, x, x)
[email protected]("unary_arith_op", [row[0] for row in unary_arith_ops])
def test_unary_arith_infer_struct_info_wrong_input_type(unary_arith_op:
Callable):
bb = relax.BlockBuilder()
x0 = relax.Var("x", relax.ShapeStructInfo((2, 3)))
diff --git a/tests/python/relax/test_training_optimizer_numeric.py
b/tests/python/relax/test_training_optimizer_numeric.py
index 01060ffa2c..4b9e40b717 100644
--- a/tests/python/relax/test_training_optimizer_numeric.py
+++ b/tests/python/relax/test_training_optimizer_numeric.py
@@ -19,6 +19,7 @@
from collections.abc import Callable
import numpy as np
+import pytest
import tvm_ffi
import tvm
@@ -79,12 +80,13 @@ def _test_optimizer(target, dev, np_func, opt_type, *args,
**kwargs):
_assert_run_result_same(tvm_func, np_func, [param_arr, grad_arr,
state_arr])
-lr, weight_decay = tvm.testing.parameters(
- (0.01, 0),
- (0.01, 0.02),
[email protected](
+ "lr,weight_decay",
+ [
+ (0.01, 0),
+ (0.01, 0.02),
+ ],
)
-
-
@tvm.testing.parametrize_targets("llvm")
def test_sgd(target, dev, lr, weight_decay):
def np_func(param_tuple, grad_tuple, state_tuple):
@@ -100,13 +102,14 @@ def test_sgd(target, dev, lr, weight_decay):
_test_optimizer(target, dev, np_func, SGD, lr, weight_decay)
-lr, momentum, dampening, weight_decay, nesterov = tvm.testing.parameters(
- (0.01, 0.9, 0, 0, False),
- (0.01, 0.9, 0.85, 0.02, False),
- (0.01, 0.9, 0.85, 0.02, True),
[email protected](
+ "lr,momentum,dampening,weight_decay,nesterov",
+ [
+ (0.01, 0.9, 0, 0, False),
+ (0.01, 0.9, 0.85, 0.02, False),
+ (0.01, 0.9, 0.85, 0.02, True),
+ ],
)
-
-
@tvm.testing.parametrize_targets("llvm")
def test_momentum_sgd(target, dev, lr, momentum, dampening, weight_decay,
nesterov):
def np_func(param_tuple, grad_tuple, state_tuple):
@@ -134,12 +137,13 @@ def test_momentum_sgd(target, dev, lr, momentum,
dampening, weight_decay, nester
)
-lr, betas, eps, weight_decay = tvm.testing.parameters(
- (0.01, (0.9, 0.999), 1e-08, 0),
- (0.01, (0.8, 0.85), 1e-07, 0.1),
[email protected](
+ "lr,betas,eps,weight_decay",
+ [
+ (0.01, (0.9, 0.999), 1e-08, 0),
+ (0.01, (0.8, 0.85), 1e-07, 0.1),
+ ],
)
-
-
@tvm.testing.parametrize_targets("llvm")
def test_adam(target, dev, lr, betas, eps, weight_decay):
def np_func(param_tuple, grad_tuple, state_tuple):
diff --git a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
index 830d2cf7d4..de13c668b2 100644
--- a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
+++ b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
@@ -17,6 +17,8 @@
from collections.abc import Callable
+import pytest
+
import tvm
import tvm.script
import tvm.testing
@@ -35,34 +37,35 @@ def _check(
tvm.ir.assert_structural_equal(parsed, expect)
-(unary_arith_op,) = tvm.testing.parameters(
- (relax.op.abs,),
- (relax.op.acos,),
- (relax.op.acosh,),
- (relax.op.asin,),
- (relax.op.asinh,),
- (relax.op.atan,),
- (relax.op.atanh,),
- (relax.op.ceil,),
- (relax.op.cos,),
- (relax.op.cosh,),
- (relax.op.exp,),
- (relax.op.floor,),
- (relax.op.log,),
- (relax.op.negative,),
- (relax.op.round,),
- (relax.op.rsqrt,),
- (relax.op.sigmoid,),
- (relax.op.sign,),
- (relax.op.sin,),
- (relax.op.sinh,),
- (relax.op.square,),
- (relax.op.sqrt,),
- (relax.op.tan,),
- (relax.op.tanh,),
[email protected](
+ "unary_arith_op",
+ [
+ relax.op.abs,
+ relax.op.acos,
+ relax.op.acosh,
+ relax.op.asin,
+ relax.op.asinh,
+ relax.op.atan,
+ relax.op.atanh,
+ relax.op.ceil,
+ relax.op.cos,
+ relax.op.cosh,
+ relax.op.exp,
+ relax.op.floor,
+ relax.op.log,
+ relax.op.negative,
+ relax.op.round,
+ relax.op.rsqrt,
+ relax.op.sigmoid,
+ relax.op.sign,
+ relax.op.sin,
+ relax.op.sinh,
+ relax.op.square,
+ relax.op.sqrt,
+ relax.op.tan,
+ relax.op.tanh,
+ ],
)
-
-
def test_unary_arith(unary_arith_op: Callable):
@R.function
def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
@@ -78,13 +81,14 @@ def test_unary_arith(unary_arith_op: Callable):
_check(foo, bb.get()["foo"])
-(unary_check_op,) = tvm.testing.parameters(
- (relax.op.isfinite,),
- (relax.op.isinf,),
- (relax.op.isnan,),
[email protected](
+ "unary_check_op",
+ [
+ relax.op.isfinite,
+ relax.op.isinf,
+ relax.op.isnan,
+ ],
)
-
-
def test_unary_check(unary_check_op: Callable):
@R.function
def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"):
@@ -100,18 +104,19 @@ def test_unary_check(unary_check_op: Callable):
_check(foo, bb.get()["foo"])
-(binary_arith_op,) = tvm.testing.parameters(
- (relax.op.add,),
- (relax.op.divide,),
- (relax.op.floor_divide,),
- (relax.op.multiply,),
- (relax.op.power,),
- (relax.op.subtract,),
- (relax.op.maximum,),
- (relax.op.minimum,),
[email protected](
+ "binary_arith_op",
+ [
+ relax.op.add,
+ relax.op.divide,
+ relax.op.floor_divide,
+ relax.op.multiply,
+ relax.op.power,
+ relax.op.subtract,
+ relax.op.maximum,
+ relax.op.minimum,
+ ],
)
-
-
def test_binary_arith(binary_arith_op: Callable):
@R.function
def foo(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32")) ->
R.Tensor(
@@ -130,16 +135,17 @@ def test_binary_arith(binary_arith_op: Callable):
_check(foo, bb.get()["foo"])
-(binary_cmp_op,) = tvm.testing.parameters(
- (relax.op.equal,),
- (relax.op.greater,),
- (relax.op.greater_equal,),
- (relax.op.less,),
- (relax.op.less_equal,),
- (relax.op.not_equal,),
[email protected](
+ "binary_cmp_op",
+ [
+ relax.op.equal,
+ relax.op.greater,
+ relax.op.greater_equal,
+ relax.op.less,
+ relax.op.less_equal,
+ relax.op.not_equal,
+ ],
)
-
-
def test_binary_cmp(binary_cmp_op: Callable):
@R.function
def foo(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32")) ->
R.Tensor(
diff --git a/tests/python/s_tir/base/test_tir_te_extern_primfunc.py
b/tests/python/s_tir/base/test_tir_te_extern_primfunc.py
index 586d4647b7..d1cd710ed5 100644
--- a/tests/python/s_tir/base/test_tir_te_extern_primfunc.py
+++ b/tests/python/s_tir/base/test_tir_te_extern_primfunc.py
@@ -172,19 +172,22 @@ def verify_func_4(module):
tvm.testing.assert_allclose(a_np + 1, f.numpy(), rtol=1e-4)
-class TestPrimFuncs:
- func, params, verify = tvm.testing.parameters(
- [func_1, ("A"), verify_func_1],
- [func_2, ("C", "D"), verify_func_2],
- [func_3, ("C", "A", "D", "E"), verify_func_3],
- [func_4, ("C", "A", "D", "E"), verify_func_4],
- )
+_primfunc_cases = [
+ [func_1, ("A"), verify_func_1],
+ [func_2, ("C", "D"), verify_func_2],
+ [func_3, ("C", "A", "D", "E"), verify_func_3],
+ [func_4, ("C", "A", "D", "E"), verify_func_4],
+]
+
+class TestPrimFuncs:
+ @pytest.mark.parametrize("func,verify", [(case[0], case[2]) for case in
_primfunc_cases])
def test_primfunc_call(self, func, verify):
target = tvm.target.Target("llvm")
func = tvm.compile(func, target=target)
verify(func)
+ @pytest.mark.parametrize("func,params,verify", _primfunc_cases)
def test_te_extern_call(self, func, params, verify):
ir_mod = tvm.IRModule.from_expr(func.with_attr("global_symbol",
"main"))
prim_func = ir_mod["main"]
diff --git a/tests/python/target/test_arm_target.py
b/tests/python/target/test_arm_target.py
index 6b964f13fe..862f41e146 100644
--- a/tests/python/target/test_arm_target.py
+++ b/tests/python/target/test_arm_target.py
@@ -28,63 +28,6 @@ from tvm.script import tirx as T
from tvm.target import codegen
from tvm.testing import env
-llvm_version, arm_target, input_dtype, kernel_dtype, is_supported =
tvm.testing.parameters(
- # Testing mcpu type
- (8, {"kind": "c", "mcpu": "cortex-m4"}, "int8", "int8", False),
- (8, {"kind": "c", "mcpu": "cortex-m7"}, "int8", "int8", False),
- (8, {"kind": "c", "mcpu": "cortex-m33"}, "int8", "int8", False),
- (8, {"kind": "c", "mcpu": "cortex-m55"}, "int8", "int8", False),
- (8, {"kind": "c", "mcpu": "cortex-m3"}, "int8", "int8", False),
- (
- 7,
- {"kind": "llvm", "mtriple": "arm-linux-gnueabi", "mattr": ["+neon"]},
- "int8",
- "int8",
- False,
- ),
- (8, {"kind": "llvm", "mtriple": "arm-linux-gnueabi", "mattr": ["+neon"]},
"int8", "int8", True),
- (9, {"kind": "llvm", "mtriple": "arm-linux-gnueabi", "mattr": ["+neon"]},
"int8", "int8", True),
- (8, {"kind": "llvm", "mtriple": "arm-linux-gnueabi"}, "int8", "int8",
False),
- (
- 7,
- {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+v8.4a",
"+dotprod"]},
- "int8",
- "int8",
- False,
- ),
- (
- 8,
- {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+v8.4a",
"+dotprod"]},
- "int8",
- "int8",
- True,
- ),
- (9, {"kind": "llvm", "mtriple": "arm-linux-gnueabi", "mattr": ["+neon"]},
"int8", "int8", True),
- (8, {"kind": "llvm", "mtriple": "aarch64-linux-gnu"}, "int8", "int8",
True),
- # Testing dtype
- (
- 8,
- {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+neon"]},
- "int16",
- "int8",
- False,
- ),
- (
- 8,
- {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+neon"]},
- "int8",
- "int16",
- False,
- ),
- (
- 8,
- {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+neon"]},
- "int16",
- "int16",
- False,
- ),
-)
-
@pytest.fixture(scope="session")
def sve_device_vector_length():
diff --git a/tests/python/target/test_llvm_features_info.py
b/tests/python/target/test_llvm_features_info.py
index 7d56767650..2f2f7e0158 100644
--- a/tests/python/target/test_llvm_features_info.py
+++ b/tests/python/target/test_llvm_features_info.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# ruff: noqa: F401
import pytest
import tvm
@@ -51,23 +50,24 @@ def test_llvm_targets(capfd):
assert expected_str in readout_error
-min_llvm_version, llvm_target, cpu_arch, cpu_features, is_supported =
tvm.testing.parameters(
- (-1, "x86_64", "sandybridge", "sse4.1", True),
- (-1, "x86_64", "ivybridge", ["sse4.1", "ssse3"], True),
- (-1, "x86_64", "ivybridge", ["sse4.1", "ssse3", "avx512bw"], False),
- # 32bit vs 64bit
- (-1, "aarch64", "cortex-a55", "neon", True),
- (-1, "aarch64", "cortex-a55", "dotprod", True),
- (-1, "aarch64", "cortex-a55", "dsp", False),
- (-1, "arm", "cortex-a55", "dsp", True),
- (-1, "aarch64", "cortex-a55", ["neon", "dotprod"], True),
- (-1, "aarch64", "cortex-a55", ["neon", "dotprod", "dsp"], False),
- (-1, "arm", "cortex-a55", ["neon", "dotprod"], True),
- (-1, "aarch64", "cortex-a55", ["neon", "dotprod", "dsp"], False),
- (-1, "arm", "cortex-a55", ["neon", "dotprod", "dsp"], True),
[email protected](
+ "min_llvm_version,llvm_target,cpu_arch,cpu_features,is_supported",
+ [
+ (-1, "x86_64", "sandybridge", "sse4.1", True),
+ (-1, "x86_64", "ivybridge", ["sse4.1", "ssse3"], True),
+ (-1, "x86_64", "ivybridge", ["sse4.1", "ssse3", "avx512bw"], False),
+ # 32bit vs 64bit
+ (-1, "aarch64", "cortex-a55", "neon", True),
+ (-1, "aarch64", "cortex-a55", "dotprod", True),
+ (-1, "aarch64", "cortex-a55", "dsp", False),
+ (-1, "arm", "cortex-a55", "dsp", True),
+ (-1, "aarch64", "cortex-a55", ["neon", "dotprod"], True),
+ (-1, "aarch64", "cortex-a55", ["neon", "dotprod", "dsp"], False),
+ (-1, "arm", "cortex-a55", ["neon", "dotprod"], True),
+ (-1, "aarch64", "cortex-a55", ["neon", "dotprod", "dsp"], False),
+ (-1, "arm", "cortex-a55", ["neon", "dotprod", "dsp"], True),
+ ],
)
-
-
def test_target_features(min_llvm_version, llvm_target, cpu_arch,
cpu_features, is_supported):
target = Target({"kind": "llvm", "mtriple": f"{llvm_target}--", "mcpu":
cpu_arch})
diff --git a/tests/python/target/test_riscv_features.py
b/tests/python/target/test_riscv_features.py
index 346e5411c3..b5c5f82353 100644
--- a/tests/python/target/test_riscv_features.py
+++ b/tests/python/target/test_riscv_features.py
@@ -23,8 +23,11 @@ from tvm.target.codegen import llvm_get_vector_width,
target_has_features
LLVM_VERSION = codegen.llvm_version_major()
+
# fmt: off
-min_llvm_version, tvm_target, vec_width = tvm.testing.parameters(
[email protected](
+ "min_llvm_version,tvm_target,vec_width",
+ [
# generic, no vector -> (default 128)
(-1, {"kind": "llvm", "device": "riscv_cpu", "mtriple":
"riscv64-linux-gnu", "mcpu": "generic-rv64", "mattr": ["+i", "+m"]}, 128),
(-1, {"kind": "llvm", "device": "riscv_cpu", "mtriple":
"riscv32-linux-gnu", "mcpu": "generic-rv32", "mattr": ["+64bit", "+a", "+c",
"+d", "+f", "+m"]}, 128),
@@ -39,9 +42,8 @@ min_llvm_version, tvm_target, vec_width =
tvm.testing.parameters(
(17, {"kind": "llvm", "device": "riscv_cpu", "mtriple":
"riscv64-linux-gnu", "mcpu": "sifive-x280"}, 512),
(18, {"kind": "llvm", "device": "riscv_cpu", "mtriple":
"riscv64-linux-gnu", "mcpu": "sifive-p670"}, 128),
(19, {"kind": "llvm", "device": "riscv_cpu", "mtriple":
"riscv64-linux-gnu", "mcpu": "spacemit-x60"}, 256),
+ ],
)
-
-
def test_riscv_rvv_features(min_llvm_version, tvm_target, vec_width):
"""Test RVV features support for different targets.
diff --git a/tests/python/target/test_x86_features.py
b/tests/python/target/test_x86_features.py
index b7c2d21a22..ee10853f55 100644
--- a/tests/python/target/test_x86_features.py
+++ b/tests/python/target/test_x86_features.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# ruff: noqa: F401, F841
+# ruff: noqa: F841
import pytest
import tvm
@@ -40,113 +40,139 @@ def _feature_supported_by_llvm(x86_feature) -> bool:
return cap is None or LLVM_VERSION <= cap
-min_llvm_version, tvm_target, x86_feature, is_supported =
tvm.testing.parameters(
- # sse4.1
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "btver2"}, "sse4a",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "penryn"}, "sse4.1",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "silvermont"},
"sse4.2", True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "slm"}, "sse4.2",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "goldmont"},
"sse4.2", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "goldmont-plus"},
"sse4.2", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "tremont"}, "sse4.2",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "nehalem"}, "sse4.2",
True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "corei7"}, "sse4.2",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "westmere"},
"sse4.2", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "bdver1"}, "sse4.2",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "bdver2"}, "sse4.2",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "bdver3"}, "sse4.2",
True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "x86-64-v2"},
"sse4.2", True),
- # avx
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "sandybridge"},
"avx", True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "corei7-avx"}, "avx",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "ivybridge"}, "avx",
True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "core-avx-i"}, "avx",
True),
- # avx2
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "haswell"}, "avx2",
True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "core-avx2"}, "avx2",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "broadwell"}, "avx2",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "skylake"}, "avx2",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "bdver4"}, "avx2",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "znver1"}, "avx2",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "znver2"}, "avx2",
True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "znver3"}, "avx2",
True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "x86-64-v3"}, "avx2",
True),
- # avx512bw
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "skylake-avx512"},
"avx512bw", True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "skx"}, "avx512bw",
True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"}, "avx512bw",
False),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"}, "avx512f",
True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"}, ["avx512bw",
"avx512f"], False),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"}, ("avx512bw",
"avx512f"), False),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"}, "avx512cd",
True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"}, ["avx512cd",
"avx512f"], True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"}, ("avx512cd",
"avx512f"), True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"}, "avx512er",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"}, "avx512pf",
True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knm"}, "avx512bw",
False),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knm"}, "avx512f",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knm"}, "avx512cd",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knm"}, "avx512er",
True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knm"}, "avx512pf",
True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "x86-64-v4"},
"avx512bw", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cannonlake"},
"avx512bw", True),
- # explicit enumeration of VNNI capable due to collision with alderlake
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "alderlake"},
"avx512bw", False),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cascadelake"},
"avx512bw", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "icelake-client"},
"avx512bw", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "icelake-server"},
"avx512bw", True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "rocketlake"},
"avx512bw", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "tigerlake"},
"avx512bw", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cooperlake"},
"avx512bw", True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "sapphirerapids"},
"avx512bw", True),
- # avx512vnni
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "alderlake"},
"avx512vnni", False),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "alderlake"},
"avxvnni", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cascadelake"},
"avx512vnni", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "icelake-client"},
"avx512vnni", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "icelake-server"},
"avx512vnni", True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "rocketlake"},
"avx512vnni", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "tigerlake"},
"avx512vnni", True),
- (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cooperlake"},
"avx512vnni", True),
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "sapphirerapids"},
"avx512vnni", True),
- # amx-int8
- (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "sapphirerapids"},
"amx-int8", True),
- # generic CPU (no features) but with extra -mattr
- (
- -1,
- {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "x86-64", "mattr":
["+sse4.1", "+avx2"]},
- "avx2",
- True,
- ),
- (
- -1,
- {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "x86-64", "mattr":
["+sse4.1", "+avx2"]},
- "sse4.1",
- True,
- ),
- # enabling +sse4.1 implies ssse3 presence in LLVM
- (
- -1,
- {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "x86-64", "mattr":
["+sse4.1", "+avx2"]},
- "ssse3",
- True,
- ),
- (
- -1,
- {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "ivybridge", "mattr":
["-ssse3"]},
- "ssse3",
- False,
- ),
- # disabling avx512f (foundation) also disables avx512bw
- (
- -1,
- {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cascadelake",
"mattr": ["-avx512f"]},
- "avx512bw",
- False,
- ),
[email protected](
+ "min_llvm_version,tvm_target,x86_feature,is_supported",
+ [
+ # sse4.1
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "btver2"},
"sse4a", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "penryn"},
"sse4.1", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "silvermont"},
"sse4.2", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "slm"}, "sse4.2",
True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "goldmont"},
"sse4.2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "goldmont-plus"},
"sse4.2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "tremont"},
"sse4.2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "nehalem"},
"sse4.2", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "corei7"},
"sse4.2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "westmere"},
"sse4.2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "bdver1"},
"sse4.2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "bdver2"},
"sse4.2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "bdver3"},
"sse4.2", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "x86-64-v2"},
"sse4.2", True),
+ # avx
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "sandybridge"},
"avx", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "corei7-avx"},
"avx", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "ivybridge"},
"avx", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "core-avx-i"},
"avx", True),
+ # avx2
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "haswell"},
"avx2", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "core-avx2"},
"avx2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "broadwell"},
"avx2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "skylake"},
"avx2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "bdver4"},
"avx2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "znver1"},
"avx2", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "znver2"},
"avx2", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "znver3"},
"avx2", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "x86-64-v3"},
"avx2", True),
+ # avx512bw
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu":
"skylake-avx512"}, "avx512bw", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "skx"},
"avx512bw", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"},
"avx512bw", False),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"},
"avx512f", True),
+ (
+ 11,
+ {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"},
+ ["avx512bw", "avx512f"],
+ False,
+ ),
+ (
+ 11,
+ {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"},
+ ("avx512bw", "avx512f"),
+ False,
+ ),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"},
"avx512cd", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"},
["avx512cd", "avx512f"], True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"},
("avx512cd", "avx512f"), True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"},
"avx512er", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knl"},
"avx512pf", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knm"},
"avx512bw", False),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knm"},
"avx512f", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knm"},
"avx512cd", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knm"},
"avx512er", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "knm"},
"avx512pf", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "x86-64-v4"},
"avx512bw", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cannonlake"},
"avx512bw", True),
+ # explicit enumeration of VNNI capable due to collision with alderlake
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "alderlake"},
"avx512bw", False),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cascadelake"},
"avx512bw", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu":
"icelake-client"}, "avx512bw", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu":
"icelake-server"}, "avx512bw", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "rocketlake"},
"avx512bw", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "tigerlake"},
"avx512bw", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cooperlake"},
"avx512bw", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu":
"sapphirerapids"}, "avx512bw", True),
+ # avx512vnni
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "alderlake"},
"avx512vnni", False),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "alderlake"},
"avxvnni", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cascadelake"},
"avx512vnni", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu":
"icelake-client"}, "avx512vnni", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu":
"icelake-server"}, "avx512vnni", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "rocketlake"},
"avx512vnni", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "tigerlake"},
"avx512vnni", True),
+ (-1, {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cooperlake"},
"avx512vnni", True),
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu":
"sapphirerapids"}, "avx512vnni", True),
+ # amx-int8
+ (11, {"kind": "llvm", "mtriple": "x86_64--", "mcpu":
"sapphirerapids"}, "amx-int8", True),
+ # generic CPU (no features) but with extra -mattr
+ (
+ -1,
+ {
+ "kind": "llvm",
+ "mtriple": "x86_64--",
+ "mcpu": "x86-64",
+ "mattr": ["+sse4.1", "+avx2"],
+ },
+ "avx2",
+ True,
+ ),
+ (
+ -1,
+ {
+ "kind": "llvm",
+ "mtriple": "x86_64--",
+ "mcpu": "x86-64",
+ "mattr": ["+sse4.1", "+avx2"],
+ },
+ "sse4.1",
+ True,
+ ),
+ # enabling +sse4.1 implies ssse3 presence in LLVM
+ (
+ -1,
+ {
+ "kind": "llvm",
+ "mtriple": "x86_64--",
+ "mcpu": "x86-64",
+ "mattr": ["+sse4.1", "+avx2"],
+ },
+ "ssse3",
+ True,
+ ),
+ (
+ -1,
+ {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "ivybridge",
"mattr": ["-ssse3"]},
+ "ssse3",
+ False,
+ ),
+ # disabling avx512f (foundation) also disables avx512bw
+ (
+ -1,
+ {"kind": "llvm", "mtriple": "x86_64--", "mcpu": "cascadelake",
"mattr": ["-avx512f"]},
+ "avx512bw",
+ False,
+ ),
+ ],
)
-
-
def test_x86_target_features(min_llvm_version, tvm_target, x86_feature,
is_supported):
"""Test X86 features support for different targets.