This is an automated email from the ASF dual-hosted git repository. yongwww pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 87b37b4d5b Fix onnx expand op (#17900) 87b37b4d5b is described below commit 87b37b4d5bd19ee4553e384cffa3aa9e8a1d72bd Author: Taylor <44216613+xinxi...@users.noreply.github.com> AuthorDate: Mon Apr 28 00:48:55 2025 +0800 Fix onnx expand op (#17900) * [ONNX] Fix Expand operator to properly handle target shapes This fixes issue #17746 where the ONNX Expand operator was not correctly expanding tensors to higher dimensions. The issue manifested when a downstream ArgMin operation received a tensor with fewer dimensions than expected, causing an 'axis out of bounds' error. Specifically: 1. The Expand op was incorrectly skipping the broadcast when input and target shapes had the same values but different ranks 2. This caused a tensor with shape [5,60] to remain [5,60] when it should have been expanded to [1,1,5,60] 3. The subsequent ArgMin op with axis=2 then failed as the tensor only had 2 dimensions instead of the expected 4 The fix ensures that Expand always broadcasts to the target shape, preserving the rank specified in the ONNX model. This allows downstream operations to work with the correct tensor dimensions. Fixes #17746 * add expand test case * fix test case * reformat --------- Co-authored-by: Anurag Singh <10385586+singh20anu...@users.noreply.github.com> Co-authored-by: taylor <xinxi...@qq.com> --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 22 ++++++++++++---------- tests/python/relax/test_frontend_onnx.py | 6 ++++++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index dd4b8a4254..24217184b5 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1917,18 +1917,20 @@ class Expand(OnnxOpConverter): # If possible, directly expand to constant shape. if isinstance(shape, relax.Constant): new_shape = shape.data.numpy().tolist() - # For some reason, onnx allows target shapes to be smaller than input shapes. - # We need to go correct it. + # ONNX Expand operator requires preserving target rank and broadcasting + # according to standard rules. Dimensions are right-aligned. data_shape = [dim.value for dim in data.struct_info.shape] - # Dimensions are right alignment. - data_shape = [1] * (len(new_shape) - len(data_shape)) + data_shape - # Fix small target shapes. - for i, s in enumerate(new_shape): - if i < len(data_shape) and s < data_shape[i]: + + # Right-align the shapes + if len(new_shape) > len(data_shape): + data_shape = [1] * (len(new_shape) - len(data_shape)) + data_shape + else: + new_shape = [1] * (len(data_shape) - len(new_shape)) + new_shape + # Fix small target shapes - if target dim is smaller than input dim + # use the input dim (ONNX-specific behavior). + for i in range(len(new_shape)): + if new_shape[i] < data_shape[i]: new_shape[i] = data_shape[i] - # If the new shape matches the input shape, no transformation is needed. - if new_shape == data_shape: - return data return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape)) # Otherwise handle dynamic shapes. diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 10c185ae09..ebc1454c23 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1692,6 +1692,12 @@ def test_expand(dynamic): data = np.random.uniform(size=in_shape).astype(np.float32) ref_data = np.tile(data, (1, 1, 4)) _test_expand("expand_with_diff_dim", data, shape, ref_data) + + in_shape = (3, 1) + shape = (1, 1, 3, 1) + data = np.random.uniform(size=in_shape).astype(np.float32) + ref_data = np.tile(data, (1, 1, 1, 1)) + _test_expand("expand_with_the_same_suffix_dims", data, shape, ref_data) else: in_shape = (1, 32, 32) shape = ("batch", 32, 32)