This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b0ccfb39c5 [CUTLASS] Add blockwise scale gemm/bmm kernels (#17789)
b0ccfb39c5 is described below

commit b0ccfb39c5d595de07f0a69dceb0cf51f0cb86a9
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Apr 1 01:00:49 2025 +0200

    [CUTLASS] Add blockwise scale gemm/bmm kernels (#17789)
    
    This PR introduces blockwise scale matmul and batch matmul CUTLASS
    kernels, adapted from SGLang (http://github.com/sgl-project/sglang),
    vLLM (https://github.com/vllm-project/vllm) and
    https://github.com/soundOfDestiny/cutlass.
    
    We add unit tests for gemm and bmm. This PR also restores some
    cutlass gemm tests that were removed before during Relay phasing out.
---
 3rdparty/cutlass                                   |   2 +-
 3rdparty/cutlass_fpA_intB_gemm                     |   2 +-
 cmake/modules/contrib/CUTLASS.cmake                |   6 +-
 .../cutlass/blockwise_scaled_gemm_runner.cuh       | 228 +++++++++++++
 .../contrib/cutlass/fp8_blockwise_scaled_gemm.cu   | 164 ++++++++++
 src/runtime/contrib/cutlass/group_gemm_runner.cuh  |  14 +-
 tests/python/contrib/test_cutlass_gemm.py          | 352 +++++++++++++++++++++
 7 files changed, 758 insertions(+), 10 deletions(-)

diff --git a/3rdparty/cutlass b/3rdparty/cutlass
index bbe579a9e3..afa1772203 160000
--- a/3rdparty/cutlass
+++ b/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49
+Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008
diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm
index f09824e950..fdef230791 160000
--- a/3rdparty/cutlass_fpA_intB_gemm
+++ b/3rdparty/cutlass_fpA_intB_gemm
@@ -1 +1 @@
-Subproject commit f09824e950ed6678670004bd23578757b3473f21
+Subproject commit fdef2307917ec2c7cc5becc29fb95d77498484bd
diff --git a/cmake/modules/contrib/CUTLASS.cmake 
b/cmake/modules/contrib/CUTLASS.cmake
index b302622cbc..b9097a02e9 100644
--- a/cmake/modules/contrib/CUTLASS.cmake
+++ b/cmake/modules/contrib/CUTLASS.cmake
@@ -58,11 +58,15 @@ if(USE_CUDA AND USE_CUTLASS)
     list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp16_group_gemm.cu)
     list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_group_gemm.cu)
     list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_gemm.cu)
+    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu)
   endif()
   if(TVM_CUTLASS_RUNTIME_SRCS)
     add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
     target_compile_options(tvm_cutlass_objs PRIVATE 
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>)
-    target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include)
+    target_include_directories(tvm_cutlass_objs PRIVATE
+      ${CUTLASS_DIR}/include
+      
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass_extensions/include
+    )
     target_compile_definitions(tvm_cutlass_objs PRIVATE 
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
     list(APPEND CUTLASS_RUNTIME_OBJS 
"$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
   endif()
diff --git a/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh 
b/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh
new file mode 100644
index 0000000000..f520bf815a
--- /dev/null
+++ b/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <type_traits>
+#include <variant>
+#include <vector>
+
+#include "../../cuda/cuda_common.h"
+
+// clang-format off
+#include "cutlass/cutlass.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/float8.h"
+#include "cutlass/tensor_ref.h"
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/thread/linear_combination.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+
+#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
+#include "cutlass_extensions/gemm/dispatch_policy.hpp"
+// clang-format on
+
+#define CUTLASS_CHECK(status)                                      \
+  {                                                                \
+    cutlass::Status error = status;                                \
+    CHECK(error == cutlass::Status::kSuccess)                      \
+        << "Got cutlass error: " << cutlassGetStatusString(error); \
+  }
+
+using namespace cute;
+using ProblemShape = Shape<int, int, int, int>;
+using tvm::runtime::NDArray;
+
+template <typename TileShape, typename ClusterShape, typename ElementD, 
typename SchedulerType,
+          int ScaleGranularityM = 1>
+struct CutlassFP8ScaledBlockwiseGemmRunner {
+  using ElementAccumulator = float;
+  using ElementCompute = float;
+  using ElementBlockScale = float;
+
+  using ElementA = cutlass::float_e4m3_t;
+  using LayoutA = cutlass::layout::RowMajor;
+  static constexpr int AlignmentA = 128 / 
cutlass::sizeof_bits<ElementA>::value;
+
+  using ElementB = cutlass::float_e4m3_t;
+  using LayoutB = cutlass::layout::ColumnMajor;
+  static constexpr int AlignmentB = 128 / 
cutlass::sizeof_bits<ElementB>::value;
+
+  using ElementC = void;
+  using LayoutC = cutlass::layout::RowMajor;
+  static constexpr int AlignmentC = 128 / 
cutlass::sizeof_bits<ElementD>::value;
+
+  using LayoutD = cutlass::layout::RowMajor;
+  static constexpr int AlignmentD = 128 / 
cutlass::sizeof_bits<ElementD>::value;
+
+  using ArchTag = cutlass::arch::Sm90;
+  using OperatorClass = cutlass::arch::OpClassTensorOp;
+  using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
+  using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
+  using StoreEpilogueCompute =
+      typename 
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
+
+  using KernelSchedule =
+      
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
+          ScaleGranularityM>;
+  using CollectiveEpilogue = typename 
cutlass::epilogue::collective::CollectiveBuilder<
+      ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, 
ElementAccumulator,
+      ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, 
AlignmentD,
+      EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp;
+
+  using CollectiveMainloop = typename 
cutlass::gemm::collective::CollectiveBuilder<
+      ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, 
LayoutB, AlignmentB,
+      ElementAccumulator, TileShape, ClusterShape,
+      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
+          sizeof(typename CollectiveEpilogue::SharedStorage))>,
+      KernelSchedule>::CollectiveOp;
+
+  using GemmKernel =
+      cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>,  // 
Indicates ProblemShape
+                                           CollectiveMainloop, 
CollectiveEpilogue, SchedulerType>;
+  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
+
+  using StrideA = typename Gemm::GemmKernel::StrideA;
+  using StrideB = typename Gemm::GemmKernel::StrideB;
+  using StrideD = typename Gemm::GemmKernel::StrideD;
+
+  void run_gemm(const ElementA* a_ptr, const ElementB* b_ptr, const 
ElementBlockScale* scales_a_ptr,
+                const ElementBlockScale* scales_b_ptr, ElementD* o_ptr, 
ProblemShape* problem_size,
+                StrideA* stride_a, StrideB* stride_b, StrideD* stride_d, 
uint8_t* workspace,
+                int64_t workspace_size, cudaStream_t stream) {
+    cutlass::KernelHardwareInfo hw_info;
+    hw_info.device_id = 0;
+    hw_info.sm_count =
+        
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
+
+    typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
+    static constexpr bool UsesStreamKScheduler =
+        cute::is_same_v<typename Gemm::GemmKernel::TileSchedulerTag,
+                        cutlass::gemm::StreamKScheduler>;
+    if constexpr (UsesStreamKScheduler) {
+      using DecompositionMode = typename cutlass::gemm::kernel::detail::
+          PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
+      using ReductionMode = typename cutlass::gemm::kernel::detail::
+          PersistentTileSchedulerSm90StreamKParams::ReductionMode;
+      scheduler.decomposition_mode = DecompositionMode::StreamK;
+      scheduler.reduction_mode = ReductionMode::Nondeterministic;
+    }
+
+    typename Gemm::Arguments arguments = {
+        cutlass::gemm::GemmUniversalMode::kGemm,
+        *problem_size,
+        {a_ptr, *stride_a, b_ptr, *stride_b, scales_a_ptr, scales_b_ptr},
+        {{}, nullptr, *stride_d, o_ptr, *stride_d},
+        hw_info,
+        scheduler};
+
+    Gemm gemm_op;
+    CUTLASS_CHECK(gemm_op.can_implement(arguments));
+    CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+    CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
+    CUTLASS_CHECK(gemm_op.run(stream));
+  }
+};
+
+template <typename TileShape, typename ClusterShape, typename ElementA, 
typename ElementB,
+          typename ElementD, typename ElementBlockScale>
+void cutlass_fp8_blockwise_scaled_gemm(ElementA* a, ElementB* b, 
ElementBlockScale* scales_a,
+                                       ElementBlockScale* scales_b, ElementD* 
out,
+                                       uint8_t* workspace, int64_t 
workspace_size, int64_t m,
+                                       int64_t n, int64_t k, cudaStream_t 
stream) {
+  if (k > 3 * n) {
+    using SchedulerType = cutlass::gemm::StreamKScheduler;
+    using Runner =
+        CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, 
SchedulerType>;
+    using StrideA = typename Runner::StrideA;
+    using StrideB = typename Runner::StrideB;
+    using StrideD = typename Runner::StrideD;
+
+    Runner runner;
+    StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
+    StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
+    StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
+    ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), 
static_cast<int>(k), 1};
+    runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, 
&stride_b, &stride_d,
+                    workspace, workspace_size, stream);
+  } else {
+    using SchedulerType = cutlass::gemm::PersistentScheduler;
+    using Runner =
+        CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, 
SchedulerType>;
+    using StrideA = typename Runner::StrideA;
+    using StrideB = typename Runner::StrideB;
+    using StrideD = typename Runner::StrideD;
+
+    Runner runner;
+    StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
+    StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
+    StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
+    ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), 
static_cast<int>(k), 1};
+    runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, 
&stride_b, &stride_d,
+                    workspace, workspace_size, stream);
+  }
+}
+
+template <typename TileShape, typename ClusterShape, typename ElementA, 
typename ElementB,
+          typename ElementD, typename ElementBlockScale>
+void cutlass_fp8_blockwise_scaled_bmm(ElementA* a, ElementB* b, 
ElementBlockScale* scales_a,
+                                      ElementBlockScale* scales_b, ElementD* 
out,
+                                      uint8_t* workspace, int64_t 
workspace_size, int64_t m,
+                                      int64_t n, int64_t k, int64_t l, 
cudaStream_t stream) {
+  if (k > 3 * n) {
+    using SchedulerType = cutlass::gemm::StreamKScheduler;
+    using Runner =
+        CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, 
SchedulerType>;
+    using StrideA = typename Runner::StrideA;
+    using StrideB = typename Runner::StrideB;
+    using StrideD = typename Runner::StrideD;
+
+    Runner runner;
+    StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
+    StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
+    StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
+    ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), 
static_cast<int>(k),
+                              static_cast<int>(l)};
+    runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, 
&stride_b, &stride_d,
+                    workspace, workspace_size, stream);
+  } else {
+    using SchedulerType = cutlass::gemm::PersistentScheduler;
+    using Runner =
+        CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, 
SchedulerType>;
+    using StrideA = typename Runner::StrideA;
+    using StrideB = typename Runner::StrideB;
+    using StrideD = typename Runner::StrideD;
+
+    Runner runner;
+    StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
+    StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
+    StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
+    ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), 
static_cast<int>(k),
+                              static_cast<int>(l)};
+    runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, 
&stride_b, &stride_d,
+                    workspace, workspace_size, stream);
+  }
+}
diff --git a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu 
b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
new file mode 100644
index 0000000000..4ac5a621a0
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include "../cublas/cublas_utils.h"
+#include "blockwise_scaled_gemm_runner.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
+
+namespace tvm {
+namespace runtime {
+
+void tvm_cutlass_fp8_blockwise_scaled_gemm(NDArray a, NDArray b, NDArray 
scales_a, NDArray scales_b,
+                                           NDArray workspace, int64_t 
block_size_0,
+                                           int64_t block_size_1, NDArray out) {
+  using TileShape = Shape<_128, _128, _128>;
+  using ClusterShape = Shape<_1, _1, _1>;
+
+  // Workspace is used for storing device-side gemm arguments and cutlass 
internal workspace.
+  // Recommened size is 4MB.
+  auto get_stream_func = 
tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+  ICHECK(get_stream_func != nullptr);
+  cudaStream_t stream = 
static_cast<cudaStream_t>((*get_stream_func)().operator void*());
+
+  CHECK_GE(a->ndim, 2);
+  CHECK_EQ(scales_a->ndim, a->ndim);
+  CHECK_EQ(b->ndim, 2);
+  CHECK_EQ(scales_b->ndim, 2);
+  CHECK_EQ(workspace->ndim, 1);
+  CHECK_EQ(out->ndim, a->ndim);
+  int64_t m = 1;
+  for (int64_t i = 0; i < a->ndim - 1; ++i) {
+    m *= a->shape[i];
+  }
+  int64_t n = b->shape[0];
+  CHECK_EQ(a->shape[a->ndim - 1], b->shape[1]) << "Only col-major B is 
supported now.";
+  int64_t k = a->shape[a->ndim - 1];
+
+  // scales_a is col-major of (*a_shape[:-1], k / block_size)
+  CHECK_EQ(scales_a->shape[0] * block_size_1, k);
+  for (int64_t i = 1; i < scales_a->ndim; ++i) {
+    CHECK_EQ(scales_a->shape[i], a->shape[i - 1]);
+  }
+  // scales_b is col-major of (k / block_size, n / block_size)
+  CHECK_EQ(scales_b->shape[0] * block_size_0, n);
+  CHECK_EQ(scales_b->shape[1] * block_size_1, k);
+
+  using tvm::runtime::DataType;
+  CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3());
+  CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3());
+  CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
+  CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
+  CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+
+  if (DataType(out->dtype) == DataType::Float(16)) {
+    cutlass_fp8_blockwise_scaled_gemm<TileShape, ClusterShape, 
cutlass::float_e4m3_t,
+                                      cutlass::float_e4m3_t, cutlass::half_t, 
float>(
+        static_cast<cutlass::float_e4m3_t*>(a->data), 
static_cast<cutlass::float_e4m3_t*>(b->data),
+        static_cast<float*>(scales_a->data), 
static_cast<float*>(scales_b->data),
+        static_cast<cutlass::half_t*>(out->data), 
static_cast<uint8_t*>(workspace->data),
+        workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, 
stream);
+  } else if (DataType(out->dtype) == DataType::BFloat(16)) {
+    cutlass_fp8_blockwise_scaled_gemm<TileShape, ClusterShape, 
cutlass::float_e4m3_t,
+                                      cutlass::float_e4m3_t, 
cutlass::bfloat16_t, float>(
+        static_cast<cutlass::float_e4m3_t*>(a->data), 
static_cast<cutlass::float_e4m3_t*>(b->data),
+        static_cast<float*>(scales_a->data), 
static_cast<float*>(scales_b->data),
+        static_cast<cutlass::bfloat16_t*>(out->data), 
static_cast<uint8_t*>(workspace->data),
+        workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, 
stream);
+  } else {
+    LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype);
+  }
+}
+
+void tvm_cutlass_fp8_blockwise_scaled_bmm(NDArray a, NDArray b, NDArray 
scales_a, NDArray scales_b,
+                                          NDArray workspace, int64_t 
block_size_0,
+                                          int64_t block_size_1, NDArray out) {
+  using TileShape = Shape<_128, _128, _128>;
+  using ClusterShape = Shape<_1, _1, _1>;
+
+  // Workspace is used for storing device-side gemm arguments and cutlass 
internal workspace.
+  // Recommened size is 4MB.
+  auto get_stream_func = 
tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+  ICHECK(get_stream_func != nullptr);
+  cudaStream_t stream = 
static_cast<cudaStream_t>((*get_stream_func)().operator void*());
+
+  CHECK_EQ(a->ndim, 3);
+  CHECK_EQ(scales_a->ndim, 3);
+  CHECK_EQ(b->ndim, 3);
+  CHECK_EQ(scales_b->ndim, 3);
+  CHECK_EQ(workspace->ndim, 1);
+  CHECK_EQ(out->ndim, 3);
+  int64_t batch_size = a->shape[0];
+  int64_t m = a->shape[1];
+  int64_t n = b->shape[1];
+  CHECK_EQ(a->shape[2], b->shape[2]) << "Only col-major B is supported now.";
+  int64_t k = a->shape[2];
+  CHECK_EQ(b->shape[0], batch_size);
+  CHECK_EQ(scales_a->shape[0], batch_size);
+  CHECK_EQ(scales_b->shape[0], batch_size);
+  CHECK_EQ(out->shape[0], batch_size);
+
+  // scales_a is col-major of (batch_size, m, k / block_size)
+  CHECK_EQ(scales_a->shape[1] * block_size_1, k);
+  CHECK_EQ(scales_a->shape[2], m);
+  // scales_b is col-major of (k / block_size, n / block_size)
+  CHECK_EQ(scales_b->shape[1] * block_size_0, n);
+  CHECK_EQ(scales_b->shape[2] * block_size_1, k);
+
+  using tvm::runtime::DataType;
+  CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3());
+  CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3());
+  CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
+  CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
+  CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+
+  if (DataType(out->dtype) == DataType::Float(16)) {
+    cutlass_fp8_blockwise_scaled_bmm<TileShape, ClusterShape, 
cutlass::float_e4m3_t,
+                                     cutlass::float_e4m3_t, cutlass::half_t, 
float>(
+        static_cast<cutlass::float_e4m3_t*>(a->data), 
static_cast<cutlass::float_e4m3_t*>(b->data),
+        static_cast<float*>(scales_a->data), 
static_cast<float*>(scales_b->data),
+        static_cast<cutlass::half_t*>(out->data), 
static_cast<uint8_t*>(workspace->data),
+        workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, 
batch_size, stream);
+  } else if (DataType(out->dtype) == DataType::BFloat(16)) {
+    cutlass_fp8_blockwise_scaled_bmm<TileShape, ClusterShape, 
cutlass::float_e4m3_t,
+                                     cutlass::float_e4m3_t, 
cutlass::bfloat16_t, float>(
+        static_cast<cutlass::float_e4m3_t*>(a->data), 
static_cast<cutlass::float_e4m3_t*>(b->data),
+        static_cast<float*>(scales_a->data), 
static_cast<float*>(scales_b->data),
+        static_cast<cutlass::bfloat16_t*>(out->data), 
static_cast<uint8_t*>(workspace->data),
+        workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, 
batch_size, stream);
+  } else {
+    LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype);
+  }
+}
+
+TVM_REGISTER_GLOBAL("cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn")
+    .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_gemm);
+TVM_REGISTER_GLOBAL("cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn")
+    .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_bmm);
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh 
b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
index 71979672b9..a3c52e27a9 100644
--- a/src/runtime/contrib/cutlass/group_gemm_runner.cuh
+++ b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
@@ -105,10 +105,10 @@ struct CutlassGroupGemmRunner {
 
   using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
 
-  using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
-  using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
-  using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
-  using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
+  using StrideA = typename Gemm::GemmKernel::InternalStrideA;
+  using StrideB = typename Gemm::GemmKernel::InternalStrideB;
+  using StrideC = typename Gemm::GemmKernel::InternalStrideC;
+  using StrideD = typename Gemm::GemmKernel::InternalStrideD;
 
   void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const 
ElementC** ptr_C,
                       ElementC** ptr_D,
@@ -163,9 +163,9 @@ __global__ void prepare_group_gemm_arguments(
   ptr_D[group_id] = out + prev_rows * n;
   problem_sizes[group_id] = {static_cast<int>(indptr[group_id] - prev_rows), 
static_cast<int>(n),
                              static_cast<int>(k)};
-  stride_A[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
-  stride_B[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
-  stride_D[group_id] = cute::make_stride(n, Int<1>{}, int64_t{0});
+  stride_A[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{});
+  stride_B[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{});
+  stride_D[group_id] = cute::make_stride(n, Int<1>{}, Int<0>{});
 }
 
 template <typename ElementA, typename ElementB, typename ElementC>
diff --git a/tests/python/contrib/test_cutlass_gemm.py 
b/tests/python/contrib/test_cutlass_gemm.py
new file mode 100644
index 0000000000..7c259e6f7d
--- /dev/null
+++ b/tests/python/contrib/test_cutlass_gemm.py
@@ -0,0 +1,352 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Tuple
+
+import ml_dtypes
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm.contrib.pickle_memoize import memoize
+
+
+def get_random_ndarray(shape, dtype):
+    if dtype == "int8":
+        return np.random.randint(-128, 128, shape).astype(dtype)
+    elif dtype == "uint8":
+        return np.random.randint(0, 256, shape).astype(dtype)
+    return np.random.uniform(-1, 1, shape).astype(dtype)
+
+
+def verify_group_gemm(
+    func_name, M, N, K, num_groups, x_dtype, weight_dtype, out_dtype, 
use_scale, rtol, atol
+):
+    group_gemm_func = tvm.get_global_func(func_name, allow_missing=True)
+    if group_gemm_func is None:
+        print(f"Skipped as {func_name} is not available")
+        return
+
+    @memoize("tvm.contrib.cutlass.test_group_gemm_sm90")
+    def get_ref_data():
+        assert M % num_groups == 0
+        M_per_group = M // num_groups
+        a_np = get_random_ndarray((M, K), "float16")
+        b_np = get_random_ndarray((num_groups, N, K), "float16")
+        indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group
+        c_np = np.concatenate(
+            [a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i 
in range(num_groups)],
+            axis=0,
+        )
+        return a_np, b_np, indptr_np, c_np
+
+    def to_numpy_dtype(dtype):
+        mapping = {"float8_e5m2": ml_dtypes.float8_e5m2, "float8_e4m3fn": 
ml_dtypes.float8_e4m3fn}
+        return mapping.get(dtype, dtype)
+
+    a_np, b_np, indptr_np, c_np = get_ref_data()
+    dev = tvm.cuda(0)
+    a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev)
+    b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev)
+    c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev)
+    indptr_nd = tvm.nd.array(indptr_np, device=dev)
+    workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev)
+    if use_scale:
+        scale = tvm.nd.array(np.array([1.0], dtype="float32"), device=dev)
+        group_gemm_func(a_nd, b_nd, indptr_nd, workspace, scale, c_nd)
+    else:
+        group_gemm_func(a_nd, b_nd, indptr_nd, workspace, c_nd)
+    tvm.testing.assert_allclose(c_nd.numpy(), c_np, rtol=rtol, atol=atol)
+
+
[email protected]_cutlass
[email protected]_cuda_compute_version(9)
+def test_group_gemm_sm90():
+    verify_group_gemm(
+        "cutlass.group_gemm_fp16_sm90",
+        8,
+        128,
+        128,
+        4,
+        "float16",
+        "float16",
+        "float16",
+        False,
+        rtol=1e-3,
+        atol=1e-3,
+    )
+    verify_group_gemm(
+        "cutlass.group_gemm_e5m2_e5m2_fp16",
+        8,
+        16,
+        16,
+        4,
+        "float8_e5m2",
+        "float8_e5m2",
+        "float16",
+        True,
+        rtol=1e-1,
+        atol=1,
+    )
+    verify_group_gemm(
+        "cutlass.group_gemm_e4m3_e4m3_fp16",
+        8,
+        16,
+        16,
+        4,
+        "float8_e4m3fn",
+        "float8_e4m3fn",
+        "float16",
+        True,
+        rtol=1e-1,
+        atol=1,
+    )
+
+
+def rowwise_quant_fp8_e4m3(shape: Tuple[int, int], block_size: Tuple[int, 
int], dtype: str):
+    x_full_np = (np.random.rand(*shape) * 2 - 1).astype(dtype)
+    x_scale_shape = (
+        *shape[:-1],
+        (shape[-1] + block_size[1] - 1) // block_size[1],
+    )
+    # For each (block_size[1]) block, compute the max abs value of `w_full_np`
+    x_max_abs_np = np.zeros(x_scale_shape, dtype="float32")
+    for i in range(x_scale_shape[-1]):
+        x_max_abs_np[..., i] = np.max(
+            np.abs(x_full_np[..., i * block_size[1] : min((i + 1) * 
block_size[1], shape[-1])]),
+            axis=-1,
+        )[0]
+    # Scale is the `x_max_abs_np` divided by the max value of quant_dtype in 
ml_dtypes
+    fp8_max = float(ml_dtypes.finfo("float8_e4m3fn").max)
+    x_scale_np = x_max_abs_np / fp8_max
+    # `x_np` is the `x_full_np` divided by the `x_scale_np` (with block 
awareness),
+    # clamped to (-fp8_max, fp8_max), and cast to `quant_dtype`
+    x_np = np.zeros_like(x_full_np, dtype="float8_e4m3fn")
+    for i in range(x_scale_shape[-1]):
+        x_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])] 
= np.clip(
+            x_full_np[..., i * block_size[1] : min((i + 1) * block_size[1], 
shape[-1])]
+            / x_scale_np[..., i : i + 1],
+            -fp8_max,
+            fp8_max,
+        )
+
+    x_scale_np = np.random.rand(*x_scale_np.shape).astype("float32") / fp8_max
+    for i in range(x_scale_shape[-1]):
+        x_full_np[..., i * block_size[1] : min((i + 1) * block_size[1], 
shape[-1])] = (
+            x_np[..., i * block_size[1] : min((i + 1) * block_size[1], 
shape[-1])].astype(
+                x_scale_np.dtype
+            )
+            * x_scale_np[..., i : i + 1]
+        )
+    return x_np, x_scale_np
+
+
+def blockwise_quant_fp8_e4m3(shape: Tuple[int, int], block_size: Tuple[int, 
int], dtype: str):
+    w_full_np = (np.random.rand(*shape) * 2 - 1).astype(dtype)
+    w_scale_shape = (
+        *shape[:-2],
+        (shape[-2] + block_size[0] - 1) // block_size[0],
+        (shape[-1] + block_size[1] - 1) // block_size[1],
+    )
+    # For each (block_size[0], block_size[1]) block, compute the max abs value 
of `w_full_np`
+    w_max_abs_np = np.zeros(w_scale_shape, dtype="float32")
+    for i in range(w_scale_shape[-2]):
+        for j in range(w_scale_shape[-1]):
+            block_shape = (
+                *shape[:-2],
+                min(block_size[0], shape[-2] - i * block_size[0]),
+                min(block_size[1], shape[-1] - j * block_size[1]),
+            )
+            w_max_abs_np[..., i, j] = np.max(
+                np.abs(
+                    w_full_np[
+                        ...,
+                        i * block_size[0] : min((i + 1) * block_size[0], 
shape[-2]),
+                        j * block_size[1] : min((j + 1) * block_size[1], 
shape[-1]),
+                    ]
+                ).reshape(*shape[:-2], block_shape[-2] * block_shape[-1]),
+                axis=-1,
+            )
+    # Scale is the `w_max_abs_np` divided by the max value of quant_dtype in 
ml_dtypes
+    fp8_max = float(ml_dtypes.finfo("float8_e4m3fn").max)
+    w_scale_np = w_max_abs_np / fp8_max
+    # `w_np` is the `w_full_np` divided by the `w_scale_np` (with block 
awareness),
+    # clamped to (-fp8_max, fp8_max), and cast to `quant_dtype`
+    w_np = np.zeros_like(w_full_np, dtype="float8_e4m3fn")
+    if len(w_scale_shape) == 2:
+        for i in range(w_scale_shape[-2]):
+            for j in range(w_scale_shape[-1]):
+                w_np[
+                    i * block_size[0] : min((i + 1) * block_size[0], 
shape[-2]),
+                    j * block_size[1] : min((j + 1) * block_size[1], 
shape[-1]),
+                ] = np.clip(
+                    w_full_np[
+                        i * block_size[0] : min((i + 1) * block_size[0], 
shape[-2]),
+                        j * block_size[1] : min((j + 1) * block_size[1], 
shape[-1]),
+                    ]
+                    / w_scale_np[..., i, j],
+                    -fp8_max,
+                    fp8_max,
+                )
+    else:
+        for e in range(w_scale_shape[0]):
+            for i in range(w_scale_shape[-2]):
+                for j in range(w_scale_shape[-1]):
+                    w_np[
+                        e,
+                        i * block_size[0] : min((i + 1) * block_size[0], 
shape[-2]),
+                        j * block_size[1] : min((j + 1) * block_size[1], 
shape[-1]),
+                    ] = np.clip(
+                        w_full_np[
+                            e,
+                            i * block_size[0] : min((i + 1) * block_size[0], 
shape[-2]),
+                            j * block_size[1] : min((j + 1) * block_size[1], 
shape[-1]),
+                        ]
+                        / w_scale_np[e, i, j],
+                        -fp8_max,
+                        fp8_max,
+                    )
+
+    w_scale_np = np.random.rand(*w_scale_np.shape).astype("float32") / fp8_max
+    return w_np, w_scale_np
+
+
+def blockwise_matmul(
+    x_fp8_np: np.ndarray,
+    x_scale_np: np.ndarray,
+    w_np: np.ndarray,
+    w_scale_np: np.ndarray,
+    block_size: Tuple[int, int],
+    dtype: str,
+):
+    o_np = np.zeros((x_fp8_np.shape[0], w_np.shape[0]), dtype=dtype)
+    for j in range(w_scale_np.shape[0]):
+        for k in range(w_scale_np.shape[1]):
+            o_np[:, j * block_size[0] : min((j + 1) * block_size[0], 
w_np.shape[0])] += (
+                np.matmul(
+                    x_fp8_np[
+                        :, k * block_size[1] : min((k + 1) * block_size[1], 
x_fp8_np.shape[1])
+                    ].astype(dtype),
+                    w_np[
+                        j * block_size[0] : min((j + 1) * block_size[0], 
w_np.shape[0]),
+                        k * block_size[1] : min((k + 1) * block_size[1], 
w_np.shape[1]),
+                    ].T.astype(dtype),
+                )
+                * x_scale_np[:, k : k + 1]
+                * w_scale_np[j, k]
+            )
+    return o_np
+
+
+def blockwise_bmm(
+    x_fp8_np: np.ndarray,
+    x_scale_np: np.ndarray,
+    w_np: np.ndarray,
+    w_scale_np: np.ndarray,
+    block_size: Tuple[int, int],
+    dtype: str,
+):
+    o_np = np.zeros((x_fp8_np.shape[0], x_fp8_np.shape[1], w_np.shape[1]), 
dtype=dtype)
+    for j in range(w_scale_np.shape[1]):
+        for k in range(w_scale_np.shape[2]):
+            o_np[..., j * block_size[0] : min((j + 1) * block_size[0], 
w_np.shape[1])] += (
+                np.matmul(
+                    x_fp8_np[
+                        ..., k * block_size[1] : min((k + 1) * block_size[1], 
x_fp8_np.shape[2])
+                    ].astype(dtype),
+                    w_np[
+                        ...,
+                        j * block_size[0] : min((j + 1) * block_size[0], 
w_np.shape[1]),
+                        k * block_size[1] : min((k + 1) * block_size[1], 
w_np.shape[2]),
+                    ]
+                    .transpose(0, 2, 1)
+                    .astype(dtype),
+                )
+                * x_scale_np[..., k : k + 1]
+                * w_scale_np[..., j : j + 1, k : k + 1]
+            )
+    return o_np
+
+
[email protected]_cutlass
[email protected]_cuda_compute_version(9)
+def test_fp8_e4m3_blockwise_scaled_gemm():
+    M = 16
+    N = 4608
+    K = 896
+    block_size = (128, 128)
+    assert N % 128 == 0 and K % 128 == 0  # Only support N/K are multiple of 
128
+
+    func_name = "cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn"
+    gemm_func = tvm.get_global_func(func_name, allow_missing=True)
+    if gemm_func is None:
+        print(f"Skipped as {func_name} is not available")
+        return
+
+    device = tvm.cuda(0)
+    dtype = "bfloat16"
+    x_np, x_scale_np = rowwise_quant_fp8_e4m3((M, K), block_size, dtype)
+    w_np, w_scale_np = blockwise_quant_fp8_e4m3((N, K), block_size, dtype)
+    o_np = blockwise_matmul(x_np, x_scale_np, w_np, w_scale_np, block_size, 
dtype)
+    x_tvm = tvm.nd.array(x_np, device=device)
+    x_scale_tvm = tvm.nd.array(x_scale_np.T, device=device)
+    w_tvm = tvm.nd.array(w_np, device=device)
+    w_scale_tvm = tvm.nd.array(w_scale_np, device=device)
+    workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=device)
+    o_tvm = tvm.nd.empty((M, N), dtype=dtype, device=device)
+    gemm_func(
+        x_tvm, w_tvm, x_scale_tvm, w_scale_tvm, workspace, block_size[0], 
block_size[1], o_tvm
+    )
+    o_tvm = o_tvm.numpy()
+    tvm.testing.assert_allclose(o_tvm, o_np, rtol=1e-4, atol=0.5)
+
+
[email protected]_cutlass
[email protected]_cuda_compute_version(9)
+def test_fp8_e4m3_blockwise_scaled_bmm():
+    B = 16
+    M = 40
+    N = 512
+    K = 128
+    block_size = (128, 128)
+    assert N % 128 == 0 and K % 128 == 0  # Only support N/K are multiple of 
128
+
+    func_name = "cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn"
+    gemm_func = tvm.get_global_func(func_name, allow_missing=True)
+    if gemm_func is None:
+        print(f"Skipped as {func_name} is not available")
+        return
+
+    device = tvm.cuda(0)
+    dtype = "bfloat16"
+    x_np, x_scale_np = rowwise_quant_fp8_e4m3((B, M, K), block_size, dtype)
+    w_np, w_scale_np = blockwise_quant_fp8_e4m3((B, N, K), block_size, dtype)
+    o_np = blockwise_bmm(x_np, x_scale_np, w_np, w_scale_np, block_size, dtype)
+    x_tvm = tvm.nd.array(x_np, device=device)
+    x_scale_tvm = tvm.nd.array(x_scale_np.transpose(0, 2, 1), device=device)
+    w_tvm = tvm.nd.array(w_np, device=device)
+    w_scale_tvm = tvm.nd.array(w_scale_np, device=device)
+    workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=device)
+    o_tvm = tvm.nd.empty((B, M, N), dtype=dtype, device=device)
+    gemm_func(
+        x_tvm, w_tvm, x_scale_tvm, w_scale_tvm, workspace, block_size[0], 
block_size[1], o_tvm
+    )
+    o_tvm = o_tvm.numpy()
+    tvm.testing.assert_allclose(o_tvm, o_np, rtol=1e-4, atol=0.5)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to