This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b3d3a7aef0 [Relax][PyTorch] Add copy_ op support in fxGraph (#17858)
b3d3a7aef0 is described below

commit b3d3a7aef0fe75292e889c1fa705e3e06709e1e5
Author: Krishna <[email protected]>
AuthorDate: Mon Apr 21 17:16:12 2025 +0530

    [Relax][PyTorch] Add copy_ op support in fxGraph (#17858)
---
 python/tvm/relax/frontend/torch/fx_translator.py |  6 ++++++
 tests/python/relax/test_frontend_from_fx.py      | 26 ++++++++++++++++++++++++
 2 files changed, 32 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 5a34befb92..ed42f995bb 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -476,6 +476,11 @@ class TorchFXImporter(BaseFXGraphImporter):
         self.env[node.args[0]] = filled
         return filled
 
+    def _inplace_copy(self, node: fx.Node) -> relax.Var:
+        src = self.env[node.args[1]]
+        self.env[node.args[0]] = src
+        return src
+
     def _masked_scatter(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         mask = self.env[node.args[1]]
@@ -782,6 +787,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "ones": self._ones,
             "one_hot": self._one_hot,
             "tensor": self._tensor,
+            "copy_": self._inplace_copy,
             # datatype
             "astype": self._type,
             "float": self._float,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index c522556380..2498fec35c 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4543,6 +4543,32 @@ def test_select():
     verify_model(Select(), [([5, 3], "float32")], {}, Expected)
 
 
+def test_inplace_copy():
+    class Inplace_Copy(Module):
+        def forward(self, x, y):
+            x.copy_(y)
+            return x
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"),
+            inp_1: R.Tensor((1, 2, 3, 4), dtype="float32"),
+        ) -> R.Tensor((1, 2, 3, 4), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((1, 2, 3, 4), dtype="float32") = inp_1
+                R.output(gv)
+            return gv
+
+    verify_model(
+        Inplace_Copy(),
+        [((1, 2, 3, 4), "float32"), ((1, 2, 3, 4), "float32")],
+        {},
+        Expected,
+    )
+
+
 def test_clone():
     class Clone(Module):
         def forward(self, x):

Reply via email to