eric-haibin-lin commented on a change in pull request #16408: Add MXNet Ops for 
fast multihead attention
URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r335803864
 
 

 ##########
 File path: src/operator/contrib/transformer.cc
 ##########
 @@ -29,6 +29,231 @@
 namespace mxnet {
 namespace op {
 
+DMLC_REGISTER_PARAMETER(InterleavedMatMulParam);
+
+static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs,
+                                            mxnet::ShapeVector* in_shape,
+                                            mxnet::ShapeVector* out_shape) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] currently 
have, "
+                                 << in_shape->size() << " inputs";
+  auto qkv_shape = in_shape->at(0);
+  CHECK_EQ(qkv_shape.ndim(), 3U)
+    << "Input queries_keys_values should be 3D in seq_length-batch-proj_dim, "
+    << "currently is: " << qkv_shape.ndim() << "D";
+  out_shape->resize(1);
+  SHAPE_ASSIGN_CHECK(*out_shape, 0,
+    mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]}));
+  return true;
+}
+
+static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs,
+                                                mxnet::ShapeVector* in_shape,
+                                                mxnet::ShapeVector* out_shape) 
{
+  CHECK_EQ(in_shape->size(), 2U) << "Input:[queries_keys_values, attention] 
currently have, "
+                                 << in_shape->size() << " inputs";
+  auto qkv_shape = in_shape->at(0);
+  auto att_shape = in_shape->at(1);
+  CHECK_EQ(qkv_shape.ndim(), 3U)
+    << "Input queries_keys_values should be 3D in seq_length-batch-3*proj_dim, 
"
+    << "currently is: " << qkv_shape.ndim() << "D";
+  CHECK_EQ(att_shape.ndim(), 3U)
+    << "Input attention should be 3D in batch-seq_length-seq_length, "
+    << "currently is: " << att_shape.ndim() << "D";
+  CHECK_EQ(qkv_shape[0], att_shape[1])
+    << "queries_keys_values.shape[0] and attention.shape[1] should be the 
same, "
+    << "currently are " << qkv_shape[0] << " and " << att_shape[1];
+  CHECK_EQ(qkv_shape[0], att_shape[2])
+    << "queries_keys_values.shape[0] and attention.shape[2] should be the 
same, "
+    << "currently are " << qkv_shape[0] << " and " << att_shape[2];
+  CHECK_EQ(qkv_shape[2] % 3, 0)
+    << "queries_keys_values.shape[2] should be a multiple of 3, "
+    << "currently is " << qkv_shape[2];
+  SHAPE_ASSIGN_CHECK(*out_shape, 0,
+    mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3}));
+  return true;
+}
+
+static bool InterleavedMatMulEncDecQKShape(const NodeAttrs& attrs,
+                                           mxnet::ShapeVector* in_shape,
+                                           mxnet::ShapeVector* out_shape) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), 2U) << "Input:[queries, keys_values], currently 
have "
+                                 << in_shape->size() << " inputs";
+  auto q_shape = in_shape->at(0);
+  auto kv_shape = in_shape->at(1);
+  CHECK_EQ(q_shape.ndim(), 3U) << "Input queries should be 3D in 
seq_length-batch-proj_dim, "
+                               << "currently is " << q_shape.ndim() << "D";
+  CHECK_EQ(kv_shape.ndim(), 3U) << "Input queries should be 3D in 
seq_length-batch-2*proj_dim, "
+                                << "currently is " << kv_shape.ndim() << "D";
+  CHECK_EQ(q_shape[2] * 2, kv_shape[2])
+    << "keys_values.shape[2] should be equal to queries.shape[2] * 2, "
+    << "currently are: " << kv_shape[2] << " and " << q_shape[2];
+  CHECK_EQ(q_shape[1], kv_shape[1])
+    << "queries.shape[1] should be equal to keys_values.shape[1], "
+    << "currently are: " << q_shape[1] << " and " << kv_shape[1];
+  SHAPE_ASSIGN_CHECK(*out_shape, 0,
+      mxnet::TShape({q_shape[1] * params.heads, q_shape[0], kv_shape[0]}));
+  return true;
+}
+
+static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs,
+                                               mxnet::ShapeVector* in_shape,
+                                               mxnet::ShapeVector* out_shape) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), 2U) << "Input: [keys_values, attention], 
currently have "
+                                 << in_shape->size() << " inputs";
+  auto kv_shape = in_shape->at(0);
+  auto att_shape = in_shape->at(1);
+  CHECK_EQ(kv_shape.ndim(), 3U)
+    << "Input keys_values should be 3D in seq_length-batch-2*proj_dim, "
+    << "currently is " << kv_shape.ndim() << "D";
+  CHECK_EQ(att_shape.ndim(), 3U)
+    << "Input attention should be 3D in batch-seq_length-seq_length, "
+    << "currently is " << att_shape.ndim() << "D";
+  CHECK_EQ(kv_shape[0], att_shape[2])
+    << "keys_values.shape[0] should be equal to attention.shape[2], currently 
are "
+    << kv_shape[0] << " and " << att_shape[2];
+  CHECK_EQ(kv_shape[1] * params.heads, att_shape[0]) << "attention.shape[0] "
+    << "should be equal to keys_values.shape[1] * heads, currently are: "
+    << att_shape[2] << " and " << kv_shape[1];
+  SHAPE_ASSIGN_CHECK(*out_shape, 0,
+      mxnet::TShape({att_shape[1], kv_shape[1], kv_shape[2] / 2}));
+  return true;
+}
+
+NNVM_REGISTER_OP(interleaved_matmul_selfatt_qk)
+.describe(R"code(Compute the matrix multiplication between the projections of
+queries and keys in multihead attention use as self attention.
+
+the input must be a single tensor of interleaved projections
+of queries, keys and values following the layout:
+(seq_length, batch_size, num_heads * head_dim * 3)
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<InterleavedMatMulParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) 
{
+  return std::vector<std::string>{"queries_keys_values"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& 
attrs) {
+  return std::vector<std::string>{"output"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulSelfAttQKShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FGradient>("FGradient",
+  ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_qk"})
+.add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved 
queries, keys and values")
+.add_arguments(InterleavedMatMulParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr_parser(ParamParser<InterleavedMatMulParam>);
+
+NNVM_REGISTER_OP(interleaved_matmul_selfatt_valatt)
+.describe(R"code(Compute the matrix multiplication between the projections of
+values and the attention weights in multihead attention use as self attention.
+
+the inputs must be a tensor of interleaved projections
+of queries, keys and values following the layout:
+(seq_length, batch_size, num_heads * head_dim * 3)
+
+and the attention weights following the layout:
+(batch_size, seq_length, seq_length)
+)code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<InterleavedMatMulParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) 
{
+  return std::vector<std::string>{"queries_keys_values", "attention"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& 
attrs) {
+  return std::vector<std::string>{"output"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", 
InterleavedMatMulSelfAttValAttShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<nnvm::FGradient>("FGradient",
+  ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_valatt"})
+.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and 
values interleaved")
+.add_argument("attention", "NDArray-or-Symbol", "Attention maps")
+.add_arguments(InterleavedMatMulParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_valatt)
+.set_num_inputs(3)
+.set_num_outputs(2)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr_parser(ParamParser<InterleavedMatMulParam>);
+
+NNVM_REGISTER_OP(interleaved_matmul_encdec_qk)
+.describe(R"code(Compute the matrix multiplication between the projections of
+queries and keys in multihead attention use as encoder-decoder.
+
+the inputs must be a tensor of projections of queries following the layout:
+(seq_length, batch_size, num_heads * head_dim)
+
+and a tensor of interleaved projections of values and keys following the 
layout:
+(seq_length, batch_size, num_heads * head_dim * 2)
+)code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<InterleavedMatMulParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) 
{
+  return std::vector<std::string>{"queries", "keys_values"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& 
attrs) {
+  return std::vector<std::string>{"output"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulEncDecQKShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<nnvm::FGradient>("FGradient",
+    ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_qk"})
+.add_argument("queries", "NDArray-or-Symbol", "Queries")
+.add_argument("keys_values", "NDArray-or-Symbol", "Keys and values 
interleaved")
+.add_arguments(InterleavedMatMulParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_qk)
+.set_num_inputs(3)
+.set_num_outputs(2)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr_parser(ParamParser<InterleavedMatMulParam>);
+
+NNVM_REGISTER_OP(interleaved_matmul_encdec_valatt)
+.describe(R"code(Compute the matrix multiplication between the projections of
+values and the attention weights in multihead attention use as encoder-decoder.
+
+the inputs must be a tensor of interleaved projections of
+keys and values following the layout:
+(seq_length, batch_size, num_heads * head_dim * 2)
+
+and the attention weights following the layout:
+(batch_size, seq_length, seq_length)
 
 Review comment:
   Can we add math formula, or at least the unfused version, for others to 
understand what it computes? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to