bartekkuncer commented on code in PR #21115:
URL: https://github.com/apache/incubator-mxnet/pull/21115#discussion_r946823797
##########
src/operator/subgraph/dnnl/dnnl_transformer.cc:
##########
@@ -29,54 +29,78 @@
#include "operator/subgraph/common.h"
#include "dnnl_transformer-inl.h"
-// 3 tensors within one (queries key values) =
+// 3 tensors within one (queries key values)
#define QKV_NUM 3
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(DNNLSelfAttParam);
+template <bool with_split>
static bool SgDNNLSelfAttShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
- const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
- auto qkv_shape = in_shape->at(0);
- CHECK_EQ(qkv_shape.ndim(), 3U)
+ const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+ unsigned int in_shape_num = 1;
+ auto in_shape_0 = in_shape->at(0);
+ auto in_shape_1 = in_shape_0; // in with_split mode there is only
one input
+ CHECK_EQ(in_shape_0.ndim(), 3U)
<< "Input queries_keys_values should be 3D in batch-seq_length-proj_dim,
"
- << "but the given tensor is " << qkv_shape.ndim() << "D";
+ << "but the given tensor is " << in_shape_0.ndim() << "D";
- if (params.quantized) {
- CHECK_EQ(in_shape->size(), 3U) << "Input: [queries_keys_values, min_qkv,
max_qkv] "
- << "- currently have " << in_shape->size()
<< " inputs";
+ if constexpr (!with_split) {
+ in_shape_1 = in_shape->at(1); // in without_split mode we need to
consider 2nd input
Review Comment:
```suggestion
in_shape_1 = in_shape->at(1); // without split we need to consider 2nd
input
```
##########
src/operator/subgraph/dnnl/dnnl_transformer.cc:
##########
@@ -29,54 +29,78 @@
#include "operator/subgraph/common.h"
#include "dnnl_transformer-inl.h"
-// 3 tensors within one (queries key values) =
+// 3 tensors within one (queries key values)
#define QKV_NUM 3
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(DNNLSelfAttParam);
+template <bool with_split>
static bool SgDNNLSelfAttShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
- const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
- auto qkv_shape = in_shape->at(0);
- CHECK_EQ(qkv_shape.ndim(), 3U)
+ const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+ unsigned int in_shape_num = 1;
+ auto in_shape_0 = in_shape->at(0);
+ auto in_shape_1 = in_shape_0; // in with_split mode there is only
one input
Review Comment:
Maybe simply with split/without split, without node and underscore?
```suggestion
auto in_shape_1 = in_shape_0; // with split there is only one
input
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]