This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5ee38eae80 [TIR][CUDA] Preserve float precision in codegen with
hexfloat output (#18320)
5ee38eae80 is described below
commit 5ee38eae809dc27eae651176fbc245c72b3d3361
Author: Lei Wang <[email protected]>
AuthorDate: Fri Sep 19 21:49:01 2025 +0800
[TIR][CUDA] Preserve float precision in codegen with hexfloat output
(#18320)
Previously, `float` constants in codegen were always emitted in
**scientific decimal format**, e.g.:
```cpp
bfloat16_t(3.487723e-05f);
```
This could introduce slight **rounding differences** compared to the actual
binary representation, since the constant is printed and then re-parsed in
decimal. we now emit the value in **hexadecimal floating-point format**
(`std::hexfloat`) to preserve the exact binary value, and additionally include
the decimal form as a comment for readability:
```cpp
bfloat16_t(0x1.2492492492492p-15f /*3.487723e-05*/)
```
---
src/target/source/codegen_cuda.cc | 11 ++++++++---
tests/python/codegen/test_target_codegen_cuda.py | 19 +++++++++++++++++++
.../test_tir_transform_inject_ptx_async_copy.py | 4 ++--
3 files changed, 29 insertions(+), 5 deletions(-)
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 4454dd3197..defc94efa2 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -1615,13 +1615,17 @@ inline void PrintConst(const FloatImmNode* op,
std::ostream& os, CodeGenCUDA* p)
// Type code is kBFloat
if (op->dtype.is_bfloat16()) {
os << "__float2bfloat16_rn";
- os << '(' << std::scientific << op->value << 'f' << ')';
+ os << '(' << std::hexfloat << op->value << 'f';
+ os << "/*" << std::scientific << op->value << "*/";
+ os << ')';
return;
}
// Type code is kFloat8_e5m2 or kE4M4Float
if (op->dtype.is_float8() || op->dtype.is_float4()) {
p->PrintType(op->dtype, os);
- os << '(' << std::scientific << op->value << 'f' << ')';
+ os << '(' << std::hexfloat << op->value << 'f';
+ os << "/*" << std::scientific << op->value << "*/";
+ os << ')';
return;
}
// Type code is kFloat
@@ -1656,7 +1660,8 @@ inline void PrintConst(const FloatImmNode* op,
std::ostream& os, CodeGenCUDA* p)
temp << "CUDART_NAN_F";
p->need_math_constants_h_ = true;
} else {
- temp << std::scientific << op->value << 'f';
+ temp << std::hexfloat << op->value << 'f';
+ temp << "/*" << std::scientific << op->value << "*/";
}
p->MarkConst(temp.str());
os << temp.str();
diff --git a/tests/python/codegen/test_target_codegen_cuda.py
b/tests/python/codegen/test_target_codegen_cuda.py
index db49f56045..0841d0f545 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -801,6 +801,25 @@ def test_cuda_device_func_call():
assert 'extern "C" __device__ float add(float a, float b) {\n return (a +
b);\n}' in cuda_code
[email protected]_cuda
+def test_cuda_float_const_hex_format():
+ """Test that float constants are emitted in hexadecimal format for
precision"""
+
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def main(
+ A: T.Buffer((1024, 1024), "float32"),
+ ):
+ for bx in T.thread_binding(1024, "blockIdx.x"):
+ for tx in T.thread_binding(1024, "threadIdx.x"):
+ A[bx, tx] = T.float32(1 / 27)
+
+ lib = tvm.compile(Module, target="cuda")
+ cuda_code = lib.mod.imports[0].inspect_source()
+ assert "0x1.2f684bda12f68p-5f" in cuda_code
+
+
@tvm.testing.requires_cuda
def test_device_host_call_same_func():
@I.ir_module
diff --git
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
index 67598b0ba0..aa4f5138a1 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
@@ -264,8 +264,8 @@ extern "C" __global__ void __launch_bounds__(16)
main_kernel(float* __restrict__
extern "C" __global__ void __launch_bounds__(16) main_kernel(float*
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
__shared__ float A_shared[64];
__shared__ float B_shared[64];
- A_shared[((int)threadIdx.x)] = 0.000000e+00f;
- B_shared[((int)threadIdx.x)] = 0.000000e+00f;
+ A_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
+ B_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
__asm__ __volatile__("cp.async.commit_group;");