This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ed477b0acb Add op support for zeros_like and fill_ (#17896)
ed477b0acb is described below
commit ed477b0acb8b3f0266934ccb88f2bd8b562acb22
Author: Pratheesh-04-MCW <[email protected]>
AuthorDate: Mon Apr 28 11:50:47 2025 +0530
Add op support for zeros_like and fill_ (#17896)
* add op support for zeros_like and fill_
* fixing whitespace issues
* unity issue
* solved datatype issue
* unity issue
* lint error
---
.../frontend/torch/base_fx_graph_translator.py | 13 +++++++
.../frontend/torch/exported_program_translator.py | 2 +
python/tvm/relax/frontend/torch/fx_translator.py | 11 +-----
.../relax/test_frontend_from_exported_program.py | 45 ++++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 20 ++++++++++
5 files changed, 82 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index e5a1ba5e99..ff5e51da0b 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1457,6 +1457,15 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
value = args[1] if isinstance(args[1], relax.Expr) else
relax.const(args[1], dtype)
return self.block_builder.emit(relax.op.full(x.struct_info.shape,
value, dtype))
+ def _inplace_fill(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dtype = x.struct_info.dtype
+ value = args[1] if isinstance(args[1], relax.Expr) else
relax.const(args[1], dtype)
+ filled = self.block_builder.emit(relax.op.full(x.struct_info.shape,
value, dtype))
+ self.env[node.args[0]] = filled
+ return filled
+
def _full(self, node: fx.Node) -> relax.Var:
import torch
@@ -1670,6 +1679,10 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
self.env[node.args[0]] = output
return output
+ def _zeros_like(self, node: fx.node) -> relax.Var:
+ x = self.env[node.args[0]]
+ return self.block_builder.emit(relax.op.zeros_like(x))
+
@abc.abstractmethod
def create_convert_map(
self,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index ce325f5fb8..86f5de5f36 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -474,6 +474,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"eye.default": self._eye,
"eye.m": self._eye,
"fill.Scalar": self._fill,
+ "fill_.Scalar": self._inplace_fill,
"full.default": self._full,
"full_like.default": self._full_like,
"index_select.default": self._index_select,
@@ -488,6 +489,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
),
"zero_.default": self._zeros_inplace,
"zeros.default": self._zeros,
+ "zeros_like.default": self._zeros_like,
# datatype
"to.dtype": self._to,
"to.dtype_layout": self._to,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index f99c508aa2..07b3df20aa 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -515,15 +515,6 @@ class TorchFXImporter(BaseFXGraphImporter):
########## Creation ##########
- def _inplace_fill(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- dtype = x.struct_info.dtype
- value = args[1] if isinstance(args[1], relax.Expr) else
relax.const(args[1], dtype)
- filled = self.block_builder.emit(relax.op.full(x.struct_info.shape,
value, dtype))
- self.env[node.args[0]] = filled
- return filled
-
def _inplace_copy(self, node: fx.Node) -> relax.Var:
src = self.env[node.args[1]]
self.env[node.args[0]] = src
@@ -830,6 +821,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"clone": lambda node: self.env[node.args[0]],
"empty": self._empty,
"empty_like": self._empty_like,
+ "fill": self._fill,
"fill_": self._inplace_fill,
"full": self._full,
"index_select": self._index_select,
@@ -844,6 +836,7 @@ class TorchFXImporter(BaseFXGraphImporter):
),
"tensor": self._tensor,
"zero_": self._zeros_inplace,
+ "zeros_like": self._zeros_like,
"copy_": self._inplace_copy,
# datatype
"astype": self._type,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 651bcc1a28..dd1869a23c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3679,6 +3679,30 @@ def test_fill():
verify_model(Fill(), example_args, {}, Expected)
+def test_fill_inplace():
+ class FillInplace(Module):
+ def forward(self, input: torch.Tensor):
+ input.fill_(42.0)
+ return input
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32")
+ ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 3), dtype="float32") = R.full(
+ R.shape([2, 3]), R.const(42.0, "float32"), dtype="float32"
+ )
+ gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(2, 3, dtype=torch.float32),)
+ verify_model(FillInplace(), example_args, {}, Expected)
+
+
def test_masked_fill():
class Masked_Fill(Module):
def forward(self, input: torch.Tensor, mask: torch.Tensor):
@@ -4046,6 +4070,27 @@ def test_zeros():
verify_model(Zeros(), example_args, {}, Expected)
+def test_zeros_like():
+ class ZerosLike(Module):
+ def forward(self, input):
+ return torch.zeros_like(input)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((128, 128), dtype="float32")
+ ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((128, 128), dtype="float32") =
R.zeros_like(input, dtype="void")
+ gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.rand(128, 128, dtype=torch.float32),)
+ verify_model(ZerosLike(), example_args, {}, Expected)
+
+
def test_type_as():
class TypeAs(Module):
def forward(self, input, other):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 52a0a9e7a7..f60f158cbf 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4747,6 +4747,26 @@ def test_zero_inplace():
verify_model(ZeroInplace(), [([128, 128], "float32")], {}, Expected)
+def test_zeros_like():
+ class ZerosLike(Module):
+ def forward(self, data):
+ return torch.zeros_like(data)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((128, 128), dtype="float32")
+ ) -> R.Tensor((128, 128), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((128, 128), dtype="float32") =
R.zeros_like(inp_0, dtype="void")
+ gv: R.Tensor((128, 128), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(ZerosLike(), [([128, 128], "float32")], {}, Expected)
+
+
def test_type_as():
class TypeAs(Module):
def forward(self, data, other):