This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b0ccfb39c5 [CUTLASS] Add blockwise scale gemm/bmm kernels (#17789)
b0ccfb39c5 is described below
commit b0ccfb39c5d595de07f0a69dceb0cf51f0cb86a9
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Apr 1 01:00:49 2025 +0200
[CUTLASS] Add blockwise scale gemm/bmm kernels (#17789)
This PR introduces blockwise scale matmul and batch matmul CUTLASS
kernels, adapted from SGLang (http://github.com/sgl-project/sglang),
vLLM (https://github.com/vllm-project/vllm) and
https://github.com/soundOfDestiny/cutlass.
We add unit tests for gemm and bmm. This PR also restores some
cutlass gemm tests that were removed before during Relay phasing out.
---
3rdparty/cutlass | 2 +-
3rdparty/cutlass_fpA_intB_gemm | 2 +-
cmake/modules/contrib/CUTLASS.cmake | 6 +-
.../cutlass/blockwise_scaled_gemm_runner.cuh | 228 +++++++++++++
.../contrib/cutlass/fp8_blockwise_scaled_gemm.cu | 164 ++++++++++
src/runtime/contrib/cutlass/group_gemm_runner.cuh | 14 +-
tests/python/contrib/test_cutlass_gemm.py | 352 +++++++++++++++++++++
7 files changed, 758 insertions(+), 10 deletions(-)
diff --git a/3rdparty/cutlass b/3rdparty/cutlass
index bbe579a9e3..afa1772203 160000
--- a/3rdparty/cutlass
+++ b/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49
+Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008
diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm
index f09824e950..fdef230791 160000
--- a/3rdparty/cutlass_fpA_intB_gemm
+++ b/3rdparty/cutlass_fpA_intB_gemm
@@ -1 +1 @@
-Subproject commit f09824e950ed6678670004bd23578757b3473f21
+Subproject commit fdef2307917ec2c7cc5becc29fb95d77498484bd
diff --git a/cmake/modules/contrib/CUTLASS.cmake
b/cmake/modules/contrib/CUTLASS.cmake
index b302622cbc..b9097a02e9 100644
--- a/cmake/modules/contrib/CUTLASS.cmake
+++ b/cmake/modules/contrib/CUTLASS.cmake
@@ -58,11 +58,15 @@ if(USE_CUDA AND USE_CUTLASS)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS
src/runtime/contrib/cutlass/fp16_group_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS
src/runtime/contrib/cutlass/fp8_group_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS
src/runtime/contrib/cutlass/fp8_gemm.cu)
+ list(APPEND TVM_CUTLASS_RUNTIME_SRCS
src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu)
endif()
if(TVM_CUTLASS_RUNTIME_SRCS)
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
target_compile_options(tvm_cutlass_objs PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>)
- target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include)
+ target_include_directories(tvm_cutlass_objs PRIVATE
+ ${CUTLASS_DIR}/include
+
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass_extensions/include
+ )
target_compile_definitions(tvm_cutlass_objs PRIVATE
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
list(APPEND CUTLASS_RUNTIME_OBJS
"$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
endif()
diff --git a/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh
b/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh
new file mode 100644
index 0000000000..f520bf815a
--- /dev/null
+++ b/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <type_traits>
+#include <variant>
+#include <vector>
+
+#include "../../cuda/cuda_common.h"
+
+// clang-format off
+#include "cutlass/cutlass.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/float8.h"
+#include "cutlass/tensor_ref.h"
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/thread/linear_combination.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+
+#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
+#include "cutlass_extensions/gemm/dispatch_policy.hpp"
+// clang-format on
+
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ CHECK(error == cutlass::Status::kSuccess) \
+ << "Got cutlass error: " << cutlassGetStatusString(error); \
+ }
+
+using namespace cute;
+using ProblemShape = Shape<int, int, int, int>;
+using tvm::runtime::NDArray;
+
+template <typename TileShape, typename ClusterShape, typename ElementD,
typename SchedulerType,
+ int ScaleGranularityM = 1>
+struct CutlassFP8ScaledBlockwiseGemmRunner {
+ using ElementAccumulator = float;
+ using ElementCompute = float;
+ using ElementBlockScale = float;
+
+ using ElementA = cutlass::float_e4m3_t;
+ using LayoutA = cutlass::layout::RowMajor;
+ static constexpr int AlignmentA = 128 /
cutlass::sizeof_bits<ElementA>::value;
+
+ using ElementB = cutlass::float_e4m3_t;
+ using LayoutB = cutlass::layout::ColumnMajor;
+ static constexpr int AlignmentB = 128 /
cutlass::sizeof_bits<ElementB>::value;
+
+ using ElementC = void;
+ using LayoutC = cutlass::layout::RowMajor;
+ static constexpr int AlignmentC = 128 /
cutlass::sizeof_bits<ElementD>::value;
+
+ using LayoutD = cutlass::layout::RowMajor;
+ static constexpr int AlignmentD = 128 /
cutlass::sizeof_bits<ElementD>::value;
+
+ using ArchTag = cutlass::arch::Sm90;
+ using OperatorClass = cutlass::arch::OpClassTensorOp;
+ using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
+ using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
+ using StoreEpilogueCompute =
+ typename
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
+
+ using KernelSchedule =
+
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
+ ScaleGranularityM>;
+ using CollectiveEpilogue = typename
cutlass::epilogue::collective::CollectiveBuilder<
+ ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator,
+ ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD,
AlignmentD,
+ EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp;
+
+ using CollectiveMainloop = typename
cutlass::gemm::collective::CollectiveBuilder<
+ ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
LayoutB, AlignmentB,
+ ElementAccumulator, TileShape, ClusterShape,
+ cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
+ sizeof(typename CollectiveEpilogue::SharedStorage))>,
+ KernelSchedule>::CollectiveOp;
+
+ using GemmKernel =
+ cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, //
Indicates ProblemShape
+ CollectiveMainloop,
CollectiveEpilogue, SchedulerType>;
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
+
+ using StrideA = typename Gemm::GemmKernel::StrideA;
+ using StrideB = typename Gemm::GemmKernel::StrideB;
+ using StrideD = typename Gemm::GemmKernel::StrideD;
+
+ void run_gemm(const ElementA* a_ptr, const ElementB* b_ptr, const
ElementBlockScale* scales_a_ptr,
+ const ElementBlockScale* scales_b_ptr, ElementD* o_ptr,
ProblemShape* problem_size,
+ StrideA* stride_a, StrideB* stride_b, StrideD* stride_d,
uint8_t* workspace,
+ int64_t workspace_size, cudaStream_t stream) {
+ cutlass::KernelHardwareInfo hw_info;
+ hw_info.device_id = 0;
+ hw_info.sm_count =
+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
+
+ typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
+ static constexpr bool UsesStreamKScheduler =
+ cute::is_same_v<typename Gemm::GemmKernel::TileSchedulerTag,
+ cutlass::gemm::StreamKScheduler>;
+ if constexpr (UsesStreamKScheduler) {
+ using DecompositionMode = typename cutlass::gemm::kernel::detail::
+ PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
+ using ReductionMode = typename cutlass::gemm::kernel::detail::
+ PersistentTileSchedulerSm90StreamKParams::ReductionMode;
+ scheduler.decomposition_mode = DecompositionMode::StreamK;
+ scheduler.reduction_mode = ReductionMode::Nondeterministic;
+ }
+
+ typename Gemm::Arguments arguments = {
+ cutlass::gemm::GemmUniversalMode::kGemm,
+ *problem_size,
+ {a_ptr, *stride_a, b_ptr, *stride_b, scales_a_ptr, scales_b_ptr},
+ {{}, nullptr, *stride_d, o_ptr, *stride_d},
+ hw_info,
+ scheduler};
+
+ Gemm gemm_op;
+ CUTLASS_CHECK(gemm_op.can_implement(arguments));
+ CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+ CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
+ CUTLASS_CHECK(gemm_op.run(stream));
+ }
+};
+
+template <typename TileShape, typename ClusterShape, typename ElementA,
typename ElementB,
+ typename ElementD, typename ElementBlockScale>
+void cutlass_fp8_blockwise_scaled_gemm(ElementA* a, ElementB* b,
ElementBlockScale* scales_a,
+ ElementBlockScale* scales_b, ElementD*
out,
+ uint8_t* workspace, int64_t
workspace_size, int64_t m,
+ int64_t n, int64_t k, cudaStream_t
stream) {
+ if (k > 3 * n) {
+ using SchedulerType = cutlass::gemm::StreamKScheduler;
+ using Runner =
+ CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD,
SchedulerType>;
+ using StrideA = typename Runner::StrideA;
+ using StrideB = typename Runner::StrideB;
+ using StrideD = typename Runner::StrideD;
+
+ Runner runner;
+ StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
+ StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
+ StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
+ ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n),
static_cast<int>(k), 1};
+ runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a,
&stride_b, &stride_d,
+ workspace, workspace_size, stream);
+ } else {
+ using SchedulerType = cutlass::gemm::PersistentScheduler;
+ using Runner =
+ CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD,
SchedulerType>;
+ using StrideA = typename Runner::StrideA;
+ using StrideB = typename Runner::StrideB;
+ using StrideD = typename Runner::StrideD;
+
+ Runner runner;
+ StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
+ StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
+ StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
+ ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n),
static_cast<int>(k), 1};
+ runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a,
&stride_b, &stride_d,
+ workspace, workspace_size, stream);
+ }
+}
+
+template <typename TileShape, typename ClusterShape, typename ElementA,
typename ElementB,
+ typename ElementD, typename ElementBlockScale>
+void cutlass_fp8_blockwise_scaled_bmm(ElementA* a, ElementB* b,
ElementBlockScale* scales_a,
+ ElementBlockScale* scales_b, ElementD*
out,
+ uint8_t* workspace, int64_t
workspace_size, int64_t m,
+ int64_t n, int64_t k, int64_t l,
cudaStream_t stream) {
+ if (k > 3 * n) {
+ using SchedulerType = cutlass::gemm::StreamKScheduler;
+ using Runner =
+ CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD,
SchedulerType>;
+ using StrideA = typename Runner::StrideA;
+ using StrideB = typename Runner::StrideB;
+ using StrideD = typename Runner::StrideD;
+
+ Runner runner;
+ StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
+ StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
+ StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
+ ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n),
static_cast<int>(k),
+ static_cast<int>(l)};
+ runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a,
&stride_b, &stride_d,
+ workspace, workspace_size, stream);
+ } else {
+ using SchedulerType = cutlass::gemm::PersistentScheduler;
+ using Runner =
+ CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD,
SchedulerType>;
+ using StrideA = typename Runner::StrideA;
+ using StrideB = typename Runner::StrideB;
+ using StrideD = typename Runner::StrideD;
+
+ Runner runner;
+ StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
+ StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
+ StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
+ ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n),
static_cast<int>(k),
+ static_cast<int>(l)};
+ runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a,
&stride_b, &stride_d,
+ workspace, workspace_size, stream);
+ }
+}
diff --git a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
new file mode 100644
index 0000000000..4ac5a621a0
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include "../cublas/cublas_utils.h"
+#include "blockwise_scaled_gemm_runner.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
+
+namespace tvm {
+namespace runtime {
+
+void tvm_cutlass_fp8_blockwise_scaled_gemm(NDArray a, NDArray b, NDArray
scales_a, NDArray scales_b,
+ NDArray workspace, int64_t
block_size_0,
+ int64_t block_size_1, NDArray out) {
+ using TileShape = Shape<_128, _128, _128>;
+ using ClusterShape = Shape<_1, _1, _1>;
+
+ // Workspace is used for storing device-side gemm arguments and cutlass
internal workspace.
+ // Recommened size is 4MB.
+ auto get_stream_func =
tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(get_stream_func != nullptr);
+ cudaStream_t stream =
static_cast<cudaStream_t>((*get_stream_func)().operator void*());
+
+ CHECK_GE(a->ndim, 2);
+ CHECK_EQ(scales_a->ndim, a->ndim);
+ CHECK_EQ(b->ndim, 2);
+ CHECK_EQ(scales_b->ndim, 2);
+ CHECK_EQ(workspace->ndim, 1);
+ CHECK_EQ(out->ndim, a->ndim);
+ int64_t m = 1;
+ for (int64_t i = 0; i < a->ndim - 1; ++i) {
+ m *= a->shape[i];
+ }
+ int64_t n = b->shape[0];
+ CHECK_EQ(a->shape[a->ndim - 1], b->shape[1]) << "Only col-major B is
supported now.";
+ int64_t k = a->shape[a->ndim - 1];
+
+ // scales_a is col-major of (*a_shape[:-1], k / block_size)
+ CHECK_EQ(scales_a->shape[0] * block_size_1, k);
+ for (int64_t i = 1; i < scales_a->ndim; ++i) {
+ CHECK_EQ(scales_a->shape[i], a->shape[i - 1]);
+ }
+ // scales_b is col-major of (k / block_size, n / block_size)
+ CHECK_EQ(scales_b->shape[0] * block_size_0, n);
+ CHECK_EQ(scales_b->shape[1] * block_size_1, k);
+
+ using tvm::runtime::DataType;
+ CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3());
+ CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3());
+ CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
+ CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
+ CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+
+ if (DataType(out->dtype) == DataType::Float(16)) {
+ cutlass_fp8_blockwise_scaled_gemm<TileShape, ClusterShape,
cutlass::float_e4m3_t,
+ cutlass::float_e4m3_t, cutlass::half_t,
float>(
+ static_cast<cutlass::float_e4m3_t*>(a->data),
static_cast<cutlass::float_e4m3_t*>(b->data),
+ static_cast<float*>(scales_a->data),
static_cast<float*>(scales_b->data),
+ static_cast<cutlass::half_t*>(out->data),
static_cast<uint8_t*>(workspace->data),
+ workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k,
stream);
+ } else if (DataType(out->dtype) == DataType::BFloat(16)) {
+ cutlass_fp8_blockwise_scaled_gemm<TileShape, ClusterShape,
cutlass::float_e4m3_t,
+ cutlass::float_e4m3_t,
cutlass::bfloat16_t, float>(
+ static_cast<cutlass::float_e4m3_t*>(a->data),
static_cast<cutlass::float_e4m3_t*>(b->data),
+ static_cast<float*>(scales_a->data),
static_cast<float*>(scales_b->data),
+ static_cast<cutlass::bfloat16_t*>(out->data),
static_cast<uint8_t*>(workspace->data),
+ workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k,
stream);
+ } else {
+ LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype);
+ }
+}
+
+void tvm_cutlass_fp8_blockwise_scaled_bmm(NDArray a, NDArray b, NDArray
scales_a, NDArray scales_b,
+ NDArray workspace, int64_t
block_size_0,
+ int64_t block_size_1, NDArray out) {
+ using TileShape = Shape<_128, _128, _128>;
+ using ClusterShape = Shape<_1, _1, _1>;
+
+ // Workspace is used for storing device-side gemm arguments and cutlass
internal workspace.
+ // Recommened size is 4MB.
+ auto get_stream_func =
tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(get_stream_func != nullptr);
+ cudaStream_t stream =
static_cast<cudaStream_t>((*get_stream_func)().operator void*());
+
+ CHECK_EQ(a->ndim, 3);
+ CHECK_EQ(scales_a->ndim, 3);
+ CHECK_EQ(b->ndim, 3);
+ CHECK_EQ(scales_b->ndim, 3);
+ CHECK_EQ(workspace->ndim, 1);
+ CHECK_EQ(out->ndim, 3);
+ int64_t batch_size = a->shape[0];
+ int64_t m = a->shape[1];
+ int64_t n = b->shape[1];
+ CHECK_EQ(a->shape[2], b->shape[2]) << "Only col-major B is supported now.";
+ int64_t k = a->shape[2];
+ CHECK_EQ(b->shape[0], batch_size);
+ CHECK_EQ(scales_a->shape[0], batch_size);
+ CHECK_EQ(scales_b->shape[0], batch_size);
+ CHECK_EQ(out->shape[0], batch_size);
+
+ // scales_a is col-major of (batch_size, m, k / block_size)
+ CHECK_EQ(scales_a->shape[1] * block_size_1, k);
+ CHECK_EQ(scales_a->shape[2], m);
+ // scales_b is col-major of (k / block_size, n / block_size)
+ CHECK_EQ(scales_b->shape[1] * block_size_0, n);
+ CHECK_EQ(scales_b->shape[2] * block_size_1, k);
+
+ using tvm::runtime::DataType;
+ CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3());
+ CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3());
+ CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
+ CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
+ CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+
+ if (DataType(out->dtype) == DataType::Float(16)) {
+ cutlass_fp8_blockwise_scaled_bmm<TileShape, ClusterShape,
cutlass::float_e4m3_t,
+ cutlass::float_e4m3_t, cutlass::half_t,
float>(
+ static_cast<cutlass::float_e4m3_t*>(a->data),
static_cast<cutlass::float_e4m3_t*>(b->data),
+ static_cast<float*>(scales_a->data),
static_cast<float*>(scales_b->data),
+ static_cast<cutlass::half_t*>(out->data),
static_cast<uint8_t*>(workspace->data),
+ workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k,
batch_size, stream);
+ } else if (DataType(out->dtype) == DataType::BFloat(16)) {
+ cutlass_fp8_blockwise_scaled_bmm<TileShape, ClusterShape,
cutlass::float_e4m3_t,
+ cutlass::float_e4m3_t,
cutlass::bfloat16_t, float>(
+ static_cast<cutlass::float_e4m3_t*>(a->data),
static_cast<cutlass::float_e4m3_t*>(b->data),
+ static_cast<float*>(scales_a->data),
static_cast<float*>(scales_b->data),
+ static_cast<cutlass::bfloat16_t*>(out->data),
static_cast<uint8_t*>(workspace->data),
+ workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k,
batch_size, stream);
+ } else {
+ LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype);
+ }
+}
+
+TVM_REGISTER_GLOBAL("cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn")
+ .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_gemm);
+TVM_REGISTER_GLOBAL("cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn")
+ .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_bmm);
+
+} // namespace runtime
+} // namespace tvm
+
+#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh
b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
index 71979672b9..a3c52e27a9 100644
--- a/src/runtime/contrib/cutlass/group_gemm_runner.cuh
+++ b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
@@ -105,10 +105,10 @@ struct CutlassGroupGemmRunner {
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
- using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
- using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
- using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
- using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
+ using StrideA = typename Gemm::GemmKernel::InternalStrideA;
+ using StrideB = typename Gemm::GemmKernel::InternalStrideB;
+ using StrideC = typename Gemm::GemmKernel::InternalStrideC;
+ using StrideD = typename Gemm::GemmKernel::InternalStrideD;
void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const
ElementC** ptr_C,
ElementC** ptr_D,
@@ -163,9 +163,9 @@ __global__ void prepare_group_gemm_arguments(
ptr_D[group_id] = out + prev_rows * n;
problem_sizes[group_id] = {static_cast<int>(indptr[group_id] - prev_rows),
static_cast<int>(n),
static_cast<int>(k)};
- stride_A[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
- stride_B[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
- stride_D[group_id] = cute::make_stride(n, Int<1>{}, int64_t{0});
+ stride_A[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{});
+ stride_B[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{});
+ stride_D[group_id] = cute::make_stride(n, Int<1>{}, Int<0>{});
}
template <typename ElementA, typename ElementB, typename ElementC>
diff --git a/tests/python/contrib/test_cutlass_gemm.py
b/tests/python/contrib/test_cutlass_gemm.py
new file mode 100644
index 0000000000..7c259e6f7d
--- /dev/null
+++ b/tests/python/contrib/test_cutlass_gemm.py
@@ -0,0 +1,352 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Tuple
+
+import ml_dtypes
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm.contrib.pickle_memoize import memoize
+
+
+def get_random_ndarray(shape, dtype):
+ if dtype == "int8":
+ return np.random.randint(-128, 128, shape).astype(dtype)
+ elif dtype == "uint8":
+ return np.random.randint(0, 256, shape).astype(dtype)
+ return np.random.uniform(-1, 1, shape).astype(dtype)
+
+
+def verify_group_gemm(
+ func_name, M, N, K, num_groups, x_dtype, weight_dtype, out_dtype,
use_scale, rtol, atol
+):
+ group_gemm_func = tvm.get_global_func(func_name, allow_missing=True)
+ if group_gemm_func is None:
+ print(f"Skipped as {func_name} is not available")
+ return
+
+ @memoize("tvm.contrib.cutlass.test_group_gemm_sm90")
+ def get_ref_data():
+ assert M % num_groups == 0
+ M_per_group = M // num_groups
+ a_np = get_random_ndarray((M, K), "float16")
+ b_np = get_random_ndarray((num_groups, N, K), "float16")
+ indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group
+ c_np = np.concatenate(
+ [a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i
in range(num_groups)],
+ axis=0,
+ )
+ return a_np, b_np, indptr_np, c_np
+
+ def to_numpy_dtype(dtype):
+ mapping = {"float8_e5m2": ml_dtypes.float8_e5m2, "float8_e4m3fn":
ml_dtypes.float8_e4m3fn}
+ return mapping.get(dtype, dtype)
+
+ a_np, b_np, indptr_np, c_np = get_ref_data()
+ dev = tvm.cuda(0)
+ a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev)
+ b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev)
+ c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev)
+ indptr_nd = tvm.nd.array(indptr_np, device=dev)
+ workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev)
+ if use_scale:
+ scale = tvm.nd.array(np.array([1.0], dtype="float32"), device=dev)
+ group_gemm_func(a_nd, b_nd, indptr_nd, workspace, scale, c_nd)
+ else:
+ group_gemm_func(a_nd, b_nd, indptr_nd, workspace, c_nd)
+ tvm.testing.assert_allclose(c_nd.numpy(), c_np, rtol=rtol, atol=atol)
+
+
[email protected]_cutlass
[email protected]_cuda_compute_version(9)
+def test_group_gemm_sm90():
+ verify_group_gemm(
+ "cutlass.group_gemm_fp16_sm90",
+ 8,
+ 128,
+ 128,
+ 4,
+ "float16",
+ "float16",
+ "float16",
+ False,
+ rtol=1e-3,
+ atol=1e-3,
+ )
+ verify_group_gemm(
+ "cutlass.group_gemm_e5m2_e5m2_fp16",
+ 8,
+ 16,
+ 16,
+ 4,
+ "float8_e5m2",
+ "float8_e5m2",
+ "float16",
+ True,
+ rtol=1e-1,
+ atol=1,
+ )
+ verify_group_gemm(
+ "cutlass.group_gemm_e4m3_e4m3_fp16",
+ 8,
+ 16,
+ 16,
+ 4,
+ "float8_e4m3fn",
+ "float8_e4m3fn",
+ "float16",
+ True,
+ rtol=1e-1,
+ atol=1,
+ )
+
+
+def rowwise_quant_fp8_e4m3(shape: Tuple[int, int], block_size: Tuple[int,
int], dtype: str):
+ x_full_np = (np.random.rand(*shape) * 2 - 1).astype(dtype)
+ x_scale_shape = (
+ *shape[:-1],
+ (shape[-1] + block_size[1] - 1) // block_size[1],
+ )
+ # For each (block_size[1]) block, compute the max abs value of `w_full_np`
+ x_max_abs_np = np.zeros(x_scale_shape, dtype="float32")
+ for i in range(x_scale_shape[-1]):
+ x_max_abs_np[..., i] = np.max(
+ np.abs(x_full_np[..., i * block_size[1] : min((i + 1) *
block_size[1], shape[-1])]),
+ axis=-1,
+ )[0]
+ # Scale is the `x_max_abs_np` divided by the max value of quant_dtype in
ml_dtypes
+ fp8_max = float(ml_dtypes.finfo("float8_e4m3fn").max)
+ x_scale_np = x_max_abs_np / fp8_max
+ # `x_np` is the `x_full_np` divided by the `x_scale_np` (with block
awareness),
+ # clamped to (-fp8_max, fp8_max), and cast to `quant_dtype`
+ x_np = np.zeros_like(x_full_np, dtype="float8_e4m3fn")
+ for i in range(x_scale_shape[-1]):
+ x_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])]
= np.clip(
+ x_full_np[..., i * block_size[1] : min((i + 1) * block_size[1],
shape[-1])]
+ / x_scale_np[..., i : i + 1],
+ -fp8_max,
+ fp8_max,
+ )
+
+ x_scale_np = np.random.rand(*x_scale_np.shape).astype("float32") / fp8_max
+ for i in range(x_scale_shape[-1]):
+ x_full_np[..., i * block_size[1] : min((i + 1) * block_size[1],
shape[-1])] = (
+ x_np[..., i * block_size[1] : min((i + 1) * block_size[1],
shape[-1])].astype(
+ x_scale_np.dtype
+ )
+ * x_scale_np[..., i : i + 1]
+ )
+ return x_np, x_scale_np
+
+
+def blockwise_quant_fp8_e4m3(shape: Tuple[int, int], block_size: Tuple[int,
int], dtype: str):
+ w_full_np = (np.random.rand(*shape) * 2 - 1).astype(dtype)
+ w_scale_shape = (
+ *shape[:-2],
+ (shape[-2] + block_size[0] - 1) // block_size[0],
+ (shape[-1] + block_size[1] - 1) // block_size[1],
+ )
+ # For each (block_size[0], block_size[1]) block, compute the max abs value
of `w_full_np`
+ w_max_abs_np = np.zeros(w_scale_shape, dtype="float32")
+ for i in range(w_scale_shape[-2]):
+ for j in range(w_scale_shape[-1]):
+ block_shape = (
+ *shape[:-2],
+ min(block_size[0], shape[-2] - i * block_size[0]),
+ min(block_size[1], shape[-1] - j * block_size[1]),
+ )
+ w_max_abs_np[..., i, j] = np.max(
+ np.abs(
+ w_full_np[
+ ...,
+ i * block_size[0] : min((i + 1) * block_size[0],
shape[-2]),
+ j * block_size[1] : min((j + 1) * block_size[1],
shape[-1]),
+ ]
+ ).reshape(*shape[:-2], block_shape[-2] * block_shape[-1]),
+ axis=-1,
+ )
+ # Scale is the `w_max_abs_np` divided by the max value of quant_dtype in
ml_dtypes
+ fp8_max = float(ml_dtypes.finfo("float8_e4m3fn").max)
+ w_scale_np = w_max_abs_np / fp8_max
+ # `w_np` is the `w_full_np` divided by the `w_scale_np` (with block
awareness),
+ # clamped to (-fp8_max, fp8_max), and cast to `quant_dtype`
+ w_np = np.zeros_like(w_full_np, dtype="float8_e4m3fn")
+ if len(w_scale_shape) == 2:
+ for i in range(w_scale_shape[-2]):
+ for j in range(w_scale_shape[-1]):
+ w_np[
+ i * block_size[0] : min((i + 1) * block_size[0],
shape[-2]),
+ j * block_size[1] : min((j + 1) * block_size[1],
shape[-1]),
+ ] = np.clip(
+ w_full_np[
+ i * block_size[0] : min((i + 1) * block_size[0],
shape[-2]),
+ j * block_size[1] : min((j + 1) * block_size[1],
shape[-1]),
+ ]
+ / w_scale_np[..., i, j],
+ -fp8_max,
+ fp8_max,
+ )
+ else:
+ for e in range(w_scale_shape[0]):
+ for i in range(w_scale_shape[-2]):
+ for j in range(w_scale_shape[-1]):
+ w_np[
+ e,
+ i * block_size[0] : min((i + 1) * block_size[0],
shape[-2]),
+ j * block_size[1] : min((j + 1) * block_size[1],
shape[-1]),
+ ] = np.clip(
+ w_full_np[
+ e,
+ i * block_size[0] : min((i + 1) * block_size[0],
shape[-2]),
+ j * block_size[1] : min((j + 1) * block_size[1],
shape[-1]),
+ ]
+ / w_scale_np[e, i, j],
+ -fp8_max,
+ fp8_max,
+ )
+
+ w_scale_np = np.random.rand(*w_scale_np.shape).astype("float32") / fp8_max
+ return w_np, w_scale_np
+
+
+def blockwise_matmul(
+ x_fp8_np: np.ndarray,
+ x_scale_np: np.ndarray,
+ w_np: np.ndarray,
+ w_scale_np: np.ndarray,
+ block_size: Tuple[int, int],
+ dtype: str,
+):
+ o_np = np.zeros((x_fp8_np.shape[0], w_np.shape[0]), dtype=dtype)
+ for j in range(w_scale_np.shape[0]):
+ for k in range(w_scale_np.shape[1]):
+ o_np[:, j * block_size[0] : min((j + 1) * block_size[0],
w_np.shape[0])] += (
+ np.matmul(
+ x_fp8_np[
+ :, k * block_size[1] : min((k + 1) * block_size[1],
x_fp8_np.shape[1])
+ ].astype(dtype),
+ w_np[
+ j * block_size[0] : min((j + 1) * block_size[0],
w_np.shape[0]),
+ k * block_size[1] : min((k + 1) * block_size[1],
w_np.shape[1]),
+ ].T.astype(dtype),
+ )
+ * x_scale_np[:, k : k + 1]
+ * w_scale_np[j, k]
+ )
+ return o_np
+
+
+def blockwise_bmm(
+ x_fp8_np: np.ndarray,
+ x_scale_np: np.ndarray,
+ w_np: np.ndarray,
+ w_scale_np: np.ndarray,
+ block_size: Tuple[int, int],
+ dtype: str,
+):
+ o_np = np.zeros((x_fp8_np.shape[0], x_fp8_np.shape[1], w_np.shape[1]),
dtype=dtype)
+ for j in range(w_scale_np.shape[1]):
+ for k in range(w_scale_np.shape[2]):
+ o_np[..., j * block_size[0] : min((j + 1) * block_size[0],
w_np.shape[1])] += (
+ np.matmul(
+ x_fp8_np[
+ ..., k * block_size[1] : min((k + 1) * block_size[1],
x_fp8_np.shape[2])
+ ].astype(dtype),
+ w_np[
+ ...,
+ j * block_size[0] : min((j + 1) * block_size[0],
w_np.shape[1]),
+ k * block_size[1] : min((k + 1) * block_size[1],
w_np.shape[2]),
+ ]
+ .transpose(0, 2, 1)
+ .astype(dtype),
+ )
+ * x_scale_np[..., k : k + 1]
+ * w_scale_np[..., j : j + 1, k : k + 1]
+ )
+ return o_np
+
+
[email protected]_cutlass
[email protected]_cuda_compute_version(9)
+def test_fp8_e4m3_blockwise_scaled_gemm():
+ M = 16
+ N = 4608
+ K = 896
+ block_size = (128, 128)
+ assert N % 128 == 0 and K % 128 == 0 # Only support N/K are multiple of
128
+
+ func_name = "cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn"
+ gemm_func = tvm.get_global_func(func_name, allow_missing=True)
+ if gemm_func is None:
+ print(f"Skipped as {func_name} is not available")
+ return
+
+ device = tvm.cuda(0)
+ dtype = "bfloat16"
+ x_np, x_scale_np = rowwise_quant_fp8_e4m3((M, K), block_size, dtype)
+ w_np, w_scale_np = blockwise_quant_fp8_e4m3((N, K), block_size, dtype)
+ o_np = blockwise_matmul(x_np, x_scale_np, w_np, w_scale_np, block_size,
dtype)
+ x_tvm = tvm.nd.array(x_np, device=device)
+ x_scale_tvm = tvm.nd.array(x_scale_np.T, device=device)
+ w_tvm = tvm.nd.array(w_np, device=device)
+ w_scale_tvm = tvm.nd.array(w_scale_np, device=device)
+ workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=device)
+ o_tvm = tvm.nd.empty((M, N), dtype=dtype, device=device)
+ gemm_func(
+ x_tvm, w_tvm, x_scale_tvm, w_scale_tvm, workspace, block_size[0],
block_size[1], o_tvm
+ )
+ o_tvm = o_tvm.numpy()
+ tvm.testing.assert_allclose(o_tvm, o_np, rtol=1e-4, atol=0.5)
+
+
[email protected]_cutlass
[email protected]_cuda_compute_version(9)
+def test_fp8_e4m3_blockwise_scaled_bmm():
+ B = 16
+ M = 40
+ N = 512
+ K = 128
+ block_size = (128, 128)
+ assert N % 128 == 0 and K % 128 == 0 # Only support N/K are multiple of
128
+
+ func_name = "cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn"
+ gemm_func = tvm.get_global_func(func_name, allow_missing=True)
+ if gemm_func is None:
+ print(f"Skipped as {func_name} is not available")
+ return
+
+ device = tvm.cuda(0)
+ dtype = "bfloat16"
+ x_np, x_scale_np = rowwise_quant_fp8_e4m3((B, M, K), block_size, dtype)
+ w_np, w_scale_np = blockwise_quant_fp8_e4m3((B, N, K), block_size, dtype)
+ o_np = blockwise_bmm(x_np, x_scale_np, w_np, w_scale_np, block_size, dtype)
+ x_tvm = tvm.nd.array(x_np, device=device)
+ x_scale_tvm = tvm.nd.array(x_scale_np.transpose(0, 2, 1), device=device)
+ w_tvm = tvm.nd.array(w_np, device=device)
+ w_scale_tvm = tvm.nd.array(w_scale_np, device=device)
+ workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=device)
+ o_tvm = tvm.nd.empty((B, M, N), dtype=dtype, device=device)
+ gemm_func(
+ x_tvm, w_tvm, x_scale_tvm, w_scale_tvm, workspace, block_size[0],
block_size[1], o_tvm
+ )
+ o_tvm = o_tvm.numpy()
+ tvm.testing.assert_allclose(o_tvm, o_np, rtol=1e-4, atol=0.5)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()