This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8327a8cb9f [Relax][ONNX] Update ReduceL1 to opset 18 (#18072)
8327a8cb9f is described below
commit 8327a8cb9f2d0bf1088784fa8c58e9425bc762d7
Author: Youngsik Yang <[email protected]>
AuthorDate: Mon Jun 23 12:11:00 2025 +0900
[Relax][ONNX] Update ReduceL1 to opset 18 (#18072)
[Relax][ONNX] Update ReduceL1 to version 18
- Update ReduceL1-13 to ReduceL1-18
- Add the corresponding test cases
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 23 ++++
tests/python/relax/test_frontend_onnx.py | 153 +++++++++++++++++++++---
2 files changed, 162 insertions(+), 14 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 201e99e3d1..59354c7589 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2612,6 +2612,29 @@ class ReduceL1(OnnxOpConverter):
keepdims = attr.get("keepdims", 1)
return relax.op.sum(relax.op.abs(data), axes, keepdims)
+ @classmethod
+ def _impl_v18(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ keepdims = attr.get("keepdims", 1)
+ noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
+
+ # Optional axes input
+ axes = None
+ if len(inputs) > 1 and inputs[1] is not None:
+ axes_const = get_constant(inputs[1], params)
+ assert isinstance(axes_const, relax.Constant), "Only constant axes
currently supported"
+ axes = axes_const.data.numpy().tolist()
+
+ # If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
+ if not axes and not noop_with_empty_axes:
+ return relax.op.sum(relax.op.abs(data), None, keepdims)
+ # If axes is empty and noop_with_empty_axes is 1, return the input
data unchanged.
+ elif not axes and noop_with_empty_axes:
+ return data
+ # If axes is provided, reduce over specified axes
+ else:
+ return relax.op.sum(relax.op.abs(data), axes, keepdims)
+
class ReduceL2(OnnxOpConverter):
"""Converts an onnx ReduceL2 node into an equivalent Relax expression."""
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 6c3334f64d..2cfa156dda 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1503,24 +1503,24 @@ def test_embedlayernormalization():
)
-def create_reduce_test_parameters():
+def create_reduce_test_parameters_axes_attr():
output = []
for value in [True, False]:
- output.append(("ReduceMax", value))
- output.append(("ReduceMean", value))
- output.append(("ReduceMin", value))
- output.append(("ReduceProd", value))
- output.append(("ReduceSum", value))
- output.append(("ReduceSumSquare", value))
- output.append(("ReduceLogSum", value))
- output.append(("ReduceLogSumExp", value))
- output.append(("ReduceL1", value))
- output.append(("ReduceL2", value))
+ output.append(("ReduceMax", value, 11))
+ output.append(("ReduceMean", value, 13))
+ output.append(("ReduceMin", value, 11))
+ output.append(("ReduceProd", value, 13))
+ output.append(("ReduceSum", value, 11))
+ output.append(("ReduceSumSquare", value, 13))
+ output.append(("ReduceLogSum", value, 13))
+ output.append(("ReduceLogSumExp", value, 13))
+ output.append(("ReduceL1", value, 13))
+ output.append(("ReduceL2", value, 13))
return output
[email protected]("func, dynamic", create_reduce_test_parameters())
-def test_all_reduce_funcs(func, dynamic):
[email protected]("func, dynamic, opset",
create_reduce_test_parameters_axes_attr())
+def test_all_reduce_funcs_axes_attr(func, dynamic, opset):
def verify_reduce_func(func, data, axis, keepdims):
inshape = data.shape
outshape = np.sum(data, axis=axis, keepdims=keepdims == 1).shape
@@ -1549,7 +1549,7 @@ def test_all_reduce_funcs(func, dynamic):
inputs_dict = {"x": data}
# Reduction ops accumulate arithmetic errors, so we use a higher
tolerance.
- check_correctness(model, inputs_dict, opset=11, rtol=1e-4, atol=1e-4)
+ check_correctness(model, inputs_dict, opset=opset, rtol=1e-4,
atol=1e-4)
for keepdims in [True, False]:
verify_reduce_func(
@@ -1577,6 +1577,131 @@ def test_all_reduce_funcs(func, dynamic):
)
+def create_reduce_test_parameters_axes_input():
+ output = []
+ for dynamic in [True, False]:
+ # TODO(@vacu9708): Enable the tests after implementing other reduce ops
+ # output.append(("ReduceMax", dynamic, 20))
+ # output.append(("ReduceMean", dynamic, 18))
+ # output.append(("ReduceMin", dynamic, 20))
+ # output.append(("ReduceProd", dynamic, 18))
+ # output.append(("ReduceSum", dynamic, 13))
+ # output.append(("ReduceSumSquare", dynamic, 18))
+ # output.append(("ReduceLogSum", dynamic, 18))
+ # output.append(("ReduceLogSumExp", dynamic, 18))
+ output.append(("ReduceL1", dynamic, 18))
+ # output.append(("ReduceL2", dynamic, 18))
+ return output
+
+
[email protected]("func, dynamic, opset",
create_reduce_test_parameters_axes_input())
+def test_all_reduce_funcs_axes_input(func, dynamic, opset):
+ def verify_reduce_func(func, data, axes, keepdims, noop_with_empty_axes):
+ inshape = data.shape
+
+ inputs = ["x"]
+ initializers = []
+
+ # Optional `axes` input
+ if axes is not None:
+ axes_name = "reduce_axes"
+ axes_np = np.asarray(axes, dtype=np.int64)
+ axes_init = helper.make_tensor(
+ name=axes_name,
+ data_type=TensorProto.INT64,
+ dims=axes_np.shape,
+ vals=axes_np,
+ )
+ initializers.append(axes_init)
+ inputs.append(axes_name)
+
+ # Determine input and output shapes
+ if not axes and not noop_with_empty_axes:
+ outshape = np.sum(data, axis=None, keepdims=keepdims).shape
+ elif not axes and noop_with_empty_axes:
+ outshape = inshape
+ else:
+ outshape = np.sum(data, axis=axes, keepdims=keepdims).shape
+
+ if dynamic:
+ in_list = ["?"] * len(inshape)
+ out_list = ["?"] * len(outshape)
+ else:
+ in_list = list(inshape)
+ out_list = list(outshape)
+
+ # Make a model node
+ node = helper.make_node(
+ func,
+ inputs=inputs,
+ outputs=["y"],
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ # Make a model graph and a model
+ graph = helper.make_graph(
+ [node],
+ "reduce18_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
in_list)],
+ initializer=initializers,
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
out_list)],
+ )
+ model = helper.make_model(graph, producer_name="reduce18_test")
+
+ # Run TVM importer vs onnxruntime
+ inputs_dict = {"x": data}
+ check_correctness(model, inputs_dict, opset=opset, rtol=1e-4,
atol=1e-4)
+
+ # Verify
+ for keepdims in [True, False]:
+ # no `axes` input && `noop_with_empty_axes` = 0 -> reduce over all
dimensions.
+ verify_reduce_func(
+ func,
+ np.random.randn(3, 2, 2).astype(np.float32),
+ axes=[],
+ keepdims=keepdims,
+ noop_with_empty_axes=False,
+ )
+
+ # no `axes` input && `noop_with_empty_axes` = 0 -> reduce over all
dimensions.
+ verify_reduce_func(
+ func,
+ np.random.randn(3, 2, 2).astype(np.float32),
+ axes=None,
+ keepdims=keepdims,
+ noop_with_empty_axes=False,
+ )
+
+ # no `axes` input && `noop_with_empty_axes` = 1 -> return the input
unchanged.
+ verify_reduce_func(
+ func,
+ np.random.randn(4, 3).astype(np.float32),
+ axes=[],
+ keepdims=keepdims,
+ noop_with_empty_axes=True,
+ )
+
+ # no `axes` input && `noop_with_empty_axes` = 1 -> return the input
unchanged.
+ # (onnxruntime bug) Runtime error on the onnxruntime part
+ # verify_reduce_func(
+ # func,
+ # np.random.randn(4, 3).astype(np.float32),
+ # axes=None,
+ # keepdims=keepdims,
+ # noop_with_empty_axes=True,
+ # )
+
+ # `axes` provided -> reduce over specified axes.
+ verify_reduce_func(
+ func,
+ np.random.randn(3, 3, 3, 1).astype(np.float32),
+ axes=(1, 2),
+ keepdims=keepdims,
+ noop_with_empty_axes=True,
+ )
+
+
@pytest.mark.parametrize("in_dtype", [np.float32, np.int32])
@pytest.mark.parametrize("axis", [None, 0, 1, 2])
@pytest.mark.parametrize("keepdims", [None, True, False])