This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1437d5caac [Relax][Pytorch] Add support for ones_like, zero_, zeros, 
type_as, item ops (#17868)
1437d5caac is described below

commit 1437d5caac4fc31b8aeeab87046a857ad3c5853e
Author: kavin-mcw <[email protected]>
AuthorDate: Thu Apr 24 10:42:52 2025 +0530

    [Relax][Pytorch] Add support for ones_like, zero_, zeros, type_as, item ops 
(#17868)
    
    * Add support for ones_like,zero_,zeros,type_as,item
    
    * Fix lint issues
    
    * Fix lint issues
    
    * Removed unused import
---
 .../frontend/torch/base_fx_graph_translator.py     |  16 +++
 .../frontend/torch/exported_program_translator.py  |  15 +++
 python/tvm/relax/frontend/torch/fx_translator.py   |   6 ++
 .../relax/test_frontend_from_exported_program.py   | 111 +++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        |  89 +++++++++++++++++
 5 files changed, 237 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index a89726495e..33f6ffc313 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1501,6 +1501,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             return self.block_builder.emit(relax.op.astype(x, dtype))
         return x
 
+    def _type_as(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        other = self.env[node.args[1]]
+        dtype = other.struct_info.dtype
+        return self.block_builder.emit(relax.op.astype(x, dtype))
+
     ########## Others ##########
 
     def _getitem(self, node: fx.Node) -> relax.Var:
@@ -1584,6 +1590,16 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         else:
             assert False
 
+    def _item(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        return self.block_builder.emit(relax.op.take(x, relax.const(0, 
"int64"), axis=0))
+
+    def _zeros_inplace(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        output = self.block_builder.emit(relax.op.zeros_like(x))
+        self.env[node.args[0]] = output
+        return output
+
     @abc.abstractmethod
     def create_convert_map(
         self,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index cdf0c46bb5..0434712050 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -253,6 +253,14 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
         return self.block_builder.emit(relax.op.one_hot(x, on_value, 
off_value, num_classes, axis))
 
+    def _zeros(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) 
else (args[0],))
+        dtype = self._convert_data_type(
+            node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+        )
+        return self.block_builder.emit(relax.op.zeros(size, dtype))
+
     ########## Others ##########
 
     def create_convert_map(
@@ -470,11 +478,18 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "new_ones.default": self._new_ones,
             "one_hot.default": self._one_hot,
             "ones.default": self._ones,
+            "ones_like.default": lambda node: self.block_builder.emit(
+                relax.op.ones_like(self.env[node.args[0]])
+            ),
+            "zero_.default": self._zeros_inplace,
+            "zeros.default": self._zeros,
             # datatype
             "to.dtype": self._to,
             "to.dtype_layout": self._to,
+            "type_as.default": self._type_as,
             # other
             "getitem": self._getitem,
+            "item.default": self._item,
         }
 
     def create_input_vars(
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index c3bf8f0454..55abf20fcc 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -836,7 +836,11 @@ class TorchFXImporter(BaseFXGraphImporter):
             "new_ones": self._new_ones,
             "ones": self._ones,
             "one_hot": self._one_hot,
+            "ones_like": lambda node: self.block_builder.emit(
+                relax.op.ones_like(self.env[node.args[0]])
+            ),
             "tensor": self._tensor,
+            "zero_": self._zeros_inplace,
             "copy_": self._inplace_copy,
             # datatype
             "astype": self._type,
@@ -845,10 +849,12 @@ class TorchFXImporter(BaseFXGraphImporter):
             "is_floating_point": self._is_floating_point,
             "to": self._to,
             "type": self._type,
+            "type_as": self._type_as,
             # other
             "getattr": self._getattr,
             "getitem": self._getitem,
             "sym_size.int": self._sym_size_int,
+            "item": self._item,
         }
 
     def update_convert_map(self, custom_convert_map: dict):
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index c6ead5aacc..108617991b 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3948,6 +3948,98 @@ def test_one_hot():
     verify_model(OneHot(), example_args, {}, Expected)
 
 
+def test_ones_like():
+    class OnesLike(Module):
+        def forward(self, input):
+            return torch.ones_like(input)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((128, 128), dtype="float32")
+        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(input, 
dtype="void")
+                gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.rand(128, 128, dtype=torch.float32),)
+
+    verify_model(OnesLike(), example_args, {}, Expected)
+
+
+def test_zero_inplace():
+    class ZeroInplace(Module):
+        def forward(self, input):
+            return input.zero_()
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((128, 128), dtype="float32")
+        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((128, 128), dtype="float32") = 
R.zeros_like(input, dtype="void")
+                gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.rand(128, 128, dtype=torch.float32),)
+
+    verify_model(ZeroInplace(), example_args, {}, Expected)
+
+
+def test_zeros():
+    class Zeros(Module):
+        def forward(self, input):
+            return torch.zeros(5, 2)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((128, 128), dtype="float32")
+        ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((5, 2), dtype="float32") = R.zeros(R.shape([5, 
2]), dtype="float32")
+                gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.rand(128, 128, dtype=torch.float32),)
+
+    verify_model(Zeros(), example_args, {}, Expected)
+
+
+def test_type_as():
+    class TypeAs(Module):
+        def forward(self, input, other):
+            return input.type_as(other)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((128, 128), dtype="float32"),
+            other: R.Tensor((128, 128), dtype="float16"),
+        ) -> R.Tuple(R.Tensor((128, 128), dtype="float16")):
+            with R.dataflow():
+                lv: R.Tensor((128, 128), dtype="float16") = R.astype(input, 
dtype="float16")
+                gv: R.Tuple(R.Tensor((128, 128), dtype="float16")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.rand(128, 128, dtype=torch.float32),
+        torch.rand(128, 128, dtype=torch.float16),
+    )
+
+    verify_model(TypeAs(), example_args, {}, Expected)
+
+
 def test_select():
     class Select(Module):
         def forward(self, input):
@@ -4379,6 +4471,25 @@ def test_narrow():
     verify_model(Narrow(), example_args, {}, Expected)
 
 
+def test_item():
+    class Item(Module):
+        def forward(self, x):
+            return x.item()
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(input: R.Tensor((1,), dtype="float32")) -> 
R.Tuple(R.Tensor((), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.take(input, R.const(0, 
"int64"), axis=0)
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, dtype=torch.float32),)
+    verify_model(Item(), example_args, {}, Expected)
+
+
 def test_norm():
     class Norm(Module):
         def __init__(self, p, dim=None, keepdim=False):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index f21cde6df2..cb69398e0a 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4506,6 +4506,95 @@ def test_empty_like():
     verify_model(EmptyLike(), [([5], "float32")], {}, Expected)
 
 
+def test_ones_like():
+    class OnesLike(Module):
+        def forward(self, data):
+            return torch.ones_like(data)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((128, 128), dtype="float32")
+        ) -> R.Tensor((128, 128), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(inp_0, 
dtype="void")
+                gv: R.Tensor((128, 128), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(OnesLike(), [([128, 128], "float32")], {}, Expected)
+
+
+def test_zero_inplace():
+    class ZeroInplace(Module):
+        def forward(self, data):
+            return data.zero_()
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((128, 128), dtype="float32")
+        ) -> R.Tensor((128, 128), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((128, 128), dtype="float32") = 
R.zeros_like(inp_0, dtype="void")
+                gv: R.Tensor((128, 128), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(ZeroInplace(), [([128, 128], "float32")], {}, Expected)
+
+
+def test_type_as():
+    class TypeAs(Module):
+        def forward(self, data, other):
+            return data.type_as(other)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((128, 128), dtype="float16"),
+            inp_1: R.Tensor((128, 128), dtype="float32"),
+        ) -> R.Tensor((128, 128), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((128, 128), dtype="float32") = R.astype(inp_0, 
dtype="float32")
+                gv: R.Tensor((128, 128), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(TypeAs(), [([128, 128], "float16"), ([128, 128], "float32")], 
{}, Expected)
+
+
+def test_item():
+    class Item(Module):
+        def forward(self, data):
+            return data.item()
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(inp_0: R.Tensor((1,), dtype="float32")) -> R.Tensor((), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.take(inp_0, R.const(0, 
"int64"), axis=0)
+                gv: R.Tensor((), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(
+        Item(),
+        [
+            (
+                [1],
+                "float32",
+            )
+        ],
+        {},
+        Expected,
+    )
+
+
 def test_numel():
     class Numel(Module):
         def forward(self, data):

Reply via email to