agrabows commented on code in PR #21115:
URL: https://github.com/apache/incubator-mxnet/pull/21115#discussion_r942622573
##########
src/operator/subgraph/dnnl/dnnl_transformer.cc:
##########
@@ -37,46 +38,62 @@ namespace op {
DMLC_REGISTER_PARAMETER(DNNLSelfAttParam);
+template <qk_common::mode qk_mode>
static bool SgDNNLSelfAttShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
- auto qkv_shape = in_shape->at(0);
- CHECK_EQ(qkv_shape.ndim(), 3U)
+ uint in_shape_num = 1;
+ auto input_shape = in_shape->at(0);
+ CHECK_EQ(input_shape.ndim(), 3U)
<< "Input queries_keys_values should be 3D in batch-seq_length-proj_dim,
"
- << "but the given tensor is " << qkv_shape.ndim() << "D";
+ << "but the given tensor is " << input_shape.ndim() << "D";
- if (params.quantized) {
- CHECK_EQ(in_shape->size(), 3U) << "Input: [queries_keys_values, min_qkv,
max_qkv] "
- << "- currently have " << in_shape->size()
<< " inputs";
+ if constexpr (qk_mode == qk_common::mode::without_split) {
+ CHECK_EQ(in_shape->at(0), in_shape->at(1));
+ in_shape_num = 2;
+ }
- SHAPE_ASSIGN_CHECK(*in_shape, 1, mxnet::TShape({1}));
+ if (params.quantized) {
+ CHECK_EQ(in_shape->size(), 3 * in_shape_num)
+ << "Input: [queries_keys_values, min_qkv, max_qkv] "
+ << "- currently have " << in_shape->size() << " inputs";
SHAPE_ASSIGN_CHECK(*in_shape, 2, mxnet::TShape({1}));
+ if constexpr (qk_mode == qk_common::mode::without_split) {
+ SHAPE_ASSIGN_CHECK(*in_shape, 3, mxnet::TShape({1}));
+ SHAPE_ASSIGN_CHECK(*in_shape, 4, mxnet::TShape({1}));
+ SHAPE_ASSIGN_CHECK(*in_shape, 5, mxnet::TShape({1}));
+ } else {
+ SHAPE_ASSIGN_CHECK(*in_shape, 1, mxnet::TShape({1}));
+ }
out_shape->resize(3);
- SHAPE_ASSIGN_CHECK(
- *out_shape, 0, mxnet::TShape({qkv_shape[0], params.heads,
qkv_shape[1], qkv_shape[1]}));
if (!params.enabled_float_output.has_value()) {
SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); // min output
SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); // max output
}
} else {
- CHECK_EQ(in_shape->size(), 1U)
+ CHECK_EQ(in_shape->size(), in_shape_num)
<< "Input:[queries_keys_values] - currently have " << in_shape->size()
<< " inputs";
out_shape->resize(1);
- SHAPE_ASSIGN_CHECK(
- *out_shape, 0, mxnet::TShape({qkv_shape[0], params.heads,
qkv_shape[1], qkv_shape[1]}));
}
+ SHAPE_ASSIGN_CHECK(
+ *out_shape, 0, mxnet::TShape({input_shape[0], params.heads,
input_shape[1], input_shape[1]}));
return true;
}
+template <qk_common::mode qk_mode>
static bool SgDNNLSelfAttQKInferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_types,
std::vector<int>* out_types) {
const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+ uint in_shape_num = 1;
+ if constexpr (qk_mode == qk_common::mode::without_split) {
+ in_shape_num = 2;
+ }
Review Comment:
this if block now has other instructions so I think it is outdated now
##########
src/operator/subgraph/dnnl/dnnl_transformer.cc:
##########
@@ -86,8 +103,20 @@ static bool SgDNNLSelfAttQKInferType(const nnvm::NodeAttrs&
attrs,
<< "QuantizedSelfAttentionQK only supports int8 input, while " <<
in_types->at(0)
<< " is given.";
- TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32);
Review Comment:
moved
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]