This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0225d67d30 [Relax][PyTorch] Fix MultiheadAttention complie (#18459)
0225d67d30 is described below

commit 0225d67d303c8b4435bf0e0cae0b2a1a4b7ba021
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Nov 16 14:08:37 2025 +0800

    [Relax][PyTorch] Fix MultiheadAttention complie (#18459)
    
    ## Related Issus
    
    closes #18440
    
    ## Why
    
    - PyTorch `masked_fill` / `full_like` accept inf or nan and TVM couldn’t
    handle these values when the tensor dtype was not float, which caused
    wrong behavior or errors.
    
    ## How
    
    - If `fill_value` is inf or nan and the tensor dtype is not float →
    convert the fill to float32.
    - For masked_fill → Create a float values tensor with full_like.
    - Cast input to float if needed.
    - In TOPI → Reject creating full with inf/nan on non-float dtypes.
---
 .../frontend/torch/base_fx_graph_translator.py     | 43 +++++++++++++++++++---
 python/tvm/topi/tensor.py                          |  9 +++++
 2 files changed, 47 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index b03723cb91..83a045ef54 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2085,8 +2085,16 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
     def _full_like(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
-        fill_value = relax.const(node.args[1])
-        return self.block_builder.emit(relax.op.full_like(x, fill_value))
+        value = node.args[1]
+        fill_value = relax.const(value)
+
+        x_dtype = x.struct_info.dtype
+        fill_dtype = None
+        if isinstance(value, (int, float)) and (math.isinf(value) or 
math.isnan(value)):
+            if not ("float" in x_dtype or "bfloat16" in x_dtype):
+                fill_dtype = "float32"
+
+        return self.block_builder.emit(relax.op.full_like(x, fill_value, 
dtype=fill_dtype))
 
     def _index_select(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
@@ -2099,7 +2107,19 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         mask = self.env[node.args[1]]
         value = node.args[2]
         rx_value = relax.const(value)
-        values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+
+        x_dtype = x.struct_info.dtype
+        fill_dtype = None
+        if isinstance(value, (int, float)) and (math.isinf(value) or 
math.isnan(value)):
+            if not ("float" in x_dtype or "bfloat16" in x_dtype):
+                fill_dtype = "float32"
+
+        values = self.block_builder.emit(relax.op.full_like(x, rx_value, 
dtype=fill_dtype))
+
+        # Cast x to match values dtype if necessary
+        if fill_dtype is not None:
+            x = self.block_builder.emit(relax.op.astype(x, fill_dtype))
+
         output = self.block_builder.emit(relax.op.where(mask, values, x))
         self.env[node.args[0]] = output
         return output
@@ -2130,8 +2150,21 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
     def _masked_fill(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         mask = self.env[node.args[1]]
-        rx_value = relax.const(node.args[2])
-        values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+        value = node.args[2]
+        rx_value = relax.const(value)
+
+        x_dtype = x.struct_info.dtype
+        fill_dtype = None
+        if isinstance(value, (int, float)) and (math.isinf(value) or 
math.isnan(value)):
+            if not ("float" in x_dtype or "bfloat16" in x_dtype):
+                fill_dtype = "float32"
+
+        values = self.block_builder.emit(relax.op.full_like(x, rx_value, 
dtype=fill_dtype))
+
+        # Cast x to match values dtype if necessary
+        if fill_dtype is not None:
+            x = self.block_builder.emit(relax.op.astype(x, fill_dtype))
+
         return self.block_builder.emit(relax.op.where(mask, values, x))
 
     def _new_ones(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/topi/tensor.py b/python/tvm/topi/tensor.py
index 449c599dea..9206e876a1 100644
--- a/python/tvm/topi/tensor.py
+++ b/python/tvm/topi/tensor.py
@@ -17,6 +17,8 @@
 # pylint: 
disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition
 """Elementwise operators"""
 
+import math as _math
+
 from typing import Optional
 
 from tvm import te
@@ -57,6 +59,13 @@ def full(shape, dtype, fill_value):
     y : tvm.te.Tensor
         The result.
     """
+
+    if isinstance(fill_value, (int, float)) and (
+        _math.isinf(fill_value) or _math.isnan(fill_value)
+    ):
+        if not ("float" in dtype or "bfloat16" in dtype):
+            raise ValueError("Infinite and NaN require a floating-point 
dtype.")
+
     return cpp.full(shape, dtype, fill_value)
 
 

Reply via email to