This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4e70e4a4ba [CUTLASS] Add FP8 gemm kernels (#17408)
4e70e4a4ba is described below
commit 4e70e4a4bacc9a225dac1a90b39b5faac7d095bd
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Sep 25 00:34:09 2024 -0400
[CUTLASS] Add FP8 gemm kernels (#17408)
This PR introduces the sm90a FP8 kernels from CUTLASS. These kernels
are helpful in the cases of small `M`, where cuBLAS has unoptimized
performance.
---
cmake/modules/contrib/CUTLASS.cmake | 1 +
src/runtime/contrib/cublas/cublas.cc | 6 +-
src/runtime/contrib/cutlass/fp8_gemm.cu | 95 +++++++++++++++++
src/runtime/contrib/cutlass/gemm_runner.cuh | 155 ++++++++++++++++++++++++++++
tests/python/contrib/test_cutlass.py | 107 ++++++++++++++++---
5 files changed, 349 insertions(+), 15 deletions(-)
diff --git a/cmake/modules/contrib/CUTLASS.cmake
b/cmake/modules/contrib/CUTLASS.cmake
index fa4a608f61..11224a8d1f 100644
--- a/cmake/modules/contrib/CUTLASS.cmake
+++ b/cmake/modules/contrib/CUTLASS.cmake
@@ -58,6 +58,7 @@ if(USE_CUDA AND USE_CUTLASS)
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
list(APPEND TVM_CUTLASS_RUNTIME_SRCS
src/runtime/contrib/cutlass/fp16_group_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS
src/runtime/contrib/cutlass/fp8_group_gemm.cu)
+ list(APPEND TVM_CUTLASS_RUNTIME_SRCS
src/runtime/contrib/cutlass/fp8_gemm.cu)
endif()
if(TVM_CUTLASS_RUNTIME_SRCS)
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
diff --git a/src/runtime/contrib/cublas/cublas.cc
b/src/runtime/contrib/cublas/cublas.cc
index 8925080abf..c9a01fc24e 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -194,11 +194,13 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t
stream,
&bias->data,
sizeof(float*)));
}
- if (scaleA != nullptr && scaleB != nullptr) {
+ if (scaleA != nullptr) {
auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
- auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&scaleA_data,
sizeof(float*)));
+ }
+ if (scaleB != nullptr) {
+ auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&scaleB_data,
sizeof(float*)));
}
diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu
b/src/runtime/contrib/cutlass/fp8_gemm.cu
new file mode 100644
index 0000000000..67e502a163
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp8_gemm.cu
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include "../cublas/cublas_utils.h"
+#include "gemm_runner.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
+
+struct KernelTraitsM64 {
+ using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
+ using TileShape = Shape<_64, _64, _128>;
+ using ClusterShape = Shape<_1, _8, _1>;
+};
+
+namespace tvm {
+namespace runtime {
+
+template <typename ElementA, typename ElementB, typename ElementC>
+void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace,
NDArray alpha,
+ NDArray out) {
+ // Workspace is used for storing device-side gemm arguments and cutlass
internal workspace.
+ // Recommened size is 4MB.
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ CHECK_GE(x->ndim, 2);
+ CHECK_EQ(weight->ndim, 2);
+ CHECK_EQ(workspace->ndim, 1);
+ CHECK_GE(out->ndim, 2);
+ CHECK_EQ(alpha->dtype.code, kDLFloat);
+ CHECK_EQ(alpha->dtype.bits, 32);
+ CHECK_EQ(alpha->ndim, 1);
+ CHECK_EQ(alpha->shape[0], 1);
+ int64_t m = 1;
+ for (int i = 0; i < x->ndim - 1; ++i) {
+ m *= x->shape[i];
+ }
+ int64_t n = weight->shape[0];
+ CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1]) << "Only col-major weight
is supported now.";
+ int64_t k = x->shape[x->ndim - 1];
+ const float* beta = nullptr;
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+ if (m <= 64) {
+ cutlass_gemm<KernelTraitsM64>(
+ static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
+ static_cast<uint8_t*>(workspace->data), workspace->shape[0], m, n, k,
+ static_cast<float*>(alpha->data), beta,
static_cast<ElementC*>(out->data), stream);
+ } else {
+ tvm::contrib::CuBlasLtThreadEntry* cublas_entry =
+ tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();
+ tvm::contrib::CallCublasLt(cublas_entry->handle, stream,
cublas_entry->matmul_pref_desc,
+ x.operator->(), weight.operator->(), nullptr,
alpha.operator->(),
+ nullptr, out.operator->(), /*transa=*/false,
/*transb=*/true,
+ cublas_entry->workspace_ptr,
cublas_entry->workspace_size,
+ CUBLASLT_EPILOGUE_DEFAULT, std::nullopt);
+ }
+}
+
+TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16")
+ .set_body_typed(
+ tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e5m2_t,
cutlass::half_t>);
+
+TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16")
+ .set_body_typed(
+ tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e4m3_t,
cutlass::half_t>);
+
+TVM_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16")
+ .set_body_typed(
+ tvm_cutlass_fp8_gemm<cutlass::float_e4m3_t, cutlass::float_e4m3_t,
cutlass::half_t>);
+
+} // namespace runtime
+} // namespace tvm
+
+#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/gemm_runner.cuh
b/src/runtime/contrib/cutlass/gemm_runner.cuh
new file mode 100644
index 0000000000..c664f6cf6f
--- /dev/null
+++ b/src/runtime/contrib/cutlass/gemm_runner.cuh
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <variant>
+#include <vector>
+
+#include "../../cuda/cuda_common.h"
+
+// clang-format off
+#include "cutlass/cutlass.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/tensor_ref.h"
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/thread/linear_combination.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+// clang-format on
+
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ CHECK(error == cutlass::Status::kSuccess) \
+ << "Got cutlass error: " << cutlassGetStatusString(error); \
+ }
+
+using namespace cute;
+using ProblemShape = Shape<int, int, int>; // <M, N, K>
+
+template <typename KernelTraits, typename ElementA, typename ElementB,
typename ElementC,
+ typename LayoutA = cutlass::layout::RowMajor,
+ typename LayoutB = cutlass::layout::ColumnMajor,
+ typename LayoutC = cutlass::layout::RowMajor>
+struct CutlassGemmRunner {
+ static constexpr int AlignmentA =
+ 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix
in units of elements
+ // (up to 16 bytes)
+
+ static constexpr int AlignmentB =
+ 128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix
in units of elements
+ // (up to 16 bytes)
+
+ static constexpr int AlignmentC =
+ 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix
in units of elements
+ // (up to 16 bytes)
+
+ // Core kernel configurations
+ using ElementAccumulator = float; // Element type for internal accumulation
+ using ScaleType = std::variant<ElementAccumulator, const
ElementAccumulator*>;
+ using ArchTag =
+ cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the
intended feature
+ using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
+ using TileShape = typename KernelTraits::TileShape;
+ using ClusterShape = typename KernelTraits::ClusterShape;
+ using StageCountType =
+ cutlass::gemm::collective::StageCountAuto; // Stage count maximized
based on the tile size
+ using KernelSchedule = typename KernelTraits::KernelSchedule; // Kernel
to launch
+ using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue
to launch
+
+ using CollectiveEpilogue = typename
cutlass::epilogue::collective::CollectiveBuilder<
+ ArchTag, OperatorClass, TileShape, ClusterShape,
+ cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator,
+ ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC,
EpilogueSchedule>::CollectiveOp;
+ using CollectiveMainloop = typename
cutlass::gemm::collective::CollectiveBuilder<
+ ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
LayoutB, AlignmentB,
+ ElementAccumulator, TileShape, ClusterShape,
+ cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
+ sizeof(typename CollectiveEpilogue::SharedStorage))>,
+ KernelSchedule>::CollectiveOp;
+
+ using GemmKernel =
+ cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
+
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
+
+ using StrideA = typename Gemm::GemmKernel::StrideA;
+ using StrideB = typename Gemm::GemmKernel::StrideB;
+ using StrideC = typename Gemm::GemmKernel::StrideC;
+ using StrideD = typename Gemm::GemmKernel::StrideD;
+
+ void run_gemm(const ElementA* ptr_A, const ElementB* ptr_B, const ElementC*
ptr_C,
+ ElementC* ptr_D, ProblemShape* problem_size, StrideA*
stride_A, StrideB* stride_B,
+ StrideC* stride_C, StrideD* stride_D, uint8_t* workspace,
int64_t workspace_size,
+ ScaleType alpha, ScaleType beta, cudaStream_t stream) {
+ cutlass::KernelHardwareInfo hw_info;
+ hw_info.device_id = 0;
+ hw_info.sm_count =
+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
+ typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
+ *problem_size,
+ {ptr_A, *stride_A, ptr_B, *stride_B},
+ {{}, ptr_C, *stride_C, ptr_D,
*stride_D},
+ // {epilogue_params, ptr_C, *stride_C,
ptr_D, *stride_D},
+ hw_info};
+
+ ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the
same type";
+ if (std::holds_alternative<ElementAccumulator>(alpha)) {
+ arguments.epilogue.thread.alpha = std::get<ElementAccumulator>(alpha);
+ arguments.epilogue.thread.beta = std::get<ElementAccumulator>(beta);
+ } else if (std::holds_alternative<const ElementAccumulator*>(alpha)) {
+ arguments.epilogue.thread.alpha_ptr = std::get<const
ElementAccumulator*>(alpha);
+ arguments.epilogue.thread.beta_ptr = std::get<const
ElementAccumulator*>(beta);
+ } else {
+ LOG(FATAL) << "Unsupported alpha and beta type";
+ throw;
+ }
+
+ Gemm gemm_op;
+ CUTLASS_CHECK(gemm_op.can_implement(arguments));
+ CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+ CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
+ CUTLASS_CHECK(gemm_op.run(stream));
+ }
+};
+
+template <typename KernelTraits, typename ElementA, typename ElementB,
typename ElementC>
+void cutlass_gemm(ElementA* x, ElementB* weight, uint8_t* workspace, int64_t
workspace_size,
+ int64_t m, int64_t n, int64_t k, std::variant<float, const
float*> alpha,
+ std::variant<float, const float*> beta, ElementC* out,
cudaStream_t stream) {
+ using Runner = CutlassGemmRunner<KernelTraits, ElementA, ElementB, ElementC>;
+ using StrideA = typename Runner::StrideA;
+ using StrideB = typename Runner::StrideB;
+ using StrideC = typename Runner::StrideC;
+
+ Runner runner;
+ StrideA stride_A = cute::make_stride(k, Int<1>{}, int64_t{0});
+ StrideB stride_B = cute::make_stride(k, Int<1>{}, int64_t{0});
+ StrideC stride_D = cute::make_stride(n, Int<1>{}, int64_t{0});
+ ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n),
static_cast<int>(k)};
+ runner.run_gemm(x, weight, out, out, &problem_size, &stride_A, &stride_B,
&stride_D, &stride_D,
+ workspace, workspace_size, alpha, beta, stream);
+}
diff --git a/tests/python/contrib/test_cutlass.py
b/tests/python/contrib/test_cutlass.py
index 154a68e116..bc80323b75 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -15,26 +15,27 @@
# specific language governing permissions and limitations
# under the License.
import logging
-import tempfile
import math
+import tempfile
+
import ml_dtypes
+import numpy as np
+
import tvm
-from tvm import relay
+import tvm.testing
+from tvm import auto_scheduler, relay
from tvm.contrib.cudnn import conv_output_shape
-import numpy as np
-from tvm.relay import op as _op
-from tvm.runtime.vm import VirtualMachine
-from tvm.relay.op.contrib.cutlass import partition_for_cutlass
-from tvm import auto_scheduler
-from tvm.relay.transform import FirstOrderGradient, ToMixedPrecision, InferType
from tvm.contrib.cutlass import (
- has_cutlass,
- num_cutlass_partitions,
finalize_modules,
finalize_modules_vm,
+ has_cutlass,
+ num_cutlass_partitions,
)
from tvm.contrib.pickle_memoize import memoize
-import tvm.testing
+from tvm.relay import op as _op
+from tvm.relay.op.contrib.cutlass import partition_for_cutlass
+from tvm.relay.transform import FirstOrderGradient, InferType, ToMixedPrecision
+from tvm.runtime.vm import VirtualMachine
logging.basicConfig(level=logging.INFO)
@@ -1189,13 +1190,13 @@ def test_group_gemm_sm90():
atol=1,
)
verify_group_gemm(
- "cutlass.group_gemm_e4m3_e5m2_fp16",
+ "cutlass.group_gemm_e5m2_e4m3_fp16",
8,
16,
16,
4,
- "e4m3_float8",
"e5m2_float8",
+ "e4m3_float8",
"float16",
True,
rtol=1e-1,
@@ -1203,5 +1204,85 @@ def test_group_gemm_sm90():
)
+def verify_gemm(func_name, M, N, K, x_dtype, weight_dtype, out_dtype,
scale_value, rtol, atol):
+ gemm_func = tvm.get_global_func(func_name, allow_missing=True)
+ if gemm_func is None:
+ print(f"Skipped as {func_name} is not available")
+ return
+
+ @memoize("tvm.contrib.cutlass.test_fp8_gemm_sm90")
+ def get_ref_data():
+ a_np = get_random_ndarray((M, K), "float16")
+ b_np = get_random_ndarray((N, K), "float16")
+ c_np = a_np @ b_np.T * scale_value
+ return a_np, b_np, c_np
+
+ def to_numpy_dtype(dtype):
+ mapping = {"e5m2_float8": ml_dtypes.float8_e5m2, "e4m3_float8":
ml_dtypes.float8_e4m3fn}
+ return mapping.get(dtype, dtype)
+
+ a_np, b_np, c_np = get_ref_data()
+ dev = tvm.cuda(0)
+ a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev)
+ b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev)
+ c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev)
+ workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev)
+ scale = tvm.nd.array(np.array([scale_value], dtype="float32"), device=dev)
+ gemm_func(a_nd, b_nd, workspace, scale, c_nd)
+ tvm.testing.assert_allclose(c_nd.asnumpy(), c_np, rtol=rtol, atol=atol)
+
+
[email protected]_cutlass
+def test_fp8_gemm_sm90():
+ verify_gemm(
+ "cutlass.gemm_e5m2_e5m2_fp16",
+ 8,
+ 16,
+ 16,
+ "e5m2_float8",
+ "e5m2_float8",
+ "float16",
+ 1.5,
+ rtol=1e-1,
+ atol=1,
+ )
+ verify_gemm(
+ "cutlass.gemm_e4m3_e4m3_fp16",
+ 8,
+ 16,
+ 16,
+ "e4m3_float8",
+ "e4m3_float8",
+ "float16",
+ 1.5,
+ rtol=1e-1,
+ atol=1,
+ )
+ verify_gemm(
+ "cutlass.gemm_e4m3_e4m3_fp16",
+ 32,
+ 16,
+ 16,
+ "e4m3_float8",
+ "e4m3_float8",
+ "float16",
+ 1.5,
+ rtol=1e-1,
+ atol=1,
+ )
+ verify_gemm(
+ "cutlass.gemm_e5m2_e4m3_fp16",
+ 8,
+ 16,
+ 16,
+ "e5m2_float8",
+ "e4m3_float8",
+ "float16",
+ 1.5,
+ rtol=1e-1,
+ atol=1,
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()