This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8ad9bffb11 [Relax][PyTorch] Add support for lerp, select and clone ops 
(#17760)
8ad9bffb11 is described below

commit 8ad9bffb1124c1553c1bc9c3e4ae5d8a84826afa
Author: Shushi Hong <[email protected]>
AuthorDate: Tue Mar 18 05:13:05 2025 +0800

    [Relax][PyTorch] Add support for lerp, select and clone ops (#17760)
    
    This PR supports Pytorch `lerp`, `select` and `clone` ops for Relax
---
 python/tvm/relax/frontend/torch/fx_translator.py | 20 ++++++++
 tests/python/relax/test_frontend_from_fx.py      | 65 ++++++++++++++++++++++++
 2 files changed, 85 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 952fb6f971..9b835b0eee 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -376,6 +376,16 @@ class TorchFXImporter(BaseFXGraphImporter):
 
         return self._max_pool2d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
 
+    ########## Linear Interpolation ##########
+
+    def _lerp(self, node: fx.Node) -> relax.Var:
+        start = self.env[node.args[0]]
+        end = self.env[node.args[1]]
+        weight = self.env[node.args[2]]
+        return self.block_builder.emit(
+            relax.op.add(start, relax.op.multiply(weight, 
relax.op.subtract(end, start)))
+        )
+
     ########## Manipulation ##########
 
     def _chunk(self, node: fx.Node) -> relax.Var:
@@ -414,6 +424,12 @@ class TorchFXImporter(BaseFXGraphImporter):
         shape = self.shape_of(x)
         return relax.const(reduce(lambda x, y: x * y, [s.value for s in 
shape]), "int32")
 
+    def _select(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1]
+        index = relax.const(node.args[2], "int64")
+        return self.block_builder.emit(relax.op.take(x, index, dim))
+
     def _size(self, node: fx.Node) -> relax.Expr:
         x = self.env[node.args[0]]
         shape = self.shape_of(x)
@@ -737,6 +753,8 @@ class TorchFXImporter(BaseFXGraphImporter):
             "scaled_dot_product_attention": self._scaled_dot_product_attention,
             "stochastic_depth": lambda node: self.env[node.args[0]],
             "unbind": self._unbind,
+            # linear interpolation
+            "lerp": self._lerp,
             # statistical
             "mean": self._mean,
             "sum": self._sum,
@@ -759,6 +777,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "repeat": self._repeat,
             "reshape": self._reshape,
             "scatter": self._scatter,
+            "select": self._select,
             "size": self._size,
             "split": self._split,
             "squeeze": self._squeeze,
@@ -772,6 +791,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "view": self._reshape,
             # tensor creation
             "arange": self._arange,
+            "clone": lambda node: self.env[node.args[0]],
             "empty": self._empty,
             "empty_like": self._empty_like,
             "fill_": self._inplace_fill,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index fbea8b7388..f06ce7a753 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4125,5 +4125,70 @@ def test_numel():
     verify_model(Numel(), [([5, 3], "float32")], {}, Expected)
 
 
+def test_select():
+    class Select(Module):
+        def forward(self, data):
+            return torch.select(data, 0, 1)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((3,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((3,), dtype="float32") = R.take(inp_0, R.const(1, 
"int64"), axis=0)
+                gv: R.Tensor((3,), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Select(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_clone():
+    class Clone(Module):
+        def forward(self, x):
+            return x.clone()
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((5, 3), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((5, 3), dtype="float32") = inp_0
+                R.output(gv)
+            return gv
+
+    verify_model(Clone(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_lerp():
+    class Lerp(Module):
+        def forward(self, start, end, weight):
+            return torch.lerp(start, end, weight)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+            inp_1: R.Tensor((5, 3), dtype="float32"),
+            inp_2: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((5, 3), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((5, 3), dtype="float32") = R.add(
+                    inp_0, R.multiply(inp_2, R.subtract(inp_1, inp_0))
+                )
+                gv: R.Tensor((5, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(
+        Lerp(), [([5, 3], "float32"), ([5, 3], "float32"), ([5, 3], 
"float32")], {}, Expected
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to