This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new aee96e64b5 [DTYPE] Fix dtype functions after dtype refactor (#18041)
aee96e64b5 is described below
commit aee96e64b576da4806e5dec563116fd1258fdd29
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Jun 5 23:50:37 2025 -0400
[DTYPE] Fix dtype functions after dtype refactor (#18041)
This PR fixes dtype function compilation in NCCL after dtype refactor.
---
src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu | 8 ++++----
src/runtime/disco/nccl/nccl.cc | 2 +-
src/runtime/disco/nccl/nccl_context.h | 4 ++--
3 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
index b8732357c7..5164958afe 100644
--- a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
+++ b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
@@ -67,8 +67,8 @@ void tvm_cutlass_fp8_blockwise_scaled_gemm(NDArray a, NDArray
b, NDArray scales_
CHECK_EQ(scales_b->shape[1] * block_size_1, k);
using tvm::runtime::DataType;
- CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3());
- CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3());
+ CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
+ CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
@@ -128,8 +128,8 @@ void tvm_cutlass_fp8_blockwise_scaled_bmm(NDArray a,
NDArray b, NDArray scales_a
CHECK_EQ(scales_b->shape[2] * block_size_1, k);
using tvm::runtime::DataType;
- CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3());
- CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3());
+ CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
+ CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 8095cbeeea..2b860b6b63 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -120,7 +120,7 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, bool
in_group, NDArray recv
int64_t numel = shape->Product();
deviceStream_t stream = ctx->GetDefaultStream();
DataType dtype = DataType(send->dtype);
- if (dtype == DataType::NVFloat8E4M3() || dtype == DataType::NVFloat8E5M2()) {
+ if (dtype == DataType::Float8E4M3FN() || dtype == DataType::Float8E5M2()) {
LOG(FATAL) << "Float8 data type cannot be allreduced, as nccl does not
support this data type.";
}
NCCL_CALL(ncclAllReduce(send->data, recv->data, numel,
diff --git a/src/runtime/disco/nccl/nccl_context.h
b/src/runtime/disco/nccl/nccl_context.h
index fff165bfdd..e24687d867 100644
--- a/src/runtime/disco/nccl/nccl_context.h
+++ b/src/runtime/disco/nccl/nccl_context.h
@@ -86,8 +86,8 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype)
{
if (dtype == DataType::Int(8)) {
return ncclInt8;
}
- if (dtype == DataType::UInt(8) || dtype == DataType::NVFloat8E4M3() ||
- dtype == DataType::NVFloat8E5M2()) {
+ if (dtype == DataType::UInt(8) || dtype == DataType::Float8E4M3FN() ||
+ dtype == DataType::Float8E5M2()) {
// For float8 data type, pretend to be uint8 in nccl.
// And will throw error when allreduce, as it makes no sense in this case.
return ncclUint8;