This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 24d9afddf5 [Relax] Fix Relax Operator PReLU (#18179)
24d9afddf5 is described below

commit 24d9afddf57da3f1d4bb38702ee4e3f479c8c717
Author: ysh329 <[email protected]>
AuthorDate: Mon Aug 4 21:00:41 2025 +0800

    [Relax] Fix Relax Operator PReLU (#18179)
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py    |  3 +-
 src/relax/op/nn/nn.cc                              | 45 +++++++++++-
 tests/python/relax/test_frontend_onnx.py           |  2 +-
 .../python/relax/test_transform_legalize_ops_nn.py | 83 ++++++++++++++++++++++
 tests/python/relax/test_tvmscript_parser_op_nn.py  | 19 +++++
 5 files changed, 147 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 926da7f022..103275375b 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1100,8 +1100,7 @@ class PRelu(OnnxOpConverter):
     def _impl_v1(cls, bb, inputs, attr, params):
         x = inputs[0]
         slope = inputs[1]
-        # TODO(tvm-team): Should add a new op for this.
-        return x * slope + relax.op.nn.relu(x) * (relax.const(1.0) - slope)
+        return relax.op.nn.prelu(x, slope)
 
 
 class ThresholdedRelu(OnnxOpConverter):
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 165e265d93..63d03553ed 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -120,13 +120,54 @@ TVM_FFI_STATIC_INIT_BLOCK({
   refl::GlobalDef().def("relax.op.nn.prelu", prelu);
 });
 
+StructInfo InferStructInfoPRelu(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  if (data_sinfo->IsUnknownNdim()) {
+    return data_sinfo;
+  }
+  if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "Prelu requires the input 
tensor to have float "
+                                                "dtype. However, the given 
input dtype is "
+                                             << data_sinfo->dtype);
+  }
+  const auto* attrs = call->attrs.as<PReluAttrs>();
+  NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis);
+
+  return data_sinfo;
+}
+
+InferLayoutOutput InferLayoutPRelu(const Call& call,
+                                   const Map<String, Array<String>>& 
desired_layouts,
+                                   const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  const auto* attrs = call->attrs.as<PReluAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+
+  // TODO(Siva): We could handle if the axis is not the sub indexed one.
+  if (layout->layout.ndim() != layout->layout.ndim_primal()) {
+    const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+    ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+    ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for 
now";
+    int ndim = tensor_sinfo->ndim;
+    layout = LayoutDecision(InitialLayout(ndim));
+  }
+
+  ObjectPtr<PReluAttrs> new_attrs = make_object<PReluAttrs>(*attrs);
+  new_attrs->axis = FindAxis(layout->layout, attrs->axis);
+
+  LayoutDecision alpha_layout = GetLayoutDecision(var_layout_map, 
call->args[1]);
+  return InferLayoutOutput({layout, alpha_layout}, {layout}, Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.nn.prelu")
     .set_num_inputs(2)
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("alpha", "Tensor", "The channel-wise learnable slope.")
     .set_attrs_type<PReluAttrs>()
-    .set_attr<FInferStructInfo>("FInferStructInfo",
-                                
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPRelu)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPRelu)
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.nn.softmax */
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 7a0a7d7bc9..3d112c2f3b 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -948,7 +948,7 @@ def test_mish():
 
 
 def test_prelu():
-    verify_binary("PRelu", [3, 32, 32], [3, 32, 32], [3, 32, 32])
+    verify_binary("PRelu", [3, 32, 32], [1], [3, 32, 32])
 
 
 def test_thresholded_relu():
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 575f4a0fb0..ff03ab4152 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -1159,6 +1159,89 @@ def test_leakyrelu_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_prelu():
+    # fmt: off
+    @tvm.script.ir_module
+    class PRelu:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1,), "float32")) 
-> R.Tensor((2, 3), "float32"):
+            gv: R.Tensor((2, 3), "float32") = R.nn.prelu(x, y)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1,), 
dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
+            gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((2, 3), 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def prelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), y: 
T.Buffer((T.int64(1),), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), 
"float32")):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            slope_broadcasted = T.alloc_buffer((T.int64(3),))
+            for c in range(T.int64(3)):
+                with T.block("slope_broadcasted"):
+                    v_c = T.axis.spatial(T.int64(3), c)
+                    T.reads(y[T.int64(0)])
+                    T.writes(slope_broadcasted[v_c])
+                    slope_broadcasted[v_c] = y[T.int64(0)]
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, 
v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1])
+    # fmt: on
+
+    mod = LegalizeOps()(PRelu)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_prelu_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class PRelu:
+        @R.function
+        def main(x: R.Tensor(("m", 7), "float32"), y: R.Tensor((1,), 
"float32")) -> R.Tensor(("m", 7), "float32"):
+            m = T.int64()
+            gv: R.Tensor((m, 7), "float32") = R.nn.prelu(x, y)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor(("m", 7), dtype="float32"), y: R.Tensor((1,), 
dtype="float32")) -> R.Tensor(("m", 7), dtype="float32"):
+            m = T.int64()
+            gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((m, 7), 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def prelu(var_x: T.handle, y: T.Buffer((T.int64(1),), "float32"), 
var_compute: T.handle):
+            T.func_attr({"tir.noalias": True})
+            m = T.int64()
+            x = T.match_buffer(var_x, (m, T.int64(7)))
+            compute = T.match_buffer(var_compute, (m, T.int64(7)))
+            # with T.block("root"):
+            slope_broadcasted = T.alloc_buffer((T.int64(7),))
+            for c in range(T.int64(7)):
+                with T.block("slope_broadcasted"):
+                    v_c = T.axis.spatial(T.int64(7), c)
+                    T.reads(y[T.int64(0)])
+                    T.writes(slope_broadcasted[v_c])
+                    slope_broadcasted[v_c] = y[T.int64(0)]
+            for i0, i1 in T.grid(m, T.int64(7)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, 
v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1])
+    # fmt: on
+
+    mod = LegalizeOps()(PRelu)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_gelu():
     # fmt: off
     @tvm.script.ir_module
diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py 
b/tests/python/relax/test_tvmscript_parser_op_nn.py
index bba08d4d84..4c458a7ead 100644
--- a/tests/python/relax/test_tvmscript_parser_op_nn.py
+++ b/tests/python/relax/test_tvmscript_parser_op_nn.py
@@ -364,5 +364,24 @@ def test_nll_loss_no_weights():
     _check(foo, bb.get()["foo"])
 
 
+def test_prelu():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 4, 4, 5), "float32"),
+        alpha: R.Tensor((1,), "float32"),
+    ) -> R.Tensor((2, 4, 4, 5), "float32"):
+        gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.prelu(x, alpha)
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 4, 4, 5), "float32"))
+    alpha = relax.Var("alpha", R.Tensor((1,), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, alpha]):
+        gv = bb.emit(relax.op.nn.prelu(x, alpha))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to