This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 10ba3c232b [Frontend][ONNX] Support select_last_index for ArgMax and 
ArgMin (#18969)
10ba3c232b is described below

commit 10ba3c232bfce87d7dc4f079c7717ac82ce7203a
Author: Kryptonite <[email protected]>
AuthorDate: Fri Apr 3 06:29:56 2026 +0300

    [Frontend][ONNX] Support select_last_index for ArgMax and ArgMin (#18969)
    
    ### Summary
    
    This PR implements the `select_last_index` attribute (introduced in
    opset 12) for the `ArgMax` and `ArgMin` ONNX operators.
    
    Previously, setting `select_last_index=1` raised
    `OpAttributeUnImplemented`. This closes the limitation tracked in the
    ONNX frontend issue.
    
    ### Implementation
    
    When `select_last_index=1`, the input tensor is reversed along the
    reduction axis using `relax.op.flip`, argmax/argmin is computed on the
    flipped copy, and the result is remapped back to the original index
    space via `last_idx = (axis_size - 1) - flipped_idx`
    
    Closes part of #18945
    
    ---------
    
    Signed-off-by: OmarAzizi <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 38 +++++++---
 tests/python/relax/test_frontend_onnx.py        | 98 +++++++++++++++++++++++++
 2 files changed, 125 insertions(+), 11 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index fd883e3d4a..ab1ea2b292 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2779,9 +2779,7 @@ class Resize(OnnxOpConverter):
                 else:
                     roi_static = roi_np
             else:
-                roi_dynamic_vec = bb.normalize(
-                    _onnx_resize_spatial_roi_vector(roi, ndims)
-                )
+                roi_dynamic_vec = 
bb.normalize(_onnx_resize_spatial_roi_vector(roi, ndims))
         else:
             roi_static = [0.0] * (2 * (ndims - 2))
 
@@ -3757,6 +3755,30 @@ class ReduceL2(OnnxOpConverter):
             return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), 
axes, keepdims))
 
 
+def _argreduce_select_last_index(bb, data, axis, keepdims, op):
+    """Helper for ArgMax/ArgMin with select_last_index=1.
+
+    Reverses the tensor along the reduction axis, runs the reduction op,
+    then remaps the index back: last_idx = (axis_size - 1) - flipped_idx.
+    Handles both static and dynamic axis sizes.
+    """
+    data_flipped = relax.op.flip(data, axis=axis)
+    flipped_idx = bb.normalize(op(data_flipped, axis, keepdims))
+    axis_size = data.struct_info.shape[axis]
+    if isinstance(axis_size, tirx.IntImm):
+        offset = relax.const(int(axis_size) - 1, "int64")
+    else:
+        # dynamic: get axis size at runtime and subtract 1
+        shape_tensor = bb.normalize(relax.op.shape_to_tensor(
+            bb.normalize(relax.op.shape_of(data))
+        ))
+        offset = bb.normalize(relax.op.subtract(
+            bb.normalize(relax.op.take(shape_tensor, relax.const(axis, 
"int64"), axis=0)),
+            relax.const(1, "int64"),
+        ))
+    return relax.op.subtract(offset, flipped_idx)
+
+
 class ArgMax(OnnxOpConverter):
     """Converts an onnx ArgMax node into an equivalent Relax expression."""
 
@@ -3788,10 +3810,7 @@ class ArgMax(OnnxOpConverter):
         axis, keepdims = cls._check_attrs(data, attr)
         select_last_index = attr.get("select_last_index", False)
         if select_last_index:
-            # TODO(vvchernov): support attr
-            raise tvm.error.OpAttributeUnImplemented(
-                "'select_last_index' attribute has not been supported yet"
-            )
+            return _argreduce_select_last_index(bb, data, axis, keepdims, 
relax.op.argmax)
         return relax.op.argmax(data, axis, keepdims)
 
 
@@ -3826,10 +3845,7 @@ class ArgMin(OnnxOpConverter):
         axis, keepdims = cls._check_attrs(data, attr)
         select_last_index = attr.get("select_last_index", False)
         if select_last_index:
-            # TODO(vvchernov): support attr
-            raise tvm.error.OpAttributeUnImplemented(
-                "'select_last_index' attribute has not been supported yet"
-            )
+            return _argreduce_select_last_index(bb, data, axis, keepdims, 
relax.op.argmin)
         return relax.op.argmin(data, axis, keepdims)
 
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index ab3a5c5148..7f9cd177ad 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -5360,5 +5360,103 @@ def test_max_roi_pool(pooled_shape, rois):
     check_correctness(model, inputs=inputs, opset=16, rtol=1e-5, atol=1e-5)
 
 
[email protected]("op_name", ["ArgMax", "ArgMin"])
[email protected]("axis", [0, 1, 2])
[email protected]("keepdims", [True, False])
+def test_arg_min_max_select_last_index(op_name, axis, keepdims):
+    """select_last_index=1 must return the LAST occurrence of the extreme 
value."""
+    shape = [3, 4, 5]
+
+    # Force a tie: place the extreme value at both index 0 and index 
(axis_size-1)
+    # so that select_last_index=0 and =1 give observably different results.
+    data = np.random.uniform(-10, 10, shape).astype(np.float32)
+    slices_first = [slice(None)] * len(shape)
+    slices_last = [slice(None)] * len(shape)
+    slices_first[axis] = 0
+    slices_last[axis] = shape[axis] - 1
+    extreme = data.max() + 1.0 if op_name == "ArgMax" else data.min() - 1.0
+    data[tuple(slices_first)] = extreme
+    data[tuple(slices_last)] = extreme
+
+    node = helper.make_node(
+        op_name,
+        inputs=["data"],
+        outputs=["out"],
+        axis=axis,
+        keepdims=int(keepdims),
+        select_last_index=1,
+    )
+
+    out_shape = list(shape)
+    if keepdims:
+        out_shape[axis] = 1
+    else:
+        out_shape.pop(axis)
+
+    graph = helper.make_graph(
+        [node],
+        "arg_select_last_index_test",
+        inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, 
shape)],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.INT64, 
out_shape)],
+    )
+    model = helper.make_model(graph, 
producer_name="arg_select_last_index_test")
+    check_correctness(model, inputs={"data": data}, opset=12)
+
+
[email protected]("op_name", ["ArgMax", "ArgMin"])
+def test_arg_min_max_select_last_index_no_tie(op_name):
+    """With all-unique values, select_last_index=1 must agree with 
select_last_index=0."""
+    shape = [4, 5]
+    # arange guarantees uniqueness so first == last for every row
+    data = np.arange(20, dtype=np.float32).reshape(shape)
+
+    for select_last in [0, 1]:
+        node = helper.make_node(
+            op_name,
+            inputs=["data"],
+            outputs=["out"],
+            axis=1,
+            keepdims=1,
+            select_last_index=select_last,
+        )
+        graph = helper.make_graph(
+            [node],
+            "arg_no_tie_test",
+            inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, 
shape)],
+            outputs=[helper.make_tensor_value_info("out", TensorProto.INT64, 
[4, 1])],
+        )
+        model = helper.make_model(graph, producer_name="arg_no_tie_test")
+        check_correctness(model, inputs={"data": data}, opset=12)
+
+
[email protected]("op_name", ["ArgMax", "ArgMin"])
+def test_arg_min_max_select_last_index_ir(op_name):
+    """select_last_index=1 must lower to flip + argmax/argmin + subtract in 
the Relax IR."""
+    shape = [3, 4, 5]
+    relax_op = "relax.argmax" if op_name == "ArgMax" else "relax.argmin"
+
+    node = helper.make_node(
+        op_name,
+        inputs=["data"],
+        outputs=["out"],
+        axis=1,
+        keepdims=1,
+        select_last_index=1,
+    )
+    graph = helper.make_graph(
+        [node],
+        "arg_select_last_index_ir_test",
+        inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, 
shape)],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.INT64, [3, 
1, 5])],
+    )
+    model = helper.make_model(graph, 
producer_name="arg_select_last_index_ir_test")
+    tvm_model = from_onnx(model, opset=12, keep_params_in_input=True)
+
+    call_ops = collect_relax_call_ops(tvm_model["main"])
+    assert relax_op in call_ops, f"Expected {relax_op} in IR, got {call_ops}"
+    assert "relax.flip" in call_ops, f"Expected relax.flip in IR, got 
{call_ops}"
+    assert "relax.subtract" in call_ops, f"Expected relax.subtract in IR, got 
{call_ops}"
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to