This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 038b327cff [Relax][Onnx][PReLU] Handle slope and axis argument with
different slope shapes (#18658)
038b327cff is described below
commit 038b327cff150f16dd85de49aa2449210d9cffcb
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Mon Jan 12 19:55:54 2026 +0700
[Relax][Onnx][PReLU] Handle slope and axis argument with different slope
shapes (#18658)
This PR support handle slope and axis argument of PReLU op with
different slope shapes: (1xCx1x1) or (S,) or (1,1) etc.
### Description
- Handle slope and axis argument of PReLu op (to pass into
relax.op.nn.prelu function)
- If slope shape = (1xCx1x1), get axis = 1 and reshape slope to (C,)
- else if slope shape = (S,) or (1, 1), get axis = len(x_shape) - 1
(take the last axis of the input x)
(https://onnx.ai/onnx/repo-docs/Broadcasting.html)
- else raise error
### Resolved
- Fixed 1: #18596
- Fixed 2: #18598
- Fixed 3: #18606
- Fixed 4: #18607
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 26 ++++++++++++++++++++++++-
python/tvm/topi/nn/elemwise.py | 7 ++++---
tests/python/relax/test_frontend_onnx.py | 3 +++
3 files changed, 32 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 2212fa6c68..1479d6f239 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1127,7 +1127,31 @@ class PRelu(OnnxOpConverter):
def _impl_v1(cls, bb, inputs, attr, params):
x = inputs[0]
slope = inputs[1]
- return relax.op.nn.prelu(x, slope)
+
+ x_shape = x.struct_info.shape
+ slope_shape = slope.struct_info.shape
+
+ ndim = len(x_shape)
+ s_ndim = len(slope_shape)
+
+ if all(ss == 1 for ss in slope_shape) or s_ndim == 1:
+ slope = relax.op.reshape(slope, (slope_shape[0],))
+ return relax.op.nn.prelu(x, slope, ndim - 1)
+
+ if s_ndim == ndim:
+ non_one_axes = [i for i, ss in enumerate(slope_shape) if ss != 1]
+
+ # Must have only ONE non-broadcast axis
+ if len(non_one_axes) != 1:
+ raise ValueError(
+ f"Invalid PRelu slope shape (multiple non-broadcast dims):
{slope_shape}"
+ )
+ axis = non_one_axes[0]
+
+ slope = relax.op.reshape(slope, (slope_shape[axis],))
+ return relax.op.nn.prelu(x, slope, axis)
+
+ raise ValueError(f"Unsupported PRelu slope shape: {slope_shape}")
class ThresholdedRelu(OnnxOpConverter):
diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py
index 59cc3598e9..332636185c 100644
--- a/python/tvm/topi/nn/elemwise.py
+++ b/python/tvm/topi/nn/elemwise.py
@@ -129,9 +129,10 @@ def prelu(x, slope, axis=1):
assert len(slope.shape) == 1
assert axis < len(x.shape)
- slope = te.compute(
- (get_const_int(x.shape[axis]),), lambda c: slope[0],
name="slope_broadcasted"
- )
+ if slope.shape[0] == 1:
+ slope = te.compute(
+ (get_const_int(x.shape[axis]),), lambda c: slope[0],
name="slope_broadcasted"
+ )
assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis])
def _compute_channelwise(*indices):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index eb4c557e75..f967b3c4c6 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1068,6 +1068,9 @@ def test_mish():
def test_prelu():
verify_binary("PRelu", [3, 32, 32], [1], [3, 32, 32])
+ verify_binary("PRelu", [3, 32, 32], [1, 1], [3, 32, 32])
+ verify_binary("PRelu", [3, 32, 32], [32], [3, 32, 32])
+ verify_binary("PRelu", [3, 32, 32], [3, 1, 1], [3, 32, 32])
def test_thresholded_relu():