This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0701aaba4b [Relax][PyTorch]: Fix the sqrt operation requires float
dtype but receives int64 in attention scaling (#18454)
0701aaba4b is described below
commit 0701aaba4b37666b30e75b5722c1e2d3bb0b50ce
Author: Neo Chien <[email protected]>
AuthorDate: Mon Nov 17 04:35:04 2025 +0800
[Relax][PyTorch]: Fix the sqrt operation requires float dtype but receives
int64 in attention scaling (#18454)
This PR is trying to fix issues
https://github.com/apache/tvm/issues/18443.
---------
Co-authored-by: cchung100m <[email protected]>
---
.../frontend/torch/exported_program_translator.py | 24 +++++++++++--
python/tvm/relax/frontend/torch/fx_translator.py | 24 +++++++++++--
.../relax/test_frontend_from_exported_program.py | 41 ++++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 21 +++++++++++
4 files changed, 106 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 63aba55a78..c6243c113e 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -64,6 +64,26 @@ class ExportedProgramImporter(BaseFXGraphImporter):
x = self.env[node.args[0]]
return self.block_builder.emit(relax.op.divide(relax.const(1.0,
x.struct_info.dtype), x))
+ def _sqrt(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dtype = x.struct_info.dtype
+
+ # Check if input is integer type and convert to float32 if needed
+ if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64"):
+ x = self.block_builder.emit(relax.op.astype(x, "float32"))
+
+ return self.block_builder.emit(relax.op.sqrt(x))
+
+ def _rsqrt(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dtype = x.struct_info.dtype
+
+ # Check if input is integer type and convert to float32 if needed
+ if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64"):
+ x = self.block_builder.emit(relax.op.astype(x, "float32"))
+
+ return self.block_builder.emit(relax.op.rsqrt(x))
+
########## Neural Network ##########
def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
@@ -919,7 +939,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"relu6.default": self._unary_op(relax.op.nn.relu6),
"relu6_.default": self._unary_op(relax.op.nn.relu6),
"round.default": self._round,
- "rsqrt.default": self._unary_op(relax.op.rsqrt),
+ "rsqrt.default": self._rsqrt,
"scalar_tensor.default": self._scalar_tensor,
"rsub.Tensor": self._rsub,
"rsub.Scalar": self._rsub,
@@ -935,7 +955,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"softplus.default": self._softplus,
"softshrink.default": self._softshrink,
"softsign.default": self._softsign,
- "sqrt.default": self._unary_op(relax.op.sqrt),
+ "sqrt.default": self._sqrt,
"square.default": self._unary_op(relax.op.square),
"tan.default": self._unary_op(relax.op.tan),
"tanh.default": self._unary_op(relax.op.tanh),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 0d2e240be6..a93f788669 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -96,6 +96,26 @@ class TorchFXImporter(BaseFXGraphImporter):
one = relax.const(1, x.struct_info.dtype)
return self.block_builder.emit(relax.op.log(relax.op.add(x, one)))
+ def _sqrt(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dtype = x.struct_info.dtype
+
+ # Check if input is integer type and convert to float32 if needed
+ if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64"]:
+ x = self.block_builder.emit(relax.op.astype(x, "float32"))
+
+ return self.block_builder.emit(relax.op.sqrt(x))
+
+ def _rsqrt(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dtype = x.struct_info.dtype
+
+ # Check if input is integer type and convert to float32 if needed
+ if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64"]:
+ x = self.block_builder.emit(relax.op.astype(x, "float32"))
+
+ return self.block_builder.emit(relax.op.rsqrt(x))
+
def _log_softmax_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -825,7 +845,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"relu": self._unary_op(relax.op.nn.relu),
"relu6": self._unary_op(relax.op.nn.relu6),
"round": self._round,
- "rsqrt": self._unary_op(relax.op.rsqrt),
+ "rsqrt": self._rsqrt,
"selu": self._unary_op(relax.op.nn.selu),
"sigmoid": self._unary_op(relax.op.sigmoid),
"sign": self._unary_op(relax.op.sign),
@@ -834,7 +854,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"sinh": self._unary_op(relax.op.sinh),
"softmax": self._softmax,
"softplus": self._softplus,
- "sqrt": self._unary_op(relax.op.sqrt),
+ "sqrt": self._sqrt,
"square": self._unary_op(relax.op.square),
"tan": self._unary_op(relax.op.tan),
"tanh": self._unary_op(relax.op.tanh),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 1b816432ce..6cf293d96b 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -126,6 +126,47 @@ def test_bool_unary_ops(pytorch_op, relax_op):
verify_model(UnaryOp(), example_args, {}, expected,
run_ep_decomposition=True)
+def test_sqrt_integer_input():
+ """Test that sqrt operation works with integer tensors by auto-converting
to float."""
+ example_args = (torch.tensor([[4, 9, 16, 25]], dtype=torch.int64),)
+
+ class SqrtIntModel(Module):
+ def forward(self, input):
+ return torch.sqrt(input)
+
+ @tvm.script.ir_module
+ class expected_int64:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 4), dtype="int64")
+ ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 4), dtype="float32") = R.astype(input_1,
dtype="float32")
+ lv1: R.Tensor((1, 4), dtype="float32") = R.sqrt(lv)
+ gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ verify_model(SqrtIntModel(), example_args, {}, expected_int64,
run_ep_decomposition=True)
+
+ example_args_int32 = (torch.tensor([[1, 4, 9]], dtype=torch.int32),)
+
+ @tvm.script.ir_module
+ class expected_int32:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3), dtype="int32")
+ ) -> R.Tuple(R.Tensor((1, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3), dtype="float32") = R.astype(input_1,
dtype="float32")
+ lv1: R.Tensor((1, 3), dtype="float32") = R.sqrt(lv)
+ gv: R.Tuple(R.Tensor((1, 3), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ verify_model(SqrtIntModel(), example_args_int32, {}, expected_int32,
run_ep_decomposition=True)
+
+
def test_extended_unary_ops():
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 69ebdcbf76..d377bb7574 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2749,6 +2749,27 @@ def test_basic_unary_ops(pytorch_op, relax_op):
verify_model(Unary(), input_info, {}, expected_unary)
+def test_sqrt_integer_input_fx():
+ input_info = [([1, 4], "int64")]
+
+ class SqrtIntModel(Module):
+ def forward(self, input):
+ return torch.sqrt(input)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(input_1: R.Tensor((1, 4), dtype="int64")) -> R.Tensor((1, 4),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 4), dtype="float32") = R.astype(input_1,
dtype="float32")
+ lv1: R.Tensor((1, 4), dtype="float32") = R.sqrt(lv)
+ gv: R.Tensor((1, 4), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ verify_model(SqrtIntModel(), input_info, {}, expected)
+
+
operator_bool_unary = [
(torch.isnan, R.isnan),
(torch.isinf, R.isinf),