This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e60fd806eb [3rdparty] Enable bfloat16 for custom allreduce kernel 
(#17780)
e60fd806eb is described below

commit e60fd806ebe91f5f6f394eb89101dbe4cf57df8c
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Mar 26 19:06:45 2025 -0400

    [3rdparty] Enable bfloat16 for custom allreduce kernel (#17780)
    
    This PR enables the datatype bfloat16 for the 3rdparty custom
    allreduce kernel. Prior to this PR, bfloat16 was disabled by macro.
---
 3rdparty/tensorrt_llm/custom_allreduce_kernels.cu | 11 +++--------
 1 file changed, 3 insertions(+), 8 deletions(-)

diff --git a/3rdparty/tensorrt_llm/custom_allreduce_kernels.cu 
b/3rdparty/tensorrt_llm/custom_allreduce_kernels.cu
index 36ac0e3a43..d26baea342 100644
--- a/3rdparty/tensorrt_llm/custom_allreduce_kernels.cu
+++ b/3rdparty/tensorrt_llm/custom_allreduce_kernels.cu
@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 
+#include <cuda_bf16.h>
 #include <cuda_fp16.h>
 #include <dlpack/dlpack.h>
 #include <stdint.h>
@@ -69,7 +70,6 @@ struct PackedOn16Bytes<half> {
   using Type = PackedHalf;
 };
 
-#ifdef ENABLE_BF16
 using PackedBFloat16 = union {
   int4 packed;
   __nv_bfloat162 unpacked[4];
@@ -79,7 +79,6 @@ template <>
 struct PackedOn16Bytes<__nv_bfloat16> {
   using Type = PackedBFloat16;
 };
-#endif
 
 // add two 128b data
 template <typename T>
@@ -387,13 +386,9 @@ void customAllReduce(AllReduceParams& params, void* data, 
size_t elts, DLDataTyp
     invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
   } else if (dataType.code == kDLFloat && dataType.bits == 16) {
     invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
-  }
-#ifdef ENABLE_BF16
-  else if (dataType.code == kDLBfloat && dataType.bits == 16) {
+  } else if (dataType.code == kDLBfloat && dataType.bits == 16) {
     invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
-  }
-#endif
-  else {
+  } else {
     LOG(FATAL) << ("Unsupported dataType for customAllReduce");
   }
 }

Reply via email to