This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0701aaba4b [Relax][PyTorch]: Fix the sqrt operation requires float 
dtype but receives int64 in attention scaling (#18454)
0701aaba4b is described below

commit 0701aaba4b37666b30e75b5722c1e2d3bb0b50ce
Author: Neo Chien <[email protected]>
AuthorDate: Mon Nov 17 04:35:04 2025 +0800

    [Relax][PyTorch]: Fix the sqrt operation requires float dtype but receives 
int64 in attention scaling (#18454)
    
    This PR is trying to fix issues
    https://github.com/apache/tvm/issues/18443.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 .../frontend/torch/exported_program_translator.py  | 24 +++++++++++--
 python/tvm/relax/frontend/torch/fx_translator.py   | 24 +++++++++++--
 .../relax/test_frontend_from_exported_program.py   | 41 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 21 +++++++++++
 4 files changed, 106 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 63aba55a78..c6243c113e 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -64,6 +64,26 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         x = self.env[node.args[0]]
         return self.block_builder.emit(relax.op.divide(relax.const(1.0, 
x.struct_info.dtype), x))
 
+    def _sqrt(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dtype = x.struct_info.dtype
+
+        # Check if input is integer type and convert to float32 if needed
+        if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16", 
"uint32", "uint64"):
+            x = self.block_builder.emit(relax.op.astype(x, "float32"))
+
+        return self.block_builder.emit(relax.op.sqrt(x))
+
+    def _rsqrt(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dtype = x.struct_info.dtype
+
+        # Check if input is integer type and convert to float32 if needed
+        if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16", 
"uint32", "uint64"):
+            x = self.block_builder.emit(relax.op.astype(x, "float32"))
+
+        return self.block_builder.emit(relax.op.rsqrt(x))
+
     ########## Neural Network ##########
 
     def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
@@ -919,7 +939,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "relu6.default": self._unary_op(relax.op.nn.relu6),
             "relu6_.default": self._unary_op(relax.op.nn.relu6),
             "round.default": self._round,
-            "rsqrt.default": self._unary_op(relax.op.rsqrt),
+            "rsqrt.default": self._rsqrt,
             "scalar_tensor.default": self._scalar_tensor,
             "rsub.Tensor": self._rsub,
             "rsub.Scalar": self._rsub,
@@ -935,7 +955,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "softplus.default": self._softplus,
             "softshrink.default": self._softshrink,
             "softsign.default": self._softsign,
-            "sqrt.default": self._unary_op(relax.op.sqrt),
+            "sqrt.default": self._sqrt,
             "square.default": self._unary_op(relax.op.square),
             "tan.default": self._unary_op(relax.op.tan),
             "tanh.default": self._unary_op(relax.op.tanh),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 0d2e240be6..a93f788669 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -96,6 +96,26 @@ class TorchFXImporter(BaseFXGraphImporter):
         one = relax.const(1, x.struct_info.dtype)
         return self.block_builder.emit(relax.op.log(relax.op.add(x, one)))
 
+    def _sqrt(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dtype = x.struct_info.dtype
+
+        # Check if input is integer type and convert to float32 if needed
+        if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16", 
"uint32", "uint64"]:
+            x = self.block_builder.emit(relax.op.astype(x, "float32"))
+
+        return self.block_builder.emit(relax.op.sqrt(x))
+
+    def _rsqrt(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dtype = x.struct_info.dtype
+
+        # Check if input is integer type and convert to float32 if needed
+        if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16", 
"uint32", "uint64"]:
+            x = self.block_builder.emit(relax.op.astype(x, "float32"))
+
+        return self.block_builder.emit(relax.op.rsqrt(x))
+
     def _log_softmax_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -825,7 +845,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "relu": self._unary_op(relax.op.nn.relu),
             "relu6": self._unary_op(relax.op.nn.relu6),
             "round": self._round,
-            "rsqrt": self._unary_op(relax.op.rsqrt),
+            "rsqrt": self._rsqrt,
             "selu": self._unary_op(relax.op.nn.selu),
             "sigmoid": self._unary_op(relax.op.sigmoid),
             "sign": self._unary_op(relax.op.sign),
@@ -834,7 +854,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "sinh": self._unary_op(relax.op.sinh),
             "softmax": self._softmax,
             "softplus": self._softplus,
-            "sqrt": self._unary_op(relax.op.sqrt),
+            "sqrt": self._sqrt,
             "square": self._unary_op(relax.op.square),
             "tan": self._unary_op(relax.op.tan),
             "tanh": self._unary_op(relax.op.tanh),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 1b816432ce..6cf293d96b 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -126,6 +126,47 @@ def test_bool_unary_ops(pytorch_op, relax_op):
     verify_model(UnaryOp(), example_args, {}, expected, 
run_ep_decomposition=True)
 
 
+def test_sqrt_integer_input():
+    """Test that sqrt operation works with integer tensors by auto-converting 
to float."""
+    example_args = (torch.tensor([[4, 9, 16, 25]], dtype=torch.int64),)
+
+    class SqrtIntModel(Module):
+        def forward(self, input):
+            return torch.sqrt(input)
+
+    @tvm.script.ir_module
+    class expected_int64:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 4), dtype="int64")
+        ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 4), dtype="float32") = R.astype(input_1, 
dtype="float32")
+                lv1: R.Tensor((1, 4), dtype="float32") = R.sqrt(lv)
+                gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    verify_model(SqrtIntModel(), example_args, {}, expected_int64, 
run_ep_decomposition=True)
+
+    example_args_int32 = (torch.tensor([[1, 4, 9]], dtype=torch.int32),)
+
+    @tvm.script.ir_module
+    class expected_int32:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3), dtype="int32")
+        ) -> R.Tuple(R.Tensor((1, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3), dtype="float32") = R.astype(input_1, 
dtype="float32")
+                lv1: R.Tensor((1, 3), dtype="float32") = R.sqrt(lv)
+                gv: R.Tuple(R.Tensor((1, 3), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    verify_model(SqrtIntModel(), example_args_int32, {}, expected_int32, 
run_ep_decomposition=True)
+
+
 def test_extended_unary_ops():
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
 
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 69ebdcbf76..d377bb7574 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2749,6 +2749,27 @@ def test_basic_unary_ops(pytorch_op, relax_op):
     verify_model(Unary(), input_info, {}, expected_unary)
 
 
+def test_sqrt_integer_input_fx():
+    input_info = [([1, 4], "int64")]
+
+    class SqrtIntModel(Module):
+        def forward(self, input):
+            return torch.sqrt(input)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(input_1: R.Tensor((1, 4), dtype="int64")) -> R.Tensor((1, 4), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 4), dtype="float32") = R.astype(input_1, 
dtype="float32")
+                lv1: R.Tensor((1, 4), dtype="float32") = R.sqrt(lv)
+                gv: R.Tensor((1, 4), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    verify_model(SqrtIntModel(), input_info, {}, expected)
+
+
 operator_bool_unary = [
     (torch.isnan, R.isnan),
     (torch.isinf, R.isinf),

Reply via email to