eric-haibin-lin commented on a change in pull request #17138: Interleaved MHA 
for CPU path
URL: https://github.com/apache/incubator-mxnet/pull/17138#discussion_r360784355
 
 

 ##########
 File path: src/operator/contrib/transformer.cc
 ##########
 @@ -122,6 +122,531 @@ static bool InterleavedMatMulEncDecValAttShape(const 
NodeAttrs& attrs,
   return true;
 }
 
+void strided_batch_sgemm(bool transA, bool transB,
+                         index_t m, index_t n, index_t k,
+                         float alpha, const float *a, index_t lda,
+                         index_t strideA, const float *b, index_t ldb,
+                         index_t strideB, float beta, float *c, index_t ldc,
+                         index_t strideC, int32_t batchCount) {
+  std::vector<const float*> pp_A(batchCount, nullptr);
+  std::vector<const float*> pp_B(batchCount, nullptr);
+  std::vector<float*> pp_C(batchCount, nullptr);
+
+  for (int i = 0; i < batchCount; i++) {
+    pp_A[i] = a + i * strideA;
+    pp_B[i] = b + i * strideB;
+    pp_C[i] = c + i * strideC;
+  }
+
+#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
+  const int GROUP_SIZE = 1;
+  MKL_INT p_m[GROUP_SIZE] = {m};
+  MKL_INT p_n[GROUP_SIZE] = {n};
+  MKL_INT p_k[GROUP_SIZE] = {k};
+  MKL_INT p_lda[GROUP_SIZE] = {lda};
+  MKL_INT p_ldb[GROUP_SIZE] = {ldb};
+  MKL_INT p_ldc[GROUP_SIZE] = {ldc};
+
+  float p_alpha[GROUP_SIZE] = {alpha};
+  float p_beta[GROUP_SIZE] = {beta};
+
+  CBLAS_TRANSPOSE cblas_a_trans = transA ? CblasTrans : CblasNoTrans;
+  CBLAS_TRANSPOSE cblas_b_trans = transB ? CblasTrans : CblasNoTrans;
+
+  MKL_INT p_group_sizeb[GROUP_SIZE] = {batchCount};
+  CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
+  CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
+
+  cblas_sgemm_batch(CblasColMajor, p_transa, p_transb,
+                    p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
+                    p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, 
p_group_sizeb);
+#else
+  for (int i = 0; i < batchCount; ++i) {
+    cblas_sgemm(CblasColMajor,
+                transA ? CblasTrans : CblasNoTrans,
+                transB ? CblasTrans : CblasNoTrans,
+                m, n, k,
+                alpha, pp_A[i], lda,
+                pp_B[i], ldb, beta, pp_C[i], ldc);
+  }
+#endif
+}
+
+void InterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs,
+                                   const OpContext &ctx,
+                                   const std::vector<TBlob> &inputs,
+                                   const std::vector<OpReqType> &req,
+                                   const std::vector<TBlob> &outputs) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+
+  if (req[0] == kNullOp)
+    return;
+
+  CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32)
+    << "Only FP32 is supported on CPU at the moment";
+
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  const float* queries_keys_values = inputs[0].FlatTo2D<cpu, float>(s).dptr_;
+  float* output = outputs[0].FlatTo2D<cpu, float>(s).dptr_;
+
+  const index_t qkv_seq_len    = inputs[0].shape_[0];
+  const index_t sequences      = inputs[0].shape_[1];
+  const index_t output_lin_dim = inputs[0].shape_[2];
+  const index_t embed_dim      = output_lin_dim / 3;
+  const index_t head_dim       = embed_dim / params.heads;
+  const index_t attn_batches   = params.heads * sequences;
+  const index_t lead_dim       = attn_batches * 3 * head_dim;
+  const index_t batch_stride   = 3 * head_dim;
+  const float beta             = req[0] == kAddTo ? 1.f : 0.f;
+  const float scale            = 1.0 / sqrt(static_cast<float>(head_dim));
+
+  strided_batch_sgemm(true,
+                      false,
+                      qkv_seq_len,
+                      qkv_seq_len,
+                      head_dim,
+                      scale,
+                      queries_keys_values + head_dim,
+                      lead_dim,
+                      batch_stride,
+                      queries_keys_values,
+                      lead_dim,
+                      batch_stride,
+                      beta,
+                      output,
+                      qkv_seq_len,
+                      qkv_seq_len * qkv_seq_len,
+                      attn_batches);
+}
+
+void BackwardInterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs,
+                                           const OpContext &ctx,
+                                           const std::vector<TBlob> &inputs,
+                                           const std::vector<OpReqType> &req,
+                                           const std::vector<TBlob> &outputs) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+  if (req[0] == kNullOp)
+    return;
+
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32)
+    << "Only FP32 is supported on CPU at the moment";
+
+  const float* output_grads        = inputs[0].FlatTo2D<cpu, float>(s).dptr_;
+  const float* queries_keys_values = inputs[1].FlatTo2D<cpu, float>(s).dptr_;
+  float* queries_keys_values_grads = outputs[0].FlatTo2D<cpu, float>(s).dptr_;
+  const index_t qkv_seq_len    = inputs[1].shape_[0];
+  const index_t sequences      = inputs[1].shape_[1];
+  const index_t output_lin_dim = inputs[1].shape_[2];
+  const index_t embed_dim      = output_lin_dim / 3;
+  const index_t head_dim       = embed_dim / params.heads;
+  const index_t attn_batches   = params.heads * sequences;
+  const index_t lead_dim       = attn_batches * 3 * head_dim;
+  const index_t batch_stride   = 3 * head_dim;
+  const float scale            = 1.0 / sqrt(static_cast<float>(head_dim));
+  const float beta = req[0] == kAddTo ? 1.f : 0.f;
+
+  if (req[0] == kWriteTo) {
+    memset(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof 
(float));
+  }
+
+  strided_batch_sgemm(false,
+                     false,
+                     head_dim,
+                     qkv_seq_len,
+                     qkv_seq_len,
+                     scale,
+                     queries_keys_values + head_dim,
+                     lead_dim,
+                     batch_stride,
+                     output_grads,
+                     qkv_seq_len,
+                     qkv_seq_len * qkv_seq_len,
+                     beta,
+                     queries_keys_values_grads,
+                     lead_dim,
+                     batch_stride,
+                     attn_batches);
+
+  strided_batch_sgemm(false,
+                      true,
+                      head_dim,
+                      qkv_seq_len,
+                      qkv_seq_len,
+                      scale,
+                      queries_keys_values,
+                      lead_dim,
+                      batch_stride,
+                      output_grads,
+                      qkv_seq_len,
+                      qkv_seq_len * qkv_seq_len,
+                      beta,
+                      queries_keys_values_grads + head_dim,
+                      lead_dim,
+                      batch_stride,
+                      attn_batches);
+}
+
+void InterleavedMatMulSelfAttValAttCPU(const nnvm::NodeAttrs& attrs,
+                                       const OpContext &ctx,
+                                       const std::vector<TBlob> &inputs,
+                                       const std::vector<OpReqType> &req,
+                                       const std::vector<TBlob> &outputs) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+  if (req[0] == kNullOp)
+    return;
+
+  CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32)
+    << "Only FP32 is supported on CPU at the moment";
+
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  const float* queries_keys_values = inputs[0].FlatTo2D<cpu, float>(s).dptr_;
+  const float* attention_maps      = inputs[1].FlatTo2D<cpu, float>(s).dptr_;
+  float* output                    = outputs[0].FlatTo2D<cpu, float>(s).dptr_;
+  const index_t qkv_seq_len    = inputs[0].shape_[0];
+  const index_t sequences      = inputs[0].shape_[1];
+  const index_t output_lin_dim = inputs[0].shape_[2];
+  const index_t embed_dim      = output_lin_dim / 3;
+  const index_t head_dim       = embed_dim / params.heads;
+  const index_t attn_batches   = params.heads * sequences;
+  const index_t lead_dim       = attn_batches * 3 * head_dim;
+  const index_t batch_stride   = 3 * head_dim;
+  const float alpha             = 1.f;
+  const float beta              = req[0] == kAddTo ? 1.f : 0.f;
+
+  strided_batch_sgemm(false,
+                      false,
+                      head_dim,
+                      qkv_seq_len,
+                      qkv_seq_len,
+                      alpha,
+                      queries_keys_values + 2 * head_dim,
+                      lead_dim,
+                      batch_stride,
+                      attention_maps,
+                      qkv_seq_len,
+                      qkv_seq_len * qkv_seq_len,
+                      beta,
+                      output,
+                      head_dim * attn_batches,
+                      head_dim,
+                      attn_batches);
+}
+
+void BackwardInterleavedMatMulSelfAttValAttCPU(const nnvm::NodeAttrs& attrs,
+                                               const OpContext &ctx,
+                                               const std::vector<TBlob> 
&inputs,
+                                               const std::vector<OpReqType> 
&req,
+                                               const std::vector<TBlob> 
&outputs) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+  if (req[0] == kNullOp)
+    return;
+
+  CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32)
+    << "Only FP32 is supported on CPU at the moment";
+
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  const float* output_grads              = inputs[0].FlatTo2D<cpu, 
float>(s).dptr_;
+  const float* queries_keys_values       = inputs[1].FlatTo2D<cpu, 
float>(s).dptr_;
+  const float* attention_maps            = inputs[2].FlatTo2D<cpu, 
float>(s).dptr_;
+  float* queries_keys_values_grads       = outputs[0].FlatTo2D<cpu, 
float>(s).dptr_;
+  float* attention_maps_grads            = outputs[1].FlatTo2D<cpu, 
float>(s).dptr_;
+  const index_t qkv_seq_len    = inputs[1].shape_[0];
+  const index_t sequences      = inputs[1].shape_[1];
+  const index_t output_lin_dim = inputs[1].shape_[2];
+  const index_t embed_dim      = output_lin_dim / 3;
+  const index_t head_dim       = embed_dim / params.heads;
+  const index_t attn_batches   = params.heads * sequences;
+  const index_t lead_dim       = attn_batches * 3 * head_dim;
+  const index_t batch_stride   = 3 * head_dim;
+  const float alpha            = 1.f;
+  if (req[0] != kNullOp) {
+    if (req[0] == kWriteTo) {
+      memset(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof 
(float));
+    }
+
+    const float beta = req[0] == kAddTo ? 1.f : 0.f;
+    strided_batch_sgemm(false,
+                        true,
+                        head_dim,
+                        qkv_seq_len,
+                        qkv_seq_len,
+                        alpha,
+                        output_grads,
+                        head_dim * attn_batches,
+                        head_dim,
+                        attention_maps,
+                        qkv_seq_len,
+                        qkv_seq_len * qkv_seq_len,
+                        beta,
+                        queries_keys_values_grads + 2 * head_dim,
+                        lead_dim,
+                        batch_stride,
+                        attn_batches);
+  }
+  if (req[1] != kNullOp) {
+    const float beta = req[1] == kAddTo ? 1.f : 0.f;
+    strided_batch_sgemm(true,
+                        false,
+                        qkv_seq_len,
+                        qkv_seq_len,
+                        head_dim,
+                        alpha,
+                        queries_keys_values + 2 * head_dim,
+                        lead_dim,
+                        batch_stride,
+                        output_grads,
+                        head_dim * attn_batches,
+                        head_dim,
+                        beta,
+                        attention_maps_grads,
+                        qkv_seq_len,
+                        qkv_seq_len * qkv_seq_len,
+                        attn_batches);
+  }
+}
+
+void InterleavedMatMulEncDecQKCPU(const nnvm::NodeAttrs& attrs,
+                                  const OpContext &ctx,
+                                  const std::vector<TBlob> &inputs,
+                                  const std::vector<OpReqType> &req,
+                                  const std::vector<TBlob> &outputs) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+  if (req[0] == kNullOp)
+    return;
+
+  CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32)
+    << "Only FP32 is supported on CPU at the moment";
+
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  const float* queries     = inputs[0].FlatTo2D<cpu, float>(s).dptr_;
+  const float* keys_values = inputs[1].FlatTo2D<cpu, float>(s).dptr_;
+  float* output            = outputs[0].FlatTo2D<cpu, float>(s).dptr_;
+  const index_t q_seq_len         = inputs[0].shape_[0];
+  const index_t sequences         = inputs[0].shape_[1];
+  const index_t output_lin_q_dim  = inputs[0].shape_[2];
+  const index_t kv_seq_len        = inputs[1].shape_[0];
+  const index_t embed_dim         = output_lin_q_dim;
+  const index_t head_dim          = embed_dim / params.heads;
+  const index_t attn_batches      = params.heads * sequences;
+  const index_t lead_dim_q        = attn_batches * head_dim;
+  const index_t lead_dim_kv       = attn_batches * 2 * head_dim;
+  const index_t batch_stride_q    = head_dim;
+  const index_t batch_stride_kv   = head_dim * 2;
+  const float beta                = req[0] == kAddTo ? 1.f : 0.f;
+  const float scale               = 1.f / sqrt(static_cast<float>(head_dim));
+
+  strided_batch_sgemm(true,
+                      false,
+                      kv_seq_len,
+                      q_seq_len,
+                      head_dim,
+                      scale,
+                      keys_values,
+                      lead_dim_kv,
+                      batch_stride_kv,
+                      queries,
+                      lead_dim_q,
+                      batch_stride_q,
+                      beta,
+                      output,
+                      kv_seq_len,
+                      kv_seq_len * q_seq_len,
+                      attn_batches);
+}
+
+void BackwardInterleavedMatMulEncDecQKCPU(const nnvm::NodeAttrs& attrs,
+                                          const OpContext &ctx,
+                                          const std::vector<TBlob> &inputs,
+                                          const std::vector<OpReqType> &req,
+                                          const std::vector<TBlob> &outputs) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+  if (req[0] == kNullOp)
+    return;
+
+  CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32)
+    << "Only FP32 is supported on CPU at the moment";
+
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  const float* output_grads = inputs[0].FlatTo2D<cpu, float>(s).dptr_;
+  const float* queries       = inputs[1].FlatTo2D<cpu, float>(s).dptr_;
+  const float* keys_values   = inputs[2].FlatTo2D<cpu, float>(s).dptr_;
+  float* queries_grads       = outputs[0].FlatTo2D<cpu, float>(s).dptr_;
+  float* keys_values_grads   = outputs[1].FlatTo2D<cpu, float>(s).dptr_;
+  const index_t q_seq_len         = inputs[1].shape_[0];
+  const index_t sequences         = inputs[1].shape_[1];
+  const index_t output_lin_q_dim  = inputs[1].shape_[2];
+  const index_t kv_seq_len        = inputs[2].shape_[0];
+  const index_t embed_dim         = output_lin_q_dim;
+  const index_t head_dim          = embed_dim / params.heads;
+  const index_t attn_batches      = params.heads * sequences;
+  const index_t lead_dim_q        = attn_batches * head_dim;
+  const index_t lead_dim_kv       = attn_batches * 2 * head_dim;
+  const index_t batch_stride_q    = head_dim;
+  const index_t batch_stride_kv   = head_dim * 2;
+  const float scale               = 1.f / sqrt(static_cast<float>(head_dim));
+
+  if (req[0] != kNullOp) {
+    const float beta = req[0] == kAddTo ? 1.f : 0.f;
+    strided_batch_sgemm(false,
+                        false,
+                        head_dim,
+                        q_seq_len,
+                        kv_seq_len,
+                        scale,
+                        keys_values,
+                        lead_dim_kv,
+                        batch_stride_kv,
+                        output_grads,
+                        kv_seq_len,
+                        kv_seq_len * q_seq_len,
+                        beta,
+                        queries_grads,
+                        lead_dim_q,
+                        batch_stride_q,
+                        attn_batches);
+  }
+  if (req[1] != kNullOp) {
+    if (req[1] == kWriteTo) {
 
 Review comment:
   @Caenorst shall we change it to `req[1] == kWriteTo || req[1] == 
kWriteInplace ` for GPU code path, too?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to