This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0225d67d30 [Relax][PyTorch] Fix MultiheadAttention complie (#18459)
0225d67d30 is described below
commit 0225d67d303c8b4435bf0e0cae0b2a1a4b7ba021
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Nov 16 14:08:37 2025 +0800
[Relax][PyTorch] Fix MultiheadAttention complie (#18459)
## Related Issus
closes #18440
## Why
- PyTorch `masked_fill` / `full_like` accept inf or nan and TVM couldn’t
handle these values when the tensor dtype was not float, which caused
wrong behavior or errors.
## How
- If `fill_value` is inf or nan and the tensor dtype is not float →
convert the fill to float32.
- For masked_fill → Create a float values tensor with full_like.
- Cast input to float if needed.
- In TOPI → Reject creating full with inf/nan on non-float dtypes.
---
.../frontend/torch/base_fx_graph_translator.py | 43 +++++++++++++++++++---
python/tvm/topi/tensor.py | 9 +++++
2 files changed, 47 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index b03723cb91..83a045ef54 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2085,8 +2085,16 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
def _full_like(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
- fill_value = relax.const(node.args[1])
- return self.block_builder.emit(relax.op.full_like(x, fill_value))
+ value = node.args[1]
+ fill_value = relax.const(value)
+
+ x_dtype = x.struct_info.dtype
+ fill_dtype = None
+ if isinstance(value, (int, float)) and (math.isinf(value) or
math.isnan(value)):
+ if not ("float" in x_dtype or "bfloat16" in x_dtype):
+ fill_dtype = "float32"
+
+ return self.block_builder.emit(relax.op.full_like(x, fill_value,
dtype=fill_dtype))
def _index_select(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
@@ -2099,7 +2107,19 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
mask = self.env[node.args[1]]
value = node.args[2]
rx_value = relax.const(value)
- values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+
+ x_dtype = x.struct_info.dtype
+ fill_dtype = None
+ if isinstance(value, (int, float)) and (math.isinf(value) or
math.isnan(value)):
+ if not ("float" in x_dtype or "bfloat16" in x_dtype):
+ fill_dtype = "float32"
+
+ values = self.block_builder.emit(relax.op.full_like(x, rx_value,
dtype=fill_dtype))
+
+ # Cast x to match values dtype if necessary
+ if fill_dtype is not None:
+ x = self.block_builder.emit(relax.op.astype(x, fill_dtype))
+
output = self.block_builder.emit(relax.op.where(mask, values, x))
self.env[node.args[0]] = output
return output
@@ -2130,8 +2150,21 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
def _masked_fill(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
mask = self.env[node.args[1]]
- rx_value = relax.const(node.args[2])
- values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+ value = node.args[2]
+ rx_value = relax.const(value)
+
+ x_dtype = x.struct_info.dtype
+ fill_dtype = None
+ if isinstance(value, (int, float)) and (math.isinf(value) or
math.isnan(value)):
+ if not ("float" in x_dtype or "bfloat16" in x_dtype):
+ fill_dtype = "float32"
+
+ values = self.block_builder.emit(relax.op.full_like(x, rx_value,
dtype=fill_dtype))
+
+ # Cast x to match values dtype if necessary
+ if fill_dtype is not None:
+ x = self.block_builder.emit(relax.op.astype(x, fill_dtype))
+
return self.block_builder.emit(relax.op.where(mask, values, x))
def _new_ones(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/topi/tensor.py b/python/tvm/topi/tensor.py
index 449c599dea..9206e876a1 100644
--- a/python/tvm/topi/tensor.py
+++ b/python/tvm/topi/tensor.py
@@ -17,6 +17,8 @@
# pylint:
disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition
"""Elementwise operators"""
+import math as _math
+
from typing import Optional
from tvm import te
@@ -57,6 +59,13 @@ def full(shape, dtype, fill_value):
y : tvm.te.Tensor
The result.
"""
+
+ if isinstance(fill_value, (int, float)) and (
+ _math.isinf(fill_value) or _math.isnan(fill_value)
+ ):
+ if not ("float" in dtype or "bfloat16" in dtype):
+ raise ValueError("Infinite and NaN require a floating-point
dtype.")
+
return cpp.full(shape, dtype, fill_value)