This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new aee96e64b5 [DTYPE] Fix dtype functions after dtype refactor (#18041)
aee96e64b5 is described below

commit aee96e64b576da4806e5dec563116fd1258fdd29
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Jun 5 23:50:37 2025 -0400

    [DTYPE] Fix dtype functions after dtype refactor (#18041)
    
    This PR fixes dtype function compilation in NCCL after dtype refactor.
---
 src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu | 8 ++++----
 src/runtime/disco/nccl/nccl.cc                           | 2 +-
 src/runtime/disco/nccl/nccl_context.h                    | 4 ++--
 3 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu 
b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
index b8732357c7..5164958afe 100644
--- a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
+++ b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
@@ -67,8 +67,8 @@ void tvm_cutlass_fp8_blockwise_scaled_gemm(NDArray a, NDArray 
b, NDArray scales_
   CHECK_EQ(scales_b->shape[1] * block_size_1, k);
 
   using tvm::runtime::DataType;
-  CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3());
-  CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3());
+  CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
+  CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
   CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
   CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
   CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
@@ -128,8 +128,8 @@ void tvm_cutlass_fp8_blockwise_scaled_bmm(NDArray a, 
NDArray b, NDArray scales_a
   CHECK_EQ(scales_b->shape[2] * block_size_1, k);
 
   using tvm::runtime::DataType;
-  CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3());
-  CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3());
+  CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
+  CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
   CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
   CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
   CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 8095cbeeea..2b860b6b63 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -120,7 +120,7 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, bool 
in_group, NDArray recv
   int64_t numel = shape->Product();
   deviceStream_t stream = ctx->GetDefaultStream();
   DataType dtype = DataType(send->dtype);
-  if (dtype == DataType::NVFloat8E4M3() || dtype == DataType::NVFloat8E5M2()) {
+  if (dtype == DataType::Float8E4M3FN() || dtype == DataType::Float8E5M2()) {
     LOG(FATAL) << "Float8 data type cannot be allreduced, as nccl does not 
support this data type.";
   }
   NCCL_CALL(ncclAllReduce(send->data, recv->data, numel,
diff --git a/src/runtime/disco/nccl/nccl_context.h 
b/src/runtime/disco/nccl/nccl_context.h
index fff165bfdd..e24687d867 100644
--- a/src/runtime/disco/nccl/nccl_context.h
+++ b/src/runtime/disco/nccl/nccl_context.h
@@ -86,8 +86,8 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) 
{
   if (dtype == DataType::Int(8)) {
     return ncclInt8;
   }
-  if (dtype == DataType::UInt(8) || dtype == DataType::NVFloat8E4M3() ||
-      dtype == DataType::NVFloat8E5M2()) {
+  if (dtype == DataType::UInt(8) || dtype == DataType::Float8E4M3FN() ||
+      dtype == DataType::Float8E5M2()) {
     // For float8 data type, pretend to be uint8 in nccl.
     // And will throw error when allreduce, as it makes no sense in this case.
     return ncclUint8;

Reply via email to