This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 378c4f3043 [BugFix][Relax]: handle ONNX ScatterElements reduction
(#19527)
378c4f3043 is described below
commit 378c4f3043a81f0a500c9d4685df0e82758df67d
Author: Sun <[email protected]>
AuthorDate: Tue May 12 12:17:28 2026 +0800
[BugFix][Relax]: handle ONNX ScatterElements reduction (#19527)
### Summary
- Respect the ONNX `reduction` attribute in the Relax ONNX frontend
`ScatterElements` converter.
- Preserve existing default behavior by mapping missing reduction and
ONNX `none` to Relax `update`.
- Add focused regression coverage for opset 11 default behavior, opset
16 `add`/`mul`, and opset 18 `none`/`min`/`max`.
### Changes
- Added a shared helper to normalize and validate ONNX reduction
attributes.
- Implemented `ScatterElements` opset 16 and opset 18 converters.
- Reused the existing `relax.op.scatter_elements(..., reduction=...)`
API.
- Reused the same reduction helper in `ScatterND` to keep behavior
consistent.
### Test Plan
- `python -m py_compile python/tvm/relax/frontend/onnx/onnx_frontend.py
tests/python/relax/test_frontend_onnx.py`
- `python -m pytest
tests/python/relax/test_frontend_onnx.py::test_gather_elements
tests/python/relax/test_frontend_onnx.py::test_scatter
tests/python/relax/test_frontend_onnx.py::test_scatter_elements_reduction
tests/python/relax/test_frontend_onnx.py::test_scatter_nd -q`
### Issue
Fixes #19435
## Local Verification Notes
- WSL conda environment: `/home/thinker/.cache/tvm-conda-onnx`
- TVM build directory: `/home/thinker/.cache/tvm-build-onnx`
- LLVM runtime check: `tvm.runtime.enabled("llvm") == True`
- Relevant ONNX frontend subset: `15 passed, 4 skipped, 2 warnings`
- Full `tests/python/relax/test_frontend_onnx.py` was also attempted. It
currently has 14 failures in unrelated `Reduce* axes input` and `TopK`
tests; running the same selected failures against `origin/main`
reproduces them, so they are not introduced by this PR.
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 40 +++++++---
tests/python/relax/test_frontend_onnx.py | 100 ++++++++++++++++++++++++
2 files changed, 131 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 7d85906cff..622e262cc4 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1159,6 +1159,20 @@ class Scatter(OnnxOpConverter):
raise ValueError("Scatter is deprecated in ONNX 11")
+def _get_onnx_reduction(attr, valid_reductions: list[str]):
+ reduction = attr.get("reduction", None)
+ reduction = reduction or b"update"
+ if isinstance(reduction, bytes):
+ reduction = reduction.decode("utf-8")
+ reduction = "update" if reduction == "none" else reduction
+ if reduction not in valid_reductions:
+ raise ValueError(
+ f"Only {valid_reductions} reductions are supported, but got
{reduction}"
+ )
+
+ return reduction
+
+
class ScatterElements(OnnxOpConverter):
"""Convert an onnx ScatterElements node into an equivalent Relax
expression."""
@@ -1167,21 +1181,29 @@ class ScatterElements(OnnxOpConverter):
axis = attr.get("axis", 0)
return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2],
axis=axis)
+ @classmethod
+ def _impl_v16(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 0)
+ reduction = _get_onnx_reduction(attr, ["update", "add", "mul"])
+ return relax.op.scatter_elements(
+ inputs[0], inputs[1], inputs[2], axis=axis, reduction=reduction
+ )
+
+ @classmethod
+ def _impl_v18(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 0)
+ reduction = _get_onnx_reduction(attr, ["update", "add", "mul", "min",
"max"])
+ return relax.op.scatter_elements(
+ inputs[0], inputs[1], inputs[2], axis=axis, reduction=reduction
+ )
+
class ScatterND(OnnxOpConverter):
"""Convert an onnx ScatterND node into an equivalent Relax expression."""
@staticmethod
def _reduction_check(attr, valid_reductions: list[str]):
- reduction = attr.get("reduction", None)
- reduction = reduction or b"update"
- reduction = reduction.decode("utf-8")
- reduction = "update" if reduction == "none" else reduction
- assert reduction in valid_reductions, (
- f"Only {valid_reductions} reductions are supported, but
{reduction} is gotten"
- )
-
- return reduction
+ return _get_onnx_reduction(attr, valid_reductions)
@classmethod
def _impl_v11(cls, bb, inputs, attr, params):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 52a4064cc8..94b85ab95a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1023,6 +1023,106 @@ def test_scatter(axis: int, name: str, opset: int):
check_correctness(model, inputs={"indices": indices}, opset=opset)
[email protected](
+ "reduction, opset, data, indices, updates",
+ [
+ (
+ None,
+ 11,
+ np.array([[1, 2, 3], [4, 5, 6]], dtype="float32"),
+ np.array([[2, 0, 1], [1, 2, 0]], dtype="int64"),
+ np.array([[30, 10, 20], [50, 60, 40]], dtype="float32"),
+ ),
+ (
+ "none",
+ 18,
+ np.array([[1, 2, 3], [4, 5, 6]], dtype="float32"),
+ np.array([[2, 0, 1], [1, 2, 0]], dtype="int64"),
+ np.array([[30, 10, 20], [50, 60, 40]], dtype="float32"),
+ ),
+ (
+ "add",
+ 16,
+ np.full((2, 3), 10, dtype="float32"),
+ np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
+ np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
+ ),
+ (
+ "mul",
+ 16,
+ np.full((2, 3), 10, dtype="float32"),
+ np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
+ np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
+ ),
+ (
+ "min",
+ 18,
+ np.full((2, 3), 10, dtype="float32"),
+ np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
+ np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
+ ),
+ (
+ "max",
+ 18,
+ np.full((2, 3), 10, dtype="float32"),
+ np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
+ np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
+ ),
+ ],
+)
+def test_scatter_elements_reduction(reduction, opset, data, indices, updates):
+ attrs = {"axis": 1}
+ if reduction is not None:
+ attrs["reduction"] = reduction
+ scatter_elements_node = helper.make_node(
+ "ScatterElements", ["data", "indices", "updates"], ["output"], **attrs
+ )
+
+ graph = helper.make_graph(
+ [scatter_elements_node],
+ "scatter_elements_reduction_test",
+ inputs=[
+ helper.make_tensor_value_info("data", TensorProto.FLOAT,
list(data.shape)),
+ helper.make_tensor_value_info("indices", TensorProto.INT64,
list(indices.shape)),
+ helper.make_tensor_value_info("updates", TensorProto.FLOAT,
list(updates.shape)),
+ ],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
list(data.shape))],
+ )
+ model = helper.make_model(graph,
producer_name="scatter_elements_reduction_test")
+
+ check_correctness(
+ model,
+ inputs={"data": data, "indices": indices, "updates": updates},
+ opset=opset,
+ )
+
+
+def test_scatter_elements_invalid_reduction():
+ data_shape = [2, 3]
+ scatter_elements_node = helper.make_node(
+ "ScatterElements",
+ ["data", "indices", "updates"],
+ ["output"],
+ axis=1,
+ reduction="unsupported",
+ )
+
+ graph = helper.make_graph(
+ [scatter_elements_node],
+ "scatter_elements_invalid_reduction_test",
+ inputs=[
+ helper.make_tensor_value_info("data", TensorProto.FLOAT,
data_shape),
+ helper.make_tensor_value_info("indices", TensorProto.INT64,
data_shape),
+ helper.make_tensor_value_info("updates", TensorProto.FLOAT,
data_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
data_shape)],
+ )
+ model = helper.make_model(graph,
producer_name="scatter_elements_invalid_reduction_test")
+
+ with pytest.raises(ValueError, match="Only .* reductions are supported,
but got unsupported"):
+ from_onnx(model, opset=18, keep_params_in_input=True)
+
+
@pytest.mark.parametrize("reduction", ["none", "add", "mul"])
def test_scatter_nd(reduction):
def verify_scatter_nd(data_shape, indices_shape, updates_shape):