This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e370fc7374 [Relax][ONNX] Normalize negative indices before the take 
call for `Gather` operator (#19525)
e370fc7374 is described below

commit e370fc7374853a7dd28dab0bba4b1b2252292e29
Author: Neo Chien <[email protected]>
AuthorDate: Mon May 11 20:52:03 2026 +0800

    [Relax][ONNX] Normalize negative indices before the take call for `Gather` 
operator (#19525)
    
    Hi Committers,
    
    This PR is trying to fix issues
    https://github.com/apache/tvm/issues/19436. Any suggestions would be
    appreciated if you are available.
    
    ### Root Cause
    1. ONNX `Gather` allows negative indices (counting from the end of the
    target axis).
    2. In the Relax ONNX importer, `Gather` was lowered directly to
    `relax.op.take` without normalizing negative indices first.
    3. This created semantic mismatch / incorrect behavior in downstream
    lowering paths that assume non-negative indices.
    4. Test failures were also caused by pytest parametrization issues:
      - using ONNX `TensorProto` enum values directly as NumPy dtypes,
    - and tuple-style parametrization triggering fixture interpretation
    errors.
    
    ### Solutions
    1. Added conditional negative-index normalization in `Gather._impl_v13`:
      - apply only for signed index dtypes,
      - use: `idx < 0 ? idx + axis_extent : idx`,
    - derive `axis_extent` from shape/runtime expression to support dynamic
    shapes.
    2. Skipped normalization for unsigned index dtypes to avoid redundant
    graph ops/checks.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 19 ++++++++
 tests/python/relax/test_frontend_onnx.py        | 62 +++++++++++++++++++++++++
 2 files changed, 81 insertions(+)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 268d91b750..7d85906cff 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1106,6 +1106,25 @@ class Gather(OnnxOpConverter):
             shape_val = data[np_index]
             return relax.PrimValue(shape_val)
 
+        indices_dtype = indices.struct_info.dtype
+        if not indices_dtype.startswith("uint"):
+            data_shape = bb.normalize(relax.op.shape_of(data))
+            data_shape_tensor = 
bb.normalize(relax.op.shape_to_tensor(data_shape))
+            axis_extent = bb.normalize(
+                relax.op.take(data_shape_tensor, relax.const(axis, "int64"), 
axis=0, mode="wrap")
+            )
+
+            if indices_dtype !="int64":
+                axis_extent = bb.normalize(relax.op.astype(axis_extent, 
indices_dtype))
+
+            indices = bb.normalize(
+                relax.op.where(
+                    relax.op.less(indices, relax.const(0, indices_dtype)),
+                    relax.op.add(indices, axis_extent),
+                    indices,
+                )
+            )
+
         return relax.op.take(data, indices, axis)
 
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 5a8d84b090..52a4064cc8 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -874,6 +874,68 @@ def test_gather():
     _verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1)
 
 
[email protected](
+    "axis, indices, out_shape",
+    [
+        (0, [-1, 0], [2, 4]),
+        (1, [-1, 0], [3, 2]),
+        (
+            1, 
+            [[-1, 0], [1, -2]], 
+            [3, 2, 2],
+        ),
+    ],
+)
[email protected]("indices_type", [TensorProto.INT64, 
TensorProto.INT32])
+def test_gather_negative_indices(axis, indices, out_shape, indices_type):
+    gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], 
axis=axis)
+    indices_shape = np.asarray(indices).shape
+
+    graph = helper.make_graph(
+        [gather_node],
+        "gather_negative_indices_test",
+        inputs=[
+            helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4]),
+            helper.make_tensor_value_info("indices", indices_type, 
indices_shape),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, 
out_shape)],
+    )
+
+    model = helper.make_model(graph, 
producer_name="gather_negative_indices_test")
+    indices_np_dtype = {
+        TensorProto.INT64: np.int64,
+        TensorProto.INT32: np.int32,
+    }[indices_type]
+    input_values = {
+        "data": np.random.randn(3, 4).astype("float32"),
+        "indices": np.array(indices).astype(indices_np_dtype),
+    }
+    check_correctness(model, inputs=input_values)
+
+
[email protected]("indices_type", [TensorProto.INT64, 
TensorProto.INT32])
+def test_gather_negative_indices_ir_normalization(indices_type):
+    gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], 
axis=1)
+    graph = helper.make_graph(
+        [gather_node],
+        "gather_negative_indices_ir_test",
+        inputs=[
+            helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4]),
+            helper.make_tensor_value_info("indices", indices_type, [2]),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 
2])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="gather_negative_indices_ir_test")
+    tvm_model = from_onnx(model, opset=13, keep_params_in_input=True)
+    call_ops = collect_relax_call_ops(tvm_model["main"])
+
+    assert "relax.where" in call_ops
+    assert "relax.less" in call_ops
+    assert "relax.add" in call_ops
+    assert "relax.take" in call_ops
+
+
 @pytest.mark.parametrize(
     "data_shape, indices_shape, axis",
     [

Reply via email to