This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e60fd806eb [3rdparty] Enable bfloat16 for custom allreduce kernel
(#17780)
e60fd806eb is described below
commit e60fd806ebe91f5f6f394eb89101dbe4cf57df8c
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Mar 26 19:06:45 2025 -0400
[3rdparty] Enable bfloat16 for custom allreduce kernel (#17780)
This PR enables the datatype bfloat16 for the 3rdparty custom
allreduce kernel. Prior to this PR, bfloat16 was disabled by macro.
---
3rdparty/tensorrt_llm/custom_allreduce_kernels.cu | 11 +++--------
1 file changed, 3 insertions(+), 8 deletions(-)
diff --git a/3rdparty/tensorrt_llm/custom_allreduce_kernels.cu
b/3rdparty/tensorrt_llm/custom_allreduce_kernels.cu
index 36ac0e3a43..d26baea342 100644
--- a/3rdparty/tensorrt_llm/custom_allreduce_kernels.cu
+++ b/3rdparty/tensorrt_llm/custom_allreduce_kernels.cu
@@ -14,6 +14,7 @@
* limitations under the License.
*/
+#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <dlpack/dlpack.h>
#include <stdint.h>
@@ -69,7 +70,6 @@ struct PackedOn16Bytes<half> {
using Type = PackedHalf;
};
-#ifdef ENABLE_BF16
using PackedBFloat16 = union {
int4 packed;
__nv_bfloat162 unpacked[4];
@@ -79,7 +79,6 @@ template <>
struct PackedOn16Bytes<__nv_bfloat16> {
using Type = PackedBFloat16;
};
-#endif
// add two 128b data
template <typename T>
@@ -387,13 +386,9 @@ void customAllReduce(AllReduceParams& params, void* data,
size_t elts, DLDataTyp
invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
} else if (dataType.code == kDLFloat && dataType.bits == 16) {
invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
- }
-#ifdef ENABLE_BF16
- else if (dataType.code == kDLBfloat && dataType.bits == 16) {
+ } else if (dataType.code == kDLBfloat && dataType.bits == 16) {
invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
- }
-#endif
- else {
+ } else {
LOG(FATAL) << ("Unsupported dataType for customAllReduce");
}
}