This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8dc9f5fdc1 [Relax][PyTorch] Fix KeyError: dtype when converting
PyTorch model with gradient checkpointing using torch.export (#18461)
8dc9f5fdc1 is described below
commit 8dc9f5fdc14857c62395c234a638643be1f73b98
Author: Neo Chien <[email protected]>
AuthorDate: Tue Nov 18 00:59:37 2025 +0800
[Relax][PyTorch] Fix KeyError: dtype when converting PyTorch model with
gradient checkpointing using torch.export (#18461)
This PR is trying to fix issues
https://github.com/apache/tvm/issues/18439.
Co-authored-by: cchung100m <[email protected]>
---
.../frontend/torch/base_fx_graph_translator.py | 6 +++++-
.../frontend/torch/exported_program_translator.py | 1 +
.../relax/test_frontend_from_exported_program.py | 21 +++++++++++++++++++++
3 files changed, 27 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 83a045ef54..b20b27eb09 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2036,7 +2036,11 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.arange(*start_end_step,
dtype=dtype))
def _empty(self, node: fx.Node) -> relax.Var:
- dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
+ import torch
+
+ dtype = self._convert_data_type(
+ node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+ )
return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
def _empty_like(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 44e967ec0e..3b982b6b46 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1143,6 +1143,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"_assert_tensor_metadata.default": lambda node: self.env[
node.args[0]
], # metadata assertion: no-op
+ "empty.default": self._empty,
"empty.memory_format": self._empty,
"empty_permuted.default": self._empty, # Similar to empty with
permuted layout
"empty_like.default": self._empty_like,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index ef2736778f..001df64815 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5278,6 +5278,27 @@ def test_empty():
verify_model(Empty(), example_args, {}, Expected,
run_ep_decomposition=True)
+def test_empty_without_dtype():
+ class EmptyWithoutDtype(Module):
+ def forward(self, input):
+ return torch.empty((5, 5))
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((5, 5), dtype="float32") = R.zeros(R.shape([5,
5]), dtype="float32")
+ gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(10, 10, dtype=torch.float32),)
+ verify_model(EmptyWithoutDtype(), example_args, {}, Expected,
run_ep_decomposition=True)
+
+
def test_fill():
class Fill(Module):
def forward(self, input: torch.Tensor):