This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 701a7539aa [Relax][Onnx] Support Multi Input Ops with Multidirectional 
Broadcasting (#18673)
701a7539aa is described below

commit 701a7539aaf99e2855ba0274a9d45e4c7942d733
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Thu Jan 29 18:22:07 2026 +0700

    [Relax][Onnx] Support Multi Input Ops with Multidirectional Broadcasting 
(#18673)
    
    This PR support Multi Input Ops with Multidirectional Broadcasting
    ### Description
    - Support Multi Input Ops with Multidirectional Broadcasting (Min, Max,
    Mean, Sum)
    - Edit handle workflow for MultiInputBase:
      + Compute target shape for Multidirectional Broadcasting
      + Broadcast_to with target shape
      + Stack op
      + Reduce ops with axis same stack op
    
    ### Expected
    - Example target shape:
    <img width="700" height="183" alt="image"
    
src="https://github.com/user-attachments/assets/f9569dff-588e-49c5-ae72-c5b6ea22b6f3";
    />
    
    ### Reference
    - Multidirectional Broadcasting:
    https://onnx.ai/onnx/repo-docs/Broadcasting.html
    - Fixed: #18592
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 26 ++++++++++--
 tests/python/relax/test_frontend_onnx.py        | 55 +++++++++++++++++++------
 2 files changed, 66 insertions(+), 15 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index befe131a69..c71fd96caf 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -38,6 +38,7 @@ import math
 import operator
 import re
 import warnings
+import functools
 from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 import numpy as _np
@@ -1659,6 +1660,22 @@ class Sqrt(OnnxOpConverter):
         return relax.op.sqrt(inputs[0])
 
 
+def compute_broadcast_shape(shape_a, shape_b):
+    """Compute target shape for Multidirectional Broadcasting"""
+    rank = max(len(shape_a), len(shape_b))
+
+    a = (1,) * (rank - len(shape_a)) + tuple(shape_a)
+    b = (1,) * (rank - len(shape_b)) + tuple(shape_b)
+
+    target = []
+    for ai, bi in zip(a, b):
+        if ai == bi or ai == 1 or bi == 1:
+            target.append(max(ai, bi))
+        else:
+            raise ValueError(f"Cannot broadcast {ai} and {bi}")
+    return tuple(target)
+
+
 class MultiInputBase(OnnxOpConverter):
     """Converts an onnx MultiInputBase node into an equivalent Relax 
expression."""
 
@@ -1674,9 +1691,12 @@ class MultiInputBase(OnnxOpConverter):
             output = cls.numpy_op(*np_inputs)  # pylint: disable=not-callable
             return relax.const(output, output.dtype)
 
-        # Expand inputs, stack them, then perform minimum over the new axis.
-        inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in 
inputs]
-        stacked_tensor = relax.op.concat(inputs, axis=0)
+        input_shapes = [inp.struct_info.shape for inp in inputs]
+        target_shape = functools.reduce(compute_broadcast_shape, input_shapes)
+
+        # broadcast_to, stack them, then perform minimum over the new axis.
+        inputs = [bb.normalize(relax.op.broadcast_to(i, target_shape)) for i 
in inputs]
+        stacked_tensor = bb.normalize(relax.op.stack(inputs, axis=0))
         return cls.relax_op(stacked_tensor, axis=0)  # pylint: 
disable=not-callable
 
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 344bc26065..b4b3baeb4d 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -393,22 +393,53 @@ def test_mod(int_mode: bool):
     verify_binary_scalar("Mod", attrs={"fmod": fmod}, dtype=dtype)
 
 
[email protected]("num_inputs", [1, 2, 4])
+SHAPE_PARAMS = [
+    ([[32, 32], [32, 32]], [32, 32]),
+    ([[32, 1], [1, 2]], [32, 2]),
+    (
+        [
+            [
+                32,
+            ],
+            [
+                1,
+            ],
+        ],
+        [
+            32,
+        ],
+    ),
+    ([[32, 32, 1, 1], [1, 32, 32]], [32, 32, 32, 32]),
+    (
+        [
+            [32, 32, 1, 1],
+            [1, 32, 1],
+            [
+                32,
+            ],
+        ],
+        [32, 32, 32, 32],
+    ),
+]
+
+
[email protected]("input_shapes, expected_output_shape", SHAPE_PARAMS)
 @pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"])
-def test_multi_input(op_name: str, num_inputs: int):
-    input_shape = [32, 32]
-    input_var = ["i" + str(i) for i in range(num_inputs)]
-    input_values = [
-        helper.make_tensor_value_info(var, TensorProto.FLOAT, input_shape) for 
var in input_var
-    ]
-    test_node = helper.make_node(op_name, input_var, ["c"])
+def test_multi_input_broadcasting(op_name, input_shapes, 
expected_output_shape):
+    num_inputs = len(input_shapes)
+    input_names = [f"i{i}" for i in range(num_inputs)]
+
+    input_values_info = []
+    for name, shape in zip(input_names, input_shapes):
+        input_values_info.append(helper.make_tensor_value_info(name, 
TensorProto.FLOAT, shape))
+    test_node = helper.make_node(op_name, input_names, ["output"])
+    output_info = helper.make_tensor_value_info("output", TensorProto.FLOAT, 
expected_output_shape)
     graph = helper.make_graph(
         [test_node],
-        "multi_input_test",
-        inputs=input_values,
-        outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, 
input_shape)],
+        f"multi_input_{op_name}_test",
+        inputs=input_values_info,
+        outputs=[output_info],
     )
-
     model = helper.make_model(graph, producer_name="multi_input_test")
     check_correctness(model)
 

Reply via email to