This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c19e5f4be0 [CUDA] FP4 cast and reinterpret support (#17708)
c19e5f4be0 is described below

commit c19e5f4be061c21cf04ba432abc2375bf783a441
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Mar 6 08:44:03 2025 -0500

    [CUDA] FP4 cast and reinterpret support (#17708)
    
    * [CUDA] FP4 cast and reinterpret support
    
    Following up on a previous PR, this PR introduces the cast and
    reinterpret support between `__nv_fp4_e2m1` and other dtypes.
    This PR also makes sure that the cast and reinterpret support
    vectorize.
    
    * change to float4_e2m1fn
---
 include/tvm/runtime/data_type.h                    |  36 +++---
 include/tvm/script/ir_builder/tir/ir.h             |   2 +-
 python/tvm/_ffi/runtime_ctypes.py                  |  23 ++--
 python/tvm/runtime/ndarray.py                      |  14 ++-
 python/tvm/script/ir_builder/tir/ir.py             |  31 ++---
 src/runtime/ndarray.cc                             |   2 +-
 src/script/ir_builder/tir/ir.cc                    |   4 +-
 src/target/llvm/codegen_llvm.cc                    |   2 +-
 src/target/source/codegen_c.cc                     |   8 +-
 src/target/source/codegen_cuda.cc                  |  93 ++++++++++++--
 src/target/source/literal/cuda_half_t.h            |  39 +++++-
 src/tir/op/op.cc                                   |   6 +-
 .../python/codegen/test_target_codegen_cuda_fp4.py | 133 +++++++++++++++++++--
 13 files changed, 326 insertions(+), 67 deletions(-)

diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 3d35d5241e..76e5e3833f 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -58,7 +58,7 @@ class DataType {
     kBFloat = kDLBfloat,
     kE4M3Float = 6U,
     kE5M2Float = 7U,
-    kE2M1Float = 8U,
+    kFloat4E2M1Fn = 8U,
     kCustomBegin = 129
   };
   /*! \brief default constructor */
@@ -88,7 +88,7 @@ class DataType {
     if (code == kE4M3Float || code == kE5M2Float) {
       ICHECK_EQ(bits, 8);
     }
-    if (code == kE2M1Float) {
+    if (code == kFloat4E2M1Fn) {
       ICHECK_EQ(bits, 4);
     }
   }
@@ -131,12 +131,10 @@ class DataType {
            bits() == 8;
   }
   /*! \return whether type is a float4 type. */
-  bool is_float4() const { return code() == DataType::kE2M1Float && bits() == 
4; }
+  bool is_float4() const { return code() == DataType::kFloat4E2M1Fn && bits() 
== 4; }
   bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && 
bits() == 8); }
-
   bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && 
bits() == 8); }
-
-  bool is_e2m1_float4() const { return (code() == DataType::kE2M1Float && 
bits() == 4); }
+  bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4E2M1Fn && 
bits() == 4); }
   /*! \return whether type is a float16 type. */
   bool is_float16() const { return is_float() && bits() == 16; }
   /*! \return whether type is a bfloat16 type. */
@@ -262,11 +260,11 @@ class DataType {
    */
   static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, 
lanes); }
   /*!
-   * \brief Construct NV float4 e2m1 datatype.
+   * \brief Construct NV float4_e2m1fn datatype.
    * \param lanes The number of lanes
    * \return The constructed data type.
    */
-  static DataType NVFloat4E2M1(int lanes = 1) { return DataType(kE2M1Float, 4, 
lanes); }
+  static DataType NVFloat4E2M1FN(int lanes = 1) { return 
DataType(kFloat4E2M1Fn, 4, lanes); }
   /*!
    * \brief Construct a bool type.
    * \param lanes The number of lanes.
@@ -313,7 +311,7 @@ inline int GetVectorBytes(DataType dtype) {
   int data_bits = dtype.bits() * dtype.lanes();
   // allow bool to exist
   if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == 
DataType::UInt(4) ||
-      dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1()) {
+      dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1FN()) {
     return 1;
   }
   ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
@@ -399,8 +397,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode 
type_code) {
       return "e4m3_float";
     case DataType::kE5M2Float:
       return "e5m2_float";
-    case DataType::kE2M1Float:
-      return "e2m1_float";
+    case DataType::kFloat4E2M1Fn:
+      return "float4_e2m1fn";
     default:
       LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
   }
@@ -458,6 +456,18 @@ inline DLDataType String2DLDataType(std::string s) {
   } else if (s.substr(0, 4) == "uint") {
     t.code = kDLUInt;
     scan = s.c_str() + 4;
+  } else if (s.substr(0, 13) == "float4_e2m1fn") {
+    // Avoid being treated as "float"
+    t.code = DataType::kFloat4E2M1Fn;
+    t.bits = 4;
+    scan = s.c_str() + 13;
+    char* endpt = nullptr;
+    if (*scan == 'x') {
+      t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
+      scan = endpt;
+    }
+    ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
+    return t;
   } else if (s.substr(0, 5) == "float") {
     t.code = kDLFloat;
     scan = s.c_str() + 5;
@@ -482,10 +492,6 @@ inline DLDataType String2DLDataType(std::string s) {
     t.code = DataType::kE5M2Float;
     t.bits = 8;
     scan = s.c_str() + 10;
-  } else if (s.substr(0, 10) == "e2m1_float") {
-    t.code = DataType::kE2M1Float;
-    t.bits = 4;
-    scan = s.c_str() + 10;
   } else if (s.substr(0, 6) == "custom") {
     t.code = ParseCustomDatatype(s, &scan);
   } else {
diff --git a/include/tvm/script/ir_builder/tir/ir.h 
b/include/tvm/script/ir_builder/tir/ir.h
index 5dd1a5c733..e78e0d51fd 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -505,7 +505,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, 
DataType::Int);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, 
DataType::NVFloat8E4M3);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, 
DataType::NVFloat8E5M2);
 
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E2M1Float4, 
DataType::NVFloat4E2M1);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1fn, 
DataType::NVFloat4E2M1FN);
 
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
diff --git a/python/tvm/_ffi/runtime_ctypes.py 
b/python/tvm/_ffi/runtime_ctypes.py
index 263a4ff69f..3f4ceadd1d 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -68,7 +68,7 @@ class DataTypeCode(object):
     BFLOAT = 4
     E4M3Float = 6
     E5M2Float = 7
-    E2M1Float = 8
+    FLOAT4E2M1FN = 8
 
 
 class DataType(ctypes.Structure):
@@ -83,7 +83,7 @@ class DataType(ctypes.Structure):
         DataTypeCode.BFLOAT: "bfloat",
         DataTypeCode.E4M3Float: "e4m3_float",
         DataTypeCode.E5M2Float: "e5m2_float",
-        DataTypeCode.E2M1Float: "e2m1_float",
+        DataTypeCode.FLOAT4E2M1FN: "float4_e2m1fn",
     }
     NUMPY2STR = {
         np.dtype(np.bool_): "bool",
@@ -114,7 +114,7 @@ class DataType(ctypes.Structure):
         "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
         "e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, 
"lanes": 1},
         "e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, 
"lanes": 1},
-        "e2m1_float4": {"type_code": DataTypeCode.E2M1Float, "bits": 4, 
"lanes": 1},
+        "float4_e2m1fn": {"type_code": DataTypeCode.FLOAT4E2M1FN, "bits": 4, 
"lanes": 1},
         "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
         "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
         "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
@@ -155,6 +155,11 @@ class DataType(ctypes.Structure):
         elif head.startswith("uint"):
             self.type_code = DataTypeCode.UINT
             head = head[4:]
+        elif head.startswith("float4_e2m1fn"):
+            # Avoid being treated as "float"
+            self.type_code = DataTypeCode.FLOAT4E2M1FN
+            bits = 4
+            head = ""
         elif head.startswith("float"):
             self.type_code = DataTypeCode.FLOAT
             head = head[5:]
@@ -171,9 +176,6 @@ class DataType(ctypes.Structure):
         elif head.startswith("e5m2_float"):
             self.type_code = DataTypeCode.E5M2Float
             head = head[10:]
-        elif head.startswith("e2m1_float"):
-            self.type_code = DataTypeCode.E2M1Float
-            head = head[10:]
         elif head.startswith("custom"):
             # pylint: disable=import-outside-toplevel
             import tvm.runtime._ffi_api
@@ -201,7 +203,12 @@ class DataType(ctypes.Structure):
             import tvm.runtime._ffi_api
 
             type_name = "custom[%s]" % 
tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
-        x = "%s%d" % (type_name, self.bits)
+        if self.type_code in [
+            DataTypeCode.FLOAT4E2M1FN,
+        ]:
+            x = type_name
+        else:
+            x = "%s%d" % (type_name, self.bits)
         lanes_as_int = ctypes.c_int16(self.lanes).value
         if lanes_as_int > 1:
             x += "x%d" % self.lanes
@@ -238,7 +245,7 @@ if ml_dtypes is not None:
     DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
     DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
     DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
-    DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "e2m1_float4"
+    DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
 
 RPC_SESS_MASK = 128
 
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 3514ee6168..47fcccf52b 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -197,7 +197,9 @@ class NDArray(NDArrayBase):
             source_array = np.ascontiguousarray(
                 source_array, dtype="uint16" if dtype == "bfloat16" else dtype
             )
-        if dtype.startswith("e2m1_float4"):
+        if self.dtype.startswith("float4_e2m1fn") and self.dtype != 
"float4_e2m1fn":
+            # float4_e2m1fn in numpy is not packed.
+            # So we need to pack the input data when converting to vectorized 
float4_e2m1fn type.
             data_bits = source_array.view(dtype="uint8")
             if data_bits.size % 2:
                 data_bits = np.pad(data_bits, (0, 1), mode="constant", 
constant_values=0)
@@ -261,22 +263,24 @@ class NDArray(NDArrayBase):
                 raise RuntimeError(
                     "ml_dtypes is not installed, cannot convert e5m2_float8 
array to numpy."
                 )
-        if dtype == "e2m1_float4":
+        if dtype == "float4_e2m1fn":
             if ml_dtypes is not None:
                 dtype = ml_dtypes.float4_e2m1fn
             else:
                 raise RuntimeError(
-                    "ml_dtypes is not installed, cannot convert e2m1_float4 
array to numpy."
+                    "ml_dtypes is not installed, cannot convert float4_e2m1fn 
array to numpy."
                 )
         np_arr = np.empty(shape, dtype=dtype)
         assert np_arr.flags["C_CONTIGUOUS"]
         data = np_arr.ctypes.data_as(ctypes.c_void_p)
-        if old_dtype.startswith("e2m1_float4"):
+        if old_dtype.startswith("float4_e2m1fn") and old_dtype != 
"float4_e2m1fn":
             nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize // 2)
         else:
             nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
         check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
-        if old_dtype == "int4" or old_dtype.startswith("e2m1_float4"):
+        if old_dtype == "int4" or (
+            old_dtype.startswith("float4_e2m1fn") and old_dtype != 
"float4_e2m1fn"
+        ):
             length = np_arr.size
             np_arr = np_arr.view("int8")
             np_arr_ret = np.empty((length,), dtype="int8")
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index 6cc19305e4..c35df7a093 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -18,8 +18,8 @@
 
 import functools
 import inspect
-from numbers import Integral
 import sys
+from numbers import Integral
 from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 # isort: off
@@ -29,8 +29,7 @@ from typing_extensions import Literal
 
 import numpy as np  # type: ignore
 
-from tvm import tir
-from tvm import ir
+from tvm import ir, tir
 from tvm.ir import Type
 from tvm.ir.base import deprecated
 from tvm.runtime import String, convert, ndarray
@@ -1457,12 +1456,13 @@ e5m2_float8x16 = func_gen(("E5M2Float8x16"))
 e5m2_float8x32 = func_gen(("E5M2Float8x32"))
 e5m2_float8x64 = func_gen(("E5M2Float8x64"))
 
-e2m1_float4 = func_gen(("E2M1Float4"))
-e2m1_float4x4 = func_gen(("E2M1Float4x4"))
-e2m1_float4x8 = func_gen(("E2M1Float4x8"))
-e2m1_float4x16 = func_gen(("E2M1Float4x16"))
-e2m1_float4x32 = func_gen(("E2M1Float4x32"))
-e2m1_float4x64 = func_gen(("E2M1Float4x64"))
+float4_e2m1fn = func_gen(("Float4E2M1fn"))
+float4_e2m1fnx2 = func_gen(("Float4E2M1fnx2"))
+float4_e2m1fnx4 = func_gen(("Float4E2M1fnx4"))
+float4_e2m1fnx8 = func_gen(("Float4E2M1fnx8"))
+float4_e2m1fnx16 = func_gen(("Float4E2M1fnx16"))
+float4_e2m1fnx32 = func_gen(("Float4E2M1fnx32"))
+float4_e2m1fnx64 = func_gen(("Float4E2M1fnx64"))
 
 
 # pylint: enable=invalid-name
@@ -2013,37 +2013,38 @@ __all__ = [
     "uint64x64",
     "e4m3_float8",
     "e5m2_float8",
-    "e2m1_float4",
+    "float4_e2m1fn",
     "float16",
     "float32",
     "float64",
+    "float4_e2m1fnx2",
     "e4m3_float8x4",
     "e5m2_float8x4",
-    "e2m1_float4x4",
+    "float4_e2m1fnx4",
     "float16x4",
     "float32x4",
     "float64x4",
     "e4m3_float8x8",
     "e5m2_float8x8",
-    "e2m1_float4x8",
+    "float4_e2m1fnx8",
     "float16x8",
     "float32x8",
     "float64x8",
     "e4m3_float8x16",
     "e5m2_float8x16",
-    "e2m1_float4x16",
+    "float4_e2m1fnx16",
     "float16x16",
     "float32x16",
     "float64x16",
     "e4m3_float8x32",
     "e5m2_float8x32",
-    "e2m1_float4x32",
+    "float4_e2m1fnx32",
     "float16x32",
     "float32x32",
     "float64x32",
     "e4m3_float8x64",
     "e5m2_float8x64",
-    "e2m1_float4x64",
+    "float4_e2m1fnx64",
     "float16x64",
     "float32x64",
     "float64x64",
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index 1812d13b9c..d876065325 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -54,7 +54,7 @@ inline void VerifyDataType(DLDataType dtype) {
       return;
     else if (dtype.bits == 4 && dtype.code == kDLInt)
       return;
-    else if (dtype.bits == 4 && dtype.code == DataType::kE2M1Float)
+    else if (dtype.bits == 4 && dtype.code == DataType::kFloat4E2M1Fn)
       return;
     else
       ICHECK_EQ(dtype.bits % 8, 0);
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index e452e102bf..a73c9cb5b4 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -757,8 +757,8 @@ 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float
 TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8);
 TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8);
 
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.E2M1Float4").set_body_typed(E2M1Float4);
-TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E2M1Float4", E2M1Float4);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1fn").set_body_typed(Float4E2M1fn);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1fn", Float4E2M1fn);
 
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index f5f17c70ef..9c2ce0bbb2 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -581,7 +581,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& 
dtype) const {
     }
   } else if (dtype.code() == DataType::kE4M3Float || dtype.code() == 
DataType::kE5M2Float) {
     etype = llvm::Type::getInt8Ty(*ctx);
-  } else if (dtype.code() == DataType::kE2M1Float) {
+  } else if (dtype.code() == DataType::kFloat4E2M1Fn) {
     etype = llvm::Type::getIntNTy(*ctx, 4);
   }
   if (!dtype.is_scalar()) {
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 0e0971b8f8..575f52e225 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -789,6 +789,11 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, 
std::ostream& os) {  // NOLI
       }
     }
 
+    if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
+      // A float4_e2m1fn element has 4 bits, which is an incomplete byte.
+      // So we cannot vector load it.
+      can_vector_load = false;
+    }
     if (can_vector_load) {
       std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
       HandleVolatileLoads(ref, op, os);
@@ -839,7 +844,8 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
   } else {
     arith::PVar<PrimExpr> base;
 
-    if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) {
+    if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr) &&
+        !value_dtype.is_float4_e2m1fn()) {
       std::string value = this->PrintExpr(op->value);
       this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value);
     } else {
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 872a024366..20b29750dc 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -82,11 +82,11 @@ std::string GetFP4Type(DataType type) {
   } else if (lanes == 4) {
     vec = "x4";
   } else {
-    LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for 
FP8";
+    LOG(FATAL) << "Only support scalar and vector types of width (2, 4) for 
FP8";
   }
   stream << "__nv_fp4";
   std::string suffix;
-  if (type.code() == DataType::kE2M1Float) {
+  if (type.code() == DataType::kFloat4E2M1Fn) {
     suffix = "_e2m1";
   } else {
     LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
@@ -196,7 +196,7 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << "#include <cuda_fp4.h>\n";
     decl_stream << "#endif\n\n";
   }
-  declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);
+  declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_, 
enable_fp4_);
 
   if (enable_warp_shuffle_) {
     decl_stream << _cuda_warp_intrinsic_util;
@@ -597,6 +597,9 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, 
DataType t, int i,
     }
     ICHECK(!type_name.empty());
     os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << 
")))->" << access[i % 2];
+  } else if (t.is_float4_e2m1fn()) {
+    os << "([](__nv_fp4_storage_t v) { __nv_fp4_e2m1 t; t.__x = v; return t; 
})((" << vec
+       << ".__x >> " << i * 4 << ") & 0xF)";
   } else {
     os << vec << "." << access[i];
   }
@@ -732,8 +735,8 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, 
std::ostream& os) {
   if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
 
   if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == 
DataType::kE5M2Float ||
-      target_ty.code() == DataType::kE2M1Float || from_ty.code() == 
DataType::kE4M3Float ||
-      from_ty.code() == DataType::kE5M2Float || from_ty.code() == 
DataType::kE2M1Float) {
+      target_ty.code() == DataType::kFloat4E2M1Fn || from_ty.code() == 
DataType::kE4M3Float ||
+      from_ty.code() == DataType::kE5M2Float || from_ty.code() == 
DataType::kFloat4E2M1Fn) {
     std::ostringstream val;
     val << "(";
     PrintType(target_ty, val);
@@ -1036,8 +1039,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
     var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
 
     os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
-    os << dst << "[" + this->PrintExpr(dst_ind) + "]"
-       << " = " << src << "[" << src_offset << " + local_id];\n";
+    os << dst << "[" + this->PrintExpr(dst_ind) + "] = " << src << "[" << 
src_offset
+       << " + local_id];\n";
     os << "}\n";
 
   } else if (op->op.same_as(builtin::mma_fill())) {
@@ -1155,6 +1158,82 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
     stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << 
")), \"r\"((int)"
            << guard << ")\n";
     stream << ");\n";
+  } else if (op->op.same_as(builtin::reinterpret())) {
+    DataType tgt_dtype = op->dtype;
+    DataType src_dtype = op->args[0]->dtype;
+    PrimExpr value = op->args[0];
+
+    // Handle float4_e2m1fn reinterpret
+    if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) {
+      return CodeGenC::VisitExpr_(op, os);
+    }
+    if (src_dtype == tgt_dtype ||
+        tgt_dtype.lanes() * tgt_dtype.bits() == src_dtype.lanes() * 
src_dtype.bits()) {
+      return CodeGenC::VisitExpr_(op, os);
+    }
+    CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes())
+        << "E2M1 float4 reinterpret expects source and target to have the same 
number of lanes. "
+        << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
+    CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes())
+        << "E2M1 float4 reinterpret expects source and target to have the same 
number of bytes. "
+        << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
+
+    int lanes = tgt_dtype.lanes();
+
+    int ssa_scope = BeginScope();
+    if (lanes == 1) {
+      // The case of lane=1 is same as the normal reinterpret,
+      // except that we allow the src and dst dtype to have different number 
of bits.
+      std::string rhs = SSAGetID(PrintExpr(value), src_dtype);
+      os << "(*(";
+      this->PrintType(tgt_dtype, os);
+      os << " *)(&(" << rhs << ")))";
+    } else if (lanes == 2) {
+      if (tgt_dtype.is_float4_e2m1fn()) {
+        // We view the source as an uint16, and then extract bits of two fp4 
numbers,
+        // and finally reinterpret the result as fp4x2.
+        value = tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), 
{value});
+        tir::Var temp_var("temp_var", DataType::UInt(16));
+        value = tir::Let(
+            temp_var, value,
+            tir::Cast(DataType::UInt(8), (temp_var & 
IntImm(DataType::UInt(16), 0xF)) |
+                                             ((temp_var >> 4) & 
IntImm(DataType::UInt(16), 0xF0))));
+      } else {
+        value = tir::Cast(DataType::UInt(16),
+                          tir::Call(DataType::UInt(8), 
tir::builtin::reinterpret(), {value}));
+        tir::Var temp_var("temp_var", DataType::UInt(16));
+        value = tir::Let(temp_var, value,
+                         (temp_var & IntImm(DataType::UInt(16), 0xF)) |
+                             ((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 
4));
+      }
+      os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(), 
{value}));
+    } else if (lanes == 4) {
+      if (tgt_dtype.is_float4_e2m1fn()) {
+        // We view the source as an uint32, and then extract bits of four fp4 
numbers,
+        // and finally reinterpret the result as fp4x4.
+        value = tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), 
{value});
+        tir::Var temp_var("temp_var", DataType::UInt(32));
+        value = tir::Let(temp_var, value,
+                         tir::Cast(DataType::UInt(16),
+                                   (temp_var & IntImm(DataType::UInt(32), 
0xF)) |
+                                       ((temp_var >> 4) & 
IntImm(DataType::UInt(32), 0xF0)) |
+                                       ((temp_var >> 8) & 
IntImm(DataType::UInt(32), 0xF00)) |
+                                       ((temp_var >> 12) & 
IntImm(DataType::UInt(32), 0xF000))));
+      } else {
+        value = tir::Cast(DataType::UInt(32),
+                          tir::Call(DataType::UInt(16), 
tir::builtin::reinterpret(), {value}));
+        tir::Var temp_var("temp_var", DataType::UInt(32));
+        value = tir::Let(temp_var, value,
+                         (temp_var & IntImm(DataType::UInt(32), 0xF)) |
+                             ((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 
4) |
+                             ((temp_var & IntImm(DataType::UInt(32), 0xF00)) 
<< 8) |
+                             ((temp_var & IntImm(DataType::UInt(32), 0xF000)) 
<< 12));
+      }
+      os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(), 
{value}));
+    } else {
+      LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: " 
<< lanes;
+    }
+    EndScope(ssa_scope);
   } else {
     CodeGenC::VisitExpr_(op, os);
   }
diff --git a/src/target/source/literal/cuda_half_t.h 
b/src/target/source/literal/cuda_half_t.h
index abdf22df26..86f2219fe8 100644
--- a/src/target/source/literal/cuda_half_t.h
+++ b/src/target/source/literal/cuda_half_t.h
@@ -385,8 +385,9 @@ static constexpr const char* _cuda_warp_intrinsic_util = R"(
 
 )";
 
-void declare_vector_type_extensions(std::ostringstream& stream, bool 
enable_fp16, bool enable_fp8) {
-  if (enable_fp16 || enable_fp8) {
+void declare_vector_type_extensions(std::ostringstream& stream, bool 
enable_fp16, bool enable_fp8,
+                                    bool enable_fp4) {
+  if (enable_fp16 || enable_fp8 || enable_fp4) {
     stream << R"(
 struct __align__(8) half4 {
   __half x, y, z, w;
@@ -455,6 +456,26 @@ struct __align__(8) half4 {
       result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
       return result;
   }
+  )";
+    }
+    if (enable_fp4) {
+      stream << R"(
+  __host__ __device__ explicit half4(const __nv_fp4x4_e2m1& fp4x4) {
+    __nv_fp4x2_storage_t lo_part, hi_part;
+    lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
+    hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
+    __half2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
+    __half2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1));
+    x = reinterpret_cast<__half*>(&lo_half2)[0];
+    y = reinterpret_cast<__half*>(&lo_half2)[1];
+    z = reinterpret_cast<__half*>(&hi_half2)[0];
+    w = reinterpret_cast<__half*>(&hi_half2)[1];
+  }
+  __host__ __device__ explicit operator __nv_fp4x4_e2m1() const {
+    __half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
+    __half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
+    return __nv_fp4x4_e2m1(lo_half2, hi_half2);
+  }
   )";
     }
     stream << R"(
@@ -462,6 +483,20 @@ struct __align__(8) half4 {
 __host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
     return half4(x, y, z, w);
 }
+)";
+  }
+  if (enable_fp4) {
+    stream << R"(
+__device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1 
y) {
+  __nv_fp4x2_e2m1 result;
+  result.__x = (x.__x) | (y.__x << 4);
+  return result;
+}
+__device__ __nv_fp4x4_e2m1 make___nv_fp4x4_e2m1(__nv_fp4_e2m1 a, __nv_fp4_e2m1 
b, __nv_fp4_e2m1 c, __nv_fp4_e2m1 d) {
+  __nv_fp4x4_e2m1 result;
+  result.__x = (static_cast<__nv_fp4x4_storage_t>(a.__x)) | 
(static_cast<__nv_fp4x4_storage_t>(b.__x) << 4) | 
(static_cast<__nv_fp4x4_storage_t>(c.__x) << 8) | 
(static_cast<__nv_fp4x4_storage_t>(d.__x) << 12);
+  return result;
+}
 )";
   }
 }
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 039acf7e92..3dab634f16 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -425,8 +425,10 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span 
span) {
 PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) {
   if (value.dtype() == t) return value;
   if (!t.is_scalable_vector() && !value.dtype().is_scalable_vector()) {
-    ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * 
t.lanes())
-        << "Bitcast requires size match " << t << " vs " << value.dtype();
+    ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * 
t.lanes() ||
+           ((value.dtype().is_float4_e2m1fn() || t.is_float4_e2m1fn()) &&
+            value.dtype().bytes() * value.dtype().lanes() == t.bytes() * 
t.lanes()))
+        << "Reinterpret requires size match " << t << " vs " << value.dtype();
   }
   return tir::Call(t, tir::builtin::reinterpret(), {value}, span);
 }
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py 
b/tests/python/codegen/test_target_codegen_cuda_fp4.py
index f137e83cc9..46825826a9 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fp4.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py
@@ -15,12 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from itertools import product
 
-import tvm
-from tvm.script import tir as T
 import numpy as np
+
+import tvm
 import tvm.testing
-from tvm.script import ir as I, relax as R, tir as T
+from tvm.script import tir as T
 
 try:
     import ml_dtypes
@@ -28,12 +29,12 @@ except ImportError:
     ml_dtypes = None
 
 native_dtype, promoted_dtype = tvm.testing.parameters(
-    ("e2m1_float4x2", "float32x2"),
-    ("e2m1_float4x2", "float16x2"),
+    ("float4_e2m1fnx2", "float32x2"),
+    ("float4_e2m1fnx2", "float16x2"),
 )
 
 
[email protected]_cuda_compute_version(9)
[email protected]_cuda_compute_version(10)
 def test_e2m1_vector_conversions(native_dtype, promoted_dtype):
     vector_length = 64
 
@@ -63,7 +64,6 @@ def test_e2m1_vector_conversions(native_dtype, 
promoted_dtype):
 
     target = "cuda"
     fadd = tvm.build(sch.mod, target=target)
-    cuda_src = fadd.imported_modules[0].get_source()
     dev = tvm.device(target, 0)
 
     numpytype = "float4_e2m1fn"
@@ -92,5 +92,124 @@ def test_e2m1_vector_conversions(native_dtype, 
promoted_dtype):
     )
 
 
[email protected]_cuda_compute_version(10)
+def test_e2m1_schedule_vectorize():
+    native_dtype = "float4_e2m1fn"
+    n = 128
+
+    dev = tvm.device("cuda", 0)
+    target = tvm.target.Target.from_device(dev)
+    for promoted_dtype, vector_length in product(
+        ["float16", "bfloat16", "float32"],
+        [1, 2, 4],
+    ):
+
+        @T.prim_func
+        def add(
+            A: T.Buffer((n,), native_dtype),
+            B: T.Buffer((n,), native_dtype),
+            C: T.Buffer((n,), native_dtype),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i in range(n):
+                with T.block("C"):
+                    v_i = T.axis.spatial(n, i)
+                    T.reads(A[v_i], B[v_i])
+                    T.writes(C[v_i])
+                    C[v_i] = T.Cast(
+                        native_dtype,
+                        T.Cast(promoted_dtype, A[v_i]) + 
T.Cast(promoted_dtype, B[v_i]),
+                    )
+
+        sch = tvm.tir.Schedule(add)
+        block = sch.get_block("C")
+        b = sch.get_loops(block)
+        bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length])
+        sch.bind(bx, "blockIdx.x")
+        sch.bind(tx, "threadIdx.x")
+        sch.vectorize(vec)
+
+        fadd = tvm.build(sch.mod, target=target)
+
+        numpytype = "float4_e2m1fn"
+        promoted_base_dtype = promoted_dtype
+
+        a_np = np.random.uniform(low=-6, high=6, size=(n,)).astype(numpytype)
+        a = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev)
+        a.copyfrom(a_np)
+        b_np = np.random.uniform(low=-6, high=6, size=(n,)).astype(numpytype)
+        b = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev)
+        b.copyfrom(b_np)
+        c = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev)
+        fadd(a, b, c)
+
+        if promoted_base_dtype != "bfloat16":
+            tvm.testing.assert_allclose(
+                c.numpy().astype(promoted_base_dtype), (a_np + 
b_np).astype(promoted_base_dtype)
+            )
+        else:
+            # assert_allclose with bfloat16 throws an error here.
+            # Thus we convert bfloat16 to float32 for comparison.
+            tvm.testing.assert_allclose(
+                c.numpy().astype(promoted_base_dtype).astype("float32"),
+                (a_np + b_np).astype(promoted_base_dtype).astype("float32"),
+            )
+
+
[email protected]_cuda_compute_version(10)
+def test_e2m1_reinterpret():
+    n = 128
+
+    dev = tvm.device("cuda", 0)
+    target = tvm.target.Target.from_device(dev)
+
+    def get_reinterpret_mod(src_dtype, dst_dtype, vector_length):
+        @T.prim_func
+        def reinterpret(
+            A: T.Buffer((n,), src_dtype),
+            B: T.Buffer((n,), dst_dtype),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i in range(n):
+                with T.block("C"):
+                    v_i = T.axis.spatial(n, i)
+                    T.reads(A[v_i])
+                    T.writes(B[v_i])
+                    B[v_i] = T.reinterpret(dst_dtype, A[v_i])
+
+        sch = tvm.tir.Schedule(reinterpret)
+        block = sch.get_block("C")
+        b = sch.get_loops(block)
+        bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length])
+        sch.bind(bx, "blockIdx.x")
+        sch.bind(tx, "threadIdx.x")
+        sch.vectorize(vec)
+        return sch.mod
+
+    # Part 1. reinterpret float4_e2m1fn to uint8
+    for vector_length in [1, 2, 4]:
+        mod = get_reinterpret_mod("float4_e2m1fn", "uint8", vector_length)
+        f = tvm.build(mod, target=target)
+        a_np = np.random.uniform(low=-6, high=6, 
size=(n,)).astype("float4_e2m1fn")
+        a = tvm.nd.empty(shape=(n,), dtype="float4_e2m1fn", device=dev)
+        a.copyfrom(a_np)
+        b = tvm.nd.empty(shape=(n,), dtype="uint8", device=dev)
+        f(a, b)
+        tvm.testing.assert_allclose(b.numpy(), a_np.view("uint8"))
+
+    # Part 2. reinterpret uint8 to float4_e2m1fn
+    for vector_length in [1, 2, 4]:
+        mod = get_reinterpret_mod("uint8", "float4_e2m1fn", vector_length)
+        f = tvm.build(mod, target=target)
+        a_np = np.random.uniform(low=-6, high=6, size=(n,)).astype("uint8")
+        a = tvm.nd.empty(shape=(n,), dtype="uint8", device=dev)
+        a.copyfrom(a_np)
+        b = tvm.nd.empty(shape=(n,), dtype="float4_e2m1fn", device=dev)
+        f(a, b)
+        tvm.testing.assert_allclose(
+            b.numpy().astype("float32"), 
a_np.view("float4_e2m1fn").astype("float32")
+        )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to