This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3f16ec25ac [Relax][PyTorch] Add support for where, cumprod and 
reciprocal ops (#17788)
3f16ec25ac is described below

commit 3f16ec25aca5f479df60258e790949bf43e783bc
Author: Shushi Hong <[email protected]>
AuthorDate: Tue Apr 1 02:08:24 2025 +0800

    [Relax][PyTorch] Add support for where, cumprod and reciprocal ops (#17788)
    
    * Update fx_translator.py
    
    * Update base_fx_graph_translator.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
---
 .../frontend/torch/base_fx_graph_translator.py     | 17 ++++++
 python/tvm/relax/frontend/torch/fx_translator.py   |  7 +++
 tests/python/relax/test_frontend_from_fx.py        | 65 ++++++++++++++++++++++
 3 files changed, 89 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 71554a8a5b..fe0ae412a2 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -949,6 +949,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return convert
 
+    def _where(self, node: fx.Node) -> relax.Var:
+        condition = self.env[node.args[0]]
+        x = self.env[node.args[1]]
+        y = self.env[node.args[2]]
+        return self.block_builder.emit(relax.op.where(condition, x, y))
+
     ########## Manipulation ##########
 
     def _cat(self, node: fx.Node) -> relax.Var:
@@ -967,6 +973,17 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             relax.op.split(x=x, indices_or_sections=n_sections, axis=dim)
         )
 
+    def _cumprod(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
None)
+        if "dtype" in node.kwargs:
+            dtype = self._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
+        else:
+            dtype = None
+
+        return self.block_builder.emit(relax.op.cumprod(x, dim, dtype))
+
     def _cumsum(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
 
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 022a7bffea..c4008a9396 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -62,6 +62,10 @@ class TorchFXImporter(BaseFXGraphImporter):
 
     ########## Unary Ops ##########
 
+    def _reciprocal(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        return self.block_builder.emit(relax.op.divide(relax.const(1.0, 
x.struct_info.dtype), x))
+
     def _leakyrelu_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -708,6 +712,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "logical_not": self._unary_op(relax.op.logical_not),
             "log_softmax": self._log_softmax,
             "neg": self._unary_op(relax.op.negative),
+            "reciprocal": self._reciprocal,
             "relu": self._unary_op(relax.op.nn.relu),
             "round": self._round,
             "rsqrt": self._unary_op(relax.op.rsqrt),
@@ -784,11 +789,13 @@ class TorchFXImporter(BaseFXGraphImporter):
             # search
             "argmax": self._argmax_argmin(relax.op.argmax),
             "argmin": self._argmax_argmin(relax.op.argmin),
+            "where": self._where,
             # tensor manipulation
             "cat": self._cat,
             "chunk": self._chunk,
             "concat": self._cat,
             "contiguous": lambda node: self.env[node.args[0]],
+            "cumprod": self._cumprod,
             "cumsum": self._cumsum,
             "expand": self._expand,
             "expand_as.default": self._expand_as,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 726ff6f8e8..b8d7f0b14e 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2339,6 +2339,27 @@ def test_extended_unary_ops():
     verify_model(LogSoftmax(), input_info, {}, expected_log_softmax)
     verify_model(LogSoftmax2(), input_info, {}, expected_log_softmax)
 
+    # reciprocal
+    class Reciprocal(Module):
+        def forward(self, input):
+            return torch.reciprocal(input)
+
+    @tvm.script.ir_module
+    class expected_reciprocal:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+                    R.const(1.0, "float32"), input_1
+                )
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Reciprocal(), input_info, {}, expected_reciprocal)
+
     # relu
     class ReLU0(Module):
         def __init__(self):
@@ -4315,5 +4336,49 @@ def test_prod():
     verify_model(Prod(), [([5, 3], "float32")], {}, Expected)
 
 
+def test_cumprod():
+    class Cumprod(Module):
+        def forward(self, x):
+            return torch.cumprod(x, 0)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((5, 3), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((5, 3), dtype="float32") = R.cumprod(inp_0, 
axis=0, exclusive=False)
+                gv: R.Tensor((5, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Cumprod(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_where():
+    class Where(Module):
+        def forward(self, condition, x, y):
+            return torch.where(condition, x, y)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="bool"),
+            inp_1: R.Tensor((5, 3), dtype="float32"),
+            inp_2: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((5, 3), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((5, 3), dtype="float32") = R.where(inp_0, inp_1, 
inp_2)
+                gv: R.Tensor((5, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(
+        Where(), [([5, 3], "bool"), ([5, 3], "float32"), ([5, 3], "float32")], 
{}, Expected
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to