This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c19e5f4be0 [CUDA] FP4 cast and reinterpret support (#17708)
c19e5f4be0 is described below
commit c19e5f4be061c21cf04ba432abc2375bf783a441
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Mar 6 08:44:03 2025 -0500
[CUDA] FP4 cast and reinterpret support (#17708)
* [CUDA] FP4 cast and reinterpret support
Following up on a previous PR, this PR introduces the cast and
reinterpret support between `__nv_fp4_e2m1` and other dtypes.
This PR also makes sure that the cast and reinterpret support
vectorize.
* change to float4_e2m1fn
---
include/tvm/runtime/data_type.h | 36 +++---
include/tvm/script/ir_builder/tir/ir.h | 2 +-
python/tvm/_ffi/runtime_ctypes.py | 23 ++--
python/tvm/runtime/ndarray.py | 14 ++-
python/tvm/script/ir_builder/tir/ir.py | 31 ++---
src/runtime/ndarray.cc | 2 +-
src/script/ir_builder/tir/ir.cc | 4 +-
src/target/llvm/codegen_llvm.cc | 2 +-
src/target/source/codegen_c.cc | 8 +-
src/target/source/codegen_cuda.cc | 93 ++++++++++++--
src/target/source/literal/cuda_half_t.h | 39 +++++-
src/tir/op/op.cc | 6 +-
.../python/codegen/test_target_codegen_cuda_fp4.py | 133 +++++++++++++++++++--
13 files changed, 326 insertions(+), 67 deletions(-)
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 3d35d5241e..76e5e3833f 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -58,7 +58,7 @@ class DataType {
kBFloat = kDLBfloat,
kE4M3Float = 6U,
kE5M2Float = 7U,
- kE2M1Float = 8U,
+ kFloat4E2M1Fn = 8U,
kCustomBegin = 129
};
/*! \brief default constructor */
@@ -88,7 +88,7 @@ class DataType {
if (code == kE4M3Float || code == kE5M2Float) {
ICHECK_EQ(bits, 8);
}
- if (code == kE2M1Float) {
+ if (code == kFloat4E2M1Fn) {
ICHECK_EQ(bits, 4);
}
}
@@ -131,12 +131,10 @@ class DataType {
bits() == 8;
}
/*! \return whether type is a float4 type. */
- bool is_float4() const { return code() == DataType::kE2M1Float && bits() ==
4; }
+ bool is_float4() const { return code() == DataType::kFloat4E2M1Fn && bits()
== 4; }
bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float &&
bits() == 8); }
-
bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float &&
bits() == 8); }
-
- bool is_e2m1_float4() const { return (code() == DataType::kE2M1Float &&
bits() == 4); }
+ bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4E2M1Fn &&
bits() == 4); }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
@@ -262,11 +260,11 @@ class DataType {
*/
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8,
lanes); }
/*!
- * \brief Construct NV float4 e2m1 datatype.
+ * \brief Construct NV float4_e2m1fn datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
- static DataType NVFloat4E2M1(int lanes = 1) { return DataType(kE2M1Float, 4,
lanes); }
+ static DataType NVFloat4E2M1FN(int lanes = 1) { return
DataType(kFloat4E2M1Fn, 4, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes.
@@ -313,7 +311,7 @@ inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype ==
DataType::UInt(4) ||
- dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1()) {
+ dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1FN()) {
return 1;
}
ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
@@ -399,8 +397,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode
type_code) {
return "e4m3_float";
case DataType::kE5M2Float:
return "e5m2_float";
- case DataType::kE2M1Float:
- return "e2m1_float";
+ case DataType::kFloat4E2M1Fn:
+ return "float4_e2m1fn";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
@@ -458,6 +456,18 @@ inline DLDataType String2DLDataType(std::string s) {
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt;
scan = s.c_str() + 4;
+ } else if (s.substr(0, 13) == "float4_e2m1fn") {
+ // Avoid being treated as "float"
+ t.code = DataType::kFloat4E2M1Fn;
+ t.bits = 4;
+ scan = s.c_str() + 13;
+ char* endpt = nullptr;
+ if (*scan == 'x') {
+ t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
+ scan = endpt;
+ }
+ ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
+ return t;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat;
scan = s.c_str() + 5;
@@ -482,10 +492,6 @@ inline DLDataType String2DLDataType(std::string s) {
t.code = DataType::kE5M2Float;
t.bits = 8;
scan = s.c_str() + 10;
- } else if (s.substr(0, 10) == "e2m1_float") {
- t.code = DataType::kE2M1Float;
- t.bits = 4;
- scan = s.c_str() + 10;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index 5dd1a5c733..e78e0d51fd 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -505,7 +505,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int,
DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8,
DataType::NVFloat8E4M3);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8,
DataType::NVFloat8E5M2);
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E2M1Float4,
DataType::NVFloat4E2M1);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1fn,
DataType::NVFloat4E2M1FN);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
diff --git a/python/tvm/_ffi/runtime_ctypes.py
b/python/tvm/_ffi/runtime_ctypes.py
index 263a4ff69f..3f4ceadd1d 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -68,7 +68,7 @@ class DataTypeCode(object):
BFLOAT = 4
E4M3Float = 6
E5M2Float = 7
- E2M1Float = 8
+ FLOAT4E2M1FN = 8
class DataType(ctypes.Structure):
@@ -83,7 +83,7 @@ class DataType(ctypes.Structure):
DataTypeCode.BFLOAT: "bfloat",
DataTypeCode.E4M3Float: "e4m3_float",
DataTypeCode.E5M2Float: "e5m2_float",
- DataTypeCode.E2M1Float: "e2m1_float",
+ DataTypeCode.FLOAT4E2M1FN: "float4_e2m1fn",
}
NUMPY2STR = {
np.dtype(np.bool_): "bool",
@@ -114,7 +114,7 @@ class DataType(ctypes.Structure):
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8,
"lanes": 1},
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8,
"lanes": 1},
- "e2m1_float4": {"type_code": DataTypeCode.E2M1Float, "bits": 4,
"lanes": 1},
+ "float4_e2m1fn": {"type_code": DataTypeCode.FLOAT4E2M1FN, "bits": 4,
"lanes": 1},
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
@@ -155,6 +155,11 @@ class DataType(ctypes.Structure):
elif head.startswith("uint"):
self.type_code = DataTypeCode.UINT
head = head[4:]
+ elif head.startswith("float4_e2m1fn"):
+ # Avoid being treated as "float"
+ self.type_code = DataTypeCode.FLOAT4E2M1FN
+ bits = 4
+ head = ""
elif head.startswith("float"):
self.type_code = DataTypeCode.FLOAT
head = head[5:]
@@ -171,9 +176,6 @@ class DataType(ctypes.Structure):
elif head.startswith("e5m2_float"):
self.type_code = DataTypeCode.E5M2Float
head = head[10:]
- elif head.startswith("e2m1_float"):
- self.type_code = DataTypeCode.E2M1Float
- head = head[10:]
elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
@@ -201,7 +203,12 @@ class DataType(ctypes.Structure):
import tvm.runtime._ffi_api
type_name = "custom[%s]" %
tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
- x = "%s%d" % (type_name, self.bits)
+ if self.type_code in [
+ DataTypeCode.FLOAT4E2M1FN,
+ ]:
+ x = type_name
+ else:
+ x = "%s%d" % (type_name, self.bits)
lanes_as_int = ctypes.c_int16(self.lanes).value
if lanes_as_int > 1:
x += "x%d" % self.lanes
@@ -238,7 +245,7 @@ if ml_dtypes is not None:
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
- DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "e2m1_float4"
+ DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
RPC_SESS_MASK = 128
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 3514ee6168..47fcccf52b 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -197,7 +197,9 @@ class NDArray(NDArrayBase):
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
)
- if dtype.startswith("e2m1_float4"):
+ if self.dtype.startswith("float4_e2m1fn") and self.dtype !=
"float4_e2m1fn":
+ # float4_e2m1fn in numpy is not packed.
+ # So we need to pack the input data when converting to vectorized
float4_e2m1fn type.
data_bits = source_array.view(dtype="uint8")
if data_bits.size % 2:
data_bits = np.pad(data_bits, (0, 1), mode="constant",
constant_values=0)
@@ -261,22 +263,24 @@ class NDArray(NDArrayBase):
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e5m2_float8
array to numpy."
)
- if dtype == "e2m1_float4":
+ if dtype == "float4_e2m1fn":
if ml_dtypes is not None:
dtype = ml_dtypes.float4_e2m1fn
else:
raise RuntimeError(
- "ml_dtypes is not installed, cannot convert e2m1_float4
array to numpy."
+ "ml_dtypes is not installed, cannot convert float4_e2m1fn
array to numpy."
)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
- if old_dtype.startswith("e2m1_float4"):
+ if old_dtype.startswith("float4_e2m1fn") and old_dtype !=
"float4_e2m1fn":
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize // 2)
else:
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
- if old_dtype == "int4" or old_dtype.startswith("e2m1_float4"):
+ if old_dtype == "int4" or (
+ old_dtype.startswith("float4_e2m1fn") and old_dtype !=
"float4_e2m1fn"
+ ):
length = np_arr.size
np_arr = np_arr.view("int8")
np_arr_ret = np.empty((length,), dtype="int8")
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index 6cc19305e4..c35df7a093 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -18,8 +18,8 @@
import functools
import inspect
-from numbers import Integral
import sys
+from numbers import Integral
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
# isort: off
@@ -29,8 +29,7 @@ from typing_extensions import Literal
import numpy as np # type: ignore
-from tvm import tir
-from tvm import ir
+from tvm import ir, tir
from tvm.ir import Type
from tvm.ir.base import deprecated
from tvm.runtime import String, convert, ndarray
@@ -1457,12 +1456,13 @@ e5m2_float8x16 = func_gen(("E5M2Float8x16"))
e5m2_float8x32 = func_gen(("E5M2Float8x32"))
e5m2_float8x64 = func_gen(("E5M2Float8x64"))
-e2m1_float4 = func_gen(("E2M1Float4"))
-e2m1_float4x4 = func_gen(("E2M1Float4x4"))
-e2m1_float4x8 = func_gen(("E2M1Float4x8"))
-e2m1_float4x16 = func_gen(("E2M1Float4x16"))
-e2m1_float4x32 = func_gen(("E2M1Float4x32"))
-e2m1_float4x64 = func_gen(("E2M1Float4x64"))
+float4_e2m1fn = func_gen(("Float4E2M1fn"))
+float4_e2m1fnx2 = func_gen(("Float4E2M1fnx2"))
+float4_e2m1fnx4 = func_gen(("Float4E2M1fnx4"))
+float4_e2m1fnx8 = func_gen(("Float4E2M1fnx8"))
+float4_e2m1fnx16 = func_gen(("Float4E2M1fnx16"))
+float4_e2m1fnx32 = func_gen(("Float4E2M1fnx32"))
+float4_e2m1fnx64 = func_gen(("Float4E2M1fnx64"))
# pylint: enable=invalid-name
@@ -2013,37 +2013,38 @@ __all__ = [
"uint64x64",
"e4m3_float8",
"e5m2_float8",
- "e2m1_float4",
+ "float4_e2m1fn",
"float16",
"float32",
"float64",
+ "float4_e2m1fnx2",
"e4m3_float8x4",
"e5m2_float8x4",
- "e2m1_float4x4",
+ "float4_e2m1fnx4",
"float16x4",
"float32x4",
"float64x4",
"e4m3_float8x8",
"e5m2_float8x8",
- "e2m1_float4x8",
+ "float4_e2m1fnx8",
"float16x8",
"float32x8",
"float64x8",
"e4m3_float8x16",
"e5m2_float8x16",
- "e2m1_float4x16",
+ "float4_e2m1fnx16",
"float16x16",
"float32x16",
"float64x16",
"e4m3_float8x32",
"e5m2_float8x32",
- "e2m1_float4x32",
+ "float4_e2m1fnx32",
"float16x32",
"float32x32",
"float64x32",
"e4m3_float8x64",
"e5m2_float8x64",
- "e2m1_float4x64",
+ "float4_e2m1fnx64",
"float16x64",
"float32x64",
"float64x64",
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index 1812d13b9c..d876065325 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -54,7 +54,7 @@ inline void VerifyDataType(DLDataType dtype) {
return;
else if (dtype.bits == 4 && dtype.code == kDLInt)
return;
- else if (dtype.bits == 4 && dtype.code == DataType::kE2M1Float)
+ else if (dtype.bits == 4 && dtype.code == DataType::kFloat4E2M1Fn)
return;
else
ICHECK_EQ(dtype.bits % 8, 0);
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index e452e102bf..a73c9cb5b4 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -757,8 +757,8 @@
TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.E2M1Float4").set_body_typed(E2M1Float4);
-TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E2M1Float4", E2M1Float4);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1fn").set_body_typed(Float4E2M1fn);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1fn", Float4E2M1fn);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index f5f17c70ef..9c2ce0bbb2 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -581,7 +581,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType&
dtype) const {
}
} else if (dtype.code() == DataType::kE4M3Float || dtype.code() ==
DataType::kE5M2Float) {
etype = llvm::Type::getInt8Ty(*ctx);
- } else if (dtype.code() == DataType::kE2M1Float) {
+ } else if (dtype.code() == DataType::kFloat4E2M1Fn) {
etype = llvm::Type::getIntNTy(*ctx, 4);
}
if (!dtype.is_scalar()) {
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 0e0971b8f8..575f52e225 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -789,6 +789,11 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op,
std::ostream& os) { // NOLI
}
}
+ if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
+ // A float4_e2m1fn element has 4 bits, which is an incomplete byte.
+ // So we cannot vector load it.
+ can_vector_load = false;
+ }
if (can_vector_load) {
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
HandleVolatileLoads(ref, op, os);
@@ -839,7 +844,8 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
} else {
arith::PVar<PrimExpr> base;
- if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) {
+ if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr) &&
+ !value_dtype.is_float4_e2m1fn()) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value);
} else {
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 872a024366..20b29750dc 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -82,11 +82,11 @@ std::string GetFP4Type(DataType type) {
} else if (lanes == 4) {
vec = "x4";
} else {
- LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for
FP8";
+ LOG(FATAL) << "Only support scalar and vector types of width (2, 4) for
FP8";
}
stream << "__nv_fp4";
std::string suffix;
- if (type.code() == DataType::kE2M1Float) {
+ if (type.code() == DataType::kFloat4E2M1Fn) {
suffix = "_e2m1";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
@@ -196,7 +196,7 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <cuda_fp4.h>\n";
decl_stream << "#endif\n\n";
}
- declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);
+ declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_,
enable_fp4_);
if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
@@ -597,6 +597,9 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec,
DataType t, int i,
}
ICHECK(!type_name.empty());
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] <<
")))->" << access[i % 2];
+ } else if (t.is_float4_e2m1fn()) {
+ os << "([](__nv_fp4_storage_t v) { __nv_fp4_e2m1 t; t.__x = v; return t;
})((" << vec
+ << ".__x >> " << i * 4 << ") & 0xF)";
} else {
os << vec << "." << access[i];
}
@@ -732,8 +735,8 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op,
std::ostream& os) {
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
if (target_ty.code() == DataType::kE4M3Float || target_ty.code() ==
DataType::kE5M2Float ||
- target_ty.code() == DataType::kE2M1Float || from_ty.code() ==
DataType::kE4M3Float ||
- from_ty.code() == DataType::kE5M2Float || from_ty.code() ==
DataType::kE2M1Float) {
+ target_ty.code() == DataType::kFloat4E2M1Fn || from_ty.code() ==
DataType::kE4M3Float ||
+ from_ty.code() == DataType::kE5M2Float || from_ty.code() ==
DataType::kFloat4E2M1Fn) {
std::ostringstream val;
val << "(";
PrintType(target_ty, val);
@@ -1036,8 +1039,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
- os << dst << "[" + this->PrintExpr(dst_ind) + "]"
- << " = " << src << "[" << src_offset << " + local_id];\n";
+ os << dst << "[" + this->PrintExpr(dst_ind) + "] = " << src << "[" <<
src_offset
+ << " + local_id];\n";
os << "}\n";
} else if (op->op.same_as(builtin::mma_fill())) {
@@ -1155,6 +1158,82 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr <<
")), \"r\"((int)"
<< guard << ")\n";
stream << ");\n";
+ } else if (op->op.same_as(builtin::reinterpret())) {
+ DataType tgt_dtype = op->dtype;
+ DataType src_dtype = op->args[0]->dtype;
+ PrimExpr value = op->args[0];
+
+ // Handle float4_e2m1fn reinterpret
+ if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) {
+ return CodeGenC::VisitExpr_(op, os);
+ }
+ if (src_dtype == tgt_dtype ||
+ tgt_dtype.lanes() * tgt_dtype.bits() == src_dtype.lanes() *
src_dtype.bits()) {
+ return CodeGenC::VisitExpr_(op, os);
+ }
+ CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes())
+ << "E2M1 float4 reinterpret expects source and target to have the same
number of lanes. "
+ << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
+ CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes())
+ << "E2M1 float4 reinterpret expects source and target to have the same
number of bytes. "
+ << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
+
+ int lanes = tgt_dtype.lanes();
+
+ int ssa_scope = BeginScope();
+ if (lanes == 1) {
+ // The case of lane=1 is same as the normal reinterpret,
+ // except that we allow the src and dst dtype to have different number
of bits.
+ std::string rhs = SSAGetID(PrintExpr(value), src_dtype);
+ os << "(*(";
+ this->PrintType(tgt_dtype, os);
+ os << " *)(&(" << rhs << ")))";
+ } else if (lanes == 2) {
+ if (tgt_dtype.is_float4_e2m1fn()) {
+ // We view the source as an uint16, and then extract bits of two fp4
numbers,
+ // and finally reinterpret the result as fp4x2.
+ value = tir::Call(DataType::UInt(16), tir::builtin::reinterpret(),
{value});
+ tir::Var temp_var("temp_var", DataType::UInt(16));
+ value = tir::Let(
+ temp_var, value,
+ tir::Cast(DataType::UInt(8), (temp_var &
IntImm(DataType::UInt(16), 0xF)) |
+ ((temp_var >> 4) &
IntImm(DataType::UInt(16), 0xF0))));
+ } else {
+ value = tir::Cast(DataType::UInt(16),
+ tir::Call(DataType::UInt(8),
tir::builtin::reinterpret(), {value}));
+ tir::Var temp_var("temp_var", DataType::UInt(16));
+ value = tir::Let(temp_var, value,
+ (temp_var & IntImm(DataType::UInt(16), 0xF)) |
+ ((temp_var & IntImm(DataType::UInt(16), 0xF0)) <<
4));
+ }
+ os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(),
{value}));
+ } else if (lanes == 4) {
+ if (tgt_dtype.is_float4_e2m1fn()) {
+ // We view the source as an uint32, and then extract bits of four fp4
numbers,
+ // and finally reinterpret the result as fp4x4.
+ value = tir::Call(DataType::UInt(32), tir::builtin::reinterpret(),
{value});
+ tir::Var temp_var("temp_var", DataType::UInt(32));
+ value = tir::Let(temp_var, value,
+ tir::Cast(DataType::UInt(16),
+ (temp_var & IntImm(DataType::UInt(32),
0xF)) |
+ ((temp_var >> 4) &
IntImm(DataType::UInt(32), 0xF0)) |
+ ((temp_var >> 8) &
IntImm(DataType::UInt(32), 0xF00)) |
+ ((temp_var >> 12) &
IntImm(DataType::UInt(32), 0xF000))));
+ } else {
+ value = tir::Cast(DataType::UInt(32),
+ tir::Call(DataType::UInt(16),
tir::builtin::reinterpret(), {value}));
+ tir::Var temp_var("temp_var", DataType::UInt(32));
+ value = tir::Let(temp_var, value,
+ (temp_var & IntImm(DataType::UInt(32), 0xF)) |
+ ((temp_var & IntImm(DataType::UInt(32), 0xF0)) <<
4) |
+ ((temp_var & IntImm(DataType::UInt(32), 0xF00))
<< 8) |
+ ((temp_var & IntImm(DataType::UInt(32), 0xF000))
<< 12));
+ }
+ os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(),
{value}));
+ } else {
+ LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: "
<< lanes;
+ }
+ EndScope(ssa_scope);
} else {
CodeGenC::VisitExpr_(op, os);
}
diff --git a/src/target/source/literal/cuda_half_t.h
b/src/target/source/literal/cuda_half_t.h
index abdf22df26..86f2219fe8 100644
--- a/src/target/source/literal/cuda_half_t.h
+++ b/src/target/source/literal/cuda_half_t.h
@@ -385,8 +385,9 @@ static constexpr const char* _cuda_warp_intrinsic_util = R"(
)";
-void declare_vector_type_extensions(std::ostringstream& stream, bool
enable_fp16, bool enable_fp8) {
- if (enable_fp16 || enable_fp8) {
+void declare_vector_type_extensions(std::ostringstream& stream, bool
enable_fp16, bool enable_fp8,
+ bool enable_fp4) {
+ if (enable_fp16 || enable_fp8 || enable_fp4) {
stream << R"(
struct __align__(8) half4 {
__half x, y, z, w;
@@ -455,6 +456,26 @@ struct __align__(8) half4 {
result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
return result;
}
+ )";
+ }
+ if (enable_fp4) {
+ stream << R"(
+ __host__ __device__ explicit half4(const __nv_fp4x4_e2m1& fp4x4) {
+ __nv_fp4x2_storage_t lo_part, hi_part;
+ lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
+ hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
+ __half2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
+ __half2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1));
+ x = reinterpret_cast<__half*>(&lo_half2)[0];
+ y = reinterpret_cast<__half*>(&lo_half2)[1];
+ z = reinterpret_cast<__half*>(&hi_half2)[0];
+ w = reinterpret_cast<__half*>(&hi_half2)[1];
+ }
+ __host__ __device__ explicit operator __nv_fp4x4_e2m1() const {
+ __half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
+ __half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
+ return __nv_fp4x4_e2m1(lo_half2, hi_half2);
+ }
)";
}
stream << R"(
@@ -462,6 +483,20 @@ struct __align__(8) half4 {
__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
return half4(x, y, z, w);
}
+)";
+ }
+ if (enable_fp4) {
+ stream << R"(
+__device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1
y) {
+ __nv_fp4x2_e2m1 result;
+ result.__x = (x.__x) | (y.__x << 4);
+ return result;
+}
+__device__ __nv_fp4x4_e2m1 make___nv_fp4x4_e2m1(__nv_fp4_e2m1 a, __nv_fp4_e2m1
b, __nv_fp4_e2m1 c, __nv_fp4_e2m1 d) {
+ __nv_fp4x4_e2m1 result;
+ result.__x = (static_cast<__nv_fp4x4_storage_t>(a.__x)) |
(static_cast<__nv_fp4x4_storage_t>(b.__x) << 4) |
(static_cast<__nv_fp4x4_storage_t>(c.__x) << 8) |
(static_cast<__nv_fp4x4_storage_t>(d.__x) << 12);
+ return result;
+}
)";
}
}
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 039acf7e92..3dab634f16 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -425,8 +425,10 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span
span) {
PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) {
if (value.dtype() == t) return value;
if (!t.is_scalable_vector() && !value.dtype().is_scalable_vector()) {
- ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() *
t.lanes())
- << "Bitcast requires size match " << t << " vs " << value.dtype();
+ ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() *
t.lanes() ||
+ ((value.dtype().is_float4_e2m1fn() || t.is_float4_e2m1fn()) &&
+ value.dtype().bytes() * value.dtype().lanes() == t.bytes() *
t.lanes()))
+ << "Reinterpret requires size match " << t << " vs " << value.dtype();
}
return tir::Call(t, tir::builtin::reinterpret(), {value}, span);
}
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py
b/tests/python/codegen/test_target_codegen_cuda_fp4.py
index f137e83cc9..46825826a9 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fp4.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py
@@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.
+from itertools import product
-import tvm
-from tvm.script import tir as T
import numpy as np
+
+import tvm
import tvm.testing
-from tvm.script import ir as I, relax as R, tir as T
+from tvm.script import tir as T
try:
import ml_dtypes
@@ -28,12 +29,12 @@ except ImportError:
ml_dtypes = None
native_dtype, promoted_dtype = tvm.testing.parameters(
- ("e2m1_float4x2", "float32x2"),
- ("e2m1_float4x2", "float16x2"),
+ ("float4_e2m1fnx2", "float32x2"),
+ ("float4_e2m1fnx2", "float16x2"),
)
[email protected]_cuda_compute_version(9)
[email protected]_cuda_compute_version(10)
def test_e2m1_vector_conversions(native_dtype, promoted_dtype):
vector_length = 64
@@ -63,7 +64,6 @@ def test_e2m1_vector_conversions(native_dtype,
promoted_dtype):
target = "cuda"
fadd = tvm.build(sch.mod, target=target)
- cuda_src = fadd.imported_modules[0].get_source()
dev = tvm.device(target, 0)
numpytype = "float4_e2m1fn"
@@ -92,5 +92,124 @@ def test_e2m1_vector_conversions(native_dtype,
promoted_dtype):
)
[email protected]_cuda_compute_version(10)
+def test_e2m1_schedule_vectorize():
+ native_dtype = "float4_e2m1fn"
+ n = 128
+
+ dev = tvm.device("cuda", 0)
+ target = tvm.target.Target.from_device(dev)
+ for promoted_dtype, vector_length in product(
+ ["float16", "bfloat16", "float32"],
+ [1, 2, 4],
+ ):
+
+ @T.prim_func
+ def add(
+ A: T.Buffer((n,), native_dtype),
+ B: T.Buffer((n,), native_dtype),
+ C: T.Buffer((n,), native_dtype),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i in range(n):
+ with T.block("C"):
+ v_i = T.axis.spatial(n, i)
+ T.reads(A[v_i], B[v_i])
+ T.writes(C[v_i])
+ C[v_i] = T.Cast(
+ native_dtype,
+ T.Cast(promoted_dtype, A[v_i]) +
T.Cast(promoted_dtype, B[v_i]),
+ )
+
+ sch = tvm.tir.Schedule(add)
+ block = sch.get_block("C")
+ b = sch.get_loops(block)
+ bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length])
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec)
+
+ fadd = tvm.build(sch.mod, target=target)
+
+ numpytype = "float4_e2m1fn"
+ promoted_base_dtype = promoted_dtype
+
+ a_np = np.random.uniform(low=-6, high=6, size=(n,)).astype(numpytype)
+ a = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev)
+ a.copyfrom(a_np)
+ b_np = np.random.uniform(low=-6, high=6, size=(n,)).astype(numpytype)
+ b = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev)
+ b.copyfrom(b_np)
+ c = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev)
+ fadd(a, b, c)
+
+ if promoted_base_dtype != "bfloat16":
+ tvm.testing.assert_allclose(
+ c.numpy().astype(promoted_base_dtype), (a_np +
b_np).astype(promoted_base_dtype)
+ )
+ else:
+ # assert_allclose with bfloat16 throws an error here.
+ # Thus we convert bfloat16 to float32 for comparison.
+ tvm.testing.assert_allclose(
+ c.numpy().astype(promoted_base_dtype).astype("float32"),
+ (a_np + b_np).astype(promoted_base_dtype).astype("float32"),
+ )
+
+
[email protected]_cuda_compute_version(10)
+def test_e2m1_reinterpret():
+ n = 128
+
+ dev = tvm.device("cuda", 0)
+ target = tvm.target.Target.from_device(dev)
+
+ def get_reinterpret_mod(src_dtype, dst_dtype, vector_length):
+ @T.prim_func
+ def reinterpret(
+ A: T.Buffer((n,), src_dtype),
+ B: T.Buffer((n,), dst_dtype),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i in range(n):
+ with T.block("C"):
+ v_i = T.axis.spatial(n, i)
+ T.reads(A[v_i])
+ T.writes(B[v_i])
+ B[v_i] = T.reinterpret(dst_dtype, A[v_i])
+
+ sch = tvm.tir.Schedule(reinterpret)
+ block = sch.get_block("C")
+ b = sch.get_loops(block)
+ bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length])
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec)
+ return sch.mod
+
+ # Part 1. reinterpret float4_e2m1fn to uint8
+ for vector_length in [1, 2, 4]:
+ mod = get_reinterpret_mod("float4_e2m1fn", "uint8", vector_length)
+ f = tvm.build(mod, target=target)
+ a_np = np.random.uniform(low=-6, high=6,
size=(n,)).astype("float4_e2m1fn")
+ a = tvm.nd.empty(shape=(n,), dtype="float4_e2m1fn", device=dev)
+ a.copyfrom(a_np)
+ b = tvm.nd.empty(shape=(n,), dtype="uint8", device=dev)
+ f(a, b)
+ tvm.testing.assert_allclose(b.numpy(), a_np.view("uint8"))
+
+ # Part 2. reinterpret uint8 to float4_e2m1fn
+ for vector_length in [1, 2, 4]:
+ mod = get_reinterpret_mod("uint8", "float4_e2m1fn", vector_length)
+ f = tvm.build(mod, target=target)
+ a_np = np.random.uniform(low=-6, high=6, size=(n,)).astype("uint8")
+ a = tvm.nd.empty(shape=(n,), dtype="uint8", device=dev)
+ a.copyfrom(a_np)
+ b = tvm.nd.empty(shape=(n,), dtype="float4_e2m1fn", device=dev)
+ f(a, b)
+ tvm.testing.assert_allclose(
+ b.numpy().astype("float32"),
a_np.view("float4_e2m1fn").astype("float32")
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()