This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1437d5caac [Relax][Pytorch] Add support for ones_like, zero_, zeros,
type_as, item ops (#17868)
1437d5caac is described below
commit 1437d5caac4fc31b8aeeab87046a857ad3c5853e
Author: kavin-mcw <[email protected]>
AuthorDate: Thu Apr 24 10:42:52 2025 +0530
[Relax][Pytorch] Add support for ones_like, zero_, zeros, type_as, item ops
(#17868)
* Add support for ones_like,zero_,zeros,type_as,item
* Fix lint issues
* Fix lint issues
* Removed unused import
---
.../frontend/torch/base_fx_graph_translator.py | 16 +++
.../frontend/torch/exported_program_translator.py | 15 +++
python/tvm/relax/frontend/torch/fx_translator.py | 6 ++
.../relax/test_frontend_from_exported_program.py | 111 +++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 89 +++++++++++++++++
5 files changed, 237 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index a89726495e..33f6ffc313 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1501,6 +1501,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.astype(x, dtype))
return x
+ def _type_as(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ other = self.env[node.args[1]]
+ dtype = other.struct_info.dtype
+ return self.block_builder.emit(relax.op.astype(x, dtype))
+
########## Others ##########
def _getitem(self, node: fx.Node) -> relax.Var:
@@ -1584,6 +1590,16 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
else:
assert False
+ def _item(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ return self.block_builder.emit(relax.op.take(x, relax.const(0,
"int64"), axis=0))
+
+ def _zeros_inplace(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ output = self.block_builder.emit(relax.op.zeros_like(x))
+ self.env[node.args[0]] = output
+ return output
+
@abc.abstractmethod
def create_convert_map(
self,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index cdf0c46bb5..0434712050 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -253,6 +253,14 @@ class ExportedProgramImporter(BaseFXGraphImporter):
return self.block_builder.emit(relax.op.one_hot(x, on_value,
off_value, num_classes, axis))
+ def _zeros(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple))
else (args[0],))
+ dtype = self._convert_data_type(
+ node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+ )
+ return self.block_builder.emit(relax.op.zeros(size, dtype))
+
########## Others ##########
def create_convert_map(
@@ -470,11 +478,18 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"new_ones.default": self._new_ones,
"one_hot.default": self._one_hot,
"ones.default": self._ones,
+ "ones_like.default": lambda node: self.block_builder.emit(
+ relax.op.ones_like(self.env[node.args[0]])
+ ),
+ "zero_.default": self._zeros_inplace,
+ "zeros.default": self._zeros,
# datatype
"to.dtype": self._to,
"to.dtype_layout": self._to,
+ "type_as.default": self._type_as,
# other
"getitem": self._getitem,
+ "item.default": self._item,
}
def create_input_vars(
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index c3bf8f0454..55abf20fcc 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -836,7 +836,11 @@ class TorchFXImporter(BaseFXGraphImporter):
"new_ones": self._new_ones,
"ones": self._ones,
"one_hot": self._one_hot,
+ "ones_like": lambda node: self.block_builder.emit(
+ relax.op.ones_like(self.env[node.args[0]])
+ ),
"tensor": self._tensor,
+ "zero_": self._zeros_inplace,
"copy_": self._inplace_copy,
# datatype
"astype": self._type,
@@ -845,10 +849,12 @@ class TorchFXImporter(BaseFXGraphImporter):
"is_floating_point": self._is_floating_point,
"to": self._to,
"type": self._type,
+ "type_as": self._type_as,
# other
"getattr": self._getattr,
"getitem": self._getitem,
"sym_size.int": self._sym_size_int,
+ "item": self._item,
}
def update_convert_map(self, custom_convert_map: dict):
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index c6ead5aacc..108617991b 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3948,6 +3948,98 @@ def test_one_hot():
verify_model(OneHot(), example_args, {}, Expected)
+def test_ones_like():
+ class OnesLike(Module):
+ def forward(self, input):
+ return torch.ones_like(input)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((128, 128), dtype="float32")
+ ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(input,
dtype="void")
+ gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.rand(128, 128, dtype=torch.float32),)
+
+ verify_model(OnesLike(), example_args, {}, Expected)
+
+
+def test_zero_inplace():
+ class ZeroInplace(Module):
+ def forward(self, input):
+ return input.zero_()
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((128, 128), dtype="float32")
+ ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((128, 128), dtype="float32") =
R.zeros_like(input, dtype="void")
+ gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.rand(128, 128, dtype=torch.float32),)
+
+ verify_model(ZeroInplace(), example_args, {}, Expected)
+
+
+def test_zeros():
+ class Zeros(Module):
+ def forward(self, input):
+ return torch.zeros(5, 2)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((128, 128), dtype="float32")
+ ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((5, 2), dtype="float32") = R.zeros(R.shape([5,
2]), dtype="float32")
+ gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.rand(128, 128, dtype=torch.float32),)
+
+ verify_model(Zeros(), example_args, {}, Expected)
+
+
+def test_type_as():
+ class TypeAs(Module):
+ def forward(self, input, other):
+ return input.type_as(other)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((128, 128), dtype="float32"),
+ other: R.Tensor((128, 128), dtype="float16"),
+ ) -> R.Tuple(R.Tensor((128, 128), dtype="float16")):
+ with R.dataflow():
+ lv: R.Tensor((128, 128), dtype="float16") = R.astype(input,
dtype="float16")
+ gv: R.Tuple(R.Tensor((128, 128), dtype="float16")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.rand(128, 128, dtype=torch.float32),
+ torch.rand(128, 128, dtype=torch.float16),
+ )
+
+ verify_model(TypeAs(), example_args, {}, Expected)
+
+
def test_select():
class Select(Module):
def forward(self, input):
@@ -4379,6 +4471,25 @@ def test_narrow():
verify_model(Narrow(), example_args, {}, Expected)
+def test_item():
+ class Item(Module):
+ def forward(self, x):
+ return x.item()
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(input: R.Tensor((1,), dtype="float32")) ->
R.Tuple(R.Tensor((), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.take(input, R.const(0,
"int64"), axis=0)
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, dtype=torch.float32),)
+ verify_model(Item(), example_args, {}, Expected)
+
+
def test_norm():
class Norm(Module):
def __init__(self, p, dim=None, keepdim=False):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index f21cde6df2..cb69398e0a 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4506,6 +4506,95 @@ def test_empty_like():
verify_model(EmptyLike(), [([5], "float32")], {}, Expected)
+def test_ones_like():
+ class OnesLike(Module):
+ def forward(self, data):
+ return torch.ones_like(data)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((128, 128), dtype="float32")
+ ) -> R.Tensor((128, 128), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(inp_0,
dtype="void")
+ gv: R.Tensor((128, 128), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(OnesLike(), [([128, 128], "float32")], {}, Expected)
+
+
+def test_zero_inplace():
+ class ZeroInplace(Module):
+ def forward(self, data):
+ return data.zero_()
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((128, 128), dtype="float32")
+ ) -> R.Tensor((128, 128), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((128, 128), dtype="float32") =
R.zeros_like(inp_0, dtype="void")
+ gv: R.Tensor((128, 128), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(ZeroInplace(), [([128, 128], "float32")], {}, Expected)
+
+
+def test_type_as():
+ class TypeAs(Module):
+ def forward(self, data, other):
+ return data.type_as(other)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((128, 128), dtype="float16"),
+ inp_1: R.Tensor((128, 128), dtype="float32"),
+ ) -> R.Tensor((128, 128), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((128, 128), dtype="float32") = R.astype(inp_0,
dtype="float32")
+ gv: R.Tensor((128, 128), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(TypeAs(), [([128, 128], "float16"), ([128, 128], "float32")],
{}, Expected)
+
+
+def test_item():
+ class Item(Module):
+ def forward(self, data):
+ return data.item()
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(inp_0: R.Tensor((1,), dtype="float32")) -> R.Tensor((),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.take(inp_0, R.const(0,
"int64"), axis=0)
+ gv: R.Tensor((), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(
+ Item(),
+ [
+ (
+ [1],
+ "float32",
+ )
+ ],
+ {},
+ Expected,
+ )
+
+
def test_numel():
class Numel(Module):
def forward(self, data):