This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8660e408cf [Relax][PyTorch] Add broadcast support for `copy` operation 
(#18493)
8660e408cf is described below

commit 8660e408cffd15a3a5230ec9fdacaa757a4c9d66
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Mon Nov 24 03:17:55 2025 +0900

    [Relax][PyTorch] Add broadcast support for `copy` operation (#18493)
    
    As per title.
    ref: [torch.Tensor.copy_ — PyTorch 2.9
    
documentation](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.copy_.html)
---
 .../frontend/torch/base_fx_graph_translator.py     |  18 ++-
 python/tvm/relax/frontend/torch/fx_translator.py   |  10 ++
 .../relax/test_frontend_from_exported_program.py   | 125 ++++++++++++++++-----
 tests/python/relax/test_frontend_from_fx.py        |  24 +++-
 4 files changed, 143 insertions(+), 34 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 4165086808..9c2e45c8fd 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2026,9 +2026,21 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         return self.env[node.args[0]]
 
     def _copy_(self, node: fx.Node) -> relax.Var:
-        # Copies the source tensor's into the destination tensor
-        # In TVM, that means simply returning the source tensor
-        return self.env[node.args[1]]
+        dest = self.env[node.args[0]]
+        src = self.env[node.args[1]]
+
+        # Match PyTorch semantics: cast to destination dtype and broadcast to 
destination shape.
+        if src.struct_info.dtype != dest.struct_info.dtype:
+            src = self.block_builder.emit(relax.op.astype(src, 
dest.struct_info.dtype))
+
+        dest_shape = self.shape_of(dest)
+        src_shape = self.shape_of(src)
+        if dest_shape != src_shape:
+            src = self.block_builder.emit(relax.op.broadcast_to(src, 
dest_shape))
+
+        # copy_ writes into the destination tensor, so update env accordingly
+        self.env[node.args[0]] = src
+        return src
 
     def _to_copy(self, node: fx.Node) -> relax.Var:
         # Returns a copy of the input tensor
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 6bf164430a..9c2d53a685 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -652,7 +652,17 @@ class TorchFXImporter(BaseFXGraphImporter):
     ########## Creation ##########
 
     def _inplace_copy(self, node: fx.Node) -> relax.Var:
+        dest = self.env[node.args[0]]
         src = self.env[node.args[1]]
+
+        if src.struct_info.dtype != dest.struct_info.dtype:
+            src = self.block_builder.emit(relax.op.astype(src, 
dest.struct_info.dtype))
+
+        dest_shape = self.shape_of(dest)
+        src_shape = self.shape_of(src)
+        if dest_shape != src_shape:
+            src = self.block_builder.emit(relax.op.broadcast_to(src, 
dest_shape))
+
         self.env[node.args[0]] = src
         return src
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 341bafc267..4c5d71216c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2916,6 +2916,7 @@ def test_pad():
                 lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.zeros(
                     R.shape([1, 3, 14, 12]), dtype="float32"
                 )
+
                 lv1: R.Tensor((1, 3, 14, 10), dtype="float32") = 
R.strided_slice(
                     lv,
                     (R.prim_value(3),),
@@ -2924,6 +2925,7 @@ def test_pad():
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
+
                 lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.strided_slice(
                     x,
                     (R.prim_value(3),),
@@ -2932,6 +2934,7 @@ def test_pad():
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
+
                 lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.strided_slice(
                     lv1,
                     (R.prim_value(2),),
@@ -2940,6 +2943,7 @@ def test_pad():
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
+
                 lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.strided_slice(
                     lv2,
                     (R.prim_value(2),),
@@ -2948,7 +2952,12 @@ def test_pad():
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
-                lv5: R.Tensor((1, 3, 14, 10), dtype="float32") = 
R.strided_slice(
+
+                lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.broadcast_to(
+                    lv4, R.shape([1, 3, 10, 10])
+                )
+
+                lv6: R.Tensor((1, 3, 14, 10), dtype="float32") = 
R.strided_slice(
                     lv,
                     (R.prim_value(3),),
                     (R.prim_value(1),),
@@ -2956,89 +2965,117 @@ def test_pad():
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
-                lv6: R.Tensor((1, 3, 14, 10), dtype="float32") = 
R.slice_scatter(
-                    lv5, lv4, R.prim_value(2), R.prim_value(12), 
R.prim_value(1), axis=2
+
+                lv7: R.Tensor((1, 3, 14, 10), dtype="float32") = 
R.slice_scatter(
+                    lv6, lv5, R.prim_value(2), R.prim_value(12), 
R.prim_value(1), axis=2
                 )
-                lv7: R.Tensor((1, 3, 14, 12), dtype="float32") = 
R.slice_scatter(
-                    lv, lv6, R.prim_value(1), R.prim_value(11), 
R.prim_value(1), axis=3
+
+                lv8: R.Tensor((1, 3, 14, 12), dtype="float32") = 
R.slice_scatter(
+                    lv, lv7, R.prim_value(1), R.prim_value(11), 
R.prim_value(1), axis=3
                 )
-                lv8: R.Tensor((1, 3, 14, 1), dtype="float32") = 
R.strided_slice(
-                    lv7,
+
+                lv9: R.Tensor((1, 3, 14, 1), dtype="float32") = 
R.strided_slice(
+                    lv8,
                     (R.prim_value(3),),
                     (R.prim_value(0),),
                     (R.prim_value(1),),
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
-                lv9: R.Tensor((1, 3, 14, 1), dtype="float32") = 
R.strided_slice(
-                    lv7,
+
+                lv10: R.Tensor((1, 3, 14, 1), dtype="float32") = 
R.strided_slice(
+                    lv8,
                     (R.prim_value(3),),
                     (R.prim_value(10),),
                     (R.prim_value(11),),
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
-                lv10: R.Tensor((1, 3, 14, 12), dtype="float32") = 
R.slice_scatter(
-                    lv7, lv9, R.prim_value(0), R.prim_value(1), 
R.prim_value(1), axis=3
+
+                lv11: R.Tensor((1, 3, 14, 1), dtype="float32") = 
R.broadcast_to(
+                    lv10, R.shape([1, 3, 14, 1])
+                )
+
+                lv12: R.Tensor((1, 3, 14, 12), dtype="float32") = 
R.slice_scatter(
+                    lv8, lv11, R.prim_value(0), R.prim_value(1), 
R.prim_value(1), axis=3
                 )
-                lv11: R.Tensor((1, 3, 14, 1), dtype="float32") = 
R.strided_slice(
-                    lv10,
+
+                lv13: R.Tensor((1, 3, 14, 1), dtype="float32") = 
R.strided_slice(
+                    lv12,
                     (R.prim_value(3),),
                     (R.prim_value(11),),
                     (R.prim_value(12),),
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
-                lv12: R.Tensor((1, 3, 14, 1), dtype="float32") = 
R.strided_slice(
-                    lv10,
+
+                lv14: R.Tensor((1, 3, 14, 1), dtype="float32") = 
R.strided_slice(
+                    lv12,
                     (R.prim_value(3),),
                     (R.prim_value(1),),
                     (R.prim_value(2),),
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
-                lv13: R.Tensor((1, 3, 14, 12), dtype="float32") = 
R.slice_scatter(
-                    lv10, lv12, R.prim_value(11), R.prim_value(12), 
R.prim_value(1), axis=3
+
+                lv15: R.Tensor((1, 3, 14, 1), dtype="float32") = 
R.broadcast_to(
+                    lv14, R.shape([1, 3, 14, 1])
+                )
+                lv16: R.Tensor((1, 3, 14, 12), dtype="float32") = 
R.slice_scatter(
+                    lv12, lv15, R.prim_value(11), R.prim_value(12), 
R.prim_value(1), axis=3
                 )
-                lv14: R.Tensor((1, 3, 2, 12), dtype="float32") = 
R.strided_slice(
-                    lv13,
+
+                lv17: R.Tensor((1, 3, 2, 12), dtype="float32") = 
R.strided_slice(
+                    lv16,
                     (R.prim_value(2),),
                     (R.prim_value(0),),
                     (R.prim_value(2),),
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
-                lv15: R.Tensor((1, 3, 2, 12), dtype="float32") = 
R.strided_slice(
-                    lv13,
+
+                lv18: R.Tensor((1, 3, 2, 12), dtype="float32") = 
R.strided_slice(
+                    lv16,
                     (R.prim_value(2),),
                     (R.prim_value(10),),
                     (R.prim_value(12),),
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
-                lv16: R.Tensor((1, 3, 14, 12), dtype="float32") = 
R.slice_scatter(
-                    lv13, lv15, R.prim_value(0), R.prim_value(2), 
R.prim_value(1), axis=2
+
+                lv19: R.Tensor((1, 3, 2, 12), dtype="float32") = 
R.broadcast_to(
+                    lv18, R.shape([1, 3, 2, 12])
                 )
-                lv17: R.Tensor((1, 3, 2, 12), dtype="float32") = 
R.strided_slice(
-                    lv16,
+
+                lv20: R.Tensor((1, 3, 14, 12), dtype="float32") = 
R.slice_scatter(
+                    lv16, lv19, R.prim_value(0), R.prim_value(2), 
R.prim_value(1), axis=2
+                )
+                lv21: R.Tensor((1, 3, 2, 12), dtype="float32") = 
R.strided_slice(
+                    lv20,
                     (R.prim_value(2),),
                     (R.prim_value(12),),
                     (R.prim_value(14),),
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
-                lv18: R.Tensor((1, 3, 2, 12), dtype="float32") = 
R.strided_slice(
-                    lv16,
+
+                lv22: R.Tensor((1, 3, 2, 12), dtype="float32") = 
R.strided_slice(
+                    lv20,
                     (R.prim_value(2),),
                     (R.prim_value(2),),
                     (R.prim_value(4),),
                     (R.prim_value(1),),
                     assume_inbound=False,
                 )
-                lv19: R.Tensor((1, 3, 14, 12), dtype="float32") = 
R.slice_scatter(
-                    lv16, lv18, R.prim_value(12), R.prim_value(14), 
R.prim_value(1), axis=2
+
+                lv23: R.Tensor((1, 3, 2, 12), dtype="float32") = 
R.broadcast_to(
+                    lv22, R.shape([1, 3, 2, 12])
+                )
+
+                lv24: R.Tensor((1, 3, 14, 12), dtype="float32") = 
R.slice_scatter(
+                    lv20, lv23, R.prim_value(12), R.prim_value(14), 
R.prim_value(1), axis=2
                 )
-                gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = 
(lv19,)
+                gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = 
(lv24,)
                 R.output(gv)
             return gv
 
@@ -5945,6 +5982,34 @@ def test_new_zeros():
     verify_model(NewZeros(), example_args, {}, expected1)
 
 
+def test_copy():
+    class CopyBroadcast(Module):
+        def forward(self, x, src):
+            x.copy_(src)
+            return x
+
+    @tvm.script.ir_module
+    class expected_copy:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3), dtype="float32"), src: R.Tensor((), 
dtype="int64")
+        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.astype(src, 
dtype="float32")
+                lv1: R.Tensor((2, 3), dtype="float32") = R.broadcast_to(lv, 
(2, 3))
+                gv: R.Tuple(
+                    R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), 
dtype="float32")
+                ) = (
+                    lv1,
+                    lv1,
+                )
+                R.output(gv)
+            return gv
+
+    example_args = (torch.zeros(2, 3, dtype=torch.float32), torch.tensor(1, 
dtype=torch.int64))
+    verify_model(CopyBroadcast(), example_args, {}, expected_copy)
+
+
 def test_to_copy():
     # float
     class ToFloat(Module):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 031a855fb9..7f0905088c 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5737,7 +5737,28 @@ def test_inplace_copy():
             inp_1: R.Tensor((1, 2, 3, 4), dtype="float32"),
         ) -> R.Tensor((1, 2, 3, 4), dtype="float32"):
             with R.dataflow():
-                gv: R.Tensor((1, 2, 3, 4), dtype="float32") = inp_1
+                lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.broadcast_to(
+                    inp_1, R.shape([1, 2, 3, 4])
+                )
+                gv: R.Tensor((1, 2, 3, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class CopyBroadcast(Module):
+        def forward(self, x, src):
+            x.copy_(src)
+            return x
+
+    @tvm.script.ir_module
+    class expected_copy:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3), dtype="float32"), src: R.Tensor((), 
dtype="int64")
+        ) -> R.Tensor((2, 3), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.astype(src, 
dtype="float32")
+                lv1: R.Tensor((2, 3), dtype="float32") = R.broadcast_to(lv, 
(2, 3))
+                gv: R.Tensor((2, 3), dtype="float32") = lv1
                 R.output(gv)
             return gv
 
@@ -5747,6 +5768,7 @@ def test_inplace_copy():
         {},
         Expected,
     )
+    verify_model(CopyBroadcast(), [((2, 3), "float32"), ((), "int64")], {}, 
expected_copy)
 
 
 def test_clone():

Reply via email to