This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e70e4a4ba [CUTLASS] Add FP8 gemm kernels (#17408)
4e70e4a4ba is described below

commit 4e70e4a4bacc9a225dac1a90b39b5faac7d095bd
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Sep 25 00:34:09 2024 -0400

    [CUTLASS] Add FP8 gemm kernels (#17408)
    
    This PR introduces the sm90a FP8 kernels from CUTLASS. These kernels
    are helpful in the cases of small `M`, where cuBLAS has unoptimized
    performance.
---
 cmake/modules/contrib/CUTLASS.cmake         |   1 +
 src/runtime/contrib/cublas/cublas.cc        |   6 +-
 src/runtime/contrib/cutlass/fp8_gemm.cu     |  95 +++++++++++++++++
 src/runtime/contrib/cutlass/gemm_runner.cuh | 155 ++++++++++++++++++++++++++++
 tests/python/contrib/test_cutlass.py        | 107 ++++++++++++++++---
 5 files changed, 349 insertions(+), 15 deletions(-)

diff --git a/cmake/modules/contrib/CUTLASS.cmake 
b/cmake/modules/contrib/CUTLASS.cmake
index fa4a608f61..11224a8d1f 100644
--- a/cmake/modules/contrib/CUTLASS.cmake
+++ b/cmake/modules/contrib/CUTLASS.cmake
@@ -58,6 +58,7 @@ if(USE_CUDA AND USE_CUTLASS)
   if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
     list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp16_group_gemm.cu)
     list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_group_gemm.cu)
+    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_gemm.cu)
   endif()
   if(TVM_CUTLASS_RUNTIME_SRCS)
     add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
diff --git a/src/runtime/contrib/cublas/cublas.cc 
b/src/runtime/contrib/cublas/cublas.cc
index 8925080abf..c9a01fc24e 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -194,11 +194,13 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t 
stream,
                                                       &bias->data, 
sizeof(float*)));
   }
 
-  if (scaleA != nullptr && scaleB != nullptr) {
+  if (scaleA != nullptr) {
     auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
-    auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &scaleA_data, 
sizeof(float*)));
+  }
+  if (scaleB != nullptr) {
+    auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &scaleB_data, 
sizeof(float*)));
   }
diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu 
b/src/runtime/contrib/cutlass/fp8_gemm.cu
new file mode 100644
index 0000000000..67e502a163
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp8_gemm.cu
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include "../cublas/cublas_utils.h"
+#include "gemm_runner.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
+
+struct KernelTraitsM64 {
+  using KernelSchedule = 
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
+  using TileShape = Shape<_64, _64, _128>;
+  using ClusterShape = Shape<_1, _8, _1>;
+};
+
+namespace tvm {
+namespace runtime {
+
+template <typename ElementA, typename ElementB, typename ElementC>
+void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, 
NDArray alpha,
+                          NDArray out) {
+  // Workspace is used for storing device-side gemm arguments and cutlass 
internal workspace.
+  // Recommened size is 4MB.
+  auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+  ICHECK(func != nullptr);
+  CHECK_GE(x->ndim, 2);
+  CHECK_EQ(weight->ndim, 2);
+  CHECK_EQ(workspace->ndim, 1);
+  CHECK_GE(out->ndim, 2);
+  CHECK_EQ(alpha->dtype.code, kDLFloat);
+  CHECK_EQ(alpha->dtype.bits, 32);
+  CHECK_EQ(alpha->ndim, 1);
+  CHECK_EQ(alpha->shape[0], 1);
+  int64_t m = 1;
+  for (int i = 0; i < x->ndim - 1; ++i) {
+    m *= x->shape[i];
+  }
+  int64_t n = weight->shape[0];
+  CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1]) << "Only col-major weight 
is supported now.";
+  int64_t k = x->shape[x->ndim - 1];
+  const float* beta = nullptr;
+  cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+  if (m <= 64) {
+    cutlass_gemm<KernelTraitsM64>(
+        static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
+        static_cast<uint8_t*>(workspace->data), workspace->shape[0], m, n, k,
+        static_cast<float*>(alpha->data), beta, 
static_cast<ElementC*>(out->data), stream);
+  } else {
+    tvm::contrib::CuBlasLtThreadEntry* cublas_entry =
+        tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();
+    tvm::contrib::CallCublasLt(cublas_entry->handle, stream, 
cublas_entry->matmul_pref_desc,
+                               x.operator->(), weight.operator->(), nullptr, 
alpha.operator->(),
+                               nullptr, out.operator->(), /*transa=*/false, 
/*transb=*/true,
+                               cublas_entry->workspace_ptr, 
cublas_entry->workspace_size,
+                               CUBLASLT_EPILOGUE_DEFAULT, std::nullopt);
+  }
+}
+
+TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16")
+    .set_body_typed(
+        tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e5m2_t, 
cutlass::half_t>);
+
+TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16")
+    .set_body_typed(
+        tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e4m3_t, 
cutlass::half_t>);
+
+TVM_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16")
+    .set_body_typed(
+        tvm_cutlass_fp8_gemm<cutlass::float_e4m3_t, cutlass::float_e4m3_t, 
cutlass::half_t>);
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/gemm_runner.cuh 
b/src/runtime/contrib/cutlass/gemm_runner.cuh
new file mode 100644
index 0000000000..c664f6cf6f
--- /dev/null
+++ b/src/runtime/contrib/cutlass/gemm_runner.cuh
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <variant>
+#include <vector>
+
+#include "../../cuda/cuda_common.h"
+
+// clang-format off
+#include "cutlass/cutlass.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/tensor_ref.h"
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/thread/linear_combination.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+// clang-format on
+
+#define CUTLASS_CHECK(status)                                      \
+  {                                                                \
+    cutlass::Status error = status;                                \
+    CHECK(error == cutlass::Status::kSuccess)                      \
+        << "Got cutlass error: " << cutlassGetStatusString(error); \
+  }
+
+using namespace cute;
+using ProblemShape = Shape<int, int, int>;  // <M, N, K>
+
+template <typename KernelTraits, typename ElementA, typename ElementB, 
typename ElementC,
+          typename LayoutA = cutlass::layout::RowMajor,
+          typename LayoutB = cutlass::layout::ColumnMajor,
+          typename LayoutC = cutlass::layout::RowMajor>
+struct CutlassGemmRunner {
+  static constexpr int AlignmentA =
+      128 / cutlass::sizeof_bits<ElementA>::value;  // Alignment of A matrix 
in units of elements
+                                                    // (up to 16 bytes)
+
+  static constexpr int AlignmentB =
+      128 / cutlass::sizeof_bits<ElementB>::value;  // Alignment of B matrix 
in units of elements
+                                                    // (up to 16 bytes)
+
+  static constexpr int AlignmentC =
+      128 / cutlass::sizeof_bits<ElementC>::value;  // Alignment of C matrix 
in units of elements
+                                                    // (up to 16 bytes)
+
+  // Core kernel configurations
+  using ElementAccumulator = float;  // Element type for internal accumulation
+  using ScaleType = std::variant<ElementAccumulator, const 
ElementAccumulator*>;
+  using ArchTag =
+      cutlass::arch::Sm90;  // Tag indicating the minimum SM that supports the 
intended feature
+  using OperatorClass = cutlass::arch::OpClassTensorOp;  // Operator class tag
+  using TileShape = typename KernelTraits::TileShape;
+  using ClusterShape = typename KernelTraits::ClusterShape;
+  using StageCountType =
+      cutlass::gemm::collective::StageCountAuto;  // Stage count maximized 
based on the tile size
+  using KernelSchedule = typename KernelTraits::KernelSchedule;    // Kernel 
to launch
+  using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;  // Epilogue 
to launch
+
+  using CollectiveEpilogue = typename 
cutlass::epilogue::collective::CollectiveBuilder<
+      ArchTag, OperatorClass, TileShape, ClusterShape,
+      cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, 
ElementAccumulator,
+      ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC, 
EpilogueSchedule>::CollectiveOp;
+  using CollectiveMainloop = typename 
cutlass::gemm::collective::CollectiveBuilder<
+      ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, 
LayoutB, AlignmentB,
+      ElementAccumulator, TileShape, ClusterShape,
+      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
+          sizeof(typename CollectiveEpilogue::SharedStorage))>,
+      KernelSchedule>::CollectiveOp;
+
+  using GemmKernel =
+      cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, 
CollectiveEpilogue>;
+
+  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
+
+  using StrideA = typename Gemm::GemmKernel::StrideA;
+  using StrideB = typename Gemm::GemmKernel::StrideB;
+  using StrideC = typename Gemm::GemmKernel::StrideC;
+  using StrideD = typename Gemm::GemmKernel::StrideD;
+
+  void run_gemm(const ElementA* ptr_A, const ElementB* ptr_B, const ElementC* 
ptr_C,
+                ElementC* ptr_D, ProblemShape* problem_size, StrideA* 
stride_A, StrideB* stride_B,
+                StrideC* stride_C, StrideD* stride_D, uint8_t* workspace, 
int64_t workspace_size,
+                ScaleType alpha, ScaleType beta, cudaStream_t stream) {
+    cutlass::KernelHardwareInfo hw_info;
+    hw_info.device_id = 0;
+    hw_info.sm_count =
+        
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
+    typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
+                                       *problem_size,
+                                       {ptr_A, *stride_A, ptr_B, *stride_B},
+                                       {{}, ptr_C, *stride_C, ptr_D, 
*stride_D},
+                                       //  {epilogue_params, ptr_C, *stride_C, 
ptr_D, *stride_D},
+                                       hw_info};
+
+    ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the 
same type";
+    if (std::holds_alternative<ElementAccumulator>(alpha)) {
+      arguments.epilogue.thread.alpha = std::get<ElementAccumulator>(alpha);
+      arguments.epilogue.thread.beta = std::get<ElementAccumulator>(beta);
+    } else if (std::holds_alternative<const ElementAccumulator*>(alpha)) {
+      arguments.epilogue.thread.alpha_ptr = std::get<const 
ElementAccumulator*>(alpha);
+      arguments.epilogue.thread.beta_ptr = std::get<const 
ElementAccumulator*>(beta);
+    } else {
+      LOG(FATAL) << "Unsupported alpha and beta type";
+      throw;
+    }
+
+    Gemm gemm_op;
+    CUTLASS_CHECK(gemm_op.can_implement(arguments));
+    CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+    CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
+    CUTLASS_CHECK(gemm_op.run(stream));
+  }
+};
+
+template <typename KernelTraits, typename ElementA, typename ElementB, 
typename ElementC>
+void cutlass_gemm(ElementA* x, ElementB* weight, uint8_t* workspace, int64_t 
workspace_size,
+                  int64_t m, int64_t n, int64_t k, std::variant<float, const 
float*> alpha,
+                  std::variant<float, const float*> beta, ElementC* out, 
cudaStream_t stream) {
+  using Runner = CutlassGemmRunner<KernelTraits, ElementA, ElementB, ElementC>;
+  using StrideA = typename Runner::StrideA;
+  using StrideB = typename Runner::StrideB;
+  using StrideC = typename Runner::StrideC;
+
+  Runner runner;
+  StrideA stride_A = cute::make_stride(k, Int<1>{}, int64_t{0});
+  StrideB stride_B = cute::make_stride(k, Int<1>{}, int64_t{0});
+  StrideC stride_D = cute::make_stride(n, Int<1>{}, int64_t{0});
+  ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), 
static_cast<int>(k)};
+  runner.run_gemm(x, weight, out, out, &problem_size, &stride_A, &stride_B, 
&stride_D, &stride_D,
+                  workspace, workspace_size, alpha, beta, stream);
+}
diff --git a/tests/python/contrib/test_cutlass.py 
b/tests/python/contrib/test_cutlass.py
index 154a68e116..bc80323b75 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -15,26 +15,27 @@
 # specific language governing permissions and limitations
 # under the License.
 import logging
-import tempfile
 import math
+import tempfile
+
 import ml_dtypes
+import numpy as np
+
 import tvm
-from tvm import relay
+import tvm.testing
+from tvm import auto_scheduler, relay
 from tvm.contrib.cudnn import conv_output_shape
-import numpy as np
-from tvm.relay import op as _op
-from tvm.runtime.vm import VirtualMachine
-from tvm.relay.op.contrib.cutlass import partition_for_cutlass
-from tvm import auto_scheduler
-from tvm.relay.transform import FirstOrderGradient, ToMixedPrecision, InferType
 from tvm.contrib.cutlass import (
-    has_cutlass,
-    num_cutlass_partitions,
     finalize_modules,
     finalize_modules_vm,
+    has_cutlass,
+    num_cutlass_partitions,
 )
 from tvm.contrib.pickle_memoize import memoize
-import tvm.testing
+from tvm.relay import op as _op
+from tvm.relay.op.contrib.cutlass import partition_for_cutlass
+from tvm.relay.transform import FirstOrderGradient, InferType, ToMixedPrecision
+from tvm.runtime.vm import VirtualMachine
 
 logging.basicConfig(level=logging.INFO)
 
@@ -1189,13 +1190,13 @@ def test_group_gemm_sm90():
         atol=1,
     )
     verify_group_gemm(
-        "cutlass.group_gemm_e4m3_e5m2_fp16",
+        "cutlass.group_gemm_e5m2_e4m3_fp16",
         8,
         16,
         16,
         4,
-        "e4m3_float8",
         "e5m2_float8",
+        "e4m3_float8",
         "float16",
         True,
         rtol=1e-1,
@@ -1203,5 +1204,85 @@ def test_group_gemm_sm90():
     )
 
 
+def verify_gemm(func_name, M, N, K, x_dtype, weight_dtype, out_dtype, 
scale_value, rtol, atol):
+    gemm_func = tvm.get_global_func(func_name, allow_missing=True)
+    if gemm_func is None:
+        print(f"Skipped as {func_name} is not available")
+        return
+
+    @memoize("tvm.contrib.cutlass.test_fp8_gemm_sm90")
+    def get_ref_data():
+        a_np = get_random_ndarray((M, K), "float16")
+        b_np = get_random_ndarray((N, K), "float16")
+        c_np = a_np @ b_np.T * scale_value
+        return a_np, b_np, c_np
+
+    def to_numpy_dtype(dtype):
+        mapping = {"e5m2_float8": ml_dtypes.float8_e5m2, "e4m3_float8": 
ml_dtypes.float8_e4m3fn}
+        return mapping.get(dtype, dtype)
+
+    a_np, b_np, c_np = get_ref_data()
+    dev = tvm.cuda(0)
+    a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev)
+    b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev)
+    c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev)
+    workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev)
+    scale = tvm.nd.array(np.array([scale_value], dtype="float32"), device=dev)
+    gemm_func(a_nd, b_nd, workspace, scale, c_nd)
+    tvm.testing.assert_allclose(c_nd.asnumpy(), c_np, rtol=rtol, atol=atol)
+
+
[email protected]_cutlass
+def test_fp8_gemm_sm90():
+    verify_gemm(
+        "cutlass.gemm_e5m2_e5m2_fp16",
+        8,
+        16,
+        16,
+        "e5m2_float8",
+        "e5m2_float8",
+        "float16",
+        1.5,
+        rtol=1e-1,
+        atol=1,
+    )
+    verify_gemm(
+        "cutlass.gemm_e4m3_e4m3_fp16",
+        8,
+        16,
+        16,
+        "e4m3_float8",
+        "e4m3_float8",
+        "float16",
+        1.5,
+        rtol=1e-1,
+        atol=1,
+    )
+    verify_gemm(
+        "cutlass.gemm_e4m3_e4m3_fp16",
+        32,
+        16,
+        16,
+        "e4m3_float8",
+        "e4m3_float8",
+        "float16",
+        1.5,
+        rtol=1e-1,
+        atol=1,
+    )
+    verify_gemm(
+        "cutlass.gemm_e5m2_e4m3_fp16",
+        8,
+        16,
+        16,
+        "e5m2_float8",
+        "e4m3_float8",
+        "float16",
+        1.5,
+        rtol=1e-1,
+        atol=1,
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to