This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7bedfeb209 [Codegen] FP4 support (#17630)
7bedfeb209 is described below

commit 7bedfeb209d89b693a51bde70344f0c7130d5abf
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Feb 27 22:42:54 2025 -0500

    [Codegen] FP4 support (#17630)
    
    * fp4
    
    * fix test
    
    * fix lint
    
    * fix
    
    * Test with manually built images
    
    ---------
    
    Co-authored-by: Yong Wu <[email protected]>
---
 ci/jenkins/docker-images.ini                       |  4 +-
 ci/jenkins/unity_jenkinsfile.groovy                |  4 +-
 include/tvm/runtime/data_type.h                    | 22 ++++-
 include/tvm/script/ir_builder/tir/ir.h             |  2 +
 include/tvm/tir/op.h                               |  2 +-
 python/tvm/_ffi/runtime_ctypes.py                  |  7 ++
 python/tvm/contrib/nvcc.py                         | 16 ++++
 python/tvm/runtime/ndarray.py                      | 25 +++++-
 python/tvm/script/ir_builder/tir/ir.py             | 14 ++++
 src/ir/expr.cc                                     |  7 +-
 src/runtime/ndarray.cc                             |  3 +
 src/script/ir_builder/tir/ir.cc                    |  3 +
 src/support/scalars.h                              |  3 +
 src/target/llvm/codegen_llvm.cc                    |  2 +
 src/target/source/codegen_c.cc                     |  2 +-
 src/target/source/codegen_cuda.cc                  | 48 ++++++++++-
 src/target/source/codegen_cuda.h                   |  6 +-
 src/tir/op/op.cc                                   | 10 +++
 src/tir/transforms/dtype_conversion.cc             |  2 +-
 src/tir/transforms/dtype_conversion.h              |  8 +-
 .../python/codegen/test_target_codegen_cuda_fp4.py | 96 ++++++++++++++++++++++
 21 files changed, 267 insertions(+), 19 deletions(-)

diff --git a/ci/jenkins/docker-images.ini b/ci/jenkins/docker-images.ini
index 6d3f78190f..0626ff2b5f 100644
--- a/ci/jenkins/docker-images.ini
+++ b/ci/jenkins/docker-images.ini
@@ -18,8 +18,8 @@
 # This data file is read during when Jenkins runs job to determine docker 
images.
 [jenkins]
 ci_arm: tlcpack/ci-arm:20250226-223225-63bc315f
-ci_cpu: tlcpack/ci_cpu:20250226-223225-63bc315f
-ci_gpu: tlcpack/ci-gpu:20250226-223225-63bc315f
+ci_cpu: tlcpack/ci_cpu:20250226-223225-63bc315f_patch
+ci_gpu: tlcpack/ci-gpu:20250226-223225-63bc315f_patch
 ci_hexagon: tlcpack/ci-hexagon:20250226-223225-63bc315f
 ci_i386: tlcpack/ci-i386:20250226-223225-63bc315f
 ci_lint: tlcpack/ci-lint:20250226-223225-63bc315f
diff --git a/ci/jenkins/unity_jenkinsfile.groovy 
b/ci/jenkins/unity_jenkinsfile.groovy
index 4cfe96937f..8dee0c7f8c 100755
--- a/ci/jenkins/unity_jenkinsfile.groovy
+++ b/ci/jenkins/unity_jenkinsfile.groovy
@@ -30,8 +30,8 @@
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 
 // NOTE: these lines are scanned by docker/dev_common.sh. Please update the 
regex as needed. -->
-ci_gpu = 'tlcpack/ci-gpu:20250226-223225-63bc315f'
-ci_cpu = 'tlcpack/ci-cpu:20250226-223225-63bc315f'
+ci_gpu = 'tlcpack/ci-gpu:20250226-223225-63bc315f_patch'
+ci_cpu = 'tlcpack/ci-cpu:20250226-223225-63bc315f_patch'
 // <--- End of regex-scanned config.
 
 // Parameters to allow overriding (in Jenkins UI), the images
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index c49fde1746..3d35d5241e 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -58,6 +58,7 @@ class DataType {
     kBFloat = kDLBfloat,
     kE4M3Float = 6U,
     kE5M2Float = 7U,
+    kE2M1Float = 8U,
     kCustomBegin = 129
   };
   /*! \brief default constructor */
@@ -87,6 +88,9 @@ class DataType {
     if (code == kE4M3Float || code == kE5M2Float) {
       ICHECK_EQ(bits, 8);
     }
+    if (code == kE2M1Float) {
+      ICHECK_EQ(bits, 4);
+    }
   }
   /*! \return The type code. */
   int code() const { return static_cast<int>(data_.code); }
@@ -126,9 +130,13 @@ class DataType {
             code() == DataType::kE5M2Float) &&
            bits() == 8;
   }
+  /*! \return whether type is a float4 type. */
+  bool is_float4() const { return code() == DataType::kE2M1Float && bits() == 
4; }
   bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && 
bits() == 8); }
 
   bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && 
bits() == 8); }
+
+  bool is_e2m1_float4() const { return (code() == DataType::kE2M1Float && 
bits() == 4); }
   /*! \return whether type is a float16 type. */
   bool is_float16() const { return is_float() && bits() == 16; }
   /*! \return whether type is a bfloat16 type. */
@@ -253,6 +261,12 @@ class DataType {
    * \return The constructed data type.
    */
   static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, 
lanes); }
+  /*!
+   * \brief Construct NV float4 e2m1 datatype.
+   * \param lanes The number of lanes
+   * \return The constructed data type.
+   */
+  static DataType NVFloat4E2M1(int lanes = 1) { return DataType(kE2M1Float, 4, 
lanes); }
   /*!
    * \brief Construct a bool type.
    * \param lanes The number of lanes.
@@ -299,7 +313,7 @@ inline int GetVectorBytes(DataType dtype) {
   int data_bits = dtype.bits() * dtype.lanes();
   // allow bool to exist
   if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == 
DataType::UInt(4) ||
-      dtype == DataType::Int(1)) {
+      dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1()) {
     return 1;
   }
   ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
@@ -385,6 +399,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode 
type_code) {
       return "e4m3_float";
     case DataType::kE5M2Float:
       return "e5m2_float";
+    case DataType::kE2M1Float:
+      return "e2m1_float";
     default:
       LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
   }
@@ -466,6 +482,10 @@ inline DLDataType String2DLDataType(std::string s) {
     t.code = DataType::kE5M2Float;
     t.bits = 8;
     scan = s.c_str() + 10;
+  } else if (s.substr(0, 10) == "e2m1_float") {
+    t.code = DataType::kE2M1Float;
+    t.bits = 4;
+    scan = s.c_str() + 10;
   } else if (s.substr(0, 6) == "custom") {
     t.code = ParseCustomDatatype(s, &scan);
   } else {
diff --git a/include/tvm/script/ir_builder/tir/ir.h 
b/include/tvm/script/ir_builder/tir/ir.h
index 380c2fcce2..5dd1a5c733 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -505,6 +505,8 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, 
DataType::Int);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, 
DataType::NVFloat8E4M3);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, 
DataType::NVFloat8E5M2);
 
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E2M1Float4, 
DataType::NVFloat4E2M1);
+
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
 
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index d06bb779d0..e98eb46be9 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -940,7 +940,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType 
value, Span span = Span())
       return LargeUIntImm(t, static_cast<int64_t>(low), 
static_cast<int64_t>(high), span);
     }
   }
-  if (t.is_float() || t.is_bfloat16() || t.is_float8())
+  if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float4())
     return FloatImm(t, static_cast<double>(value), span);
   // For now, we store const scalar values of custom datatypes within doubles; 
later, during the
   // datatypes lowering pass, we will lower the value to its true 
representation in the format
diff --git a/python/tvm/_ffi/runtime_ctypes.py 
b/python/tvm/_ffi/runtime_ctypes.py
index f79df1644e..263a4ff69f 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -68,6 +68,7 @@ class DataTypeCode(object):
     BFLOAT = 4
     E4M3Float = 6
     E5M2Float = 7
+    E2M1Float = 8
 
 
 class DataType(ctypes.Structure):
@@ -82,6 +83,7 @@ class DataType(ctypes.Structure):
         DataTypeCode.BFLOAT: "bfloat",
         DataTypeCode.E4M3Float: "e4m3_float",
         DataTypeCode.E5M2Float: "e5m2_float",
+        DataTypeCode.E2M1Float: "e2m1_float",
     }
     NUMPY2STR = {
         np.dtype(np.bool_): "bool",
@@ -112,6 +114,7 @@ class DataType(ctypes.Structure):
         "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
         "e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, 
"lanes": 1},
         "e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, 
"lanes": 1},
+        "e2m1_float4": {"type_code": DataTypeCode.E2M1Float, "bits": 4, 
"lanes": 1},
         "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
         "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
         "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
@@ -168,6 +171,9 @@ class DataType(ctypes.Structure):
         elif head.startswith("e5m2_float"):
             self.type_code = DataTypeCode.E5M2Float
             head = head[10:]
+        elif head.startswith("e2m1_float"):
+            self.type_code = DataTypeCode.E2M1Float
+            head = head[10:]
         elif head.startswith("custom"):
             # pylint: disable=import-outside-toplevel
             import tvm.runtime._ffi_api
@@ -232,6 +238,7 @@ if ml_dtypes is not None:
     DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
     DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
     DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
+    DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "e2m1_float4"
 
 RPC_SESS_MASK = 128
 
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index be35bf6319..d12ddf883c 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -445,3 +445,19 @@ def have_fp8(compute_version):
     if major >= 9:
         return True
     return False
+
+
+@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp4")
+def have_fp4(compute_version):
+    """Whether fp4 support is provided in the specified compute capability or 
not
+
+    Parameters
+    ----------
+    compute_version : str
+        GPU capability
+    """
+    major, minor = parse_compute_version(compute_version)
+    # fp4 is suppored in Blackwell (10.0) or later architectures.
+    if major == 10 and minor == 0:
+        return True
+    return False
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 082a28c7e2..3514ee6168 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -197,6 +197,13 @@ class NDArray(NDArrayBase):
             source_array = np.ascontiguousarray(
                 source_array, dtype="uint16" if dtype == "bfloat16" else dtype
             )
+        if dtype.startswith("e2m1_float4"):
+            data_bits = source_array.view(dtype="uint8")
+            if data_bits.size % 2:
+                data_bits = np.pad(data_bits, (0, 1), mode="constant", 
constant_values=0)
+            data_bits = data_bits.reshape(-1, 2)
+            packed = ((data_bits[:, 0] & 0x0F) << 4) | (data_bits[:, 1] & 0x0F)
+            source_array = packed.astype(np.int8)
         assert source_array.flags["C_CONTIGUOUS"]
         data = source_array.ctypes.data_as(ctypes.c_void_p)
         nbytes = ctypes.c_size_t(source_array.size * 
source_array.dtype.itemsize)
@@ -254,20 +261,32 @@ class NDArray(NDArrayBase):
                 raise RuntimeError(
                     "ml_dtypes is not installed, cannot convert e5m2_float8 
array to numpy."
                 )
+        if dtype == "e2m1_float4":
+            if ml_dtypes is not None:
+                dtype = ml_dtypes.float4_e2m1fn
+            else:
+                raise RuntimeError(
+                    "ml_dtypes is not installed, cannot convert e2m1_float4 
array to numpy."
+                )
         np_arr = np.empty(shape, dtype=dtype)
         assert np_arr.flags["C_CONTIGUOUS"]
         data = np_arr.ctypes.data_as(ctypes.c_void_p)
-        nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
+        if old_dtype.startswith("e2m1_float4"):
+            nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize // 2)
+        else:
+            nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
         check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
-        if old_dtype == "int4":
+        if old_dtype == "int4" or old_dtype.startswith("e2m1_float4"):
             length = np_arr.size
+            np_arr = np_arr.view("int8")
             np_arr_ret = np.empty((length,), dtype="int8")
             np_arr = np_arr.reshape((length,))
             old_index = np.bitwise_and(np_arr, 0x0F)
             even_index = np.bitwise_and(np_arr >> 4, 0x0F)
             np_arr_ret[1::2] = old_index[0 : length // 2]
             np_arr_ret[0::2] = even_index[0 : length // 2]
-            return np_arr_ret.reshape(shape)
+            return np_arr_ret.reshape(shape).view(dtype)
+
         return np_arr
 
     def copyto(self, target, mem_scope=None):
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index da0e2954e8..6cc19305e4 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1457,6 +1457,14 @@ e5m2_float8x16 = func_gen(("E5M2Float8x16"))
 e5m2_float8x32 = func_gen(("E5M2Float8x32"))
 e5m2_float8x64 = func_gen(("E5M2Float8x64"))
 
+e2m1_float4 = func_gen(("E2M1Float4"))
+e2m1_float4x4 = func_gen(("E2M1Float4x4"))
+e2m1_float4x8 = func_gen(("E2M1Float4x8"))
+e2m1_float4x16 = func_gen(("E2M1Float4x16"))
+e2m1_float4x32 = func_gen(("E2M1Float4x32"))
+e2m1_float4x64 = func_gen(("E2M1Float4x64"))
+
+
 # pylint: enable=invalid-name
 
 
@@ -2005,31 +2013,37 @@ __all__ = [
     "uint64x64",
     "e4m3_float8",
     "e5m2_float8",
+    "e2m1_float4",
     "float16",
     "float32",
     "float64",
     "e4m3_float8x4",
     "e5m2_float8x4",
+    "e2m1_float4x4",
     "float16x4",
     "float32x4",
     "float64x4",
     "e4m3_float8x8",
     "e5m2_float8x8",
+    "e2m1_float4x8",
     "float16x8",
     "float32x8",
     "float64x8",
     "e4m3_float8x16",
     "e5m2_float8x16",
+    "e2m1_float4x16",
     "float16x16",
     "float32x16",
     "float64x16",
     "e4m3_float8x32",
     "e5m2_float8x32",
+    "e2m1_float4x32",
     "float16x32",
     "float32x32",
     "float64x32",
     "e4m3_float8x64",
     "e5m2_float8x64",
+    "e2m1_float4x64",
     "float16x64",
     "float32x64",
     "float64x64",
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index ded046eafc..766abf3483 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -110,7 +110,7 @@ TVM_REGISTER_NODE_TYPE(IntImmNode);
 FloatImm::FloatImm(DataType dtype, double value, Span span) {
   ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";
 
-  ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() ||
+  ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || 
dtype.is_float4() ||
          dtype.code() >= DataType::kCustomBegin)
       << "ValueError: FloatImm supports only float, but " << dtype << " was 
supplied.";
 
@@ -137,6 +137,11 @@ FloatImm::FloatImm(DataType dtype, double value, Span 
span) {
                                << dtype;
       ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << " 
exceeds maximum of "
                               << dtype;
+    } else if (dtype.is_float4()) {
+      ICHECK_GE(value, -support::kMaxE2M1)
+          << "ValueError: Literal value " << value << " exceeds minimum of " 
<< dtype;
+      ICHECK_LE(value, support::kMaxE2M1)
+          << "ValueError: Literal value " << value << " exceeds maximum of " 
<< dtype;
     }
   }
   ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index c2cf5f388a..1812d13b9c 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -28,6 +28,7 @@
 #include <tvm/runtime/registry.h>
 
 #include "runtime_base.h"
+#include "tvm/runtime/data_type.h"
 
 extern "C" {
 // C-mangled dlpack deleter.
@@ -53,6 +54,8 @@ inline void VerifyDataType(DLDataType dtype) {
       return;
     else if (dtype.bits == 4 && dtype.code == kDLInt)
       return;
+    else if (dtype.bits == 4 && dtype.code == DataType::kE2M1Float)
+      return;
     else
       ICHECK_EQ(dtype.bits % 8, 0);
   }
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 17353561ee..e452e102bf 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -757,6 +757,9 @@ 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float
 TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8);
 TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8);
 
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.E2M1Float4").set_body_typed(E2M1Float4);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E2M1Float4", E2M1Float4);
+
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
diff --git a/src/support/scalars.h b/src/support/scalars.h
index 05763f8044..b229a6b338 100644
--- a/src/support/scalars.h
+++ b/src/support/scalars.h
@@ -69,6 +69,9 @@ constexpr double kMaxE4M3 = 448;
 // See https://arxiv.org/pdf/2209.05433.pdf
 constexpr double kMaxE5M2 = 57344;
 
+// 2^2 * (1 + 1/2)
+constexpr double kMaxE2M1 = 6.0;
+
 }  // namespace support
 }  // namespace tvm
 
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 6c051fc939..f5f17c70ef 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -581,6 +581,8 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& 
dtype) const {
     }
   } else if (dtype.code() == DataType::kE4M3Float || dtype.code() == 
DataType::kE5M2Float) {
     etype = llvm::Type::getInt8Ty(*ctx);
+  } else if (dtype.code() == DataType::kE2M1Float) {
+    etype = llvm::Type::getIntNTy(*ctx, 4);
   }
   if (!dtype.is_scalar()) {
 #if TVM_LLVM_VERSION >= 110
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 9f68cd8d66..0e0971b8f8 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -240,7 +240,7 @@ std::string CodeGenC::GetBufferRef(DataType t, const 
BufferNode* buffer, PrimExp
   }
 
   std::string index_str = PrintExpr(index);
-  if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
+  if ((t.bits() == 4 && !t.is_float4()) || (t.bits() == 1 && t.is_int())) {
     // This is a special case, because CodegenCUDA::PrintType()
     // returns "int" for bool and for 4-bit integers. In most cases,
     // we divide by the number of lanes to determine the index.
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 0400518251..872a024366 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -71,6 +71,30 @@ std::string GetFP8Type(DataType type) {
   return stream.str();
 }
 
+std::string GetFP4Type(DataType type) {
+  std::stringstream stream;
+  int32_t lanes = type.lanes();
+  std::string vec;
+  if (type.is_scalar()) {
+    vec = "";
+  } else if (lanes == 2) {
+    vec = "x2";
+  } else if (lanes == 4) {
+    vec = "x4";
+  } else {
+    LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for 
FP8";
+  }
+  stream << "__nv_fp4";
+  std::string suffix;
+  if (type.code() == DataType::kE2M1Float) {
+    suffix = "_e2m1";
+  } else {
+    LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
+  }
+  stream << vec << suffix;
+  return stream.str();
+}
+
 CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }
 
 void CodeGenCUDA::Init(bool output_ssa) {
@@ -133,7 +157,11 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << "#else\n";
     decl_stream << _cuda_half_t_def;
     decl_stream << "#endif\n\n";
+
+    decl_stream << "#include <cuda.h>\n";
+    decl_stream << "#if (CUDA_VERSION <12080)\n";
     decl_stream << _cuda_half_util;
+    decl_stream << "#endif\n";
   }
 
   if (enable_bf16_) {
@@ -163,6 +191,11 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << "struct fp8_e5x16_t {\n fp8_e5_t data[16]; \n};\n";
     decl_stream << "#endif\n\n";
   }
+  if (enable_fp4_) {
+    decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n";
+    decl_stream << "#include <cuda_fp4.h>\n";
+    decl_stream << "#endif\n\n";
+  }
   declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);
 
   if (enable_warp_shuffle_) {
@@ -314,6 +347,14 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) 
{  // NOLINT(*)
       os << "uint" << t.lanes() / 4;
     }
     return;
+  } else if (t.is_float4()) {
+    enable_fp4_ = true;
+    if (t.lanes() <= 4) {
+      os << GetFP4Type(t);
+    } else {
+      fail = true;
+    }
+    return;
   } else if (t == DataType::Bool()) {
     os << "bool";
     return;
@@ -691,7 +732,8 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, 
std::ostream& os) {
   if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
 
   if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == 
DataType::kE5M2Float ||
-      from_ty.code() == DataType::kE4M3Float || from_ty.code() == 
DataType::kE5M2Float) {
+      target_ty.code() == DataType::kE2M1Float || from_ty.code() == 
DataType::kE4M3Float ||
+      from_ty.code() == DataType::kE5M2Float || from_ty.code() == 
DataType::kE2M1Float) {
     std::ostringstream val;
     val << "(";
     PrintType(target_ty, val);
@@ -1273,7 +1315,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, 
std::ostream& os) {  // NO
     return;
   }
 
-  if (op->dtype.is_float8()) {
+  if (op->dtype.is_float8() || op->dtype.is_float4()) {
     int lanes = op->dtype.lanes();
     ICHECK(lanes == 1 || lanes == 2 || lanes == 4);
     std::string v = PrintExpr(op->value);
@@ -1388,7 +1430,7 @@ inline void PrintConst(const FloatImmNode* op, 
std::ostream& os, CodeGenCUDA* p)
     return;
   }
   // Type code is kE5M2Float or kE4M4Float
-  if (op->dtype.is_float8()) {
+  if (op->dtype.is_float8() || op->dtype.is_float4()) {
     p->PrintType(op->dtype, os);
     os << '(' << std::scientific << op->value << 'f' << ')';
     return;
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 7fe818b6b4..ed5709ac12 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -42,8 +42,8 @@ class CodeGenCUDA final : public CodeGenC {
   void Init(bool output_ssa);
   std::string Finish();
   bool need_include_path() {
-    return (enable_fp16_ || enable_bf16_ || enable_int8_ || enable_fp8_ || 
need_math_constants_h_ ||
-            need_mma_h_);
+    return (enable_fp16_ || enable_bf16_ || enable_int8_ || enable_fp8_ || 
enable_fp4_ ||
+            need_math_constants_h_ || need_mma_h_);
   }
   // override behavior
   void PrintFuncPrefix(std::ostream& os) final;
@@ -96,6 +96,8 @@ class CodeGenCUDA final : public CodeGenC {
   bool enable_bf16_{false};
   // whether enable fp8
   bool enable_fp8_{false};
+  // whether enable fp4
+  bool enable_fp4_{false};
   // whether enable int8
   bool enable_int8_{false};
   // whether enable warp shuffle intrinsics
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index dad4ea98d6..039acf7e92 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -201,6 +201,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span 
span) {  // NOLINT(*)
   } else if (ltype.is_float8() && !rtype.is_float8()) {
     // Cast int->float8 for rhs when lhs is a float8
     rhs = cast(ltype, rhs);
+  } else if (!ltype.is_float4() && rtype.is_float4()) {
+    // Cast int->float4 for lhs when rhs is a float4
+    lhs = cast(rtype, lhs);
+  } else if (ltype.is_float4() && !rtype.is_float4()) {
+    // Cast int->float4 for rhs when lhs is a float4
+    rhs = cast(ltype, rhs);
   } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && 
rtype.is_uint())) {
     // Promote int to higher bits e.g. int8 + int16 --> int16 + int16
     if (ltype.bits() < rtype.bits()) {
@@ -272,6 +278,8 @@ PrimExpr max_value(const DataType& dtype, Span span) {
     } else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
       return FloatImm(dtype, 448.0, span);
     }
+  } else if (dtype.is_float4()) {
+    return FloatImm(dtype, 6.0, span);
   }
   LOG(FATAL) << "Cannot decide max_value for type" << dtype;
 }
@@ -313,6 +321,8 @@ PrimExpr min_value(const DataType& dtype, Span span) {
     } else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
       return FloatImm(dtype, -448.0, span);
     }
+  } else if (dtype.is_float4()) {
+    return FloatImm(dtype, -6.0, span);
   }
   LOG(FATAL) << "Cannot decide min_value for type" << dtype;
 }
diff --git a/src/tir/transforms/dtype_conversion.cc 
b/src/tir/transforms/dtype_conversion.cc
index de94cf6473..dfb0a5a631 100644
--- a/src/tir/transforms/dtype_conversion.cc
+++ b/src/tir/transforms/dtype_conversion.cc
@@ -39,7 +39,7 @@ PrimExpr DTypeConversion(PrimExpr src_value, DataType 
tgt_dtype, RoundingMode ro
   CHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes())
       << "The lanes for data type for source value must matches the target 
datatype.";
   auto is_floating_point = [](DataType dtype) {
-    return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16();
+    return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16() || 
dtype.is_float4();
   };
   // Both source dtype and target dtype should be floating point.
   CHECK(is_floating_point(src_dtype) && is_floating_point(tgt_dtype));
diff --git a/src/tir/transforms/dtype_conversion.h 
b/src/tir/transforms/dtype_conversion.h
index b509abb9cd..8edbf1bc1e 100644
--- a/src/tir/transforms/dtype_conversion.h
+++ b/src/tir/transforms/dtype_conversion.h
@@ -99,7 +99,7 @@ class FloatConfig {
    * \return The FloatConfig class containing internal floating point 
representation.
    */
   static FloatConfig FromDataType(DataType dtype) {
-    CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8())
+    CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || 
dtype.is_float4())
         << "FloatConfig is only applicable to floating point data types, got " 
<< dtype
         << " instead.";
     if (dtype.is_float()) {
@@ -117,7 +117,7 @@ class FloatConfig {
     } else if (dtype.is_bfloat16()) {
       // bfloat16,
       return FloatConfig(8, 7, 127, InftyStyle::kIEEE, NaNStyle::kIEEE);
-    } else {  // float8
+    } else if (dtype.is_float8()) {  // float8
       // NVIDIA/Arm/Intel's FP8 formats for Deep Learning
       // Reference: https://arxiv.org/abs/2209.05433
       switch (dtype.code()) {
@@ -128,6 +128,10 @@ class FloatConfig {
           // E5M2 format, consistent with IEEE-754
           return FloatConfig(5, 2, 15, InftyStyle::kIEEE, NaNStyle::kIEEE);
       }
+    } else {
+      // float4
+      // E2M1 format, not consistent with IEEE-754
+      return FloatConfig(2, 1, 1, InftyStyle::kNone, NaNStyle::kNone);
     }
   }
 };
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py 
b/tests/python/codegen/test_target_codegen_cuda_fp4.py
new file mode 100644
index 0000000000..f137e83cc9
--- /dev/null
+++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import tvm
+from tvm.script import tir as T
+import numpy as np
+import tvm.testing
+from tvm.script import ir as I, relax as R, tir as T
+
+try:
+    import ml_dtypes
+except ImportError:
+    ml_dtypes = None
+
+native_dtype, promoted_dtype = tvm.testing.parameters(
+    ("e2m1_float4x2", "float32x2"),
+    ("e2m1_float4x2", "float16x2"),
+)
+
+
[email protected]_cuda_compute_version(9)
+def test_e2m1_vector_conversions(native_dtype, promoted_dtype):
+    vector_length = 64
+
+    @T.prim_func
+    def add(
+        A: T.Buffer((vector_length,), native_dtype),
+        B: T.Buffer((vector_length,), native_dtype),
+        C: T.Buffer((vector_length,), native_dtype),
+    ):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        for i in range(vector_length):
+            with T.block("C"):
+                v_i = T.axis.spatial(vector_length, i)
+                T.reads(A[v_i], B[v_i])
+                T.writes(C[v_i])
+                C[v_i] = T.Cast(
+                    native_dtype, T.Cast(promoted_dtype, A[v_i]) + 
T.Cast(promoted_dtype, B[v_i])
+                )
+
+    sch = tvm.tir.Schedule(add)
+    block = sch.get_block("C")
+    b = sch.get_loops(block)
+    bx, tx = sch.split(b[0], factors=[None, 32])
+    sch.bind(bx, "blockIdx.x")
+    sch.bind(tx, "threadIdx.x")
+
+    target = "cuda"
+    fadd = tvm.build(sch.mod, target=target)
+    cuda_src = fadd.imported_modules[0].get_source()
+    dev = tvm.device(target, 0)
+
+    numpytype = "float4_e2m1fn"
+    if "x" in native_dtype:
+        lanes = int(native_dtype.split("x")[-1])
+    else:
+        lanes = 1
+
+    if "x" in promoted_dtype:
+        promoted_base_dtype = promoted_dtype.split("x")[0]
+    else:
+        promoted_base_dtype = promoted_dtype
+
+    np_shape = (vector_length, lanes) if lanes > 1 else (vector_length,)
+    a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype)
+    a = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
+    a.copyfrom(a_np)
+    b_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype)
+    b = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
+    b.copyfrom(b_np)
+    c = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
+    fadd(a, b, c)
+
+    tvm.testing.assert_allclose(
+        c.numpy().astype(promoted_base_dtype), (a_np + 
b_np).astype(promoted_base_dtype)
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to