This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dc043fe64c [Relax][PyTorch] Support one_hot, empty_like ops for
ExportedProgram importer (#17751)
dc043fe64c is described below
commit dc043fe64c4e82c466bfa89a4295dc17594036f7
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Mar 16 14:29:28 2025 +0800
[Relax][PyTorch] Support one_hot, empty_like ops for ExportedProgram
importer (#17751)
* Update exported_program_translator.py
* Update test_frontend_from_exported_program.py
* Update test_frontend_from_exported_program.py
* Update exported_program_translator.py
---
.../frontend/torch/exported_program_translator.py | 19 ++++++
.../relax/test_frontend_from_exported_program.py | 68 ++++++++++++++++++++++
2 files changed, 87 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index d67cacb960..bc7a4c4cb0 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -174,6 +174,23 @@ class ExportedProgramImporter(BaseFXGraphImporter):
stride = [node.args[4] if len(node.args) > 4 else 1]
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin,
end, stride))
+ ########## Creation ##########
+
+ def _one_hot(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ num_classes = node.args[1] if len(node.args) > 1 else
node.kwargs.get("num_classes")
+ if num_classes is None:
+ raise ValueError("num_classes not found in node.args or
node.kwargs")
+
+ on_value = node.args[2] if len(node.args) > 2 else
node.kwargs.get("on_value", 1)
+ off_value = node.args[3] if len(node.args) > 3 else
node.kwargs.get("off_value", 0)
+ axis = node.args[4] if len(node.args) > 4 else node.kwargs.get("axis",
-1)
+
+ on_value = relax.PrimValue(on_value)
+ off_value = relax.PrimValue(off_value)
+
+ return self.block_builder.emit(relax.op.one_hot(x, on_value,
off_value, num_classes, axis))
+
########## Others ##########
def create_convert_map(
@@ -331,8 +348,10 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"contiguous.default": lambda node: self.env[node.args[0]], # no-op
"clone.default": lambda node: self.env[node.args[0]],
"empty.memory_format": self._empty,
+ "empty_like.default": self._empty_like,
"fill.Scalar": self._fill,
"new_ones.default": self._new_ones,
+ "one_hot.default": self._one_hot,
# other
"getitem": self._getitem,
}
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index e2be933050..1b4e802539 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3425,6 +3425,74 @@ def test_no_bind_return_tuple():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_empty_like():
+ class EmptyLike(Module):
+ def forward(self, data):
+ return torch.empty_like(data)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((5,), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0,
dtype="void")
+ gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(5, dtype=torch.float32),)
+
+ verify_model(EmptyLike(), example_args, {}, Expected)
+
+
+def test_one_hot():
+ class OneHot(Module):
+ def forward(self, indices):
+ return torch.nn.functional.one_hot(indices, num_classes=10)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5,), dtype="int64"),
+ ) -> R.Tuple(R.Tensor((5, 10), dtype="int64")):
+ with R.dataflow():
+ lv: R.Tensor((5, 10), dtype="int64") = R.one_hot(
+ inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1
+ )
+ gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),)
+
+ verify_model(OneHot(), example_args, {}, Expected)
+
+
+def test_select():
+ class Select(Module):
+ def forward(self, input):
+ return torch.select(input, 0, 1)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((3,), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3,), dtype="float32") = R.take(inp_0, R.const(1,
"int64"), axis=0)
+ gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(2, 3, dtype=torch.float32),)
+
+ verify_model(Select(), example_args, {}, Expected)
+
+
def test_gather():
class Gather0(Module):
def forward(self, data, indices):