ptrendx commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r335094420
########## File path: src/operator/contrib/transformer.cc ########## @@ -29,6 +29,163 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(InterleavedMatMulParam); + +static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); + CHECK_EQ(in_shape->size(), 1); + auto qkv_shape = in_shape->at(0); + CHECK_EQ(qkv_shape.ndim(), 3); + out_shape->resize(1); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); + return true; +} + +static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + CHECK_EQ(in_shape->size(), 2); + auto qkv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(qkv_shape.ndim(), 3); + CHECK_EQ(att_shape.ndim(), 3); + CHECK_EQ(qkv_shape[0], att_shape[1]); + CHECK_EQ(qkv_shape[0], att_shape[2]); + CHECK_EQ(qkv_shape[2] % 3, 0); Review comment: Could you make some meaningful error messages when the shape does not match your expectation? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services