sxjscience commented on a change in pull request #19387:
URL: https://github.com/apache/incubator-mxnet/pull/19387#discussion_r508814571



##########
File path: src/operator/contrib/transformer-inl.h
##########
@@ -61,6 +61,229 @@ static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs,
   });
 }
 
+
+
+struct SldWinAttenParam : public dmlc::Parameter<SldWinAttenParam> {
+  int w;
+  bool symmetric;
+  DMLC_DECLARE_PARAMETER(SldWinAttenParam) {
+    DMLC_DECLARE_FIELD(w)
+    .describe("The one-sided window length");
+    DMLC_DECLARE_FIELD(symmetric)
+    .describe("Whether to use causal attention");
+  }
+};
+
+
+struct SldWinAttenMaskLike {
+  MSHADOW_XINLINE static void Map(int i, float *out, int32_t *dilation, 
int32_t *val_length,
+                                  bool symmetric, int w, int seq_length, int 
num_heads) {
+    out[i] = 1;
+    int w_len = symmetric ? (w + w + 1) : (w + 1);
+    int idx_0 = i / (seq_length * num_heads * w_len); // batch idx
+    int tmp = i % (seq_length * num_heads * w_len);
+    int idx_1 = tmp / (num_heads * w_len); // sequence idx
+    tmp = tmp % (num_heads * w_len);
+    int idx_2 = tmp / w_len; // head idx
+    int idx_3 = tmp % w_len; // win idx
+
+    bool is_zero = idx_3 < (w - idx_1/dilation[idx_2]) || idx_1 >= 
val_length[idx_0] \
+      || (symmetric && (w_len-idx_3-1) < (w - 
(val_length[idx_0]-idx_1-1)/dilation[idx_2]));
+    if (is_zero) out[i] = 0;
+  }
+};
+
+
+template<typename xpu>
+void SldWinAttenMaskLikeForward(const nnvm::NodeAttrs& attrs,
+                                const OpContext& ctx,
+                                const std::vector<TBlob>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK_EQ(req[0], kWriteTo) << "Currently only support kWriteTo";
+  using namespace mshadow;
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+  CHECK_EQ(outputs[0].type_flag_, kFloat32);
+  float* out = outputs[0].dptr<float>();
+  CHECK_EQ(inputs[1].type_flag_, kInt32);
+  int32_t* dilation = inputs[1].dptr<int32_t>();
+  CHECK_EQ(inputs[2].type_flag_, kInt32);
+  int32_t* val_length = inputs[2].dptr<int32_t>();
+
+  int seq_length = inputs[0].shape_[1];
+  int num_heads = inputs[0].shape_[2];
+  int num_threads = outputs[0].Size();
+
+  mxnet_op::Kernel<SldWinAttenMaskLike, xpu>::Launch(s, num_threads, out, 
dilation,
+      val_length, param.symmetric, param.w, seq_length, num_heads);
+}
+
+
+
+
+struct DiagMM {
+  MSHADOW_XINLINE static void Map(int tid, float *out, float *lhs, float *rhs,
+                                  int32_t *dilation, int batch_size, int 
seq_length,
+                                  int num_heads, int out_last_dim, int 
lhs_last_dim, int w,
+                                  int w_right, bool diagonal_lhs, bool 
transpose_lhs) {

Review comment:
       We may try to improve the kernel by TVM/Ansor. We can use TVM as the 
code generator. Another option is to use cutlass.
   
   The general idea of optimizing the kernel is like the following:
   
   If you have reorganized the local context, the attention value can be 
obtained via a single CUBLAS BatchedMatmul. However, this will cost more memory 
than the current solution. In order to balance the memory cost and latency, you 
should streamline the process like how convolution is implemented.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to