szhengac commented on a change in pull request #19387:
URL: https://github.com/apache/incubator-mxnet/pull/19387#discussion_r508808216



##########
File path: src/operator/contrib/transformer.cc
##########
@@ -841,5 +841,196 @@ MXNET_OPERATOR_REGISTER_UNARY(_contrib_div_sqrt_dim)
 .set_attr<FCompute>("FCompute<cpu>", DivSqrtDimForward_<cpu>)
 .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"_contrib_div_sqrt_dim"});
 
+
+DMLC_REGISTER_PARAMETER(SldWinAttenParam);
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_mask_like)
+.add_alias("_npx_sldwin_atten_mask_like")
+.describe(R"code(Compute the mask for the sliding window attention score, used 
in
+Longformer (https://arxiv.org/pdf/2004.05150.pdf). In this attention pattern,
+given a fixed window size *2w*, each token attends to *w* tokens on the left 
side
+if we use causal attention (setting *symmetric* to *False*),
+otherwise each token attends to *w* tokens on each side.
+
+The shapes of the inputs are:
+- *score* : (batch_size, seq_length, num_heads, w + w + 1) if symmetric is 
True,
+            otherwise (batch_size, seq_length, num_heads, w + 1).
+- *dilation* : (num_heads,)
+- *valid_length* : (batch_size,)
+
+The shape of the output is:
+- *mask* : same as the shape of *score*
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SldWinAttenParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) 
{
+  return std::vector<std::string>{"score", "dilation", "valid_length"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& 
attrs) {
+  return std::vector<std::string>{"mask"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
+                                              mxnet::ShapeVector *in_attrs,
+                                              mxnet::ShapeVector *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 3U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const mxnet::TShape& dshape = (*in_attrs)[0];
+  if (!shape_is_known(dshape)) return false;
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+  return true;
+})
+.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs &attrs,
+                                             std::vector<int> *in_attrs,
+                                             std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 3U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
+  return out_attrs->at(0) != -1;
+})
+.set_attr<FCompute>("FCompute<cpu>", SldWinAttenMaskLikeForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("score", "NDArray-or-Symbol", "sliding window attention score")
+.add_argument("dilation", "NDArray-or-Symbol", "dilation")
+.add_argument("valid_length", "NDArray-or-Symbol", "valid length")
+.add_arguments(SldWinAttenParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_score)
+.add_alias("_npx_sldwin_atten_score")
+.describe(R"code(Compute the sliding window attention score, which is used in
+Longformer (https://arxiv.org/pdf/2004.05150.pdf). In this attention pattern,
+given a fixed window size *2w*, each token attends to *w* tokens on the left 
side
+if we use causal attention (setting *symmetric* to *False*),
+otherwise each token attends to *w* tokens on each side.
+
+The shapes of the inputs are:
+- *query* : (batch_size, seq_length, num_heads, num_head_units)
+- *key* : (batch_size, seq_length, num_heads, num_head_units)
+- *dilation* : (num_heads,)

Review comment:
       this allows us to have different dilation for different heads.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to