This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8dc9f5fdc1 [Relax][PyTorch] Fix KeyError: dtype when converting 
PyTorch model with gradient checkpointing using torch.export (#18461)
8dc9f5fdc1 is described below

commit 8dc9f5fdc14857c62395c234a638643be1f73b98
Author: Neo Chien <[email protected]>
AuthorDate: Tue Nov 18 00:59:37 2025 +0800

    [Relax][PyTorch] Fix KeyError: dtype when converting PyTorch model with 
gradient checkpointing using torch.export (#18461)
    
    This PR is trying to fix issues
    https://github.com/apache/tvm/issues/18439.
    
    Co-authored-by: cchung100m <[email protected]>
---
 .../frontend/torch/base_fx_graph_translator.py      |  6 +++++-
 .../frontend/torch/exported_program_translator.py   |  1 +
 .../relax/test_frontend_from_exported_program.py    | 21 +++++++++++++++++++++
 3 files changed, 27 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 83a045ef54..b20b27eb09 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2036,7 +2036,11 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         return self.block_builder.emit(relax.op.arange(*start_end_step, 
dtype=dtype))
 
     def _empty(self, node: fx.Node) -> relax.Var:
-        dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
+        import torch
+
+        dtype = self._convert_data_type(
+            node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+        )
         return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
 
     def _empty_like(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 44e967ec0e..3b982b6b46 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1143,6 +1143,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "_assert_tensor_metadata.default": lambda node: self.env[
                 node.args[0]
             ],  # metadata assertion: no-op
+            "empty.default": self._empty,
             "empty.memory_format": self._empty,
             "empty_permuted.default": self._empty,  # Similar to empty with 
permuted layout
             "empty_like.default": self._empty_like,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index ef2736778f..001df64815 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5278,6 +5278,27 @@ def test_empty():
     verify_model(Empty(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
+def test_empty_without_dtype():
+    class EmptyWithoutDtype(Module):
+        def forward(self, input):
+            return torch.empty((5, 5))
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((5, 5), dtype="float32") = R.zeros(R.shape([5, 
5]), dtype="float32")
+                gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(10, 10, dtype=torch.float32),)
+    verify_model(EmptyWithoutDtype(), example_args, {}, Expected, 
run_ep_decomposition=True)
+
+
 def test_fill():
     class Fill(Module):
         def forward(self, input: torch.Tensor):

Reply via email to