This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d3bb716e47 [Relax][PyTorch] Add support for decomposed operators in
extended unary ops tests (#18400)
d3bb716e47 is described below
commit d3bb716e473354b68bd20fb1cb6d9bd63124b48d
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Oct 26 23:56:17 2025 -0400
[Relax][PyTorch] Add support for decomposed operators in extended unary ops
tests (#18400)
* finish1
* finish2
* finish3
* finish4
---
.../frontend/torch/exported_program_translator.py | 10 +
.../relax/test_frontend_from_exported_program.py | 223 +++++++++++++++------
2 files changed, 169 insertions(+), 64 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 67d93b0669..cbf9e33a12 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -809,9 +809,16 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"cosh.default": self._unary_op(relax.op.cosh),
"dropout.default": lambda node: self.env[node.args[0]],
"dropout_.default": lambda node: self.env[node.args[0]],
+ "native_dropout.default": lambda node: self.env[node.args[0]],
"elu.default": self._elu,
"erf.default": self._unary_op(relax.op.erf),
"exp.default": self._unary_op(relax.op.exp),
+ "expm1.default": lambda node: self.block_builder.emit(
+ relax.op.subtract(
+ relax.op.exp(self.env[node.args[0]]),
+ relax.const(1.0, self.env[node.args[0]].struct_info.dtype),
+ )
+ ),
"floor.default": self._unary_op(relax.op.floor),
"gelu.default": self._gelu,
"hardsigmoid.default": self._hardsigmoid,
@@ -869,6 +876,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"bitwise_or.Scalar": self._binary_op(relax.op.bitwise_or,
operator.or_),
"bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or,
operator.or_),
"bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or,
operator.or_),
+ "div.Scalar": self._binary_op(relax.op.divide, operator.truediv),
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
"div.Tensor_mode": self._div,
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
@@ -1019,7 +1027,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"detach_.default": self._detach,
"contiguous.default": lambda node: self.env[node.args[0]], # no-op
"clone.default": lambda node: self.env[node.args[0]],
+ "bernoulli.p": lambda node: self.env[node.args[0]], # Dropout:
just return input
"empty.memory_format": self._empty,
+ "empty_permuted.default": self._empty, # Similar to empty with
permuted layout
"empty_like.default": self._empty_like,
"eye.default": self._eye,
"eye.m": self._eye,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 657ade455b..3382141567 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -31,9 +31,11 @@ from tvm.script import tir as T
from tvm.relax.frontend.torch import from_exported_program
-def verify_model(torch_model, example_args, binding, expected,
dynamic_shapes=None):
+def verify_model(
+ torch_model, example_args, binding, expected, dynamic_shapes=None,
run_ep_decomposition=False
+):
exported_program = export(torch_model, args=example_args,
dynamic_shapes=dynamic_shapes)
- mod = from_exported_program(exported_program)
+ mod = from_exported_program(exported_program,
run_ep_decomposition=run_ep_decomposition)
binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()}
expected = relax.transform.BindParams("main", binding)(expected)
@@ -155,26 +157,19 @@ def test_extended_unary_ops():
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
- lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
lv, R.const(1.0, "float32")
)
- lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
- lv_div, R.const(1.0, "float32")
- )
- lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(
- R.const(0.0, "float32"), lv_sub
- )
- lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(
- R.const(1.0, "float32"), lv_min
+ lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
+ input_1, R.const(0.0, "float32")
)
- lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.relu(input_1)
- lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.add(lv_scaled, lv_relu_x)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) =
(lv_celu,)
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv2,
input_1, lv1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
R.output(gv)
return gv
- verify_model(Celu1(), example_args, {}, expected_celu)
- verify_model(Celu2(), example_args, {}, expected_celu)
+ verify_model(Celu1(), example_args, {}, expected_celu,
run_ep_decomposition=True)
+ verify_model(Celu2(), example_args, {}, expected_celu,
run_ep_decomposition=True)
# clamp
class Clamp(Module):
@@ -197,7 +192,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Clamp(), example_args, {}, expected_clamp)
+ verify_model(Clamp(), example_args, {}, expected_clamp,
run_ep_decomposition=True)
class ClampMinOnly(Module):
def forward(self, input):
@@ -217,7 +212,9 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(ClampMinOnly(), example_args, {}, expected_clamp_min_only)
+ verify_model(
+ ClampMinOnly(), example_args, {}, expected_clamp_min_only,
run_ep_decomposition=True
+ )
class ClampTensors(Module):
def forward(self, input):
@@ -245,7 +242,9 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(ClampTensors(), example_args, {}, expected_clamp_tensors)
+ verify_model(
+ ClampTensors(), example_args, {}, expected_clamp_tensors,
run_ep_decomposition=True
+ )
# dropout
@@ -266,20 +265,44 @@ def test_extended_unary_ops():
return torch.ops.aten.dropout_(input, 0.5, train=True)
@tvm.script.ir_module
- class expected_dropout:
+ class expected_dropout_for_1_2:
@R.function
def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
# block 0
with R.dataflow():
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) =
(input_1,)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) =
(input,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_dropout_for_3:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10,
10), dtype="float32")
+ ):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.zeros(
+ R.shape([1, 3, 10, 10]), dtype="float32"
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ lv, R.const(0.5, "float32")
+ )
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(input, lv1)
+ gv: R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) = (lv2, lv2)
R.output(gv)
return gv
- verify_model(Dropout1(), example_args, {}, expected_dropout)
- verify_model(Dropout2(), example_args, {}, expected_dropout)
- verify_model(Dropout3(), example_args, {}, expected_dropout)
+ verify_model(Dropout1(), example_args, {}, expected_dropout_for_1_2,
run_ep_decomposition=True)
+ verify_model(Dropout2(), example_args, {}, expected_dropout_for_1_2,
run_ep_decomposition=True)
+ verify_model(Dropout3(), example_args, {}, expected_dropout_for_3,
run_ep_decomposition=True)
# elu
class Elu(Module):
@@ -298,28 +321,32 @@ def test_extended_unary_ops():
class expected_elu:
@R.function
def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
with R.dataflow():
- lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.exp(input_1)
- lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.subtract(
- R.const(1.0, dtype="float32"), lv_exp
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
+ input, R.const(0.0, "float32")
)
- lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10),
dtype="float32") = R.nn.relu(
- lv_one_minus_exp
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
+ input, R.const(1.0, "float32")
)
- lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(
- R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
+ input, R.const(1.0, "float32")
)
- lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.relu(input_1)
- lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.add(lv_scaled, lv_relu_x)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) =
(lv_elu,)
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv2)
+ lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
+ lv3, R.const(1.0, "float32")
+ )
+ lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
+ lv4, R.const(1.0, "float32")
+ )
+ lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv,
lv1, lv5)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,)
R.output(gv)
return gv
- verify_model(Elu(), example_args, {}, expected_elu)
- verify_model(Elu2(), example_args, {}, expected_elu)
+ verify_model(Elu(), example_args, {}, expected_elu,
run_ep_decomposition=True)
+ verify_model(Elu2(), example_args, {}, expected_elu,
run_ep_decomposition=True)
# hardsigmoid
class Hardsigmoid(torch.nn.Module):
@@ -341,17 +368,24 @@ def test_extended_unary_ops():
inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0,
R.const(3, "float32"))
- lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0,
6)
- lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
- lv1, R.const(6, "float32")
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
+ inp_0, R.const(3.0, "float32")
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ lv, R.prim_value(0), R.prim_value(T.float64("inf"))
+ )
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ lv1, R.prim_value(T.float64("-inf")), R.prim_value(6)
+ )
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ lv2, R.const(6.0, "float32")
)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
R.output(gv)
return gv
- verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid)
- verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid)
+ verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid,
run_ep_decomposition=True)
+ verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid,
run_ep_decomposition=True)
# hardwish
class Hardswish(torch.nn.Module):
@@ -371,25 +405,67 @@ def test_extended_unary_ops():
return torch.ops.aten.hardswish_(input)
@tvm.script.ir_module
- class expected1:
+ class expected_hardswish_for_1_2:
@R.function
def main(
inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0,
R.const(3, "float32"))
- lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0,
6)
- lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
- lv1, R.const(6, "float32")
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
+ inp_0, R.const(3.0, "float32")
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ lv, R.prim_value(0), R.prim_value(T.float64("inf"))
+ )
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ lv1, R.prim_value(T.float64("-inf")), R.prim_value(6)
)
lv3: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(inp_0, lv2)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
+ lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ lv3, R.const(6.0, "float32")
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv4,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_hardswish_for_3:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10,
10), dtype="float32")
+ ):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
+ input, R.const(3.0, "float32")
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ lv, R.prim_value(0), R.prim_value(T.float64("inf"))
+ )
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ lv1, R.prim_value(T.float64("-inf")), R.prim_value(6)
+ )
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(input, lv2)
+ lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ lv3, R.const(6.0, "float32")
+ )
+ gv: R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) = (lv4, lv4)
R.output(gv)
return gv
- verify_model(Hardswish(), example_args, {}, expected1)
- verify_model(Hardswish2(), example_args, {}, expected1)
- verify_model(Hardswish3(), example_args, {}, expected1)
+ verify_model(
+ Hardswish(), example_args, {}, expected_hardswish_for_1_2,
run_ep_decomposition=True
+ )
+ verify_model(
+ Hardswish2(), example_args, {}, expected_hardswish_for_1_2,
run_ep_decomposition=True
+ )
+ verify_model(
+ Hardswish3(), example_args, {}, expected_hardswish_for_3,
run_ep_decomposition=True
+ )
# log2
class Log2(Module):
@@ -411,7 +487,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Log2(), example_args, {}, Expected_log2)
+ verify_model(Log2(), example_args, {}, Expected_log2,
run_ep_decomposition=True)
# log10
class Log10(Module):
@@ -433,7 +509,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Log10(), example_args, {}, Expected_log10)
+ verify_model(Log10(), example_args, {}, Expected_log10,
run_ep_decomposition=True)
# log1p
class Log1p(Module):
@@ -454,7 +530,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Log1p(), example_args, {}, Expected_log1p)
+ verify_model(Log1p(), example_args, {}, Expected_log1p,
run_ep_decomposition=True)
# reciprocal
class Reciprocal(Module):
@@ -475,7 +551,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Reciprocal(), example_args, {}, expected_reciprocal)
+ verify_model(Reciprocal(), example_args, {}, expected_reciprocal,
run_ep_decomposition=True)
# Returns the maximum value of all elements in the input tensor.
class MaxModel(Module):
@@ -494,7 +570,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(MaxModel(), example_args, {}, expected_max)
+ verify_model(MaxModel(), example_args, {}, expected_max,
run_ep_decomposition=True)
# Returns the minimum value of all elements in the input tensor.
class MinModel(Module):
@@ -513,7 +589,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(MinModel(), example_args, {}, expected_min)
+ verify_model(MinModel(), example_args, {}, expected_min,
run_ep_decomposition=True)
# relu6
class ReLU6_1(torch.nn.Module):
@@ -558,9 +634,28 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(ReLU6_1(), example_args, {}, expected_relu6_1)
- verify_model(ReLU6_2(), example_args, {}, expected_relu6_2)
- verify_model(ReLU6_3(), example_args, {}, expected_relu6_2)
+ @tvm.script.ir_module
+ class expected_relu6_3:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10,
10), dtype="float32")
+ ):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ x, R.prim_value(0), R.prim_value(6)
+ )
+ gv: R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) = (lv, lv)
+ R.output(gv)
+ return gv
+
+ verify_model(ReLU6_1(), example_args, {}, expected_relu6_1,
run_ep_decomposition=True)
+ verify_model(ReLU6_2(), example_args, {}, expected_relu6_2,
run_ep_decomposition=True)
+ verify_model(ReLU6_3(), example_args, {}, expected_relu6_3,
run_ep_decomposition=True)
def test_hardtanh():