This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 10ba3c232b [Frontend][ONNX] Support select_last_index for ArgMax and
ArgMin (#18969)
10ba3c232b is described below
commit 10ba3c232bfce87d7dc4f079c7717ac82ce7203a
Author: Kryptonite <[email protected]>
AuthorDate: Fri Apr 3 06:29:56 2026 +0300
[Frontend][ONNX] Support select_last_index for ArgMax and ArgMin (#18969)
### Summary
This PR implements the `select_last_index` attribute (introduced in
opset 12) for the `ArgMax` and `ArgMin` ONNX operators.
Previously, setting `select_last_index=1` raised
`OpAttributeUnImplemented`. This closes the limitation tracked in the
ONNX frontend issue.
### Implementation
When `select_last_index=1`, the input tensor is reversed along the
reduction axis using `relax.op.flip`, argmax/argmin is computed on the
flipped copy, and the result is remapped back to the original index
space via `last_idx = (axis_size - 1) - flipped_idx`
Closes part of #18945
---------
Signed-off-by: OmarAzizi <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 38 +++++++---
tests/python/relax/test_frontend_onnx.py | 98 +++++++++++++++++++++++++
2 files changed, 125 insertions(+), 11 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index fd883e3d4a..ab1ea2b292 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2779,9 +2779,7 @@ class Resize(OnnxOpConverter):
else:
roi_static = roi_np
else:
- roi_dynamic_vec = bb.normalize(
- _onnx_resize_spatial_roi_vector(roi, ndims)
- )
+ roi_dynamic_vec =
bb.normalize(_onnx_resize_spatial_roi_vector(roi, ndims))
else:
roi_static = [0.0] * (2 * (ndims - 2))
@@ -3757,6 +3755,30 @@ class ReduceL2(OnnxOpConverter):
return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data),
axes, keepdims))
+def _argreduce_select_last_index(bb, data, axis, keepdims, op):
+ """Helper for ArgMax/ArgMin with select_last_index=1.
+
+ Reverses the tensor along the reduction axis, runs the reduction op,
+ then remaps the index back: last_idx = (axis_size - 1) - flipped_idx.
+ Handles both static and dynamic axis sizes.
+ """
+ data_flipped = relax.op.flip(data, axis=axis)
+ flipped_idx = bb.normalize(op(data_flipped, axis, keepdims))
+ axis_size = data.struct_info.shape[axis]
+ if isinstance(axis_size, tirx.IntImm):
+ offset = relax.const(int(axis_size) - 1, "int64")
+ else:
+ # dynamic: get axis size at runtime and subtract 1
+ shape_tensor = bb.normalize(relax.op.shape_to_tensor(
+ bb.normalize(relax.op.shape_of(data))
+ ))
+ offset = bb.normalize(relax.op.subtract(
+ bb.normalize(relax.op.take(shape_tensor, relax.const(axis,
"int64"), axis=0)),
+ relax.const(1, "int64"),
+ ))
+ return relax.op.subtract(offset, flipped_idx)
+
+
class ArgMax(OnnxOpConverter):
"""Converts an onnx ArgMax node into an equivalent Relax expression."""
@@ -3788,10 +3810,7 @@ class ArgMax(OnnxOpConverter):
axis, keepdims = cls._check_attrs(data, attr)
select_last_index = attr.get("select_last_index", False)
if select_last_index:
- # TODO(vvchernov): support attr
- raise tvm.error.OpAttributeUnImplemented(
- "'select_last_index' attribute has not been supported yet"
- )
+ return _argreduce_select_last_index(bb, data, axis, keepdims,
relax.op.argmax)
return relax.op.argmax(data, axis, keepdims)
@@ -3826,10 +3845,7 @@ class ArgMin(OnnxOpConverter):
axis, keepdims = cls._check_attrs(data, attr)
select_last_index = attr.get("select_last_index", False)
if select_last_index:
- # TODO(vvchernov): support attr
- raise tvm.error.OpAttributeUnImplemented(
- "'select_last_index' attribute has not been supported yet"
- )
+ return _argreduce_select_last_index(bb, data, axis, keepdims,
relax.op.argmin)
return relax.op.argmin(data, axis, keepdims)
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index ab3a5c5148..7f9cd177ad 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -5360,5 +5360,103 @@ def test_max_roi_pool(pooled_shape, rois):
check_correctness(model, inputs=inputs, opset=16, rtol=1e-5, atol=1e-5)
[email protected]("op_name", ["ArgMax", "ArgMin"])
[email protected]("axis", [0, 1, 2])
[email protected]("keepdims", [True, False])
+def test_arg_min_max_select_last_index(op_name, axis, keepdims):
+ """select_last_index=1 must return the LAST occurrence of the extreme
value."""
+ shape = [3, 4, 5]
+
+ # Force a tie: place the extreme value at both index 0 and index
(axis_size-1)
+ # so that select_last_index=0 and =1 give observably different results.
+ data = np.random.uniform(-10, 10, shape).astype(np.float32)
+ slices_first = [slice(None)] * len(shape)
+ slices_last = [slice(None)] * len(shape)
+ slices_first[axis] = 0
+ slices_last[axis] = shape[axis] - 1
+ extreme = data.max() + 1.0 if op_name == "ArgMax" else data.min() - 1.0
+ data[tuple(slices_first)] = extreme
+ data[tuple(slices_last)] = extreme
+
+ node = helper.make_node(
+ op_name,
+ inputs=["data"],
+ outputs=["out"],
+ axis=axis,
+ keepdims=int(keepdims),
+ select_last_index=1,
+ )
+
+ out_shape = list(shape)
+ if keepdims:
+ out_shape[axis] = 1
+ else:
+ out_shape.pop(axis)
+
+ graph = helper.make_graph(
+ [node],
+ "arg_select_last_index_test",
+ inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT,
shape)],
+ outputs=[helper.make_tensor_value_info("out", TensorProto.INT64,
out_shape)],
+ )
+ model = helper.make_model(graph,
producer_name="arg_select_last_index_test")
+ check_correctness(model, inputs={"data": data}, opset=12)
+
+
[email protected]("op_name", ["ArgMax", "ArgMin"])
+def test_arg_min_max_select_last_index_no_tie(op_name):
+ """With all-unique values, select_last_index=1 must agree with
select_last_index=0."""
+ shape = [4, 5]
+ # arange guarantees uniqueness so first == last for every row
+ data = np.arange(20, dtype=np.float32).reshape(shape)
+
+ for select_last in [0, 1]:
+ node = helper.make_node(
+ op_name,
+ inputs=["data"],
+ outputs=["out"],
+ axis=1,
+ keepdims=1,
+ select_last_index=select_last,
+ )
+ graph = helper.make_graph(
+ [node],
+ "arg_no_tie_test",
+ inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT,
shape)],
+ outputs=[helper.make_tensor_value_info("out", TensorProto.INT64,
[4, 1])],
+ )
+ model = helper.make_model(graph, producer_name="arg_no_tie_test")
+ check_correctness(model, inputs={"data": data}, opset=12)
+
+
[email protected]("op_name", ["ArgMax", "ArgMin"])
+def test_arg_min_max_select_last_index_ir(op_name):
+ """select_last_index=1 must lower to flip + argmax/argmin + subtract in
the Relax IR."""
+ shape = [3, 4, 5]
+ relax_op = "relax.argmax" if op_name == "ArgMax" else "relax.argmin"
+
+ node = helper.make_node(
+ op_name,
+ inputs=["data"],
+ outputs=["out"],
+ axis=1,
+ keepdims=1,
+ select_last_index=1,
+ )
+ graph = helper.make_graph(
+ [node],
+ "arg_select_last_index_ir_test",
+ inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT,
shape)],
+ outputs=[helper.make_tensor_value_info("out", TensorProto.INT64, [3,
1, 5])],
+ )
+ model = helper.make_model(graph,
producer_name="arg_select_last_index_ir_test")
+ tvm_model = from_onnx(model, opset=12, keep_params_in_input=True)
+
+ call_ops = collect_relax_call_ops(tvm_model["main"])
+ assert relax_op in call_ops, f"Expected {relax_op} in IR, got {call_ops}"
+ assert "relax.flip" in call_ops, f"Expected relax.flip in IR, got
{call_ops}"
+ assert "relax.subtract" in call_ops, f"Expected relax.subtract in IR, got
{call_ops}"
+
+
if __name__ == "__main__":
tvm.testing.main()