This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8b012ed369 [Relax] Legalize nn.dropout as inference no-op (#19841)
8b012ed369 is described below
commit 8b012ed3696bbb47a922c8653ca01eaee3bd5c87
Author: Guan-Ming Chiu <[email protected]>
AuthorDate: Tue Jun 23 07:17:21 2026 +0800
[Relax] Legalize nn.dropout as inference no-op (#19841)
## Related Issue
Closed #19695
## Why
Building any module containing relax.nn.dropout crashed in VM codegen
because the op had no real legalization, and the ONNX frontend could not
import it
## How
- Legalize nn.dropout to pass the input through with an all-ones mask,
matching its (output, mask) tuple result.
- Add and register a Dropout converter in the ONNX frontend.
- Add legalize structural and ONNX onnxruntime-parity tests.
Signed-off-by: Guan-Ming (Wesley) Chiu
<[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 20 ++++++++++++
python/tvm/relax/transform/legalize_ops/nn.py | 8 +++--
tests/python/relax/test_frontend_onnx.py | 26 +++++++++++++++
.../python/relax/test_transform_legalize_ops_nn.py | 38 ++++++++++++++++++++++
4 files changed, 90 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 61f95e7130..8562cb60a2 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3070,6 +3070,25 @@ class Identity(OnnxOpConverter):
return inputs[0]
+class Dropout(OnnxOpConverter):
+ """Converts an onnx Dropout node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ ratio = float(attr.get("ratio", 0.5))
+ return relax.op.nn.dropout(inputs[0], ratio)
+
+ @classmethod
+ def _impl_v12(cls, bb, inputs, attr, params):
+ # Since opset 12 ratio is the optional second input rather than an
attribute.
+ ratio = 0.5
+ if len(inputs) >= 2 and inputs[1] is not None:
+ const = get_constant(inputs[1], params)
+ if isinstance(const, relax.Constant):
+ ratio = float(const.data.numpy())
+ return relax.op.nn.dropout(inputs[0], ratio)
+
+
def _onnx_resize_spatial_roi_vector(roi_full: relax.Expr, rank: int) ->
relax.Expr:
"""Map ONNX ROI [starts..., ends...] to TOPI spatial ROI (drop N/C
axes)."""
return relax.op.concat(
@@ -5284,6 +5303,7 @@ def _get_convert_map():
"ConvTranspose": ConvTranspose,
"Flatten": Flatten,
"Identity": Identity,
+ "Dropout": Dropout,
"Resize": Resize,
"Einsum": Einsum,
"Range": Range,
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index f87c16aa0a..d68426f02a 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -697,8 +697,12 @@ def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr:
@register_legalize("relax.nn.dropout")
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
- logging.info("Dropout is handled by frontend translator at this moment and
is not legalized.")
- return call
+ # Dropout is a no-op at inference: pass the input through and return an
all-ones mask.
+ return bb.call_te(
+ lambda x: [topi.identity(x), topi.full_like(x, 1.0)],
+ call.args[0],
+ primfunc_name_hint="dropout",
+ )
def _te_attention(
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 82e3f997d2..721e26e792 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4254,6 +4254,32 @@ def test_maxunpool(kernel_shape, pads, strides):
check_correctness(model, inputs={"I": indices})
+def test_dropout():
+ verify_unary("Dropout", [1, 3, 32, 32])
+ verify_unary("Dropout", [1, 3, 32, 32], opset=11, attrs={"ratio": 0.5})
+
+ # Opset 12+ passes ratio as an optional input; check it is captured into
the relax op.
+ node = helper.make_node("Dropout", ["x", "ratio"], ["y"])
+ graph = helper.make_graph(
+ [node],
+ "dropout_ratio_input",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 3,
4, 4])],
+ initializer=[helper.make_tensor("ratio", TensorProto.FLOAT, [],
[0.3])],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 3,
4, 4])],
+ )
+ model = helper.make_model(graph, producer_name="dropout_ratio_input")
+ model.opset_import[0].version = 13
+ mod = from_onnx(model, opset=13)
+ rates = [
+ float(b.value.attrs.rate)
+ for f in mod.functions.values()
+ for block in getattr(f.body, "blocks", [])
+ for b in block.bindings
+ if getattr(getattr(b.value, "op", None), "name", "") ==
"relax.nn.dropout"
+ ]
+ assert rates == pytest.approx([0.3])
+
+
def test_flatten():
verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 0})
verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": -1})
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 81648e91b4..601985f7be 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -4116,5 +4116,43 @@ def test_batch_flatten_undefined_shape():
tvm.ir.assert_structural_equal(mod, BatchFlattenUndefinedShape)
+def test_dropout():
+ # fmt: off
+ @tvm.script.ir_module
+ class Dropout:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tuple(R.Tensor((2, 3),
"float32"), R.Tensor((2, 3), "float32")):
+ gv: R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3),
"float32")) = R.nn.dropout(x, rate=0.5)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func(private=True, s_tir=True)
+ def dropout(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute:
T.Buffer((T.int64(2), T.int64(3)), "float32"), T_full_like:
T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tirx.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.sblock("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(x[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = x[v_i0, v_i1]
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.sblock("T_full_like"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads()
+ T.writes(T_full_like[v_ax0, v_ax1])
+ T_full_like[v_ax0, v_ax1] = T.float32(1.0)
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2,
3), dtype="float32"), R.Tensor((2, 3), dtype="float32")):
+ cls = Expected
+ gv = R.call_tir(cls.dropout, (x,), out_sinfo=[R.Tensor((2, 3),
dtype="float32"), R.Tensor((2, 3), dtype="float32")])
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Dropout)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()