This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 87b37b4d5b Fix onnx expand op (#17900)
87b37b4d5b is described below

commit 87b37b4d5bd19ee4553e384cffa3aa9e8a1d72bd
Author: Taylor <44216613+xinxi...@users.noreply.github.com>
AuthorDate: Mon Apr 28 00:48:55 2025 +0800

    Fix onnx expand op (#17900)
    
    * [ONNX] Fix Expand operator to properly handle target shapes
    
    This fixes issue #17746 where the ONNX Expand operator was not correctly
    expanding tensors to higher dimensions. The issue manifested when a
    downstream ArgMin operation received a tensor with fewer dimensions than
    expected, causing an 'axis out of bounds' error.
    
    Specifically:
    1. The Expand op was incorrectly skipping the broadcast when input and
       target shapes had the same values but different ranks
    2. This caused a tensor with shape [5,60] to remain [5,60] when it
       should have been expanded to [1,1,5,60]
    3. The subsequent ArgMin op with axis=2 then failed as the tensor only
       had 2 dimensions instead of the expected 4
    
    The fix ensures that Expand always broadcasts to the target shape,
    preserving the rank specified in the ONNX model. This allows downstream
    operations to work with the correct tensor dimensions.
    
    Fixes #17746
    
    * add expand test case
    
    * fix test case
    
    * reformat
    
    ---------
    
    Co-authored-by: Anurag Singh 
<10385586+singh20anu...@users.noreply.github.com>
    Co-authored-by: taylor <xinxi...@qq.com>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 22 ++++++++++++----------
 tests/python/relax/test_frontend_onnx.py        |  6 ++++++
 2 files changed, 18 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index dd4b8a4254..24217184b5 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1917,18 +1917,20 @@ class Expand(OnnxOpConverter):
         # If possible, directly expand to constant shape.
         if isinstance(shape, relax.Constant):
             new_shape = shape.data.numpy().tolist()
-            # For some reason, onnx allows target shapes to be smaller than 
input shapes.
-            # We need to go correct it.
+            # ONNX Expand operator requires preserving target rank and 
broadcasting
+            # according to standard rules. Dimensions are right-aligned.
             data_shape = [dim.value for dim in data.struct_info.shape]
-            # Dimensions are right alignment.
-            data_shape = [1] * (len(new_shape) - len(data_shape)) + data_shape
-            # Fix small target shapes.
-            for i, s in enumerate(new_shape):
-                if i < len(data_shape) and s < data_shape[i]:
+
+            # Right-align the shapes
+            if len(new_shape) > len(data_shape):
+                data_shape = [1] * (len(new_shape) - len(data_shape)) + 
data_shape
+            else:
+                new_shape = [1] * (len(data_shape) - len(new_shape)) + 
new_shape
+            # Fix small target shapes - if target dim is smaller than input dim
+            # use the input dim (ONNX-specific behavior).
+            for i in range(len(new_shape)):
+                if new_shape[i] < data_shape[i]:
                     new_shape[i] = data_shape[i]
-            # If the new shape matches the input shape, no transformation is 
needed.
-            if new_shape == data_shape:
-                return data
             return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape))
 
         # Otherwise handle dynamic shapes.
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 10c185ae09..ebc1454c23 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1692,6 +1692,12 @@ def test_expand(dynamic):
         data = np.random.uniform(size=in_shape).astype(np.float32)
         ref_data = np.tile(data, (1, 1, 4))
         _test_expand("expand_with_diff_dim", data, shape, ref_data)
+
+        in_shape = (3, 1)
+        shape = (1, 1, 3, 1)
+        data = np.random.uniform(size=in_shape).astype(np.float32)
+        ref_data = np.tile(data, (1, 1, 1, 1))
+        _test_expand("expand_with_the_same_suffix_dims", data, shape, ref_data)
     else:
         in_shape = (1, 32, 32)
         shape = ("batch", 32, 32)

Reply via email to