This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 89e9028849 [Cutlass] Add group gemm kernels (#16751)
89e9028849 is described below

commit 89e9028849ae3803a10eda086434c8d9e3bc3298
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Mar 20 05:50:19 2024 -0700

    [Cutlass] Add group gemm kernels (#16751)
    
    * [CMAKE][CUTLASS] Improve dependancy management with different cutlass 
versions.
     * Each cutlass-based submodule library now uses its own cutlass submodule 
dependancy
     * TVM's cutlass submodule is decoupled from others and is bumped to
     v3.4.1 for H100 support
     * Add scaffold for new cutlass fp8 dequant gemm interface targetting
     TVM's cutlass submodule
    
    * Remove handling for moe_gemm.cc and flash_decoding.cu which are no longer 
used upstream.
    
    * Add cutlass fp8 group gemm
    
    * Add fp16 grouped gemm support for sm90
    
    * [Cutlass] Support alpha scaling in fp8 group gemm
    
    * [Cutlass] Support device alpha_ptr for fp8 group gemm
    
    
    
    ---------
    
    Co-authored-by: Chris Sullivan <[email protected]>
    Co-authored-by: masahi <[email protected]>
---
 3rdparty/cutlass                                  |   2 +-
 CMakeLists.txt                                    |  27 ++-
 cmake/modules/contrib/CUTLASS.cmake               |  49 ++++-
 src/runtime/contrib/cutlass/fp16_group_gemm.cu    |  70 ++++++++
 src/runtime/contrib/cutlass/fp8_group_gemm.cu     |  83 +++++++++
 src/runtime/contrib/cutlass/group_gemm_runner.cuh | 209 ++++++++++++++++++++++
 src/runtime/contrib/cutlass/weight_preprocess.cc  |   2 +-
 tests/python/contrib/test_cutlass.py              |  98 ++++++++++
 8 files changed, 531 insertions(+), 9 deletions(-)

diff --git a/3rdparty/cutlass b/3rdparty/cutlass
index ff61a49dd1..bbe579a9e3 160000
--- a/3rdparty/cutlass
+++ b/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit ff61a49dd1a728a96e9a8434ed408a2a52d73119
+Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49
diff --git a/CMakeLists.txt b/CMakeLists.txt
index c9d836b681..906509004a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -369,6 +369,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS
   src/runtime/minrpc/*.cc
   src/runtime/relax_vm/*.cc
 )
+set(TVM_RUNTIME_EXT_OBJS "")
 
 if(BUILD_FOR_HEXAGON)
   if(NOT BUILD_STATIC_RUNTIME)
@@ -595,18 +596,32 @@ add_library(tvm_libinfo_objs OBJECT ${LIBINFO_FILE})
 
 include(GNUInstallDirs)
 if(NOT BUILD_DUMMY_LIBTVM)
-  add_library(tvm SHARED $<TARGET_OBJECTS:tvm_objs> 
$<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
+  add_library(tvm SHARED
+    $<TARGET_OBJECTS:tvm_objs>
+    $<TARGET_OBJECTS:tvm_runtime_objs>
+    $<TARGET_OBJECTS:tvm_libinfo_objs>
+    ${TVM_RUNTIME_EXT_OBJS}
+  )
+
 else()
   # dummy version of libtvm that can be used by downstream to specify 
dependencies
   # the real runner still need a full version of libtvm
-  add_library(tvm SHARED $<TARGET_OBJECTS:tvm_runtime_objs> 
$<TARGET_OBJECTS:tvm_libinfo_objs>)
+  add_library(tvm SHARED
+    $<TARGET_OBJECTS:tvm_runtime_objs>
+    $<TARGET_OBJECTS:tvm_libinfo_objs>
+    ${TVM_RUNTIME_EXT_OBJS}
+  )
 endif()
 
 target_include_directories(tvm PUBLIC 
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>")
 set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS 
"${TVM_NO_UNDEFINED_SYMBOLS}")
 set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}")
 if(BUILD_STATIC_RUNTIME)
-  add_library(tvm_runtime STATIC $<TARGET_OBJECTS:tvm_runtime_objs> 
$<TARGET_OBJECTS:tvm_libinfo_objs>)
+  add_library(tvm_runtime STATIC
+    $<TARGET_OBJECTS:tvm_runtime_objs>
+    $<TARGET_OBJECTS:tvm_libinfo_objs>
+    ${TVM_RUNTIME_EXT_OBJS}
+  )
   set(NOTICE_MULTILINE
     "You have build static version of the TVM runtime library. Make "
     "sure to use --whole-archive when linking it into your project.")
@@ -614,7 +629,11 @@ if(BUILD_STATIC_RUNTIME)
   add_custom_command(TARGET tvm_runtime POST_BUILD
     COMMAND ${CMAKE_COMMAND} -E cmake_echo_color --yellow --bold ${NOTICE})
 else()
-  add_library(tvm_runtime SHARED $<TARGET_OBJECTS:tvm_runtime_objs> 
$<TARGET_OBJECTS:tvm_libinfo_objs>)
+  add_library(tvm_runtime SHARED
+    $<TARGET_OBJECTS:tvm_runtime_objs>
+    $<TARGET_OBJECTS:tvm_libinfo_objs>
+    ${TVM_RUNTIME_EXT_OBJS}
+  )
   set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS 
"${TVM_NO_UNDEFINED_SYMBOLS}")
 endif()
 
diff --git a/cmake/modules/contrib/CUTLASS.cmake 
b/cmake/modules/contrib/CUTLASS.cmake
index 9ce27820b8..fa4a608f61 100644
--- a/cmake/modules/contrib/CUTLASS.cmake
+++ b/cmake/modules/contrib/CUTLASS.cmake
@@ -16,16 +16,59 @@
 # under the License.
 
 if(USE_CUDA AND USE_CUTLASS)
-  tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC 
src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc)
+  set(CUTLASS_GEN_COND "$<AND:$<BOOL:${USE_CUDA}>,$<BOOL:${USE_CUTLASS}>>")
+  set(CUTLASS_RUNTIME_OBJS "")
+
+  tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC
+    src/relay/backend/contrib/cutlass/*.cc
+    src/relax/backend/contrib/cutlass/*.cc
+  )
   list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC})
 
   set(FPA_INTB_GEMM_TVM_BINDING ON)
   set(FPA_INTB_GEMM_TVM_HOME ${PROJECT_SOURCE_DIR})
 
-  set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
+  ### Build cutlass runtime objects for fpA_intB_gemm using its cutlass 
submodule
   add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm)
+  target_include_directories(fpA_intB_gemm PRIVATE
+    ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm
+    ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
+  )
+  set(CUTLASS_FPA_INTB_RUNTIME_SRCS "")
+  list(APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS 
src/runtime/contrib/cutlass/weight_preprocess.cc)
+  add_library(fpA_intB_cutlass_objs OBJECT ${CUTLASS_FPA_INTB_RUNTIME_SRCS})
+  target_compile_definitions(fpA_intB_cutlass_objs PRIVATE 
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
+  target_include_directories(fpA_intB_cutlass_objs PRIVATE
+    ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm
+    ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
+  )
+  list(APPEND CUTLASS_RUNTIME_OBJS 
"$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:fpA_intB_cutlass_objs>>")
+
+  ### Build cutlass runtime objects for flash attention
   add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn)
-  list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
+  target_include_directories(flash_attn PRIVATE
+    ${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn
+    ${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn/cutlass/include
+  )
+
+  ### Build cutlass runtime objects using TVM's 3rdparty/cutlass submodule
+  set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
+  set(TVM_CUTLASS_RUNTIME_SRCS "")
+
+  if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
+    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp16_group_gemm.cu)
+    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_group_gemm.cu)
+  endif()
+  if(TVM_CUTLASS_RUNTIME_SRCS)
+    add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
+    target_compile_options(tvm_cutlass_objs PRIVATE 
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>)
+    target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include)
+    target_compile_definitions(tvm_cutlass_objs PRIVATE 
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
+    list(APPEND CUTLASS_RUNTIME_OBJS 
"$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
+  endif()
+
+  ### Add cutlass objects to list of TVM runtime extension objs
+  list(APPEND TVM_RUNTIME_EXT_OBJS "${CUTLASS_RUNTIME_OBJS}")
 
   message(STATUS "Build with CUTLASS")
 endif()
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cu 
b/src/runtime/contrib/cutlass/fp16_group_gemm.cu
new file mode 100644
index 0000000000..3c051819b2
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cu
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include "group_gemm_runner.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
+
+template <>
+struct KernelTraits<cutlass::half_t> {
+  using KernelSchedule = 
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
+  using TileShape = Shape<_128, _256, _64>;  // Threadblock-level tile size
+  using ClusterShape = Shape<_2, _2, _1>;    // Shape of the threadblocks in a 
cluster
+};
+
+namespace tvm {
+namespace runtime {
+
+template <typename ElementA, typename ElementB, typename ElementC>
+void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, 
NDArray workspace,
+                                 NDArray out) {
+  // Workspace is used for storing device-side group gemm arguments and 
cutlass internal workspace.
+  // Recommened size is 4MB.
+  auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+  ICHECK(func != nullptr);
+  CHECK_EQ(x->ndim, 2);
+  CHECK_EQ(weight->ndim, 3);
+  CHECK_EQ(indptr->ndim, 1);
+  CHECK_EQ(workspace->ndim, 1);
+  CHECK_EQ(out->ndim, 2);
+  int num_groups = weight->shape[0];
+  int n = weight->shape[1];
+  int k = weight->shape[2];
+  float alpha = 1.0f;
+  float beta = 0.0f;
+  cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+  cutlass_group_gemm(static_cast<ElementA*>(x->data), 
static_cast<ElementB*>(weight->data),
+                     static_cast<int64_t*>(indptr->data), 
static_cast<uint8_t*>(workspace->data),
+                     workspace->shape[0], n, k, num_groups, alpha, beta,
+                     static_cast<ElementC*>(out->data), stream);
+}
+
+TVM_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90")
+    .set_body_typed(tvm_cutlass_group_gemm_sm90<cutlass::half_t, 
cutlass::half_t, cutlass::half_t>);
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm.cu 
b/src/runtime/contrib/cutlass/fp8_group_gemm.cu
new file mode 100644
index 0000000000..c93da6ff57
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp8_group_gemm.cu
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include "group_gemm_runner.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
+
+template <>
+struct KernelTraits<cutlass::float_e4m3_t> {
+  using KernelSchedule = 
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
+  using TileShape = Shape<_128, _256, _64>;  // Threadblock-level tile size
+  using ClusterShape = Shape<_2, _2, _1>;    // Shape of the threadblocks in a 
cluster
+};
+
+template <>
+struct KernelTraits<cutlass::float_e5m2_t> : 
KernelTraits<cutlass::float_e4m3_t> {};
+
+namespace tvm {
+namespace runtime {
+
+template <typename ElementA, typename ElementB, typename ElementC>
+void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, 
NDArray workspace,
+                                NDArray alpha, NDArray out) {
+  // Workspace is used for storing device-side group gemm arguments and 
cutlass internal workspace.
+  // Recommened size is 4MB.
+  auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+  ICHECK(func != nullptr);
+  CHECK_EQ(x->ndim, 2);
+  CHECK_EQ(weight->ndim, 3);
+  CHECK_EQ(indptr->ndim, 1);
+  CHECK_EQ(workspace->ndim, 1);
+  CHECK_EQ(out->ndim, 2);
+  CHECK_EQ(alpha->dtype.code, kDLFloat);
+  CHECK_EQ(alpha->dtype.bits, 32);
+  int num_groups = weight->shape[0];
+  int n = weight->shape[1];
+  int k = weight->shape[2];
+  const float* beta = nullptr;
+  cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+  cutlass_group_gemm(static_cast<ElementA*>(x->data), 
static_cast<ElementB*>(weight->data),
+                     static_cast<int64_t*>(indptr->data), 
static_cast<uint8_t*>(workspace->data),
+                     workspace->shape[0], n, k, num_groups, 
static_cast<float*>(alpha->data), beta,
+                     static_cast<ElementC*>(out->data), stream);
+}
+
+TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16")
+    .set_body_typed(
+        tvm_cutlass_fp8_group_gemm<cutlass::float_e5m2_t, 
cutlass::float_e5m2_t, cutlass::half_t>);
+
+TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e4m3_fp16")
+    .set_body_typed(
+        tvm_cutlass_fp8_group_gemm<cutlass::float_e5m2_t, 
cutlass::float_e4m3_t, cutlass::half_t>);
+
+TVM_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e4m3_fp16")
+    .set_body_typed(
+        tvm_cutlass_fp8_group_gemm<cutlass::float_e4m3_t, 
cutlass::float_e4m3_t, cutlass::half_t>);
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh 
b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
new file mode 100644
index 0000000000..50bdcf7bec
--- /dev/null
+++ b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <variant>
+#include <vector>
+
+#include "../../cuda/cuda_common.h"
+
+// clang-format off
+#include "cutlass/cutlass.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/tensor_ref.h"
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/thread/linear_combination.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/group_array_problem_shape.hpp"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+// clang-format on
+
+#define CUTLASS_CHECK(status)                                                  
                  \
+  {                                                                            
                  \
+    cutlass::Status error = status;                                            
                  \
+    if (error != cutlass::Status::kSuccess) {                                  
                  \
+      std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " 
at: " << __LINE__ \
+                << std::endl;                                                  
                  \
+      exit(EXIT_FAILURE);                                                      
                  \
+    }                                                                          
                  \
+  }
+
+using namespace cute;
+using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;  
// <M,N,K> per group
+
+inline size_t aligned(size_t value, size_t alignment = 16) {
+  return (value + alignment - 1) / alignment * alignment;
+}
+
+template <typename T>
+struct KernelTraits;
+
+template <typename ElementA, typename ElementB, typename ElementC,
+          typename LayoutA = cutlass::layout::RowMajor,
+          typename LayoutB = cutlass::layout::ColumnMajor,
+          typename LayoutC = cutlass::layout::RowMajor>
+struct CutlassGroupGemmRunner {
+  static constexpr int AlignmentA =
+      128 / cutlass::sizeof_bits<ElementA>::value;  // Alignment of A matrix 
in units of elements
+                                                    // (up to 16 bytes)
+
+  static constexpr int AlignmentB =
+      128 / cutlass::sizeof_bits<ElementB>::value;  // Alignment of B matrix 
in units of elements
+                                                    // (up to 16 bytes)
+
+  static constexpr int AlignmentC =
+      128 / cutlass::sizeof_bits<ElementC>::value;  // Alignment of C matrix 
in units of elements
+                                                    // (up to 16 bytes)
+
+  // Core kernel configurations
+  using ElementAccumulator = float;  // Element type for internal accumulation
+  using ScaleType = std::variant<ElementAccumulator, const 
ElementAccumulator*>;
+  using ArchTag =
+      cutlass::arch::Sm90;  // Tag indicating the minimum SM that supports the 
intended feature
+  using OperatorClass = cutlass::arch::OpClassTensorOp;  // Operator class tag
+  using TileShape = typename KernelTraits<ElementA>::TileShape;
+  using ClusterShape = typename KernelTraits<ElementA>::ClusterShape;
+  using StageCountType =
+      cutlass::gemm::collective::StageCountAuto;  // Stage count maximized 
based on the tile size
+  using KernelSchedule = typename KernelTraits<ElementA>::KernelSchedule;     
// Kernel to launch
+  using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;  
// Epilogue to launch
+
+  using CollectiveEpilogue = typename 
cutlass::epilogue::collective::CollectiveBuilder<
+      cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, 
ClusterShape,
+      cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, 
ElementAccumulator,
+      ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC,
+      EpilogueSchedule>::CollectiveOp;
+
+  using CollectiveMainloop = typename 
cutlass::gemm::collective::CollectiveBuilder<
+      ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, 
LayoutB*, AlignmentB,
+      ElementAccumulator, TileShape, ClusterShape,
+      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
+          sizeof(typename CollectiveEpilogue::SharedStorage))>,
+      KernelSchedule>::CollectiveOp;
+
+  using GemmKernel =
+      cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, 
CollectiveEpilogue>;
+
+  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
+
+  using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
+  using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
+  using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
+  using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
+
+  void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const 
ElementC** ptr_C,
+                      ElementC** ptr_D,
+                      typename ProblemShape::UnderlyingProblemShape* 
problem_sizes,
+                      typename ProblemShape::UnderlyingProblemShape* 
problem_sizes_host,
+                      StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, 
StrideD* stride_D,
+                      uint8_t* workspace, int64_t workspace_size, int 
num_groups, ScaleType alpha,
+                      ScaleType beta, cudaStream_t stream) {
+    typename Gemm::EpilogueOutputOp::Params epilogue_params = [&]() {
+      ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the 
same type";
+      if (std::holds_alternative<ElementAccumulator>(alpha)) {
+        return typename 
Gemm::EpilogueOutputOp::Params{std::get<ElementAccumulator>(alpha),
+                                                       
std::get<ElementAccumulator>(beta)};
+      } else if (std::holds_alternative<const ElementAccumulator*>(alpha)) {
+        return typename Gemm::EpilogueOutputOp::Params{std::get<const 
ElementAccumulator*>(alpha),
+                                                       std::get<const 
ElementAccumulator*>(beta)};
+      } else {
+        LOG(FATAL) << "Unsupported alpha and beta type";
+        throw;
+      }
+    }();
+
+    cutlass::KernelHardwareInfo hw_info;
+    hw_info.device_id = 0;
+    hw_info.sm_count =
+        
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
+    typename Gemm::Arguments 
arguments{cutlass::gemm::GemmUniversalMode::kGrouped,
+                                       {num_groups, problem_sizes, 
problem_sizes_host},
+                                       {ptr_A, stride_A, ptr_B, stride_B},
+                                       {epilogue_params, ptr_C, stride_C, 
ptr_D, stride_D},
+                                       hw_info};
+    Gemm gemm_op;
+    CUTLASS_CHECK(gemm_op.can_implement(arguments));
+    CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+    CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
+    CUTLASS_CHECK(gemm_op.run());
+  }
+};
+
+template <typename ElementA, typename ElementB, typename ElementC, typename 
StrideA,
+          typename StrideB, typename StrideC>
+__global__ void prepare_group_gemm_arguments(
+    const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D,
+    typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* 
stride_A,
+    StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* 
weight, ElementC* out,
+    int64_t* indptr, int64_t n, int64_t k, int64_t num_groups) {
+  int group_id = threadIdx.x;
+  if (group_id >= num_groups) return;
+  int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1];
+  ptr_A[group_id] = x + prev_rows * k;
+  ptr_B[group_id] = weight + group_id * k * n;
+  ptr_D[group_id] = out + prev_rows * n;
+  problem_sizes[group_id] = {static_cast<int>(indptr[group_id] - prev_rows), 
static_cast<int>(n),
+                             static_cast<int>(k)};
+  stride_A[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
+  stride_B[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
+  stride_D[group_id] = cute::make_stride(n, Int<1>{}, int64_t{0});
+}
+
+template <typename ElementA, typename ElementB, typename ElementC>
+void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, 
uint8_t* workspace,
+                        int64_t workspace_size, int64_t n, int64_t k, int64_t 
num_groups,
+                        std::variant<float, const float*> alpha,
+                        std::variant<float, const float*> beta, ElementC* out,
+                        cudaStream_t stream) {
+  using Runner = CutlassGroupGemmRunner<ElementA, ElementB, ElementC>;
+  using StrideA = typename Runner::StrideA;
+  using StrideB = typename Runner::StrideB;
+  using StrideC = typename Runner::StrideC;
+
+  Runner runner;
+  std::ptrdiff_t offset = 0;
+  const ElementA** ptr_A = reinterpret_cast<const ElementA**>(workspace + 
offset);
+  offset += aligned(sizeof(ElementA*) * num_groups);
+  const ElementB** ptr_B = reinterpret_cast<const ElementB**>(workspace + 
offset);
+  offset += aligned(sizeof(ElementB*) * num_groups);
+  ElementC** ptr_D = reinterpret_cast<ElementC**>(workspace + offset);
+  offset += aligned(sizeof(ElementC*) * num_groups);
+  typename ProblemShape::UnderlyingProblemShape* problem_sizes =
+      reinterpret_cast<typename 
ProblemShape::UnderlyingProblemShape*>(workspace + offset);
+  offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * 
num_groups);
+  StrideA* stride_A = reinterpret_cast<StrideA*>(workspace + offset);
+  offset += aligned(sizeof(StrideA) * num_groups);
+  StrideB* stride_B = reinterpret_cast<StrideB*>(workspace + offset);
+  offset += aligned(sizeof(StrideB) * num_groups);
+  StrideC* stride_D = reinterpret_cast<StrideC*>(workspace + offset);
+  offset += aligned(sizeof(StrideC) * num_groups);
+  prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(ptr_A, ptr_B, 
ptr_D, problem_sizes,
+                                                             stride_A, 
stride_B, stride_D, x,
+                                                             weight, out, 
indptr, n, k, num_groups);
+  offset = aligned(offset, 256);
+  runner.run_group_gemm(ptr_A, ptr_B, const_cast<const ElementC**>(ptr_D), 
ptr_D, problem_sizes,
+                        nullptr, stride_A, stride_B, stride_D, stride_D, 
workspace + offset,
+                        workspace_size - offset, num_groups, alpha, beta, 
stream);
+}
diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc 
b/src/runtime/contrib/cutlass/weight_preprocess.cc
index 4b378fa4a7..5fded82762 100644
--- a/src/runtime/contrib/cutlass/weight_preprocess.cc
+++ b/src/runtime/contrib/cutlass/weight_preprocess.cc
@@ -21,7 +21,7 @@
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
 
-#include 
"../../../3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/cutlass_preprocessors.h"
+#include "cutlass_kernels/cutlass_preprocessors.h"
 
 namespace tvm {
 namespace runtime {
diff --git a/tests/python/contrib/test_cutlass.py 
b/tests/python/contrib/test_cutlass.py
index 6eaf10c2ab..154a68e116 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -17,6 +17,7 @@
 import logging
 import tempfile
 import math
+import ml_dtypes
 import tvm
 from tvm import relay
 from tvm.contrib.cudnn import conv_output_shape
@@ -32,6 +33,7 @@ from tvm.contrib.cutlass import (
     finalize_modules,
     finalize_modules_vm,
 )
+from tvm.contrib.pickle_memoize import memoize
 import tvm.testing
 
 logging.basicConfig(level=logging.INFO)
@@ -1105,5 +1107,101 @@ def test_dense_transpose_dense():
     verify_dense_transpose_dense(get_dense_transpose_dense(M, N, K), M, N, K)
 
 
+def verify_group_gemm(
+    func_name, M, N, K, num_groups, x_dtype, weight_dtype, out_dtype, 
use_scale, rtol, atol
+):
+    group_gemm_func = tvm.get_global_func(func_name, allow_missing=True)
+    if group_gemm_func is None:
+        print(f"Skipped as {func_name} is not available")
+        return
+
+    @memoize("tvm.contrib.cutlass.test_group_gemm_sm90")
+    def get_ref_data():
+        assert M % num_groups == 0
+        M_per_group = M // num_groups
+        a_np = get_random_ndarray((M, K), "float16")
+        b_np = get_random_ndarray((num_groups, N, K), "float16")
+        indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group
+        c_np = np.concatenate(
+            [a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i 
in range(num_groups)],
+            axis=0,
+        )
+        return a_np, b_np, indptr_np, c_np
+
+    def to_numpy_dtype(dtype):
+        mapping = {"e5m2_float8": ml_dtypes.float8_e5m2, "e4m3_float8": 
ml_dtypes.float8_e4m3fn}
+        return mapping.get(dtype, dtype)
+
+    a_np, b_np, indptr_np, c_np = get_ref_data()
+    dev = tvm.cuda(0)
+    a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev)
+    b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev)
+    c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev)
+    indptr_nd = tvm.nd.array(indptr_np, device=dev)
+    workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev)
+    if use_scale:
+        scale = tvm.nd.array(np.array([1.0], dtype="float32"), device=dev)
+        group_gemm_func(a_nd, b_nd, indptr_nd, workspace, scale, c_nd)
+    else:
+        group_gemm_func(a_nd, b_nd, indptr_nd, workspace, c_nd)
+    tvm.testing.assert_allclose(c_nd.asnumpy(), c_np, rtol=rtol, atol=atol)
+
+
[email protected]_cutlass
+def test_group_gemm_sm90():
+    verify_group_gemm(
+        "cutlass.group_gemm_fp16_sm90",
+        8,
+        128,
+        128,
+        4,
+        "float16",
+        "float16",
+        "float16",
+        False,
+        rtol=1e-3,
+        atol=1e-3,
+    )
+    verify_group_gemm(
+        "cutlass.group_gemm_e5m2_e5m2_fp16",
+        8,
+        16,
+        16,
+        4,
+        "e5m2_float8",
+        "e5m2_float8",
+        "float16",
+        True,
+        rtol=1e-1,
+        atol=1,
+    )
+    verify_group_gemm(
+        "cutlass.group_gemm_e4m3_e4m3_fp16",
+        8,
+        16,
+        16,
+        4,
+        "e4m3_float8",
+        "e4m3_float8",
+        "float16",
+        True,
+        rtol=1e-1,
+        atol=1,
+    )
+    verify_group_gemm(
+        "cutlass.group_gemm_e4m3_e5m2_fp16",
+        8,
+        16,
+        16,
+        4,
+        "e4m3_float8",
+        "e5m2_float8",
+        "float16",
+        True,
+        rtol=1e-1,
+        atol=1,
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to