This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2a62c72154 [FP8][Codegen] Add make_fp8 vector constructors (#17065)
2a62c72154 is described below
commit 2a62c7215419a859321460c7fb9e2da272f4d003
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jun 5 07:45:04 2024 -0700
[FP8][Codegen] Add make_fp8 vector constructors (#17065)
* [FP8][Codegen] Add make_fp8 vector constructors.
Allows vectorized fp8 loading.
---------
Co-authored-by: Chris Sullivan <[email protected]>
---
src/target/source/codegen_cuda.cc | 25 +++++++++++-----------
src/target/source/literal/cuda_half_t.h | 20 +++++++++++++++++
.../python/codegen/test_target_codegen_cuda_fp8.py | 2 +-
3 files changed, 33 insertions(+), 14 deletions(-)
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index ecb0957611..bd28048301 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -48,21 +48,22 @@ std::string GetFP8Type(DataType type) {
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
- vec = "_2";
+ vec = "x2";
} else if (lanes == 4) {
- vec = "_4";
- } else if (lanes == 8) {
- vec = "_8";
+ vec = "x4";
} else {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for
FP8";
}
+ stream << "__nv_fp8";
+ std::string suffix;
if (type.code() == DataType::kE4M3Float) {
- stream << "fp8_e4" << vec << "_t";
+ suffix = "_e4m3";
} else if (type.code() == DataType::kE5M2Float) {
- stream << "fp8_e5" << vec << "_t";
+ suffix = "_e5m2";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
}
+ stream << vec << suffix;
return stream.str();
}
@@ -146,12 +147,6 @@ std::string CodeGenCUDA::Finish() {
if (enable_fp8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n";
decl_stream << "#include <cuda_fp8.h>\n";
- decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n";
- decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n";
- decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n";
- decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n";
- decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n";
- decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n";
decl_stream << "#endif\n\n";
}
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);
@@ -299,7 +294,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os)
{ // NOLINT(*)
if (!fail) return;
} else if (t.is_float8()) {
enable_fp8_ = true;
- os << GetFP8Type(t);
+ if (t.lanes() <= 4) {
+ os << GetFP8Type(t);
+ } else {
+ os << "uint" << t.lanes() / 4;
+ }
return;
} else if (t == DataType::Bool()) {
os << "bool";
diff --git a/src/target/source/literal/cuda_half_t.h
b/src/target/source/literal/cuda_half_t.h
index 27d44d9f7f..c5ecda07a4 100644
--- a/src/target/source/literal/cuda_half_t.h
+++ b/src/target/source/literal/cuda_half_t.h
@@ -431,6 +431,26 @@ struct __align__(8) half4 {
(static_cast<__uint32_t>(lo_part.__x) |
(static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
}
+ __device__ __nv_fp8x2_e5m2 make_fp8x2_e5m2(__nv_fp8_storage_t x,
__nv_fp8_storage_t y) {
+ __nv_fp8x2_e5m2 result;
+ result.__x = (x) | (y << 8);
+ return result;
+ }
+ __device__ __nv_fp8x4_e5m2 make_fp8x4_e5m2(__nv_fp8_storage_t a,
__nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) {
+ __nv_fp8x4_e5m2 result;
+ result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
+ return result;
+ }
+ __device__ __nv_fp8x2_e4m3 make_fp8x2_e4m3(__nv_fp8_storage_t x,
__nv_fp8_storage_t y) {
+ __nv_fp8x2_e4m3 result;
+ result.__x = (x) | (y << 8);
+ return result;
+ }
+ __device__ __nv_fp8x4_e4m3 make_fp8x4_e4m3(__nv_fp8_storage_t a,
__nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) {
+ __nv_fp8x4_e4m3 result;
+ result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
+ return result;
+ }
)";
}
stream << R"(
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py
b/tests/python/codegen/test_target_codegen_cuda_fp8.py
index 5566ae2434..adcb05839b 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fp8.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py
@@ -64,7 +64,7 @@ def test_e4m3_conversions():
fadd = tvm.build(sch.mod, target=target)
cuda_src = fadd.imported_modules[0].get_source()
- assert "fp8_e4_t" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in
generated CUDA"
+ assert "__nv_fp8_e4m3" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found
in generated CUDA"
dev = tvm.device(target, 0)