This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dc043fe64c [Relax][PyTorch] Support one_hot, empty_like ops for 
ExportedProgram importer (#17751)
dc043fe64c is described below

commit dc043fe64c4e82c466bfa89a4295dc17594036f7
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Mar 16 14:29:28 2025 +0800

    [Relax][PyTorch] Support one_hot, empty_like ops for ExportedProgram 
importer (#17751)
    
    * Update exported_program_translator.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update exported_program_translator.py
---
 .../frontend/torch/exported_program_translator.py  | 19 ++++++
 .../relax/test_frontend_from_exported_program.py   | 68 ++++++++++++++++++++++
 2 files changed, 87 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index d67cacb960..bc7a4c4cb0 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -174,6 +174,23 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         stride = [node.args[4] if len(node.args) > 4 else 1]
         return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, 
end, stride))
 
+    ########## Creation ##########
+
+    def _one_hot(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        num_classes = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("num_classes")
+        if num_classes is None:
+            raise ValueError("num_classes not found in node.args or 
node.kwargs")
+
+        on_value = node.args[2] if len(node.args) > 2 else 
node.kwargs.get("on_value", 1)
+        off_value = node.args[3] if len(node.args) > 3 else 
node.kwargs.get("off_value", 0)
+        axis = node.args[4] if len(node.args) > 4 else node.kwargs.get("axis", 
-1)
+
+        on_value = relax.PrimValue(on_value)
+        off_value = relax.PrimValue(off_value)
+
+        return self.block_builder.emit(relax.op.one_hot(x, on_value, 
off_value, num_classes, axis))
+
     ########## Others ##########
 
     def create_convert_map(
@@ -331,8 +348,10 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "contiguous.default": lambda node: self.env[node.args[0]],  # no-op
             "clone.default": lambda node: self.env[node.args[0]],
             "empty.memory_format": self._empty,
+            "empty_like.default": self._empty_like,
             "fill.Scalar": self._fill,
             "new_ones.default": self._new_ones,
+            "one_hot.default": self._one_hot,
             # other
             "getitem": self._getitem,
         }
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index e2be933050..1b4e802539 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3425,6 +3425,74 @@ def test_no_bind_return_tuple():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_empty_like():
+    class EmptyLike(Module):
+        def forward(self, data):
+            return torch.empty_like(data)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((5,), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0, 
dtype="void")
+                gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(5, dtype=torch.float32),)
+
+    verify_model(EmptyLike(), example_args, {}, Expected)
+
+
+def test_one_hot():
+    class OneHot(Module):
+        def forward(self, indices):
+            return torch.nn.functional.one_hot(indices, num_classes=10)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5,), dtype="int64"),
+        ) -> R.Tuple(R.Tensor((5, 10), dtype="int64")):
+            with R.dataflow():
+                lv: R.Tensor((5, 10), dtype="int64") = R.one_hot(
+                    inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1
+                )
+                gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),)
+
+    verify_model(OneHot(), example_args, {}, Expected)
+
+
+def test_select():
+    class Select(Module):
+        def forward(self, input):
+            return torch.select(input, 0, 1)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((3,), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3,), dtype="float32") = R.take(inp_0, R.const(1, 
"int64"), axis=0)
+                gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(2, 3, dtype=torch.float32),)
+
+    verify_model(Select(), example_args, {}, Expected)
+
+
 def test_gather():
     class Gather0(Module):
         def forward(self, data, indices):

Reply via email to