This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ee38eae80 [TIR][CUDA] Preserve float precision in codegen with 
hexfloat output (#18320)
5ee38eae80 is described below

commit 5ee38eae809dc27eae651176fbc245c72b3d3361
Author: Lei Wang <[email protected]>
AuthorDate: Fri Sep 19 21:49:01 2025 +0800

    [TIR][CUDA] Preserve float precision in codegen with hexfloat output 
(#18320)
    
    Previously, `float` constants in codegen were always emitted in 
**scientific decimal format**, e.g.:
    
    ```cpp
    bfloat16_t(3.487723e-05f);
    ```
    
    This could introduce slight **rounding differences** compared to the actual 
binary representation, since the constant is printed and then re-parsed in 
decimal. we now emit the value in **hexadecimal floating-point format** 
(`std::hexfloat`) to preserve the exact binary value, and additionally include 
the decimal form as a comment for readability:
    
    ```cpp
    bfloat16_t(0x1.2492492492492p-15f /*3.487723e-05*/)
    ```
---
 src/target/source/codegen_cuda.cc                     | 11 ++++++++---
 tests/python/codegen/test_target_codegen_cuda.py      | 19 +++++++++++++++++++
 .../test_tir_transform_inject_ptx_async_copy.py       |  4 ++--
 3 files changed, 29 insertions(+), 5 deletions(-)

diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 4454dd3197..defc94efa2 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -1615,13 +1615,17 @@ inline void PrintConst(const FloatImmNode* op, 
std::ostream& os, CodeGenCUDA* p)
   // Type code is kBFloat
   if (op->dtype.is_bfloat16()) {
     os << "__float2bfloat16_rn";
-    os << '(' << std::scientific << op->value << 'f' << ')';
+    os << '(' << std::hexfloat << op->value << 'f';
+    os << "/*" << std::scientific << op->value << "*/";
+    os << ')';
     return;
   }
   // Type code is kFloat8_e5m2 or kE4M4Float
   if (op->dtype.is_float8() || op->dtype.is_float4()) {
     p->PrintType(op->dtype, os);
-    os << '(' << std::scientific << op->value << 'f' << ')';
+    os << '(' << std::hexfloat << op->value << 'f';
+    os << "/*" << std::scientific << op->value << "*/";
+    os << ')';
     return;
   }
   // Type code is kFloat
@@ -1656,7 +1660,8 @@ inline void PrintConst(const FloatImmNode* op, 
std::ostream& os, CodeGenCUDA* p)
         temp << "CUDART_NAN_F";
         p->need_math_constants_h_ = true;
       } else {
-        temp << std::scientific << op->value << 'f';
+        temp << std::hexfloat << op->value << 'f';
+        temp << "/*" << std::scientific << op->value << "*/";
       }
       p->MarkConst(temp.str());
       os << temp.str();
diff --git a/tests/python/codegen/test_target_codegen_cuda.py 
b/tests/python/codegen/test_target_codegen_cuda.py
index db49f56045..0841d0f545 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -801,6 +801,25 @@ def test_cuda_device_func_call():
     assert 'extern "C" __device__ float add(float a, float b) {\n  return (a + 
b);\n}' in cuda_code
 
 
[email protected]_cuda
+def test_cuda_float_const_hex_format():
+    """Test that float constants are emitted in hexadecimal format for 
precision"""
+
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(
+            A: T.Buffer((1024, 1024), "float32"),
+        ):
+            for bx in T.thread_binding(1024, "blockIdx.x"):
+                for tx in T.thread_binding(1024, "threadIdx.x"):
+                    A[bx, tx] = T.float32(1 / 27)
+
+    lib = tvm.compile(Module, target="cuda")
+    cuda_code = lib.mod.imports[0].inspect_source()
+    assert "0x1.2f684bda12f68p-5f" in cuda_code
+
+
 @tvm.testing.requires_cuda
 def test_device_host_call_same_func():
     @I.ir_module
diff --git 
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
index 67598b0ba0..aa4f5138a1 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
@@ -264,8 +264,8 @@ extern "C" __global__ void __launch_bounds__(16) 
main_kernel(float* __restrict__
 extern "C" __global__ void __launch_bounds__(16) main_kernel(float* 
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
   __shared__ float A_shared[64];
   __shared__ float B_shared[64];
-  A_shared[((int)threadIdx.x)] = 0.000000e+00f;
-  B_shared[((int)threadIdx.x)] = 0.000000e+00f;
+  A_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
+  B_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
 __asm__ __volatile__("cp.async.commit_group;");
 
 

Reply via email to