This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ed477b0acb Add op support for zeros_like and fill_ (#17896)
ed477b0acb is described below

commit ed477b0acb8b3f0266934ccb88f2bd8b562acb22
Author: Pratheesh-04-MCW <[email protected]>
AuthorDate: Mon Apr 28 11:50:47 2025 +0530

    Add op support for zeros_like and fill_ (#17896)
    
    * add op support for zeros_like and fill_
    
    * fixing whitespace issues
    
    * unity issue
    
    * solved datatype issue
    
    * unity issue
    
    * lint error
---
 .../frontend/torch/base_fx_graph_translator.py     | 13 +++++++
 .../frontend/torch/exported_program_translator.py  |  2 +
 python/tvm/relax/frontend/torch/fx_translator.py   | 11 +-----
 .../relax/test_frontend_from_exported_program.py   | 45 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 20 ++++++++++
 5 files changed, 82 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index e5a1ba5e99..ff5e51da0b 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1457,6 +1457,15 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         value = args[1] if isinstance(args[1], relax.Expr) else 
relax.const(args[1], dtype)
         return self.block_builder.emit(relax.op.full(x.struct_info.shape, 
value, dtype))
 
+    def _inplace_fill(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dtype = x.struct_info.dtype
+        value = args[1] if isinstance(args[1], relax.Expr) else 
relax.const(args[1], dtype)
+        filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, 
value, dtype))
+        self.env[node.args[0]] = filled
+        return filled
+
     def _full(self, node: fx.Node) -> relax.Var:
         import torch
 
@@ -1670,6 +1679,10 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         self.env[node.args[0]] = output
         return output
 
+    def _zeros_like(self, node: fx.node) -> relax.Var:
+        x = self.env[node.args[0]]
+        return self.block_builder.emit(relax.op.zeros_like(x))
+
     @abc.abstractmethod
     def create_convert_map(
         self,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index ce325f5fb8..86f5de5f36 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -474,6 +474,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "eye.default": self._eye,
             "eye.m": self._eye,
             "fill.Scalar": self._fill,
+            "fill_.Scalar": self._inplace_fill,
             "full.default": self._full,
             "full_like.default": self._full_like,
             "index_select.default": self._index_select,
@@ -488,6 +489,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             ),
             "zero_.default": self._zeros_inplace,
             "zeros.default": self._zeros,
+            "zeros_like.default": self._zeros_like,
             # datatype
             "to.dtype": self._to,
             "to.dtype_layout": self._to,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index f99c508aa2..07b3df20aa 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -515,15 +515,6 @@ class TorchFXImporter(BaseFXGraphImporter):
 
     ########## Creation ##########
 
-    def _inplace_fill(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        dtype = x.struct_info.dtype
-        value = args[1] if isinstance(args[1], relax.Expr) else 
relax.const(args[1], dtype)
-        filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, 
value, dtype))
-        self.env[node.args[0]] = filled
-        return filled
-
     def _inplace_copy(self, node: fx.Node) -> relax.Var:
         src = self.env[node.args[1]]
         self.env[node.args[0]] = src
@@ -830,6 +821,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "clone": lambda node: self.env[node.args[0]],
             "empty": self._empty,
             "empty_like": self._empty_like,
+            "fill": self._fill,
             "fill_": self._inplace_fill,
             "full": self._full,
             "index_select": self._index_select,
@@ -844,6 +836,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             ),
             "tensor": self._tensor,
             "zero_": self._zeros_inplace,
+            "zeros_like": self._zeros_like,
             "copy_": self._inplace_copy,
             # datatype
             "astype": self._type,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 651bcc1a28..dd1869a23c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3679,6 +3679,30 @@ def test_fill():
     verify_model(Fill(), example_args, {}, Expected)
 
 
+def test_fill_inplace():
+    class FillInplace(Module):
+        def forward(self, input: torch.Tensor):
+            input.fill_(42.0)
+            return input
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 3), dtype="float32") = R.full(
+                    R.shape([2, 3]), R.const(42.0, "float32"), dtype="float32"
+                )
+                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(2, 3, dtype=torch.float32),)
+    verify_model(FillInplace(), example_args, {}, Expected)
+
+
 def test_masked_fill():
     class Masked_Fill(Module):
         def forward(self, input: torch.Tensor, mask: torch.Tensor):
@@ -4046,6 +4070,27 @@ def test_zeros():
     verify_model(Zeros(), example_args, {}, Expected)
 
 
+def test_zeros_like():
+    class ZerosLike(Module):
+        def forward(self, input):
+            return torch.zeros_like(input)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((128, 128), dtype="float32")
+        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((128, 128), dtype="float32") = 
R.zeros_like(input, dtype="void")
+                gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.rand(128, 128, dtype=torch.float32),)
+    verify_model(ZerosLike(), example_args, {}, Expected)
+
+
 def test_type_as():
     class TypeAs(Module):
         def forward(self, input, other):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 52a0a9e7a7..f60f158cbf 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4747,6 +4747,26 @@ def test_zero_inplace():
     verify_model(ZeroInplace(), [([128, 128], "float32")], {}, Expected)
 
 
+def test_zeros_like():
+    class ZerosLike(Module):
+        def forward(self, data):
+            return torch.zeros_like(data)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((128, 128), dtype="float32")
+        ) -> R.Tensor((128, 128), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((128, 128), dtype="float32") = 
R.zeros_like(inp_0, dtype="void")
+                gv: R.Tensor((128, 128), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(ZerosLike(), [([128, 128], "float32")], {}, Expected)
+
+
 def test_type_as():
     class TypeAs(Module):
         def forward(self, data, other):

Reply via email to