This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bddc091bff [TIR][Schedule] Fix bug on bfloat16 conversion (#18556)
bddc091bff is described below
commit bddc091bffc31a3cc9dde16c169222774784e0dc
Author: Park Woorak <[email protected]>
AuthorDate: Tue Dec 9 04:25:14 2025 +0900
[TIR][Schedule] Fix bug on bfloat16 conversion (#18556)
## Description
This PR fixes a conversion bug that occurs when performing operations on
`bfloat16` tensors.
In conclusion, when applying the `BF16ComputeLegalize` compile pass and
visiting a `BufferStoreNode`, if the stored value's dtype is different
from the buffer's, `DTypeConversion()` should be used instead of a
simple `cast` to apply the appropriate conversion logic.
## Test
I added a test for this situation based on the existing tests.
With the fix, `B[i] = A[i]` turns into `B[i] = bf16tof32(A[i])`
properly, so the test passes.
I'm not really sure whether the structure or name of this added test is
appropriate.
So let me gladly modify it if there is any comment on this.
## Process
### Problem observed
This bug was identified when applying `nn.Linear()` to a `bfloat16`
tensor resulted in excessively large numbers.
While it appears to exist in other operations as well, it's particularly
noticeable when the inner dimension of `MatMul` is a multiple of
`8`(`16` for CUDA and ROCm).
#### Example of problematic code
```python
from ml_dtypes import bfloat16
import numpy as np
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op
from tvm.target import Target
n = 10
INNER_DIM = 8 * n # if INNER_DIM is a multiple of 8
class TestModule(nn.Module):
def __init__(self):
self.weight = nn.Parameter((32, INNER_DIM), dtype=dtype)
def run(self, x: Tensor):
t = op.matmul(self.weight, x, out_dtype=dtype)
return t
def get_default_spec(self):
mod_spec = {
"run": {
"x": nn.spec.Tensor([INNER_DIM, 100], dtype),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
}
return nn.spec.ModuleSpec.from_raw(mod_spec, self)
def compile_module(...):
...
def main():
target = "metal" # or "cuda", "vulkan", ...
model = TestModule()
ex, _ = compile_module(model, target)
device = tvm.device(target, 0)
vm = create_vm(ex, device=device)
frun = vm["run"]
params = []
param = tvm.runtime.empty(
(32, INNER_DIM),
dtype="bfloat16",
device=device,
)
param.copyfrom(np.ones((32, INNER_DIM), dtype=bfloat16))
params.append(param)
inputs = np.ones((INNER_DIM, 100), dtype=bfloat16)
arr = frun(inputs, params)
print(f"{arr=}") # arr has weird values!
```
In cases where the inner dimension is not a multiple of `8`(or `16`),
the issue was avoided by applying `T.if_then_else()` through
`PadEinsum`. `PadEinsum` itself wasn't a troublemaker, and rather helped
identify the issue.
### Problem Identified
I could see the problems were avoided by wrapping an expression with
`T.if_then_else()` or `T.cast()` before applying `BF16ComputeLegalize`
compile pass.
#### Statement with problem
```python
weight_reindex_shared[v0, v1, v2] = weight[v1, v2]
```
#### Statements without problem
```python
# 1) wrapped with T.if_then_else()
weight_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < 511, weight[v1,
v2], T.bfloat16(0.0))
# 2) wrapped with T.Cast()
weight_reindex_pad_shared[v0, v1, v2] = T.Cast("float32", weight[v1, v2])
# ...
```
In the `BF16ComputeLegalize` compile pass, if a specific `Expr`(here,
`weight[...]`) is processed through `PromoteToTarget()`(eventually,
`DTypeConversion()`), the syntax changes to the syntax below(TO-BE),
which applies the conversion logic. While the problematic statement
simply applies `T.Cast()`(AS-IS).
#### AS-IS
```python
T.Cast("float32", weight[...])
```
#### TO-BE
```python
T.reinterpret("float32", T.shift_left(T.Cast("uint32",
T.reinterpret("uint16", weight[...])), T.uint32(16)))
```
### Fixing the problem
This situation is caused by L332 in the code below. Changing this part
to apply `DTypeConversion()` instead of `cast()` will resolve the issue.
(In the cases that the `Expr` is wrapped with `T.if_then_else()` or
something else, the `Expr` is processed properly in other visit
functions through L312 or L313. So the problems were avoided.)
#### L332
```diff
- value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value);
+ value = DTypeConversion(value,
new_buf->dtype.with_lanes(value.dtype().lanes()));
```
https://github.com/apache/tvm/blob/26b107fa12672c3b958da222fc87755a69d64c42/src/tir/transforms/unsupported_dtype_legalize.cc#L311-L338
---
src/tir/transforms/unsupported_dtype_legalize.cc | 2 +-
.../test_tir_transform_bf16_legalize.py | 63 ++++++++++++++++++++++
2 files changed, 64 insertions(+), 1 deletion(-)
diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc
b/src/tir/transforms/unsupported_dtype_legalize.cc
index d35caa4db9..74a69dfbc3 100644
--- a/src/tir/transforms/unsupported_dtype_legalize.cc
+++ b/src/tir/transforms/unsupported_dtype_legalize.cc
@@ -329,7 +329,7 @@ class ComputeLegalizer : public StmtExprMutator {
// this happens when buffer get rewritten to f32
// but values remain as fp8/bf16
ICHECK(MatchDType(value->dtype));
- value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value);
+ value = DTypeConversion(value,
new_buf->dtype.with_lanes(value.dtype().lanes()));
}
ICHECK(!op->predicate.defined()) << "Predicated buffer store is not
currently supported in "
"data type legalizer pass.";
diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py
b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py
index fa1aa558b6..37e3d34f8c 100644
--- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py
+++ b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py
@@ -44,6 +44,69 @@ def f32tobf16(v):
return T.reinterpret("bfloat16", f32tou16(v))
+def test_bf16_simple_store_will_legalize():
+ def get_before():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ Cptr: T.handle("bfloat16"),
+ ):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+ B = T.decl_buffer((100,), "bfloat16")
+ C = T.decl_buffer((100,), "bfloat16", data=Cptr)
+ for i in T.grid(100):
+ B[i] = A[i]
+ C[i] = T.exp(B[i])
+
+ return Before
+
+ def after_compute_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ Cptr: T.handle("bfloat16"),
+ ):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+ B = T.decl_buffer((100,), "float32")
+ C = T.decl_buffer((100,), "bfloat16", data=Cptr)
+ for i in T.grid(100):
+ B[i] = bf16tof32(A[i])
+ C[i] = f32tobf16(T.exp(B[i]))
+
+ return After
+
+ def after_storage_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func
+ def main(
+ Aptr: T.handle("uint16", storage_scope="shared"),
+ Cptr: T.handle("uint16"),
+ ):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "uint16", data=Aptr)
+ B = T.decl_buffer((100,), "float32")
+ C = T.decl_buffer((100,), "uint16", data=Cptr)
+ for i in T.grid(100):
+ B[i] = u16tof32(A[i])
+ C[i] = f32tou16(T.exp(B[i]))
+
+ return After
+
+ target = Target("nvidia/geforce-rtx-2080-ti")
+ before = BindTarget(target)(get_before())
+ after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
+ after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
+ tvm.ir.assert_structural_equal(after_compute,
BindTarget(target)(after_compute_legalize()))
+ tvm.ir.assert_structural_equal(after_storage,
BindTarget(target)(after_storage_legalize()))
+
+
def test_bf16_storage_compute_scope_will_legalize():
def get_before():
@tvm.script.ir_module