sxjscience commented on a change in pull request #16408: Add MXNet Ops for fast 
multihead attention
URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r333268972
 
 

 ##########
 File path: src/operator/contrib/transformer.cu
 ##########
 @@ -22,12 +22,898 @@
  * \file transformer.cu
  * \brief GPU implementation of the operators used in Transformer
  */
+
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <cuda_fp16.h>
+#include <cuda_profiler_api.h>
+
 #include <mxnet/base.h>
 #include "./transformer-inl.h"
+#include "../../common/cuda_utils.h"
+
+#include "cutlass/cutlass.h"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/wmma_matrix.h"
+#ifdef CUTLASS_USE_WMMA_API
+#include "cutlass/gemm/wmma_gemm_traits.h"
 
 namespace mxnet {
 namespace op {
 
+// gemm_switch_fp32accum and the functions called are almost fully copied from:
+// MLPerf v0.6 submission repository from NVIDIA by 
https://github.com/kevinstephano
+template<typename DType>
+void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool transA, bool 
transB,
+                              int32_t m, int32_t n, int32_t k,
+                              float alpha, const DType* a, int32_t lda, 
int32_t strideA,
+                              const DType *b, int32_t ldb, int32_t strideB, 
float beta,
+                              DType *c, int32_t ldc, int32_t strideC, int32_t 
batchCount,
+                              cublasGemmAlgo_t algo = 
CUBLAS_GEMM_DEFAULT_TENSOR_OP) {
+  using namespace mxnet::common::cuda;
+  CHECK_EQ(s->blas_handle_ownership_, mshadow::Stream<gpu>::OwnHandle)
+      << "Must init CuBLAS handle in stream";
+
+  cublasHandle_t blas_handle = mshadow::Stream<gpu>::GetBlasHandle(s);
+  auto err = CUBLAS_STATUS_SUCCESS;
+  // TODO(cfujitsang): handle computation_precision
+  err = cublasGemmStridedBatchedEx(
+      blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
+      static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
+      reinterpret_cast<void*>(&alpha),
+      a, CublasType<DType>::kCudaFlag, static_cast<int>(lda), strideA,
+      b, CublasType<DType>::kCudaFlag, static_cast<int>(ldb), strideB,
+      reinterpret_cast<void*>(&beta),
+      c, CublasType<DType>::kCudaFlag, static_cast<int>(ldc), strideC,
+      static_cast<int>(batchCount), CUDA_R_32F, algo);
+  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas gemmEx fail.";
+}
+
+template<::cutlass::MatrixLayout::Kind A_LAYOUT,
+         ::cutlass::MatrixLayout::Kind B_LAYOUT,
+         int SRC_A, int SRC_B, int DST_C, typename DType>
+void CutlassGemm_FP32Accum(cudaStream_t, int32_t m, int32_t n, int32_t k,
+                           float alpha, const DType *a, int32_t lda,
+                           int32_t strideA, const DType *b, int32_t ldb,
+                           int32_t strideB, float beta, DType *c, int32_t ldc,
+                           int32_t strideC, int32_t batchCount) {
+  LOG(FATAL) << "Not implemented with this DType and shape (Cutlass)";
+}
+
+
+template<::cutlass::MatrixLayout::Kind A_LAYOUT,
+         ::cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int 
DST_C>
+void CutlassGemm_FP32Accum(cudaStream_t stream, int32_t m, int32_t n, int32_t 
k,
+                           float alpha, const mshadow::half::half_t *a, 
int32_t lda,
+                           int32_t strideA, const mshadow::half::half_t *b, 
int32_t ldb,
+                           int32_t strideB, float beta, mshadow::half::half_t 
*c, int32_t ldc,
+                           int32_t strideC, int32_t batchCount) {
+  typedef cutlass::gemm::WmmaGemmTraits<
+    A_LAYOUT,
+    B_LAYOUT,
+    cutlass::Shape<32, 16, 16>,
+    half,
+    half,
+    half,
+    cutlass::gemm::LinearScaling<float>,
+    float,
+    typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<
+      typename cutlass::Shape<32, 16, 16> >::Shape,
+      typename cutlass::Shape<16, 16, 16>,
+      SRC_A,   // kScalarsPerLdgA_
+      SRC_B,   // kScalarsPerLdgB_
+      SRC_A,   // KScalarsPerLdsA_
+      SRC_B,   // KScalarsPerLdsB_
+      DST_C,   // kScalarsPerLdgCAndStgD_
+      DST_C/2,  // kScalarsPerStsD_
+      DST_C/2  // kScalarsPerLdsD_
+    >
+    WmmaGemmTraits;
+
+  typedef cutlass::gemm::Gemm<WmmaGemmTraits> Gemm;
+  typename Gemm::Params params;
+
+
+  int result = params.initialize(
+    m,  // M dimension for each batch
+    n,  // N dimension for each batch
+    k,  // K dimension for each batch
+    alpha,  // scalar alpha
+    reinterpret_cast<const __half*>(a),
+    lda,
+    strideA,  // distance in memory between the first element of neighboring 
batch
+    reinterpret_cast<const __half*>(b),
+    ldb,
+    strideB,  // distance in memory between the first element of neighboring 
batch
+    beta,  // scalar beta
+    reinterpret_cast<__half*>(c),  // source matrix C
+    ldc,
+    strideC,  // distance in memory between the first element of neighboring 
batch
+    reinterpret_cast<__half*>(c),  // destination matrix C (may be different 
memory than C)
+    ldc,
+    strideC,  // distance in memory between the first element of neighboring 
batch
+    batchCount);
+
+  CHECK_EQ(result, 0) << "Failed to initialize CUTLASS Gemm::Params object.";
+
+  // Launch the CUTLASS GEMM kernel.
+  Gemm::launch(params);
+}
+
+template<typename DType>
+void gemm_switch_fp32accum(mshadow::Stream<gpu>* s, bool transA, bool transB,
+                           int32_t m, int32_t n, int32_t k,
+                           float alpha, const DType *a, int32_t lda,
+                           int32_t strideA, const DType *b, int32_t ldb,
+                           int32_t strideB, float beta, DType *c, int32_t ldc,
+                           int32_t strideC, int32_t batchCount) {
+  using cutlass::MatrixLayout::kRowMajor;
+  using cutlass::MatrixLayout::kColumnMajor;
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+  if (transA && (!transB)) {
+    if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, 
strideA, b, ldb,
+        strideB, beta, c, ldc, strideC, batchCount, 
CUBLAS_GEMM_ALGO0_TENSOR_OP);
+    } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 8, 8, 4>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 8, 8, 2>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 8, 4, 8>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 8, 4, 4>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 8, 4, 2>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 8, 2, 8>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 8, 2, 4>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 8, 2, 2>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 4, 8, 8>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 4, 8, 4>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 4, 8, 2>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 4, 4, 8>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 4, 4, 4>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 4, 4, 2>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 4, 2, 8>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 4, 2, 4>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 4, 2, 2>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 2, 8, 8>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 2, 8, 4>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 2, 8, 2>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 2, 4, 8>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 2, 4, 4>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 2, 4, 2>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 2, 2, 8>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 2, 2, 4>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<kRowMajor, kColumnMajor, 2, 2, 2>(stream, m, n, k, 
alpha, a, lda,
+        strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+    } else {
+      CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, 
strideA, b, ldb,
 
 Review comment:
   One issue of the optimizations here is that it will increase the library 
size. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to