This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 038b327cff [Relax][Onnx][PReLU] Handle slope and axis argument with 
different slope shapes (#18658)
038b327cff is described below

commit 038b327cff150f16dd85de49aa2449210d9cffcb
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Mon Jan 12 19:55:54 2026 +0700

    [Relax][Onnx][PReLU] Handle slope and axis argument with different slope 
shapes (#18658)
    
    This PR support handle slope and axis argument of PReLU op with
    different slope shapes: (1xCx1x1) or (S,) or (1,1) etc.
    
    ### Description
    - Handle slope and axis argument of PReLu op (to pass into
    relax.op.nn.prelu function)
    - If slope shape = (1xCx1x1), get axis = 1 and reshape slope to (C,)
    - else if slope shape = (S,) or (1, 1), get axis = len(x_shape) - 1
    (take the last axis of the input x)
    (https://onnx.ai/onnx/repo-docs/Broadcasting.html)
     - else raise error
    
    ### Resolved
    - Fixed 1: #18596
    - Fixed 2: #18598
    - Fixed 3: #18606
    - Fixed 4: #18607
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 26 ++++++++++++++++++++++++-
 python/tvm/topi/nn/elemwise.py                  |  7 ++++---
 tests/python/relax/test_frontend_onnx.py        |  3 +++
 3 files changed, 32 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 2212fa6c68..1479d6f239 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1127,7 +1127,31 @@ class PRelu(OnnxOpConverter):
     def _impl_v1(cls, bb, inputs, attr, params):
         x = inputs[0]
         slope = inputs[1]
-        return relax.op.nn.prelu(x, slope)
+
+        x_shape = x.struct_info.shape
+        slope_shape = slope.struct_info.shape
+
+        ndim = len(x_shape)
+        s_ndim = len(slope_shape)
+
+        if all(ss == 1 for ss in slope_shape) or s_ndim == 1:
+            slope = relax.op.reshape(slope, (slope_shape[0],))
+            return relax.op.nn.prelu(x, slope, ndim - 1)
+
+        if s_ndim == ndim:
+            non_one_axes = [i for i, ss in enumerate(slope_shape) if ss != 1]
+
+            # Must have only ONE non-broadcast axis
+            if len(non_one_axes) != 1:
+                raise ValueError(
+                    f"Invalid PRelu slope shape (multiple non-broadcast dims): 
{slope_shape}"
+                )
+            axis = non_one_axes[0]
+
+            slope = relax.op.reshape(slope, (slope_shape[axis],))
+            return relax.op.nn.prelu(x, slope, axis)
+
+        raise ValueError(f"Unsupported PRelu slope shape: {slope_shape}")
 
 
 class ThresholdedRelu(OnnxOpConverter):
diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py
index 59cc3598e9..332636185c 100644
--- a/python/tvm/topi/nn/elemwise.py
+++ b/python/tvm/topi/nn/elemwise.py
@@ -129,9 +129,10 @@ def prelu(x, slope, axis=1):
 
     assert len(slope.shape) == 1
     assert axis < len(x.shape)
-    slope = te.compute(
-        (get_const_int(x.shape[axis]),), lambda c: slope[0], 
name="slope_broadcasted"
-    )
+    if slope.shape[0] == 1:
+        slope = te.compute(
+            (get_const_int(x.shape[axis]),), lambda c: slope[0], 
name="slope_broadcasted"
+        )
     assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis])
 
     def _compute_channelwise(*indices):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index eb4c557e75..f967b3c4c6 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1068,6 +1068,9 @@ def test_mish():
 
 def test_prelu():
     verify_binary("PRelu", [3, 32, 32], [1], [3, 32, 32])
+    verify_binary("PRelu", [3, 32, 32], [1, 1], [3, 32, 32])
+    verify_binary("PRelu", [3, 32, 32], [32], [3, 32, 32])
+    verify_binary("PRelu", [3, 32, 32], [3, 1, 1], [3, 32, 32])
 
 
 def test_thresholded_relu():

Reply via email to