Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast 
multihead attention
URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r336220153
 
 

 ##########
 File path: src/operator/contrib/transformer.cc
 ##########
 @@ -29,6 +29,231 @@
 namespace mxnet {
 namespace op {
 
+DMLC_REGISTER_PARAMETER(InterleavedMatMulParam);
+
+static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs,
+                                            mxnet::ShapeVector* in_shape,
+                                            mxnet::ShapeVector* out_shape) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] currently 
have, "
+                                 << in_shape->size() << " inputs";
+  auto qkv_shape = in_shape->at(0);
+  CHECK_EQ(qkv_shape.ndim(), 3U)
+    << "Input queries_keys_values should be 3D in seq_length-batch-proj_dim, "
+    << "currently is: " << qkv_shape.ndim() << "D";
+  out_shape->resize(1);
+  SHAPE_ASSIGN_CHECK(*out_shape, 0,
+    mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]}));
+  return true;
+}
+
+static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs,
+                                                mxnet::ShapeVector* in_shape,
+                                                mxnet::ShapeVector* out_shape) 
{
+  CHECK_EQ(in_shape->size(), 2U) << "Input:[queries_keys_values, attention] 
currently have, "
+                                 << in_shape->size() << " inputs";
+  auto qkv_shape = in_shape->at(0);
+  auto att_shape = in_shape->at(1);
+  CHECK_EQ(qkv_shape.ndim(), 3U)
+    << "Input queries_keys_values should be 3D in seq_length-batch-3*proj_dim, 
"
+    << "currently is: " << qkv_shape.ndim() << "D";
+  CHECK_EQ(att_shape.ndim(), 3U)
+    << "Input attention should be 3D in batch-seq_length-seq_length, "
+    << "currently is: " << att_shape.ndim() << "D";
+  CHECK_EQ(qkv_shape[0], att_shape[1])
+    << "queries_keys_values.shape[0] and attention.shape[1] should be the 
same, "
+    << "currently are " << qkv_shape[0] << " and " << att_shape[1];
+  CHECK_EQ(qkv_shape[0], att_shape[2])
+    << "queries_keys_values.shape[0] and attention.shape[2] should be the 
same, "
+    << "currently are " << qkv_shape[0] << " and " << att_shape[2];
+  CHECK_EQ(qkv_shape[2] % 3, 0)
+    << "queries_keys_values.shape[2] should be a multiple of 3, "
+    << "currently is " << qkv_shape[2];
+  SHAPE_ASSIGN_CHECK(*out_shape, 0,
+    mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3}));
+  return true;
+}
+
+static bool InterleavedMatMulEncDecQKShape(const NodeAttrs& attrs,
+                                           mxnet::ShapeVector* in_shape,
+                                           mxnet::ShapeVector* out_shape) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), 2U) << "Input:[queries, keys_values], currently 
have "
+                                 << in_shape->size() << " inputs";
+  auto q_shape = in_shape->at(0);
+  auto kv_shape = in_shape->at(1);
+  CHECK_EQ(q_shape.ndim(), 3U) << "Input queries should be 3D in 
seq_length-batch-proj_dim, "
+                               << "currently is " << q_shape.ndim() << "D";
+  CHECK_EQ(kv_shape.ndim(), 3U) << "Input queries should be 3D in 
seq_length-batch-2*proj_dim, "
+                                << "currently is " << kv_shape.ndim() << "D";
+  CHECK_EQ(q_shape[2] * 2, kv_shape[2])
+    << "keys_values.shape[2] should be equal to queries.shape[2] * 2, "
+    << "currently are: " << kv_shape[2] << " and " << q_shape[2];
+  CHECK_EQ(q_shape[1], kv_shape[1])
+    << "queries.shape[1] should be equal to keys_values.shape[1], "
+    << "currently are: " << q_shape[1] << " and " << kv_shape[1];
+  SHAPE_ASSIGN_CHECK(*out_shape, 0,
+      mxnet::TShape({q_shape[1] * params.heads, q_shape[0], kv_shape[0]}));
+  return true;
+}
+
+static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs,
+                                               mxnet::ShapeVector* in_shape,
+                                               mxnet::ShapeVector* out_shape) {
+  const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), 2U) << "Input: [keys_values, attention], 
currently have "
+                                 << in_shape->size() << " inputs";
+  auto kv_shape = in_shape->at(0);
+  auto att_shape = in_shape->at(1);
+  CHECK_EQ(kv_shape.ndim(), 3U)
+    << "Input keys_values should be 3D in seq_length-batch-2*proj_dim, "
+    << "currently is " << kv_shape.ndim() << "D";
+  CHECK_EQ(att_shape.ndim(), 3U)
+    << "Input attention should be 3D in batch-seq_length-seq_length, "
+    << "currently is " << att_shape.ndim() << "D";
+  CHECK_EQ(kv_shape[0], att_shape[2])
+    << "keys_values.shape[0] should be equal to attention.shape[2], currently 
are "
+    << kv_shape[0] << " and " << att_shape[2];
+  CHECK_EQ(kv_shape[1] * params.heads, att_shape[0]) << "attention.shape[0] "
+    << "should be equal to keys_values.shape[1] * heads, currently are: "
+    << att_shape[2] << " and " << kv_shape[1];
+  SHAPE_ASSIGN_CHECK(*out_shape, 0,
+      mxnet::TShape({att_shape[1], kv_shape[1], kv_shape[2] / 2}));
+  return true;
+}
+
+NNVM_REGISTER_OP(interleaved_matmul_selfatt_qk)
+.describe(R"code(Compute the matrix multiplication between the projections of
+queries and keys in multihead attention use as self attention.
+
+the input must be a single tensor of interleaved projections
+of queries, keys and values following the layout:
+(seq_length, batch_size, num_heads * head_dim * 3)
 
 Review comment:
   Adding the concatenation reduce by about 20% the speedup due to multihead 
attention. I think we can think about an improvement but meanwhile that is 
still a speedup. I would encourage to make an analysis of LAMB coefficients 
difference within multihead attention blocks, maybe directly applying it on the 
concatenation of weights would be fine :man_shrugging: .

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to