This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bcbfc2933a [Relax][PyTorch] Support where, cumprod and reciprocal ops 
for ExportedProgram importer (#17801)
bcbfc2933a is described below

commit bcbfc2933a66188dfab290dbdf90a3df158de022
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Apr 3 06:07:51 2025 +0800

    [Relax][PyTorch] Support where, cumprod and reciprocal ops for 
ExportedProgram importer (#17801)
    
    * Update exported_program_translator.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_exported_program.py
---
 .../frontend/torch/exported_program_translator.py  |  7 +++
 .../relax/test_frontend_from_exported_program.py   | 68 ++++++++++++++++++++++
 2 files changed, 75 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 97ccc6393c..62e98b88ed 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -60,6 +60,10 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         one = relax.const(1, x.struct_info.dtype)
         return self.block_builder.emit(relax.op.log(relax.op.add(x, one)))
 
+    def _reciprocal(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        return self.block_builder.emit(relax.op.divide(relax.const(1.0, 
x.struct_info.dtype), x))
+
     ########## Neural Network ##########
 
     def _batch_norm(self, node: fx.Node, training) -> relax.Var:
@@ -272,6 +276,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "log1p.default": self._log1p,
             "log_softmax.int": self._log_softmax,
             "neg.default": self._unary_op(relax.op.negative),
+            "reciprocal.default": self._reciprocal,
             "relu.default": self._unary_op(relax.op.nn.relu),
             "round.default": self._round,
             "rsqrt.default": self._unary_op(relax.op.rsqrt),
@@ -361,6 +366,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             # search
             "argmax.default": self._argmax_argmin(relax.op.argmax),
             "argmin.default": self._argmax_argmin(relax.op.argmin),
+            "where.self": self._where,
             # tensor manipulation
             "cat.default": self._cat,
             "chunk.default": self._chunk,
@@ -368,6 +374,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "concat.default": self._cat,
             "copy_.default": self._copy_,
             "cumsum.default": self._cumsum,
+            "cumprod.default": self._cumprod,
             "expand.default": self._expand,
             "expand_as.default": self._expand_as,
             "flip.default": self._flip,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 98f0f1d9ca..e37ee0e404 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -477,6 +477,27 @@ def test_extended_unary_ops():
     # log_softmax
     test_logsoftmax()
 
+    # reciprocal
+    class Reciprocal(Module):
+        def forward(self, input):
+            return torch.reciprocal(input)
+
+    @tvm.script.ir_module
+    class expected_reciprocal:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+                    R.const(1.0, "float32"), input_1
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Reciprocal(), example_args, {}, expected_reciprocal)
+
     # relu
     class ReLU0(Module):
         def __init__(self):
@@ -3818,5 +3839,52 @@ def test_prod():
     verify_model(Prod(), example_args, {}, Expected)
 
 
+def test_cumprod():
+    class Cumprod(Module):
+        def forward(self, x):
+            return torch.cumprod(x, 0)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((5, 3), dtype="float32") = R.cumprod(inp_0, 
axis=0, exclusive=False)
+                gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_input = torch.randn(5, 3, dtype=torch.float32)
+    verify_model(Cumprod(), (example_input,), {}, Expected)
+
+
+def test_where():
+    class Where(Module):
+        def forward(self, condition, x, y):
+            return torch.where(condition, x, y)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="bool"),
+            inp_1: R.Tensor((5, 3), dtype="float32"),
+            inp_2: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((5, 3), dtype="float32") = R.where(inp_0, inp_1, 
inp_2)
+                gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    condition = torch.randint(0, 2, (5, 3), dtype=torch.bool)
+    x = torch.randn(5, 3, dtype=torch.float32)
+    y = torch.randn(5, 3, dtype=torch.float32)
+
+    verify_model(Where(), (condition, x, y), {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to