This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 89e9028849 [Cutlass] Add group gemm kernels (#16751)
89e9028849 is described below
commit 89e9028849ae3803a10eda086434c8d9e3bc3298
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Mar 20 05:50:19 2024 -0700
[Cutlass] Add group gemm kernels (#16751)
* [CMAKE][CUTLASS] Improve dependancy management with different cutlass
versions.
* Each cutlass-based submodule library now uses its own cutlass submodule
dependancy
* TVM's cutlass submodule is decoupled from others and is bumped to
v3.4.1 for H100 support
* Add scaffold for new cutlass fp8 dequant gemm interface targetting
TVM's cutlass submodule
* Remove handling for moe_gemm.cc and flash_decoding.cu which are no longer
used upstream.
* Add cutlass fp8 group gemm
* Add fp16 grouped gemm support for sm90
* [Cutlass] Support alpha scaling in fp8 group gemm
* [Cutlass] Support device alpha_ptr for fp8 group gemm
---------
Co-authored-by: Chris Sullivan <[email protected]>
Co-authored-by: masahi <[email protected]>
---
3rdparty/cutlass | 2 +-
CMakeLists.txt | 27 ++-
cmake/modules/contrib/CUTLASS.cmake | 49 ++++-
src/runtime/contrib/cutlass/fp16_group_gemm.cu | 70 ++++++++
src/runtime/contrib/cutlass/fp8_group_gemm.cu | 83 +++++++++
src/runtime/contrib/cutlass/group_gemm_runner.cuh | 209 ++++++++++++++++++++++
src/runtime/contrib/cutlass/weight_preprocess.cc | 2 +-
tests/python/contrib/test_cutlass.py | 98 ++++++++++
8 files changed, 531 insertions(+), 9 deletions(-)
diff --git a/3rdparty/cutlass b/3rdparty/cutlass
index ff61a49dd1..bbe579a9e3 160000
--- a/3rdparty/cutlass
+++ b/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit ff61a49dd1a728a96e9a8434ed408a2a52d73119
+Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49
diff --git a/CMakeLists.txt b/CMakeLists.txt
index c9d836b681..906509004a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -369,6 +369,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/minrpc/*.cc
src/runtime/relax_vm/*.cc
)
+set(TVM_RUNTIME_EXT_OBJS "")
if(BUILD_FOR_HEXAGON)
if(NOT BUILD_STATIC_RUNTIME)
@@ -595,18 +596,32 @@ add_library(tvm_libinfo_objs OBJECT ${LIBINFO_FILE})
include(GNUInstallDirs)
if(NOT BUILD_DUMMY_LIBTVM)
- add_library(tvm SHARED $<TARGET_OBJECTS:tvm_objs>
$<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
+ add_library(tvm SHARED
+ $<TARGET_OBJECTS:tvm_objs>
+ $<TARGET_OBJECTS:tvm_runtime_objs>
+ $<TARGET_OBJECTS:tvm_libinfo_objs>
+ ${TVM_RUNTIME_EXT_OBJS}
+ )
+
else()
# dummy version of libtvm that can be used by downstream to specify
dependencies
# the real runner still need a full version of libtvm
- add_library(tvm SHARED $<TARGET_OBJECTS:tvm_runtime_objs>
$<TARGET_OBJECTS:tvm_libinfo_objs>)
+ add_library(tvm SHARED
+ $<TARGET_OBJECTS:tvm_runtime_objs>
+ $<TARGET_OBJECTS:tvm_libinfo_objs>
+ ${TVM_RUNTIME_EXT_OBJS}
+ )
endif()
target_include_directories(tvm PUBLIC
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>")
set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS
"${TVM_NO_UNDEFINED_SYMBOLS}")
set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}")
if(BUILD_STATIC_RUNTIME)
- add_library(tvm_runtime STATIC $<TARGET_OBJECTS:tvm_runtime_objs>
$<TARGET_OBJECTS:tvm_libinfo_objs>)
+ add_library(tvm_runtime STATIC
+ $<TARGET_OBJECTS:tvm_runtime_objs>
+ $<TARGET_OBJECTS:tvm_libinfo_objs>
+ ${TVM_RUNTIME_EXT_OBJS}
+ )
set(NOTICE_MULTILINE
"You have build static version of the TVM runtime library. Make "
"sure to use --whole-archive when linking it into your project.")
@@ -614,7 +629,11 @@ if(BUILD_STATIC_RUNTIME)
add_custom_command(TARGET tvm_runtime POST_BUILD
COMMAND ${CMAKE_COMMAND} -E cmake_echo_color --yellow --bold ${NOTICE})
else()
- add_library(tvm_runtime SHARED $<TARGET_OBJECTS:tvm_runtime_objs>
$<TARGET_OBJECTS:tvm_libinfo_objs>)
+ add_library(tvm_runtime SHARED
+ $<TARGET_OBJECTS:tvm_runtime_objs>
+ $<TARGET_OBJECTS:tvm_libinfo_objs>
+ ${TVM_RUNTIME_EXT_OBJS}
+ )
set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS
"${TVM_NO_UNDEFINED_SYMBOLS}")
endif()
diff --git a/cmake/modules/contrib/CUTLASS.cmake
b/cmake/modules/contrib/CUTLASS.cmake
index 9ce27820b8..fa4a608f61 100644
--- a/cmake/modules/contrib/CUTLASS.cmake
+++ b/cmake/modules/contrib/CUTLASS.cmake
@@ -16,16 +16,59 @@
# under the License.
if(USE_CUDA AND USE_CUTLASS)
- tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC
src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc)
+ set(CUTLASS_GEN_COND "$<AND:$<BOOL:${USE_CUDA}>,$<BOOL:${USE_CUTLASS}>>")
+ set(CUTLASS_RUNTIME_OBJS "")
+
+ tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC
+ src/relay/backend/contrib/cutlass/*.cc
+ src/relax/backend/contrib/cutlass/*.cc
+ )
list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC})
set(FPA_INTB_GEMM_TVM_BINDING ON)
set(FPA_INTB_GEMM_TVM_HOME ${PROJECT_SOURCE_DIR})
- set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
+ ### Build cutlass runtime objects for fpA_intB_gemm using its cutlass
submodule
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm)
+ target_include_directories(fpA_intB_gemm PRIVATE
+ ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm
+ ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
+ )
+ set(CUTLASS_FPA_INTB_RUNTIME_SRCS "")
+ list(APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS
src/runtime/contrib/cutlass/weight_preprocess.cc)
+ add_library(fpA_intB_cutlass_objs OBJECT ${CUTLASS_FPA_INTB_RUNTIME_SRCS})
+ target_compile_definitions(fpA_intB_cutlass_objs PRIVATE
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
+ target_include_directories(fpA_intB_cutlass_objs PRIVATE
+ ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm
+ ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
+ )
+ list(APPEND CUTLASS_RUNTIME_OBJS
"$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:fpA_intB_cutlass_objs>>")
+
+ ### Build cutlass runtime objects for flash attention
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn)
- list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
+ target_include_directories(flash_attn PRIVATE
+ ${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn
+ ${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn/cutlass/include
+ )
+
+ ### Build cutlass runtime objects using TVM's 3rdparty/cutlass submodule
+ set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
+ set(TVM_CUTLASS_RUNTIME_SRCS "")
+
+ if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
+ list(APPEND TVM_CUTLASS_RUNTIME_SRCS
src/runtime/contrib/cutlass/fp16_group_gemm.cu)
+ list(APPEND TVM_CUTLASS_RUNTIME_SRCS
src/runtime/contrib/cutlass/fp8_group_gemm.cu)
+ endif()
+ if(TVM_CUTLASS_RUNTIME_SRCS)
+ add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
+ target_compile_options(tvm_cutlass_objs PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>)
+ target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include)
+ target_compile_definitions(tvm_cutlass_objs PRIVATE
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
+ list(APPEND CUTLASS_RUNTIME_OBJS
"$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
+ endif()
+
+ ### Add cutlass objects to list of TVM runtime extension objs
+ list(APPEND TVM_RUNTIME_EXT_OBJS "${CUTLASS_RUNTIME_OBJS}")
message(STATUS "Build with CUTLASS")
endif()
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cu
b/src/runtime/contrib/cutlass/fp16_group_gemm.cu
new file mode 100644
index 0000000000..3c051819b2
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cu
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include "group_gemm_runner.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
+
+template <>
+struct KernelTraits<cutlass::half_t> {
+ using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
+ using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
+ using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a
cluster
+};
+
+namespace tvm {
+namespace runtime {
+
+template <typename ElementA, typename ElementB, typename ElementC>
+void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr,
NDArray workspace,
+ NDArray out) {
+ // Workspace is used for storing device-side group gemm arguments and
cutlass internal workspace.
+ // Recommened size is 4MB.
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ CHECK_EQ(x->ndim, 2);
+ CHECK_EQ(weight->ndim, 3);
+ CHECK_EQ(indptr->ndim, 1);
+ CHECK_EQ(workspace->ndim, 1);
+ CHECK_EQ(out->ndim, 2);
+ int num_groups = weight->shape[0];
+ int n = weight->shape[1];
+ int k = weight->shape[2];
+ float alpha = 1.0f;
+ float beta = 0.0f;
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+ cutlass_group_gemm(static_cast<ElementA*>(x->data),
static_cast<ElementB*>(weight->data),
+ static_cast<int64_t*>(indptr->data),
static_cast<uint8_t*>(workspace->data),
+ workspace->shape[0], n, k, num_groups, alpha, beta,
+ static_cast<ElementC*>(out->data), stream);
+}
+
+TVM_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90")
+ .set_body_typed(tvm_cutlass_group_gemm_sm90<cutlass::half_t,
cutlass::half_t, cutlass::half_t>);
+
+} // namespace runtime
+} // namespace tvm
+
+#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm.cu
b/src/runtime/contrib/cutlass/fp8_group_gemm.cu
new file mode 100644
index 0000000000..c93da6ff57
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp8_group_gemm.cu
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include "group_gemm_runner.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
+
+template <>
+struct KernelTraits<cutlass::float_e4m3_t> {
+ using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
+ using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
+ using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a
cluster
+};
+
+template <>
+struct KernelTraits<cutlass::float_e5m2_t> :
KernelTraits<cutlass::float_e4m3_t> {};
+
+namespace tvm {
+namespace runtime {
+
+template <typename ElementA, typename ElementB, typename ElementC>
+void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr,
NDArray workspace,
+ NDArray alpha, NDArray out) {
+ // Workspace is used for storing device-side group gemm arguments and
cutlass internal workspace.
+ // Recommened size is 4MB.
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ CHECK_EQ(x->ndim, 2);
+ CHECK_EQ(weight->ndim, 3);
+ CHECK_EQ(indptr->ndim, 1);
+ CHECK_EQ(workspace->ndim, 1);
+ CHECK_EQ(out->ndim, 2);
+ CHECK_EQ(alpha->dtype.code, kDLFloat);
+ CHECK_EQ(alpha->dtype.bits, 32);
+ int num_groups = weight->shape[0];
+ int n = weight->shape[1];
+ int k = weight->shape[2];
+ const float* beta = nullptr;
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+ cutlass_group_gemm(static_cast<ElementA*>(x->data),
static_cast<ElementB*>(weight->data),
+ static_cast<int64_t*>(indptr->data),
static_cast<uint8_t*>(workspace->data),
+ workspace->shape[0], n, k, num_groups,
static_cast<float*>(alpha->data), beta,
+ static_cast<ElementC*>(out->data), stream);
+}
+
+TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16")
+ .set_body_typed(
+ tvm_cutlass_fp8_group_gemm<cutlass::float_e5m2_t,
cutlass::float_e5m2_t, cutlass::half_t>);
+
+TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e4m3_fp16")
+ .set_body_typed(
+ tvm_cutlass_fp8_group_gemm<cutlass::float_e5m2_t,
cutlass::float_e4m3_t, cutlass::half_t>);
+
+TVM_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e4m3_fp16")
+ .set_body_typed(
+ tvm_cutlass_fp8_group_gemm<cutlass::float_e4m3_t,
cutlass::float_e4m3_t, cutlass::half_t>);
+
+} // namespace runtime
+} // namespace tvm
+
+#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh
b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
new file mode 100644
index 0000000000..50bdcf7bec
--- /dev/null
+++ b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <variant>
+#include <vector>
+
+#include "../../cuda/cuda_common.h"
+
+// clang-format off
+#include "cutlass/cutlass.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/tensor_ref.h"
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/thread/linear_combination.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/group_array_problem_shape.hpp"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+// clang-format on
+
+#define CUTLASS_CHECK(status)
\
+ {
\
+ cutlass::Status error = status;
\
+ if (error != cutlass::Status::kSuccess) {
\
+ std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << "
at: " << __LINE__ \
+ << std::endl;
\
+ exit(EXIT_FAILURE);
\
+ }
\
+ }
+
+using namespace cute;
+using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
// <M,N,K> per group
+
+inline size_t aligned(size_t value, size_t alignment = 16) {
+ return (value + alignment - 1) / alignment * alignment;
+}
+
+template <typename T>
+struct KernelTraits;
+
+template <typename ElementA, typename ElementB, typename ElementC,
+ typename LayoutA = cutlass::layout::RowMajor,
+ typename LayoutB = cutlass::layout::ColumnMajor,
+ typename LayoutC = cutlass::layout::RowMajor>
+struct CutlassGroupGemmRunner {
+ static constexpr int AlignmentA =
+ 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix
in units of elements
+ // (up to 16 bytes)
+
+ static constexpr int AlignmentB =
+ 128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix
in units of elements
+ // (up to 16 bytes)
+
+ static constexpr int AlignmentC =
+ 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix
in units of elements
+ // (up to 16 bytes)
+
+ // Core kernel configurations
+ using ElementAccumulator = float; // Element type for internal accumulation
+ using ScaleType = std::variant<ElementAccumulator, const
ElementAccumulator*>;
+ using ArchTag =
+ cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the
intended feature
+ using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
+ using TileShape = typename KernelTraits<ElementA>::TileShape;
+ using ClusterShape = typename KernelTraits<ElementA>::ClusterShape;
+ using StageCountType =
+ cutlass::gemm::collective::StageCountAuto; // Stage count maximized
based on the tile size
+ using KernelSchedule = typename KernelTraits<ElementA>::KernelSchedule;
// Kernel to launch
+ using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;
// Epilogue to launch
+
+ using CollectiveEpilogue = typename
cutlass::epilogue::collective::CollectiveBuilder<
+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape,
+ cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator,
+ ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC,
+ EpilogueSchedule>::CollectiveOp;
+
+ using CollectiveMainloop = typename
cutlass::gemm::collective::CollectiveBuilder<
+ ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB,
LayoutB*, AlignmentB,
+ ElementAccumulator, TileShape, ClusterShape,
+ cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
+ sizeof(typename CollectiveEpilogue::SharedStorage))>,
+ KernelSchedule>::CollectiveOp;
+
+ using GemmKernel =
+ cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
+
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
+
+ using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
+ using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
+ using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
+ using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
+
+ void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const
ElementC** ptr_C,
+ ElementC** ptr_D,
+ typename ProblemShape::UnderlyingProblemShape*
problem_sizes,
+ typename ProblemShape::UnderlyingProblemShape*
problem_sizes_host,
+ StrideA* stride_A, StrideB* stride_B, StrideC* stride_C,
StrideD* stride_D,
+ uint8_t* workspace, int64_t workspace_size, int
num_groups, ScaleType alpha,
+ ScaleType beta, cudaStream_t stream) {
+ typename Gemm::EpilogueOutputOp::Params epilogue_params = [&]() {
+ ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the
same type";
+ if (std::holds_alternative<ElementAccumulator>(alpha)) {
+ return typename
Gemm::EpilogueOutputOp::Params{std::get<ElementAccumulator>(alpha),
+
std::get<ElementAccumulator>(beta)};
+ } else if (std::holds_alternative<const ElementAccumulator*>(alpha)) {
+ return typename Gemm::EpilogueOutputOp::Params{std::get<const
ElementAccumulator*>(alpha),
+ std::get<const
ElementAccumulator*>(beta)};
+ } else {
+ LOG(FATAL) << "Unsupported alpha and beta type";
+ throw;
+ }
+ }();
+
+ cutlass::KernelHardwareInfo hw_info;
+ hw_info.device_id = 0;
+ hw_info.sm_count =
+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
+ typename Gemm::Arguments
arguments{cutlass::gemm::GemmUniversalMode::kGrouped,
+ {num_groups, problem_sizes,
problem_sizes_host},
+ {ptr_A, stride_A, ptr_B, stride_B},
+ {epilogue_params, ptr_C, stride_C,
ptr_D, stride_D},
+ hw_info};
+ Gemm gemm_op;
+ CUTLASS_CHECK(gemm_op.can_implement(arguments));
+ CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+ CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
+ CUTLASS_CHECK(gemm_op.run());
+ }
+};
+
+template <typename ElementA, typename ElementB, typename ElementC, typename
StrideA,
+ typename StrideB, typename StrideC>
+__global__ void prepare_group_gemm_arguments(
+ const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D,
+ typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA*
stride_A,
+ StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB*
weight, ElementC* out,
+ int64_t* indptr, int64_t n, int64_t k, int64_t num_groups) {
+ int group_id = threadIdx.x;
+ if (group_id >= num_groups) return;
+ int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1];
+ ptr_A[group_id] = x + prev_rows * k;
+ ptr_B[group_id] = weight + group_id * k * n;
+ ptr_D[group_id] = out + prev_rows * n;
+ problem_sizes[group_id] = {static_cast<int>(indptr[group_id] - prev_rows),
static_cast<int>(n),
+ static_cast<int>(k)};
+ stride_A[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
+ stride_B[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
+ stride_D[group_id] = cute::make_stride(n, Int<1>{}, int64_t{0});
+}
+
+template <typename ElementA, typename ElementB, typename ElementC>
+void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr,
uint8_t* workspace,
+ int64_t workspace_size, int64_t n, int64_t k, int64_t
num_groups,
+ std::variant<float, const float*> alpha,
+ std::variant<float, const float*> beta, ElementC* out,
+ cudaStream_t stream) {
+ using Runner = CutlassGroupGemmRunner<ElementA, ElementB, ElementC>;
+ using StrideA = typename Runner::StrideA;
+ using StrideB = typename Runner::StrideB;
+ using StrideC = typename Runner::StrideC;
+
+ Runner runner;
+ std::ptrdiff_t offset = 0;
+ const ElementA** ptr_A = reinterpret_cast<const ElementA**>(workspace +
offset);
+ offset += aligned(sizeof(ElementA*) * num_groups);
+ const ElementB** ptr_B = reinterpret_cast<const ElementB**>(workspace +
offset);
+ offset += aligned(sizeof(ElementB*) * num_groups);
+ ElementC** ptr_D = reinterpret_cast<ElementC**>(workspace + offset);
+ offset += aligned(sizeof(ElementC*) * num_groups);
+ typename ProblemShape::UnderlyingProblemShape* problem_sizes =
+ reinterpret_cast<typename
ProblemShape::UnderlyingProblemShape*>(workspace + offset);
+ offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) *
num_groups);
+ StrideA* stride_A = reinterpret_cast<StrideA*>(workspace + offset);
+ offset += aligned(sizeof(StrideA) * num_groups);
+ StrideB* stride_B = reinterpret_cast<StrideB*>(workspace + offset);
+ offset += aligned(sizeof(StrideB) * num_groups);
+ StrideC* stride_D = reinterpret_cast<StrideC*>(workspace + offset);
+ offset += aligned(sizeof(StrideC) * num_groups);
+ prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(ptr_A, ptr_B,
ptr_D, problem_sizes,
+ stride_A,
stride_B, stride_D, x,
+ weight, out,
indptr, n, k, num_groups);
+ offset = aligned(offset, 256);
+ runner.run_group_gemm(ptr_A, ptr_B, const_cast<const ElementC**>(ptr_D),
ptr_D, problem_sizes,
+ nullptr, stride_A, stride_B, stride_D, stride_D,
workspace + offset,
+ workspace_size - offset, num_groups, alpha, beta,
stream);
+}
diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc
b/src/runtime/contrib/cutlass/weight_preprocess.cc
index 4b378fa4a7..5fded82762 100644
--- a/src/runtime/contrib/cutlass/weight_preprocess.cc
+++ b/src/runtime/contrib/cutlass/weight_preprocess.cc
@@ -21,7 +21,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-#include
"../../../3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/cutlass_preprocessors.h"
+#include "cutlass_kernels/cutlass_preprocessors.h"
namespace tvm {
namespace runtime {
diff --git a/tests/python/contrib/test_cutlass.py
b/tests/python/contrib/test_cutlass.py
index 6eaf10c2ab..154a68e116 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -17,6 +17,7 @@
import logging
import tempfile
import math
+import ml_dtypes
import tvm
from tvm import relay
from tvm.contrib.cudnn import conv_output_shape
@@ -32,6 +33,7 @@ from tvm.contrib.cutlass import (
finalize_modules,
finalize_modules_vm,
)
+from tvm.contrib.pickle_memoize import memoize
import tvm.testing
logging.basicConfig(level=logging.INFO)
@@ -1105,5 +1107,101 @@ def test_dense_transpose_dense():
verify_dense_transpose_dense(get_dense_transpose_dense(M, N, K), M, N, K)
+def verify_group_gemm(
+ func_name, M, N, K, num_groups, x_dtype, weight_dtype, out_dtype,
use_scale, rtol, atol
+):
+ group_gemm_func = tvm.get_global_func(func_name, allow_missing=True)
+ if group_gemm_func is None:
+ print(f"Skipped as {func_name} is not available")
+ return
+
+ @memoize("tvm.contrib.cutlass.test_group_gemm_sm90")
+ def get_ref_data():
+ assert M % num_groups == 0
+ M_per_group = M // num_groups
+ a_np = get_random_ndarray((M, K), "float16")
+ b_np = get_random_ndarray((num_groups, N, K), "float16")
+ indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group
+ c_np = np.concatenate(
+ [a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i
in range(num_groups)],
+ axis=0,
+ )
+ return a_np, b_np, indptr_np, c_np
+
+ def to_numpy_dtype(dtype):
+ mapping = {"e5m2_float8": ml_dtypes.float8_e5m2, "e4m3_float8":
ml_dtypes.float8_e4m3fn}
+ return mapping.get(dtype, dtype)
+
+ a_np, b_np, indptr_np, c_np = get_ref_data()
+ dev = tvm.cuda(0)
+ a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev)
+ b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev)
+ c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev)
+ indptr_nd = tvm.nd.array(indptr_np, device=dev)
+ workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev)
+ if use_scale:
+ scale = tvm.nd.array(np.array([1.0], dtype="float32"), device=dev)
+ group_gemm_func(a_nd, b_nd, indptr_nd, workspace, scale, c_nd)
+ else:
+ group_gemm_func(a_nd, b_nd, indptr_nd, workspace, c_nd)
+ tvm.testing.assert_allclose(c_nd.asnumpy(), c_np, rtol=rtol, atol=atol)
+
+
[email protected]_cutlass
+def test_group_gemm_sm90():
+ verify_group_gemm(
+ "cutlass.group_gemm_fp16_sm90",
+ 8,
+ 128,
+ 128,
+ 4,
+ "float16",
+ "float16",
+ "float16",
+ False,
+ rtol=1e-3,
+ atol=1e-3,
+ )
+ verify_group_gemm(
+ "cutlass.group_gemm_e5m2_e5m2_fp16",
+ 8,
+ 16,
+ 16,
+ 4,
+ "e5m2_float8",
+ "e5m2_float8",
+ "float16",
+ True,
+ rtol=1e-1,
+ atol=1,
+ )
+ verify_group_gemm(
+ "cutlass.group_gemm_e4m3_e4m3_fp16",
+ 8,
+ 16,
+ 16,
+ 4,
+ "e4m3_float8",
+ "e4m3_float8",
+ "float16",
+ True,
+ rtol=1e-1,
+ atol=1,
+ )
+ verify_group_gemm(
+ "cutlass.group_gemm_e4m3_e5m2_fp16",
+ 8,
+ 16,
+ 16,
+ 4,
+ "e4m3_float8",
+ "e5m2_float8",
+ "float16",
+ True,
+ rtol=1e-1,
+ atol=1,
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()