This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 24d9afddf5 [Relax] Fix Relax Operator PReLU (#18179)
24d9afddf5 is described below
commit 24d9afddf57da3f1d4bb38702ee4e3f479c8c717
Author: ysh329 <[email protected]>
AuthorDate: Mon Aug 4 21:00:41 2025 +0800
[Relax] Fix Relax Operator PReLU (#18179)
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 3 +-
src/relax/op/nn/nn.cc | 45 +++++++++++-
tests/python/relax/test_frontend_onnx.py | 2 +-
.../python/relax/test_transform_legalize_ops_nn.py | 83 ++++++++++++++++++++++
tests/python/relax/test_tvmscript_parser_op_nn.py | 19 +++++
5 files changed, 147 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 926da7f022..103275375b 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1100,8 +1100,7 @@ class PRelu(OnnxOpConverter):
def _impl_v1(cls, bb, inputs, attr, params):
x = inputs[0]
slope = inputs[1]
- # TODO(tvm-team): Should add a new op for this.
- return x * slope + relax.op.nn.relu(x) * (relax.const(1.0) - slope)
+ return relax.op.nn.prelu(x, slope)
class ThresholdedRelu(OnnxOpConverter):
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 165e265d93..63d03553ed 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -120,13 +120,54 @@ TVM_FFI_STATIC_INIT_BLOCK({
refl::GlobalDef().def("relax.op.nn.prelu", prelu);
});
+StructInfo InferStructInfoPRelu(const Call& call, const BlockBuilder& ctx) {
+ TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ if (data_sinfo->IsUnknownNdim()) {
+ return data_sinfo;
+ }
+ if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
+ ctx->ReportFatal(Diagnostic::Error(call) << "Prelu requires the input
tensor to have float "
+ "dtype. However, the given
input dtype is "
+ << data_sinfo->dtype);
+ }
+ const auto* attrs = call->attrs.as<PReluAttrs>();
+ NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis);
+
+ return data_sinfo;
+}
+
+InferLayoutOutput InferLayoutPRelu(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ const auto* attrs = call->attrs.as<PReluAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+
+ // TODO(Siva): We could handle if the axis is not the sub indexed one.
+ if (layout->layout.ndim() != layout->layout.ndim_primal()) {
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for
now";
+ int ndim = tensor_sinfo->ndim;
+ layout = LayoutDecision(InitialLayout(ndim));
+ }
+
+ ObjectPtr<PReluAttrs> new_attrs = make_object<PReluAttrs>(*attrs);
+ new_attrs->axis = FindAxis(layout->layout, attrs->axis);
+
+ LayoutDecision alpha_layout = GetLayoutDecision(var_layout_map,
call->args[1]);
+ return InferLayoutOutput({layout, alpha_layout}, {layout}, Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.nn.prelu")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("alpha", "Tensor", "The channel-wise learnable slope.")
.set_attrs_type<PReluAttrs>()
- .set_attr<FInferStructInfo>("FInferStructInfo",
-
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPRelu)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPRelu)
.set_attr<Bool>("FPurity", Bool(true));
/* relax.nn.softmax */
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 7a0a7d7bc9..3d112c2f3b 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -948,7 +948,7 @@ def test_mish():
def test_prelu():
- verify_binary("PRelu", [3, 32, 32], [3, 32, 32], [3, 32, 32])
+ verify_binary("PRelu", [3, 32, 32], [1], [3, 32, 32])
def test_thresholded_relu():
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 575f4a0fb0..ff03ab4152 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -1159,6 +1159,89 @@ def test_leakyrelu_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_prelu():
+ # fmt: off
+ @tvm.script.ir_module
+ class PRelu:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1,), "float32"))
-> R.Tensor((2, 3), "float32"):
+ gv: R.Tensor((2, 3), "float32") = R.nn.prelu(x, y)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1,),
dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
+ gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((2, 3),
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def prelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), y:
T.Buffer((T.int64(1),), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)),
"float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ slope_broadcasted = T.alloc_buffer((T.int64(3),))
+ for c in range(T.int64(3)):
+ with T.block("slope_broadcasted"):
+ v_c = T.axis.spatial(T.int64(3), c)
+ T.reads(y[T.int64(0)])
+ T.writes(slope_broadcasted[v_c])
+ slope_broadcasted[v_c] = y[T.int64(0)]
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0,
v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1])
+ # fmt: on
+
+ mod = LegalizeOps()(PRelu)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_prelu_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class PRelu:
+ @R.function
+ def main(x: R.Tensor(("m", 7), "float32"), y: R.Tensor((1,),
"float32")) -> R.Tensor(("m", 7), "float32"):
+ m = T.int64()
+ gv: R.Tensor((m, 7), "float32") = R.nn.prelu(x, y)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor(("m", 7), dtype="float32"), y: R.Tensor((1,),
dtype="float32")) -> R.Tensor(("m", 7), dtype="float32"):
+ m = T.int64()
+ gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((m, 7),
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def prelu(var_x: T.handle, y: T.Buffer((T.int64(1),), "float32"),
var_compute: T.handle):
+ T.func_attr({"tir.noalias": True})
+ m = T.int64()
+ x = T.match_buffer(var_x, (m, T.int64(7)))
+ compute = T.match_buffer(var_compute, (m, T.int64(7)))
+ # with T.block("root"):
+ slope_broadcasted = T.alloc_buffer((T.int64(7),))
+ for c in range(T.int64(7)):
+ with T.block("slope_broadcasted"):
+ v_c = T.axis.spatial(T.int64(7), c)
+ T.reads(y[T.int64(0)])
+ T.writes(slope_broadcasted[v_c])
+ slope_broadcasted[v_c] = y[T.int64(0)]
+ for i0, i1 in T.grid(m, T.int64(7)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0,
v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1])
+ # fmt: on
+
+ mod = LegalizeOps()(PRelu)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_gelu():
# fmt: off
@tvm.script.ir_module
diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py
b/tests/python/relax/test_tvmscript_parser_op_nn.py
index bba08d4d84..4c458a7ead 100644
--- a/tests/python/relax/test_tvmscript_parser_op_nn.py
+++ b/tests/python/relax/test_tvmscript_parser_op_nn.py
@@ -364,5 +364,24 @@ def test_nll_loss_no_weights():
_check(foo, bb.get()["foo"])
+def test_prelu():
+ @R.function
+ def foo(
+ x: R.Tensor((2, 4, 4, 5), "float32"),
+ alpha: R.Tensor((1,), "float32"),
+ ) -> R.Tensor((2, 4, 4, 5), "float32"):
+ gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.prelu(x, alpha)
+ return gv
+
+ x = relax.Var("x", R.Tensor((2, 4, 4, 5), "float32"))
+ alpha = relax.Var("alpha", R.Tensor((1,), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x, alpha]):
+ gv = bb.emit(relax.op.nn.prelu(x, alpha))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
if __name__ == "__main__":
tvm.testing.main()