sxjscience commented on a change in pull request #19387:
URL: https://github.com/apache/incubator-mxnet/pull/19387#discussion_r508814571
##########
File path: src/operator/contrib/transformer-inl.h
##########
@@ -61,6 +61,229 @@ static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs,
});
}
+
+
+struct SldWinAttenParam : public dmlc::Parameter<SldWinAttenParam> {
+ int w;
+ bool symmetric;
+ DMLC_DECLARE_PARAMETER(SldWinAttenParam) {
+ DMLC_DECLARE_FIELD(w)
+ .describe("The one-sided window length");
+ DMLC_DECLARE_FIELD(symmetric)
+ .describe("Whether to use causal attention");
+ }
+};
+
+
+struct SldWinAttenMaskLike {
+ MSHADOW_XINLINE static void Map(int i, float *out, int32_t *dilation,
int32_t *val_length,
+ bool symmetric, int w, int seq_length, int
num_heads) {
+ out[i] = 1;
+ int w_len = symmetric ? (w + w + 1) : (w + 1);
+ int idx_0 = i / (seq_length * num_heads * w_len); // batch idx
+ int tmp = i % (seq_length * num_heads * w_len);
+ int idx_1 = tmp / (num_heads * w_len); // sequence idx
+ tmp = tmp % (num_heads * w_len);
+ int idx_2 = tmp / w_len; // head idx
+ int idx_3 = tmp % w_len; // win idx
+
+ bool is_zero = idx_3 < (w - idx_1/dilation[idx_2]) || idx_1 >=
val_length[idx_0] \
+ || (symmetric && (w_len-idx_3-1) < (w -
(val_length[idx_0]-idx_1-1)/dilation[idx_2]));
+ if (is_zero) out[i] = 0;
+ }
+};
+
+
+template<typename xpu>
+void SldWinAttenMaskLikeForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ CHECK_EQ(inputs.size(), 3U);
+ CHECK_EQ(outputs.size(), 1U);
+ CHECK_EQ(req.size(), 1U);
+ CHECK_EQ(req[0], kWriteTo) << "Currently only support kWriteTo";
+ using namespace mshadow;
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+ CHECK_EQ(outputs[0].type_flag_, kFloat32);
+ float* out = outputs[0].dptr<float>();
+ CHECK_EQ(inputs[1].type_flag_, kInt32);
+ int32_t* dilation = inputs[1].dptr<int32_t>();
+ CHECK_EQ(inputs[2].type_flag_, kInt32);
+ int32_t* val_length = inputs[2].dptr<int32_t>();
+
+ int seq_length = inputs[0].shape_[1];
+ int num_heads = inputs[0].shape_[2];
+ int num_threads = outputs[0].Size();
+
+ mxnet_op::Kernel<SldWinAttenMaskLike, xpu>::Launch(s, num_threads, out,
dilation,
+ val_length, param.symmetric, param.w, seq_length, num_heads);
+}
+
+
+
+
+struct DiagMM {
+ MSHADOW_XINLINE static void Map(int tid, float *out, float *lhs, float *rhs,
+ int32_t *dilation, int batch_size, int
seq_length,
+ int num_heads, int out_last_dim, int
lhs_last_dim, int w,
+ int w_right, bool diagonal_lhs, bool
transpose_lhs) {
Review comment:
We may try to improve the kernel by TVM/Ansor. We can use TVM as the
code generator. Another option is to use cutlass.
The general idea of optimizing the kernel is like the following:
If you have reorganized the local context, the attention value can be
obtained via a single CUBLAS BatchedMatmul. However, this will cost more memory
than the current solution. In order to balance the memory cost and latency, you
should streamline the process like how convolution is implemented.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]