This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bcbfc2933a [Relax][PyTorch] Support where, cumprod and reciprocal ops
for ExportedProgram importer (#17801)
bcbfc2933a is described below
commit bcbfc2933a66188dfab290dbdf90a3df158de022
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Apr 3 06:07:51 2025 +0800
[Relax][PyTorch] Support where, cumprod and reciprocal ops for
ExportedProgram importer (#17801)
* Update exported_program_translator.py
* Update test_frontend_from_exported_program.py
* Update test_frontend_from_exported_program.py
---
.../frontend/torch/exported_program_translator.py | 7 +++
.../relax/test_frontend_from_exported_program.py | 68 ++++++++++++++++++++++
2 files changed, 75 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 97ccc6393c..62e98b88ed 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -60,6 +60,10 @@ class ExportedProgramImporter(BaseFXGraphImporter):
one = relax.const(1, x.struct_info.dtype)
return self.block_builder.emit(relax.op.log(relax.op.add(x, one)))
+ def _reciprocal(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ return self.block_builder.emit(relax.op.divide(relax.const(1.0,
x.struct_info.dtype), x))
+
########## Neural Network ##########
def _batch_norm(self, node: fx.Node, training) -> relax.Var:
@@ -272,6 +276,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"log1p.default": self._log1p,
"log_softmax.int": self._log_softmax,
"neg.default": self._unary_op(relax.op.negative),
+ "reciprocal.default": self._reciprocal,
"relu.default": self._unary_op(relax.op.nn.relu),
"round.default": self._round,
"rsqrt.default": self._unary_op(relax.op.rsqrt),
@@ -361,6 +366,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# search
"argmax.default": self._argmax_argmin(relax.op.argmax),
"argmin.default": self._argmax_argmin(relax.op.argmin),
+ "where.self": self._where,
# tensor manipulation
"cat.default": self._cat,
"chunk.default": self._chunk,
@@ -368,6 +374,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"concat.default": self._cat,
"copy_.default": self._copy_,
"cumsum.default": self._cumsum,
+ "cumprod.default": self._cumprod,
"expand.default": self._expand,
"expand_as.default": self._expand_as,
"flip.default": self._flip,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 98f0f1d9ca..e37ee0e404 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -477,6 +477,27 @@ def test_extended_unary_ops():
# log_softmax
test_logsoftmax()
+ # reciprocal
+ class Reciprocal(Module):
+ def forward(self, input):
+ return torch.reciprocal(input)
+
+ @tvm.script.ir_module
+ class expected_reciprocal:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ R.const(1.0, "float32"), input_1
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Reciprocal(), example_args, {}, expected_reciprocal)
+
# relu
class ReLU0(Module):
def __init__(self):
@@ -3818,5 +3839,52 @@ def test_prod():
verify_model(Prod(), example_args, {}, Expected)
+def test_cumprod():
+ class Cumprod(Module):
+ def forward(self, x):
+ return torch.cumprod(x, 0)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((5, 3), dtype="float32") = R.cumprod(inp_0,
axis=0, exclusive=False)
+ gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_input = torch.randn(5, 3, dtype=torch.float32)
+ verify_model(Cumprod(), (example_input,), {}, Expected)
+
+
+def test_where():
+ class Where(Module):
+ def forward(self, condition, x, y):
+ return torch.where(condition, x, y)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="bool"),
+ inp_1: R.Tensor((5, 3), dtype="float32"),
+ inp_2: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((5, 3), dtype="float32") = R.where(inp_0, inp_1,
inp_2)
+ gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ condition = torch.randint(0, 2, (5, 3), dtype=torch.bool)
+ x = torch.randn(5, 3, dtype=torch.float32)
+ y = torch.randn(5, 3, dtype=torch.float32)
+
+ verify_model(Where(), (condition, x, y), {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()