This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7bedfeb209 [Codegen] FP4 support (#17630)
7bedfeb209 is described below
commit 7bedfeb209d89b693a51bde70344f0c7130d5abf
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Feb 27 22:42:54 2025 -0500
[Codegen] FP4 support (#17630)
* fp4
* fix test
* fix lint
* fix
* Test with manually built images
---------
Co-authored-by: Yong Wu <[email protected]>
---
ci/jenkins/docker-images.ini | 4 +-
ci/jenkins/unity_jenkinsfile.groovy | 4 +-
include/tvm/runtime/data_type.h | 22 ++++-
include/tvm/script/ir_builder/tir/ir.h | 2 +
include/tvm/tir/op.h | 2 +-
python/tvm/_ffi/runtime_ctypes.py | 7 ++
python/tvm/contrib/nvcc.py | 16 ++++
python/tvm/runtime/ndarray.py | 25 +++++-
python/tvm/script/ir_builder/tir/ir.py | 14 ++++
src/ir/expr.cc | 7 +-
src/runtime/ndarray.cc | 3 +
src/script/ir_builder/tir/ir.cc | 3 +
src/support/scalars.h | 3 +
src/target/llvm/codegen_llvm.cc | 2 +
src/target/source/codegen_c.cc | 2 +-
src/target/source/codegen_cuda.cc | 48 ++++++++++-
src/target/source/codegen_cuda.h | 6 +-
src/tir/op/op.cc | 10 +++
src/tir/transforms/dtype_conversion.cc | 2 +-
src/tir/transforms/dtype_conversion.h | 8 +-
.../python/codegen/test_target_codegen_cuda_fp4.py | 96 ++++++++++++++++++++++
21 files changed, 267 insertions(+), 19 deletions(-)
diff --git a/ci/jenkins/docker-images.ini b/ci/jenkins/docker-images.ini
index 6d3f78190f..0626ff2b5f 100644
--- a/ci/jenkins/docker-images.ini
+++ b/ci/jenkins/docker-images.ini
@@ -18,8 +18,8 @@
# This data file is read during when Jenkins runs job to determine docker
images.
[jenkins]
ci_arm: tlcpack/ci-arm:20250226-223225-63bc315f
-ci_cpu: tlcpack/ci_cpu:20250226-223225-63bc315f
-ci_gpu: tlcpack/ci-gpu:20250226-223225-63bc315f
+ci_cpu: tlcpack/ci_cpu:20250226-223225-63bc315f_patch
+ci_gpu: tlcpack/ci-gpu:20250226-223225-63bc315f_patch
ci_hexagon: tlcpack/ci-hexagon:20250226-223225-63bc315f
ci_i386: tlcpack/ci-i386:20250226-223225-63bc315f
ci_lint: tlcpack/ci-lint:20250226-223225-63bc315f
diff --git a/ci/jenkins/unity_jenkinsfile.groovy
b/ci/jenkins/unity_jenkinsfile.groovy
index 4cfe96937f..8dee0c7f8c 100755
--- a/ci/jenkins/unity_jenkinsfile.groovy
+++ b/ci/jenkins/unity_jenkinsfile.groovy
@@ -30,8 +30,8 @@
import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
// NOTE: these lines are scanned by docker/dev_common.sh. Please update the
regex as needed. -->
-ci_gpu = 'tlcpack/ci-gpu:20250226-223225-63bc315f'
-ci_cpu = 'tlcpack/ci-cpu:20250226-223225-63bc315f'
+ci_gpu = 'tlcpack/ci-gpu:20250226-223225-63bc315f_patch'
+ci_cpu = 'tlcpack/ci-cpu:20250226-223225-63bc315f_patch'
// <--- End of regex-scanned config.
// Parameters to allow overriding (in Jenkins UI), the images
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index c49fde1746..3d35d5241e 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -58,6 +58,7 @@ class DataType {
kBFloat = kDLBfloat,
kE4M3Float = 6U,
kE5M2Float = 7U,
+ kE2M1Float = 8U,
kCustomBegin = 129
};
/*! \brief default constructor */
@@ -87,6 +88,9 @@ class DataType {
if (code == kE4M3Float || code == kE5M2Float) {
ICHECK_EQ(bits, 8);
}
+ if (code == kE2M1Float) {
+ ICHECK_EQ(bits, 4);
+ }
}
/*! \return The type code. */
int code() const { return static_cast<int>(data_.code); }
@@ -126,9 +130,13 @@ class DataType {
code() == DataType::kE5M2Float) &&
bits() == 8;
}
+ /*! \return whether type is a float4 type. */
+ bool is_float4() const { return code() == DataType::kE2M1Float && bits() ==
4; }
bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float &&
bits() == 8); }
bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float &&
bits() == 8); }
+
+ bool is_e2m1_float4() const { return (code() == DataType::kE2M1Float &&
bits() == 4); }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
@@ -253,6 +261,12 @@ class DataType {
* \return The constructed data type.
*/
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8,
lanes); }
+ /*!
+ * \brief Construct NV float4 e2m1 datatype.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+ static DataType NVFloat4E2M1(int lanes = 1) { return DataType(kE2M1Float, 4,
lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes.
@@ -299,7 +313,7 @@ inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype ==
DataType::UInt(4) ||
- dtype == DataType::Int(1)) {
+ dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1()) {
return 1;
}
ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
@@ -385,6 +399,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode
type_code) {
return "e4m3_float";
case DataType::kE5M2Float:
return "e5m2_float";
+ case DataType::kE2M1Float:
+ return "e2m1_float";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
@@ -466,6 +482,10 @@ inline DLDataType String2DLDataType(std::string s) {
t.code = DataType::kE5M2Float;
t.bits = 8;
scan = s.c_str() + 10;
+ } else if (s.substr(0, 10) == "e2m1_float") {
+ t.code = DataType::kE2M1Float;
+ t.bits = 4;
+ scan = s.c_str() + 10;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index 380c2fcce2..5dd1a5c733 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -505,6 +505,8 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int,
DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8,
DataType::NVFloat8E4M3);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8,
DataType::NVFloat8E5M2);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E2M1Float4,
DataType::NVFloat4E2M1);
+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index d06bb779d0..e98eb46be9 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -940,7 +940,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType
value, Span span = Span())
return LargeUIntImm(t, static_cast<int64_t>(low),
static_cast<int64_t>(high), span);
}
}
- if (t.is_float() || t.is_bfloat16() || t.is_float8())
+ if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float4())
return FloatImm(t, static_cast<double>(value), span);
// For now, we store const scalar values of custom datatypes within doubles;
later, during the
// datatypes lowering pass, we will lower the value to its true
representation in the format
diff --git a/python/tvm/_ffi/runtime_ctypes.py
b/python/tvm/_ffi/runtime_ctypes.py
index f79df1644e..263a4ff69f 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -68,6 +68,7 @@ class DataTypeCode(object):
BFLOAT = 4
E4M3Float = 6
E5M2Float = 7
+ E2M1Float = 8
class DataType(ctypes.Structure):
@@ -82,6 +83,7 @@ class DataType(ctypes.Structure):
DataTypeCode.BFLOAT: "bfloat",
DataTypeCode.E4M3Float: "e4m3_float",
DataTypeCode.E5M2Float: "e5m2_float",
+ DataTypeCode.E2M1Float: "e2m1_float",
}
NUMPY2STR = {
np.dtype(np.bool_): "bool",
@@ -112,6 +114,7 @@ class DataType(ctypes.Structure):
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8,
"lanes": 1},
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8,
"lanes": 1},
+ "e2m1_float4": {"type_code": DataTypeCode.E2M1Float, "bits": 4,
"lanes": 1},
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
@@ -168,6 +171,9 @@ class DataType(ctypes.Structure):
elif head.startswith("e5m2_float"):
self.type_code = DataTypeCode.E5M2Float
head = head[10:]
+ elif head.startswith("e2m1_float"):
+ self.type_code = DataTypeCode.E2M1Float
+ head = head[10:]
elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
@@ -232,6 +238,7 @@ if ml_dtypes is not None:
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
+ DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "e2m1_float4"
RPC_SESS_MASK = 128
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index be35bf6319..d12ddf883c 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -445,3 +445,19 @@ def have_fp8(compute_version):
if major >= 9:
return True
return False
+
+
+@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp4")
+def have_fp4(compute_version):
+ """Whether fp4 support is provided in the specified compute capability or
not
+
+ Parameters
+ ----------
+ compute_version : str
+ GPU capability
+ """
+ major, minor = parse_compute_version(compute_version)
+ # fp4 is suppored in Blackwell (10.0) or later architectures.
+ if major == 10 and minor == 0:
+ return True
+ return False
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 082a28c7e2..3514ee6168 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -197,6 +197,13 @@ class NDArray(NDArrayBase):
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
)
+ if dtype.startswith("e2m1_float4"):
+ data_bits = source_array.view(dtype="uint8")
+ if data_bits.size % 2:
+ data_bits = np.pad(data_bits, (0, 1), mode="constant",
constant_values=0)
+ data_bits = data_bits.reshape(-1, 2)
+ packed = ((data_bits[:, 0] & 0x0F) << 4) | (data_bits[:, 1] & 0x0F)
+ source_array = packed.astype(np.int8)
assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size *
source_array.dtype.itemsize)
@@ -254,20 +261,32 @@ class NDArray(NDArrayBase):
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e5m2_float8
array to numpy."
)
+ if dtype == "e2m1_float4":
+ if ml_dtypes is not None:
+ dtype = ml_dtypes.float4_e2m1fn
+ else:
+ raise RuntimeError(
+ "ml_dtypes is not installed, cannot convert e2m1_float4
array to numpy."
+ )
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
- nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
+ if old_dtype.startswith("e2m1_float4"):
+ nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize // 2)
+ else:
+ nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
- if old_dtype == "int4":
+ if old_dtype == "int4" or old_dtype.startswith("e2m1_float4"):
length = np_arr.size
+ np_arr = np_arr.view("int8")
np_arr_ret = np.empty((length,), dtype="int8")
np_arr = np_arr.reshape((length,))
old_index = np.bitwise_and(np_arr, 0x0F)
even_index = np.bitwise_and(np_arr >> 4, 0x0F)
np_arr_ret[1::2] = old_index[0 : length // 2]
np_arr_ret[0::2] = even_index[0 : length // 2]
- return np_arr_ret.reshape(shape)
+ return np_arr_ret.reshape(shape).view(dtype)
+
return np_arr
def copyto(self, target, mem_scope=None):
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index da0e2954e8..6cc19305e4 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1457,6 +1457,14 @@ e5m2_float8x16 = func_gen(("E5M2Float8x16"))
e5m2_float8x32 = func_gen(("E5M2Float8x32"))
e5m2_float8x64 = func_gen(("E5M2Float8x64"))
+e2m1_float4 = func_gen(("E2M1Float4"))
+e2m1_float4x4 = func_gen(("E2M1Float4x4"))
+e2m1_float4x8 = func_gen(("E2M1Float4x8"))
+e2m1_float4x16 = func_gen(("E2M1Float4x16"))
+e2m1_float4x32 = func_gen(("E2M1Float4x32"))
+e2m1_float4x64 = func_gen(("E2M1Float4x64"))
+
+
# pylint: enable=invalid-name
@@ -2005,31 +2013,37 @@ __all__ = [
"uint64x64",
"e4m3_float8",
"e5m2_float8",
+ "e2m1_float4",
"float16",
"float32",
"float64",
"e4m3_float8x4",
"e5m2_float8x4",
+ "e2m1_float4x4",
"float16x4",
"float32x4",
"float64x4",
"e4m3_float8x8",
"e5m2_float8x8",
+ "e2m1_float4x8",
"float16x8",
"float32x8",
"float64x8",
"e4m3_float8x16",
"e5m2_float8x16",
+ "e2m1_float4x16",
"float16x16",
"float32x16",
"float64x16",
"e4m3_float8x32",
"e5m2_float8x32",
+ "e2m1_float4x32",
"float16x32",
"float32x32",
"float64x32",
"e4m3_float8x64",
"e5m2_float8x64",
+ "e2m1_float4x64",
"float16x64",
"float32x64",
"float64x64",
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index ded046eafc..766abf3483 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -110,7 +110,7 @@ TVM_REGISTER_NODE_TYPE(IntImmNode);
FloatImm::FloatImm(DataType dtype, double value, Span span) {
ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";
- ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() ||
+ ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() ||
dtype.is_float4() ||
dtype.code() >= DataType::kCustomBegin)
<< "ValueError: FloatImm supports only float, but " << dtype << " was
supplied.";
@@ -137,6 +137,11 @@ FloatImm::FloatImm(DataType dtype, double value, Span
span) {
<< dtype;
ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << "
exceeds maximum of "
<< dtype;
+ } else if (dtype.is_float4()) {
+ ICHECK_GE(value, -support::kMaxE2M1)
+ << "ValueError: Literal value " << value << " exceeds minimum of "
<< dtype;
+ ICHECK_LE(value, support::kMaxE2M1)
+ << "ValueError: Literal value " << value << " exceeds maximum of "
<< dtype;
}
}
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index c2cf5f388a..1812d13b9c 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -28,6 +28,7 @@
#include <tvm/runtime/registry.h>
#include "runtime_base.h"
+#include "tvm/runtime/data_type.h"
extern "C" {
// C-mangled dlpack deleter.
@@ -53,6 +54,8 @@ inline void VerifyDataType(DLDataType dtype) {
return;
else if (dtype.bits == 4 && dtype.code == kDLInt)
return;
+ else if (dtype.bits == 4 && dtype.code == DataType::kE2M1Float)
+ return;
else
ICHECK_EQ(dtype.bits % 8, 0);
}
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 17353561ee..e452e102bf 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -757,6 +757,9 @@
TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.E2M1Float4").set_body_typed(E2M1Float4);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E2M1Float4", E2M1Float4);
+
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
diff --git a/src/support/scalars.h b/src/support/scalars.h
index 05763f8044..b229a6b338 100644
--- a/src/support/scalars.h
+++ b/src/support/scalars.h
@@ -69,6 +69,9 @@ constexpr double kMaxE4M3 = 448;
// See https://arxiv.org/pdf/2209.05433.pdf
constexpr double kMaxE5M2 = 57344;
+// 2^2 * (1 + 1/2)
+constexpr double kMaxE2M1 = 6.0;
+
} // namespace support
} // namespace tvm
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 6c051fc939..f5f17c70ef 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -581,6 +581,8 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType&
dtype) const {
}
} else if (dtype.code() == DataType::kE4M3Float || dtype.code() ==
DataType::kE5M2Float) {
etype = llvm::Type::getInt8Ty(*ctx);
+ } else if (dtype.code() == DataType::kE2M1Float) {
+ etype = llvm::Type::getIntNTy(*ctx, 4);
}
if (!dtype.is_scalar()) {
#if TVM_LLVM_VERSION >= 110
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 9f68cd8d66..0e0971b8f8 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -240,7 +240,7 @@ std::string CodeGenC::GetBufferRef(DataType t, const
BufferNode* buffer, PrimExp
}
std::string index_str = PrintExpr(index);
- if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
+ if ((t.bits() == 4 && !t.is_float4()) || (t.bits() == 1 && t.is_int())) {
// This is a special case, because CodegenCUDA::PrintType()
// returns "int" for bool and for 4-bit integers. In most cases,
// we divide by the number of lanes to determine the index.
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 0400518251..872a024366 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -71,6 +71,30 @@ std::string GetFP8Type(DataType type) {
return stream.str();
}
+std::string GetFP4Type(DataType type) {
+ std::stringstream stream;
+ int32_t lanes = type.lanes();
+ std::string vec;
+ if (type.is_scalar()) {
+ vec = "";
+ } else if (lanes == 2) {
+ vec = "x2";
+ } else if (lanes == 4) {
+ vec = "x4";
+ } else {
+ LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for
FP8";
+ }
+ stream << "__nv_fp4";
+ std::string suffix;
+ if (type.code() == DataType::kE2M1Float) {
+ suffix = "_e2m1";
+ } else {
+ LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
+ }
+ stream << vec << suffix;
+ return stream.str();
+}
+
CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }
void CodeGenCUDA::Init(bool output_ssa) {
@@ -133,7 +157,11 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#else\n";
decl_stream << _cuda_half_t_def;
decl_stream << "#endif\n\n";
+
+ decl_stream << "#include <cuda.h>\n";
+ decl_stream << "#if (CUDA_VERSION <12080)\n";
decl_stream << _cuda_half_util;
+ decl_stream << "#endif\n";
}
if (enable_bf16_) {
@@ -163,6 +191,11 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "struct fp8_e5x16_t {\n fp8_e5_t data[16]; \n};\n";
decl_stream << "#endif\n\n";
}
+ if (enable_fp4_) {
+ decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n";
+ decl_stream << "#include <cuda_fp4.h>\n";
+ decl_stream << "#endif\n\n";
+ }
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);
if (enable_warp_shuffle_) {
@@ -314,6 +347,14 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os)
{ // NOLINT(*)
os << "uint" << t.lanes() / 4;
}
return;
+ } else if (t.is_float4()) {
+ enable_fp4_ = true;
+ if (t.lanes() <= 4) {
+ os << GetFP4Type(t);
+ } else {
+ fail = true;
+ }
+ return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
@@ -691,7 +732,8 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op,
std::ostream& os) {
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
if (target_ty.code() == DataType::kE4M3Float || target_ty.code() ==
DataType::kE5M2Float ||
- from_ty.code() == DataType::kE4M3Float || from_ty.code() ==
DataType::kE5M2Float) {
+ target_ty.code() == DataType::kE2M1Float || from_ty.code() ==
DataType::kE4M3Float ||
+ from_ty.code() == DataType::kE5M2Float || from_ty.code() ==
DataType::kE2M1Float) {
std::ostringstream val;
val << "(";
PrintType(target_ty, val);
@@ -1273,7 +1315,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op,
std::ostream& os) { // NO
return;
}
- if (op->dtype.is_float8()) {
+ if (op->dtype.is_float8() || op->dtype.is_float4()) {
int lanes = op->dtype.lanes();
ICHECK(lanes == 1 || lanes == 2 || lanes == 4);
std::string v = PrintExpr(op->value);
@@ -1388,7 +1430,7 @@ inline void PrintConst(const FloatImmNode* op,
std::ostream& os, CodeGenCUDA* p)
return;
}
// Type code is kE5M2Float or kE4M4Float
- if (op->dtype.is_float8()) {
+ if (op->dtype.is_float8() || op->dtype.is_float4()) {
p->PrintType(op->dtype, os);
os << '(' << std::scientific << op->value << 'f' << ')';
return;
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 7fe818b6b4..ed5709ac12 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -42,8 +42,8 @@ class CodeGenCUDA final : public CodeGenC {
void Init(bool output_ssa);
std::string Finish();
bool need_include_path() {
- return (enable_fp16_ || enable_bf16_ || enable_int8_ || enable_fp8_ ||
need_math_constants_h_ ||
- need_mma_h_);
+ return (enable_fp16_ || enable_bf16_ || enable_int8_ || enable_fp8_ ||
enable_fp4_ ||
+ need_math_constants_h_ || need_mma_h_);
}
// override behavior
void PrintFuncPrefix(std::ostream& os) final;
@@ -96,6 +96,8 @@ class CodeGenCUDA final : public CodeGenC {
bool enable_bf16_{false};
// whether enable fp8
bool enable_fp8_{false};
+ // whether enable fp4
+ bool enable_fp4_{false};
// whether enable int8
bool enable_int8_{false};
// whether enable warp shuffle intrinsics
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index dad4ea98d6..039acf7e92 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -201,6 +201,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span
span) { // NOLINT(*)
} else if (ltype.is_float8() && !rtype.is_float8()) {
// Cast int->float8 for rhs when lhs is a float8
rhs = cast(ltype, rhs);
+ } else if (!ltype.is_float4() && rtype.is_float4()) {
+ // Cast int->float4 for lhs when rhs is a float4
+ lhs = cast(rtype, lhs);
+ } else if (ltype.is_float4() && !rtype.is_float4()) {
+ // Cast int->float4 for rhs when lhs is a float4
+ rhs = cast(ltype, rhs);
} else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() &&
rtype.is_uint())) {
// Promote int to higher bits e.g. int8 + int16 --> int16 + int16
if (ltype.bits() < rtype.bits()) {
@@ -272,6 +278,8 @@ PrimExpr max_value(const DataType& dtype, Span span) {
} else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
return FloatImm(dtype, 448.0, span);
}
+ } else if (dtype.is_float4()) {
+ return FloatImm(dtype, 6.0, span);
}
LOG(FATAL) << "Cannot decide max_value for type" << dtype;
}
@@ -313,6 +321,8 @@ PrimExpr min_value(const DataType& dtype, Span span) {
} else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
return FloatImm(dtype, -448.0, span);
}
+ } else if (dtype.is_float4()) {
+ return FloatImm(dtype, -6.0, span);
}
LOG(FATAL) << "Cannot decide min_value for type" << dtype;
}
diff --git a/src/tir/transforms/dtype_conversion.cc
b/src/tir/transforms/dtype_conversion.cc
index de94cf6473..dfb0a5a631 100644
--- a/src/tir/transforms/dtype_conversion.cc
+++ b/src/tir/transforms/dtype_conversion.cc
@@ -39,7 +39,7 @@ PrimExpr DTypeConversion(PrimExpr src_value, DataType
tgt_dtype, RoundingMode ro
CHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes())
<< "The lanes for data type for source value must matches the target
datatype.";
auto is_floating_point = [](DataType dtype) {
- return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16();
+ return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16() ||
dtype.is_float4();
};
// Both source dtype and target dtype should be floating point.
CHECK(is_floating_point(src_dtype) && is_floating_point(tgt_dtype));
diff --git a/src/tir/transforms/dtype_conversion.h
b/src/tir/transforms/dtype_conversion.h
index b509abb9cd..8edbf1bc1e 100644
--- a/src/tir/transforms/dtype_conversion.h
+++ b/src/tir/transforms/dtype_conversion.h
@@ -99,7 +99,7 @@ class FloatConfig {
* \return The FloatConfig class containing internal floating point
representation.
*/
static FloatConfig FromDataType(DataType dtype) {
- CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8())
+ CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() ||
dtype.is_float4())
<< "FloatConfig is only applicable to floating point data types, got "
<< dtype
<< " instead.";
if (dtype.is_float()) {
@@ -117,7 +117,7 @@ class FloatConfig {
} else if (dtype.is_bfloat16()) {
// bfloat16,
return FloatConfig(8, 7, 127, InftyStyle::kIEEE, NaNStyle::kIEEE);
- } else { // float8
+ } else if (dtype.is_float8()) { // float8
// NVIDIA/Arm/Intel's FP8 formats for Deep Learning
// Reference: https://arxiv.org/abs/2209.05433
switch (dtype.code()) {
@@ -128,6 +128,10 @@ class FloatConfig {
// E5M2 format, consistent with IEEE-754
return FloatConfig(5, 2, 15, InftyStyle::kIEEE, NaNStyle::kIEEE);
}
+ } else {
+ // float4
+ // E2M1 format, not consistent with IEEE-754
+ return FloatConfig(2, 1, 1, InftyStyle::kNone, NaNStyle::kNone);
}
}
};
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py
b/tests/python/codegen/test_target_codegen_cuda_fp4.py
new file mode 100644
index 0000000000..f137e83cc9
--- /dev/null
+++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import tvm
+from tvm.script import tir as T
+import numpy as np
+import tvm.testing
+from tvm.script import ir as I, relax as R, tir as T
+
+try:
+ import ml_dtypes
+except ImportError:
+ ml_dtypes = None
+
+native_dtype, promoted_dtype = tvm.testing.parameters(
+ ("e2m1_float4x2", "float32x2"),
+ ("e2m1_float4x2", "float16x2"),
+)
+
+
[email protected]_cuda_compute_version(9)
+def test_e2m1_vector_conversions(native_dtype, promoted_dtype):
+ vector_length = 64
+
+ @T.prim_func
+ def add(
+ A: T.Buffer((vector_length,), native_dtype),
+ B: T.Buffer((vector_length,), native_dtype),
+ C: T.Buffer((vector_length,), native_dtype),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i in range(vector_length):
+ with T.block("C"):
+ v_i = T.axis.spatial(vector_length, i)
+ T.reads(A[v_i], B[v_i])
+ T.writes(C[v_i])
+ C[v_i] = T.Cast(
+ native_dtype, T.Cast(promoted_dtype, A[v_i]) +
T.Cast(promoted_dtype, B[v_i])
+ )
+
+ sch = tvm.tir.Schedule(add)
+ block = sch.get_block("C")
+ b = sch.get_loops(block)
+ bx, tx = sch.split(b[0], factors=[None, 32])
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(tx, "threadIdx.x")
+
+ target = "cuda"
+ fadd = tvm.build(sch.mod, target=target)
+ cuda_src = fadd.imported_modules[0].get_source()
+ dev = tvm.device(target, 0)
+
+ numpytype = "float4_e2m1fn"
+ if "x" in native_dtype:
+ lanes = int(native_dtype.split("x")[-1])
+ else:
+ lanes = 1
+
+ if "x" in promoted_dtype:
+ promoted_base_dtype = promoted_dtype.split("x")[0]
+ else:
+ promoted_base_dtype = promoted_dtype
+
+ np_shape = (vector_length, lanes) if lanes > 1 else (vector_length,)
+ a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype)
+ a = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
+ a.copyfrom(a_np)
+ b_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype)
+ b = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
+ b.copyfrom(b_np)
+ c = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
+ fadd(a, b, c)
+
+ tvm.testing.assert_allclose(
+ c.numpy().astype(promoted_base_dtype), (a_np +
b_np).astype(promoted_base_dtype)
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()