This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8660e408cf [Relax][PyTorch] Add broadcast support for `copy` operation
(#18493)
8660e408cf is described below
commit 8660e408cffd15a3a5230ec9fdacaa757a4c9d66
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Mon Nov 24 03:17:55 2025 +0900
[Relax][PyTorch] Add broadcast support for `copy` operation (#18493)
As per title.
ref: [torch.Tensor.copy_ — PyTorch 2.9
documentation](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.copy_.html)
---
.../frontend/torch/base_fx_graph_translator.py | 18 ++-
python/tvm/relax/frontend/torch/fx_translator.py | 10 ++
.../relax/test_frontend_from_exported_program.py | 125 ++++++++++++++++-----
tests/python/relax/test_frontend_from_fx.py | 24 +++-
4 files changed, 143 insertions(+), 34 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 4165086808..9c2e45c8fd 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2026,9 +2026,21 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.env[node.args[0]]
def _copy_(self, node: fx.Node) -> relax.Var:
- # Copies the source tensor's into the destination tensor
- # In TVM, that means simply returning the source tensor
- return self.env[node.args[1]]
+ dest = self.env[node.args[0]]
+ src = self.env[node.args[1]]
+
+ # Match PyTorch semantics: cast to destination dtype and broadcast to
destination shape.
+ if src.struct_info.dtype != dest.struct_info.dtype:
+ src = self.block_builder.emit(relax.op.astype(src,
dest.struct_info.dtype))
+
+ dest_shape = self.shape_of(dest)
+ src_shape = self.shape_of(src)
+ if dest_shape != src_shape:
+ src = self.block_builder.emit(relax.op.broadcast_to(src,
dest_shape))
+
+ # copy_ writes into the destination tensor, so update env accordingly
+ self.env[node.args[0]] = src
+ return src
def _to_copy(self, node: fx.Node) -> relax.Var:
# Returns a copy of the input tensor
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 6bf164430a..9c2d53a685 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -652,7 +652,17 @@ class TorchFXImporter(BaseFXGraphImporter):
########## Creation ##########
def _inplace_copy(self, node: fx.Node) -> relax.Var:
+ dest = self.env[node.args[0]]
src = self.env[node.args[1]]
+
+ if src.struct_info.dtype != dest.struct_info.dtype:
+ src = self.block_builder.emit(relax.op.astype(src,
dest.struct_info.dtype))
+
+ dest_shape = self.shape_of(dest)
+ src_shape = self.shape_of(src)
+ if dest_shape != src_shape:
+ src = self.block_builder.emit(relax.op.broadcast_to(src,
dest_shape))
+
self.env[node.args[0]] = src
return src
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 341bafc267..4c5d71216c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2916,6 +2916,7 @@ def test_pad():
lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.zeros(
R.shape([1, 3, 14, 12]), dtype="float32"
)
+
lv1: R.Tensor((1, 3, 14, 10), dtype="float32") =
R.strided_slice(
lv,
(R.prim_value(3),),
@@ -2924,6 +2925,7 @@ def test_pad():
(R.prim_value(1),),
assume_inbound=False,
)
+
lv2: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.strided_slice(
x,
(R.prim_value(3),),
@@ -2932,6 +2934,7 @@ def test_pad():
(R.prim_value(1),),
assume_inbound=False,
)
+
lv3: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.strided_slice(
lv1,
(R.prim_value(2),),
@@ -2940,6 +2943,7 @@ def test_pad():
(R.prim_value(1),),
assume_inbound=False,
)
+
lv4: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.strided_slice(
lv2,
(R.prim_value(2),),
@@ -2948,7 +2952,12 @@ def test_pad():
(R.prim_value(1),),
assume_inbound=False,
)
- lv5: R.Tensor((1, 3, 14, 10), dtype="float32") =
R.strided_slice(
+
+ lv5: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.broadcast_to(
+ lv4, R.shape([1, 3, 10, 10])
+ )
+
+ lv6: R.Tensor((1, 3, 14, 10), dtype="float32") =
R.strided_slice(
lv,
(R.prim_value(3),),
(R.prim_value(1),),
@@ -2956,89 +2965,117 @@ def test_pad():
(R.prim_value(1),),
assume_inbound=False,
)
- lv6: R.Tensor((1, 3, 14, 10), dtype="float32") =
R.slice_scatter(
- lv5, lv4, R.prim_value(2), R.prim_value(12),
R.prim_value(1), axis=2
+
+ lv7: R.Tensor((1, 3, 14, 10), dtype="float32") =
R.slice_scatter(
+ lv6, lv5, R.prim_value(2), R.prim_value(12),
R.prim_value(1), axis=2
)
- lv7: R.Tensor((1, 3, 14, 12), dtype="float32") =
R.slice_scatter(
- lv, lv6, R.prim_value(1), R.prim_value(11),
R.prim_value(1), axis=3
+
+ lv8: R.Tensor((1, 3, 14, 12), dtype="float32") =
R.slice_scatter(
+ lv, lv7, R.prim_value(1), R.prim_value(11),
R.prim_value(1), axis=3
)
- lv8: R.Tensor((1, 3, 14, 1), dtype="float32") =
R.strided_slice(
- lv7,
+
+ lv9: R.Tensor((1, 3, 14, 1), dtype="float32") =
R.strided_slice(
+ lv8,
(R.prim_value(3),),
(R.prim_value(0),),
(R.prim_value(1),),
(R.prim_value(1),),
assume_inbound=False,
)
- lv9: R.Tensor((1, 3, 14, 1), dtype="float32") =
R.strided_slice(
- lv7,
+
+ lv10: R.Tensor((1, 3, 14, 1), dtype="float32") =
R.strided_slice(
+ lv8,
(R.prim_value(3),),
(R.prim_value(10),),
(R.prim_value(11),),
(R.prim_value(1),),
assume_inbound=False,
)
- lv10: R.Tensor((1, 3, 14, 12), dtype="float32") =
R.slice_scatter(
- lv7, lv9, R.prim_value(0), R.prim_value(1),
R.prim_value(1), axis=3
+
+ lv11: R.Tensor((1, 3, 14, 1), dtype="float32") =
R.broadcast_to(
+ lv10, R.shape([1, 3, 14, 1])
+ )
+
+ lv12: R.Tensor((1, 3, 14, 12), dtype="float32") =
R.slice_scatter(
+ lv8, lv11, R.prim_value(0), R.prim_value(1),
R.prim_value(1), axis=3
)
- lv11: R.Tensor((1, 3, 14, 1), dtype="float32") =
R.strided_slice(
- lv10,
+
+ lv13: R.Tensor((1, 3, 14, 1), dtype="float32") =
R.strided_slice(
+ lv12,
(R.prim_value(3),),
(R.prim_value(11),),
(R.prim_value(12),),
(R.prim_value(1),),
assume_inbound=False,
)
- lv12: R.Tensor((1, 3, 14, 1), dtype="float32") =
R.strided_slice(
- lv10,
+
+ lv14: R.Tensor((1, 3, 14, 1), dtype="float32") =
R.strided_slice(
+ lv12,
(R.prim_value(3),),
(R.prim_value(1),),
(R.prim_value(2),),
(R.prim_value(1),),
assume_inbound=False,
)
- lv13: R.Tensor((1, 3, 14, 12), dtype="float32") =
R.slice_scatter(
- lv10, lv12, R.prim_value(11), R.prim_value(12),
R.prim_value(1), axis=3
+
+ lv15: R.Tensor((1, 3, 14, 1), dtype="float32") =
R.broadcast_to(
+ lv14, R.shape([1, 3, 14, 1])
+ )
+ lv16: R.Tensor((1, 3, 14, 12), dtype="float32") =
R.slice_scatter(
+ lv12, lv15, R.prim_value(11), R.prim_value(12),
R.prim_value(1), axis=3
)
- lv14: R.Tensor((1, 3, 2, 12), dtype="float32") =
R.strided_slice(
- lv13,
+
+ lv17: R.Tensor((1, 3, 2, 12), dtype="float32") =
R.strided_slice(
+ lv16,
(R.prim_value(2),),
(R.prim_value(0),),
(R.prim_value(2),),
(R.prim_value(1),),
assume_inbound=False,
)
- lv15: R.Tensor((1, 3, 2, 12), dtype="float32") =
R.strided_slice(
- lv13,
+
+ lv18: R.Tensor((1, 3, 2, 12), dtype="float32") =
R.strided_slice(
+ lv16,
(R.prim_value(2),),
(R.prim_value(10),),
(R.prim_value(12),),
(R.prim_value(1),),
assume_inbound=False,
)
- lv16: R.Tensor((1, 3, 14, 12), dtype="float32") =
R.slice_scatter(
- lv13, lv15, R.prim_value(0), R.prim_value(2),
R.prim_value(1), axis=2
+
+ lv19: R.Tensor((1, 3, 2, 12), dtype="float32") =
R.broadcast_to(
+ lv18, R.shape([1, 3, 2, 12])
)
- lv17: R.Tensor((1, 3, 2, 12), dtype="float32") =
R.strided_slice(
- lv16,
+
+ lv20: R.Tensor((1, 3, 14, 12), dtype="float32") =
R.slice_scatter(
+ lv16, lv19, R.prim_value(0), R.prim_value(2),
R.prim_value(1), axis=2
+ )
+ lv21: R.Tensor((1, 3, 2, 12), dtype="float32") =
R.strided_slice(
+ lv20,
(R.prim_value(2),),
(R.prim_value(12),),
(R.prim_value(14),),
(R.prim_value(1),),
assume_inbound=False,
)
- lv18: R.Tensor((1, 3, 2, 12), dtype="float32") =
R.strided_slice(
- lv16,
+
+ lv22: R.Tensor((1, 3, 2, 12), dtype="float32") =
R.strided_slice(
+ lv20,
(R.prim_value(2),),
(R.prim_value(2),),
(R.prim_value(4),),
(R.prim_value(1),),
assume_inbound=False,
)
- lv19: R.Tensor((1, 3, 14, 12), dtype="float32") =
R.slice_scatter(
- lv16, lv18, R.prim_value(12), R.prim_value(14),
R.prim_value(1), axis=2
+
+ lv23: R.Tensor((1, 3, 2, 12), dtype="float32") =
R.broadcast_to(
+ lv22, R.shape([1, 3, 2, 12])
+ )
+
+ lv24: R.Tensor((1, 3, 14, 12), dtype="float32") =
R.slice_scatter(
+ lv20, lv23, R.prim_value(12), R.prim_value(14),
R.prim_value(1), axis=2
)
- gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) =
(lv19,)
+ gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) =
(lv24,)
R.output(gv)
return gv
@@ -5945,6 +5982,34 @@ def test_new_zeros():
verify_model(NewZeros(), example_args, {}, expected1)
+def test_copy():
+ class CopyBroadcast(Module):
+ def forward(self, x, src):
+ x.copy_(src)
+ return x
+
+ @tvm.script.ir_module
+ class expected_copy:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32"), src: R.Tensor((),
dtype="int64")
+ ) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.astype(src,
dtype="float32")
+ lv1: R.Tensor((2, 3), dtype="float32") = R.broadcast_to(lv,
(2, 3))
+ gv: R.Tuple(
+ R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3),
dtype="float32")
+ ) = (
+ lv1,
+ lv1,
+ )
+ R.output(gv)
+ return gv
+
+ example_args = (torch.zeros(2, 3, dtype=torch.float32), torch.tensor(1,
dtype=torch.int64))
+ verify_model(CopyBroadcast(), example_args, {}, expected_copy)
+
+
def test_to_copy():
# float
class ToFloat(Module):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 031a855fb9..7f0905088c 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5737,7 +5737,28 @@ def test_inplace_copy():
inp_1: R.Tensor((1, 2, 3, 4), dtype="float32"),
) -> R.Tensor((1, 2, 3, 4), dtype="float32"):
with R.dataflow():
- gv: R.Tensor((1, 2, 3, 4), dtype="float32") = inp_1
+ lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.broadcast_to(
+ inp_1, R.shape([1, 2, 3, 4])
+ )
+ gv: R.Tensor((1, 2, 3, 4), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class CopyBroadcast(Module):
+ def forward(self, x, src):
+ x.copy_(src)
+ return x
+
+ @tvm.script.ir_module
+ class expected_copy:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32"), src: R.Tensor((),
dtype="int64")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.astype(src,
dtype="float32")
+ lv1: R.Tensor((2, 3), dtype="float32") = R.broadcast_to(lv,
(2, 3))
+ gv: R.Tensor((2, 3), dtype="float32") = lv1
R.output(gv)
return gv
@@ -5747,6 +5768,7 @@ def test_inplace_copy():
{},
Expected,
)
+ verify_model(CopyBroadcast(), [((2, 3), "float32"), ((), "int64")], {},
expected_copy)
def test_clone():