This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 1a418e4e1c [FEATURE] Add query_keys transformer version without split
(#21115)
1a418e4e1c is described below
commit 1a418e4e1cf89a87838eb02abf74ccfa4bfe6c37
Author: AdamGrabowski <[email protected]>
AuthorDate: Tue Aug 23 15:33:02 2022 +0200
[FEATURE] Add query_keys transformer version without split (#21115)
* Add query_keys transformer version without split
* Add op to amp, implement different shaped inputs cases
* Fix sanity
* Remove semicolon
* Fix windows build
* Fix windows build sanity
* Change to bool template, fix comments, shorten test case
* Change comment message
* Fix SelectInput(), input names and Forward()
---
python/mxnet/amp/lists/symbol_bf16.py | 1 +
python/mxnet/amp/lists/symbol_fp16.py | 1 +
.../subgraph/dnnl/dnnl_post_amp_property.h | 1 +
.../subgraph/dnnl/dnnl_post_quantize_property.h | 1 +
.../subgraph/dnnl/dnnl_subgraph_property.cc | 2 +
src/operator/subgraph/dnnl/dnnl_transformer.cc | 295 ++++++++++++++-------
.../subgraph/dnnl/dnnl_transformer_qk_common.h | 230 ++++++++++++++++
.../subgraph/dnnl/dnnl_transformer_qk_property.h | 217 ++++++---------
.../dnnl/dnnl_transformer_valatt_property.h | 12 +-
tests/python/dnnl/op_cfg.py | 8 +-
.../python/dnnl/subgraphs/test_matmul_subgraph.py | 38 ++-
11 files changed, 555 insertions(+), 251 deletions(-)
diff --git a/python/mxnet/amp/lists/symbol_bf16.py
b/python/mxnet/amp/lists/symbol_bf16.py
index 5b9df27497..3f9bfbc707 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -31,6 +31,7 @@ if Features.instance.is_enabled('ONEDNN'):
'_sg_onednn_conv',
'_sg_onednn_fully_connected',
'_sg_onednn_selfatt_qk',
+ '_sg_onednn_selfatt_qk_split',
'_sg_onednn_selfatt_valatt'
])
diff --git a/python/mxnet/amp/lists/symbol_fp16.py
b/python/mxnet/amp/lists/symbol_fp16.py
index 76e9488f69..62e87dfc59 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -634,6 +634,7 @@ if Features().is_enabled('ONEDNN'):
'_sg_onednn_conv',
'_sg_onednn_fully_connected',
'_sg_onednn_selfatt_qk',
+ '_sg_onednn_selfatt_qk_split',
'_sg_onednn_selfatt_valatt',
'_sg_onednn_batch_dot',
'_sg_onednn_batch_norm',
diff --git a/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
b/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
index 6ec7c54e38..8e14c9b361 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
@@ -35,6 +35,7 @@ inline bool IsSupportedAMPFuseOp(const nnvm::Node& node) {
static const std::set<const Op*> supported_ops = {Op::Get("_sg_onednn_conv"),
Op::Get("_sg_onednn_fully_connected"),
Op::Get("_sg_onednn_selfatt_qk"),
+
Op::Get("_sg_onednn_selfatt_qk_split"),
Op::Get("_sg_onednn_selfatt_valatt")};
return supported_ops.count(node.op()) > 0;
}
diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index 94c7e63085..6e905f59da 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -45,6 +45,7 @@ bool SupportsRequantizeFusion(const Op* op) {
Op::Get("_sg_onednn_conv"),
Op::Get("_sg_onednn_fully_connected"),
Op::Get("_sg_onednn_selfatt_qk"),
+ Op::Get("_sg_onednn_selfatt_qk_split"),
Op::Get("_sg_onednn_selfatt_valatt"),
Op::Get("_sg_onednn_batch_dot")};
diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
index 86e08020ee..bc190a73ae 100644
--- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
+++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
@@ -45,6 +45,7 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLConvProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLFCProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBNReLUProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLRemoveCastsProperty);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerQKSplitProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerQKProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerValAttProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBatchDotProperty);
@@ -56,6 +57,7 @@
MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN_QUANTIZE).set_attr("context", Context::CP
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLIdentityProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE,
SgDNNLConvProperty).set_attr("quantize", true);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE,
SgDNNLFCProperty).set_attr("quantize", true);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE,
SgDNNLTransformerQKSplitProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLTransformerQKProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE,
SgDNNLTransformerValAttProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLBatchDotProperty)
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer.cc
b/src/operator/subgraph/dnnl/dnnl_transformer.cc
index 67fee5f98f..dfafbb038d 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer.cc
+++ b/src/operator/subgraph/dnnl/dnnl_transformer.cc
@@ -29,7 +29,7 @@
#include "operator/subgraph/common.h"
#include "dnnl_transformer-inl.h"
-// 3 tensors within one (queries key values) =
+// 3 tensors within one (queries keys values)
#define QKV_NUM 3
namespace mxnet {
@@ -37,46 +37,70 @@ namespace op {
DMLC_REGISTER_PARAMETER(DNNLSelfAttParam);
+template <bool with_split>
static bool SgDNNLSelfAttShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
- const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
- auto qkv_shape = in_shape->at(0);
- CHECK_EQ(qkv_shape.ndim(), 3U)
+ const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+ unsigned int in_shape_num = 1;
+ auto in_shape_0 = in_shape->at(0);
+ auto in_shape_1 = in_shape_0; // with split there is only one
input
+ CHECK_EQ(in_shape_0.ndim(), 3U)
<< "Input queries_keys_values should be 3D in batch-seq_length-proj_dim,
"
- << "but the given tensor is " << qkv_shape.ndim() << "D";
+ << "but the given tensor is " << in_shape_0.ndim() << "D";
- if (params.quantized) {
- CHECK_EQ(in_shape->size(), 3U) << "Input: [queries_keys_values, min_qkv,
max_qkv] "
- << "- currently have " << in_shape->size()
<< " inputs";
+ if constexpr (!with_split) {
+ in_shape_1 = in_shape->at(1); // without split we need to consider 2nd
input
+ CHECK_EQ(in_shape_1.ndim(), 3U)
+ << "Input queries_keys_values should be 3D in
batch-seq_length-proj_dim, "
+ << "but the given tensor is " << in_shape_1.ndim() << "D";
+ CHECK_EQ(in_shape_0[0], in_shape_1[0]);
+ CHECK_EQ(in_shape_0[2], in_shape_1[2]);
+ in_shape_num = 2;
+ }
- SHAPE_ASSIGN_CHECK(*in_shape, 1, mxnet::TShape({1}));
- SHAPE_ASSIGN_CHECK(*in_shape, 2, mxnet::TShape({1}));
+ if (params.quantized) {
+ CHECK_EQ(in_shape->size(), 3 * in_shape_num)
+ << "Input: [queries_keys_values, min_qkv, max_qkv] "
+ << "- currently have " << in_shape->size() << " inputs";
+ if constexpr (with_split) {
+ SHAPE_ASSIGN_CHECK(*in_shape, 1, mxnet::TShape({1}));
+ SHAPE_ASSIGN_CHECK(*in_shape, 2, mxnet::TShape({1}));
+ } else {
+ SHAPE_ASSIGN_CHECK(*in_shape, 2, mxnet::TShape({1}));
+ SHAPE_ASSIGN_CHECK(*in_shape, 3, mxnet::TShape({1}));
+ SHAPE_ASSIGN_CHECK(*in_shape, 4, mxnet::TShape({1}));
+ SHAPE_ASSIGN_CHECK(*in_shape, 5, mxnet::TShape({1}));
+ }
out_shape->resize(3);
- SHAPE_ASSIGN_CHECK(
- *out_shape, 0, mxnet::TShape({qkv_shape[0], params.heads,
qkv_shape[1], qkv_shape[1]}));
if (!params.enabled_float_output.has_value()) {
SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); // min output
SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); // max output
}
} else {
- CHECK_EQ(in_shape->size(), 1U)
+ CHECK_EQ(in_shape->size(), in_shape_num)
<< "Input:[queries_keys_values] - currently have " << in_shape->size()
<< " inputs";
out_shape->resize(1);
- SHAPE_ASSIGN_CHECK(
- *out_shape, 0, mxnet::TShape({qkv_shape[0], params.heads,
qkv_shape[1], qkv_shape[1]}));
}
+ SHAPE_ASSIGN_CHECK(
+ *out_shape, 0, mxnet::TShape({in_shape_0[0], params.heads,
in_shape_0[1], in_shape_1[1]}));
return true;
}
+template <bool with_split>
static bool SgDNNLSelfAttQKInferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_types,
std::vector<int>* out_types) {
- const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+ const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+ unsigned int in_shape_num = 1;
+ if constexpr (!with_split) {
+ CHECK_EQ(in_types->at(0), in_types->at(1));
+ in_shape_num = 2;
+ }
if (params.quantized) {
- CHECK_EQ(in_types->size(), 3U);
+ CHECK_EQ(in_types->size(), 3 * in_shape_num);
if (in_types->at(0) == mshadow::kBfloat16) {
return false;
@@ -86,8 +110,15 @@ static bool SgDNNLSelfAttQKInferType(const nnvm::NodeAttrs&
attrs,
<< "QuantizedSelfAttentionQK only supports int8 input, while " <<
in_types->at(0)
<< " is given.";
- TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32);
- TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32);
+ if constexpr (with_split) {
+ TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32);
+ } else {
+ TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*in_types, 3, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*in_types, 4, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*in_types, 5, mshadow::kFloat32);
+ }
if (params.enabled_float_output.has_value()) {
CHECK_EQ(out_types->size(), 1U);
@@ -103,12 +134,18 @@ static bool SgDNNLSelfAttQKInferType(const
nnvm::NodeAttrs& attrs,
TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);
}
} else {
- CHECK_EQ(in_types->size(), 1U);
+ CHECK_EQ(in_types->size(), in_shape_num);
CHECK_EQ(out_types->size(), 1U);
if (in_types->at(0) == mshadow::kFloat32) {
TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kFloat32);
+ if constexpr (!with_split) {
+ TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32);
+ }
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
} else if (in_types->at(0) == mshadow::kBfloat16) {
+ if constexpr (!with_split) {
+ TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kBfloat16);
+ }
if (params.enabled_float_output.has_value()) {
TYPE_ASSIGN_CHECK(*out_types, 0, params.enabled_float_output.value());
} else {
@@ -128,6 +165,7 @@ class SgDNNLSelfAttQKOp {
explicit SgDNNLSelfAttQKOp(const nnvm::NodeAttrs& attrs)
: param_(nnvm::get<DNNLSelfAttParam>(attrs.parsed)) {}
+ template <bool with_split>
void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
@@ -142,6 +180,7 @@ class SgDNNLSelfAttQKOp {
"inference computation.";
}
+ template <bool with_split>
void Initialize(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
@@ -159,11 +198,14 @@ class SgDNNLSelfAttQKOp {
std::shared_ptr<dnnl::memory> cached_query_mem_;
std::shared_ptr<dnnl::memory> cached_key_mem_;
std::shared_ptr<dnnl::memory> cached_out_mem_;
- float min_data_;
- float max_data_;
+ float min_data_0_;
+ float max_data_0_;
+ float min_data_1_;
+ float max_data_1_;
float min_output_;
float max_output_;
- float data_scale_{0.0f};
+ float data_scale_0_{0.0f};
+ float data_scale_1_{0.0f};
};
static OpStatePtr CreateSgDNNLSelfAttQKState(const nnvm::NodeAttrs& attrs,
@@ -173,18 +215,19 @@ static OpStatePtr CreateSgDNNLSelfAttQKState(const
nnvm::NodeAttrs& attrs,
return OpStatePtr::Create<SgDNNLSelfAttQKOp>(attrs);
}
-static void SgDNNLSelfAttQKForward(const OpStatePtr& state_pointer,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
+template <bool with_split>
+void SgDNNLSelfAttQKForward(const OpStatePtr& state_pointer,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
SgDNNLSelfAttQKOp& op = state_pointer.get_state<SgDNNLSelfAttQKOp>();
bool already_prepared = false;
if (!op.IsInitialized()) {
- op.Initialize(ctx, inputs, req, outputs);
+ op.Initialize<with_split>(ctx, inputs, req, outputs);
already_prepared = true;
}
- op.Forward(ctx, inputs, req, outputs, already_prepared);
+ op.Forward<with_split>(ctx, inputs, req, outputs, already_prepared);
}
static bool SgDNNLSelfAttStorageType(const nnvm::NodeAttrs& attrs,
@@ -195,52 +238,70 @@ static bool SgDNNLSelfAttStorageType(const
nnvm::NodeAttrs& attrs,
return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}
+template <bool with_split>
void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
using namespace dnnl;
- const auto qkv_tensor = inputs[0];
- const auto out_tensor = outputs[0];
+ const auto in_tensor_0 = inputs[0];
+ auto in_tensor_1 = in_tensor_0; // with split there is only one input
+ const auto out_tensor = outputs[0];
- const auto qkv_dtype = get_dnnl_type(qkv_tensor.dtype());
+ const auto in_dtype = get_dnnl_type(in_tensor_0.dtype());
const memory::dim heads = param_.heads;
- const memory::dim sequences = inputs[0].shape()[0];
- const memory::dim qkv_seq_len = inputs[0].shape()[1];
- const memory::dim output_lin_dim = inputs[0].shape()[2];
- const memory::dim embed_dim = output_lin_dim / QKV_NUM;
+ const memory::dim sequences = in_tensor_0.shape()[0];
+ const memory::dim qkv_seq_len_0 = in_tensor_0.shape()[1];
+ const memory::dim output_lin_dim = in_tensor_0.shape()[2];
+ memory::dim embed_dim = output_lin_dim;
+ if constexpr (with_split) {
+ embed_dim /= QKV_NUM;
+ } else {
+ in_tensor_1 = inputs[1]; // without split we need to consider 2nd input
+ }
+ const memory::dim qkv_seq_len_1 = in_tensor_1.shape()[1];
const memory::dim head_dim = embed_dim / heads;
- const memory::dim batch_stride = output_lin_dim * qkv_seq_len;
+ const memory::dim batch_stride_0 = output_lin_dim * qkv_seq_len_0;
+ const memory::dim batch_stride_1 = output_lin_dim * qkv_seq_len_1;
float min_data = 0.0f;
float max_data = 0.0f;
const auto engine = CpuEngine::Get()->get_engine();
- memory::dims query_dims = {sequences, heads, qkv_seq_len, head_dim};
- memory::dims key_dims = {sequences, heads, head_dim, qkv_seq_len};
-
- memory::dims query_strides = {batch_stride, head_dim, output_lin_dim, 1};
- memory::dims key_strides = {batch_stride, head_dim, 1, output_lin_dim};
+ memory::dims query_dims = {sequences, heads, qkv_seq_len_0, head_dim};
+ memory::dims key_dims = {sequences, heads, head_dim, qkv_seq_len_1};
+ memory::dims query_strides = {batch_stride_0, head_dim, output_lin_dim, 1};
+ memory::dims key_strides = {batch_stride_1, head_dim, 1, output_lin_dim};
- auto query_md = memory::desc(query_dims, qkv_dtype, query_strides);
- auto key_md = memory::desc(key_dims, qkv_dtype, key_strides);
+ auto query_md = memory::desc(query_dims, in_dtype, query_strides);
+ auto key_md = memory::desc(key_dims, in_dtype, key_strides);
float oscale = 1.0f;
if (param_.quantized) {
- min_data_ = inputs[1].data().dptr<float>()[0];
- max_data_ = inputs[2].data().dptr<float>()[0];
- data_scale_ = GetQuantizeScale(qkv_tensor.dtype(), min_data_, max_data_);
+ if constexpr (with_split) {
+ min_data_0_ = inputs[1].data().dptr<float>()[0];
+ max_data_0_ = inputs[2].data().dptr<float>()[0];
+ data_scale_0_ = data_scale_1_ =
+ GetQuantizeScale(in_tensor_0.dtype(), min_data_0_, max_data_0_);
+ } else {
+ min_data_0_ = inputs[2].data().dptr<float>()[0];
+ max_data_0_ = inputs[3].data().dptr<float>()[0];
+ min_data_1_ = inputs[4].data().dptr<float>()[0];
+ max_data_1_ = inputs[5].data().dptr<float>()[0];
+ data_scale_0_ = GetQuantizeScale(in_tensor_0.dtype(), min_data_0_,
max_data_0_);
+ data_scale_1_ = GetQuantizeScale(in_tensor_1.dtype(), min_data_1_,
max_data_1_);
+ }
if (param_.min_calib_range.has_value() &&
param_.max_calib_range.has_value()) {
min_output_ = param_.min_calib_range.value();
max_output_ = param_.max_calib_range.value();
oscale = GetQuantizeScale(out_tensor.dtype(), min_output_,
max_output_) /
- (data_scale_ * data_scale_);
+ (data_scale_0_ * data_scale_1_);
} else if (param_.enabled_float_output.has_value()) {
- oscale = 1.0f / (data_scale_ * data_scale_);
+ oscale = 1.0f / (data_scale_0_ * data_scale_1_);
} else {
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct,
cpu>::Launch(
@@ -256,9 +317,14 @@ void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
DType* query_mem_ptr = inputs[0].data().dptr<DType>();
- DType* key_mem_ptr = query_mem_ptr + embed_dim;
- cached_query_mem_ = std::make_shared<memory>(query_md, engine,
query_mem_ptr);
- cached_key_mem_ = std::make_shared<memory>(key_md, engine,
key_mem_ptr);
+ DType* key_mem_ptr;
+ if constexpr (with_split) {
+ key_mem_ptr = query_mem_ptr + embed_dim;
+ } else {
+ key_mem_ptr = inputs[1].data().dptr<DType>();
+ }
+ cached_query_mem_ = std::make_shared<memory>(query_md, engine,
query_mem_ptr);
+ cached_key_mem_ = std::make_shared<memory>(key_md, engine, key_mem_ptr);
});
MSHADOW_TYPE_SWITCH(out_tensor.dtype(), DType, {
@@ -272,6 +338,7 @@ void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
initialized_ = true;
}
+template <bool with_split>
void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
@@ -283,7 +350,12 @@ void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx,
MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
DType* query_mem_ptr = inputs[0].data().dptr<DType>();
- DType* key_mem_ptr = query_mem_ptr + embed_dim;
+ DType* key_mem_ptr;
+ if constexpr (with_split) {
+ key_mem_ptr = query_mem_ptr + embed_dim;
+ } else {
+ key_mem_ptr = inputs[1].data().dptr<DType>();
+ }
cached_query_mem_->set_data_handle(query_mem_ptr);
cached_key_mem_->set_data_handle(key_mem_ptr);
});
@@ -298,16 +370,20 @@ void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx,
if (param_.quantized && !param_.enabled_float_output.has_value()) {
float* output_min = outputs[1].data().dptr<float>();
float* output_max = outputs[2].data().dptr<float>();
-
- *output_min = min_output_;
- *output_max = max_output_;
+ *output_min = min_output_;
+ *output_max = max_output_;
}
}
+template <bool with_split>
nnvm::ObjectPtr SgDNNLSelfAttQKQuantizedOp(const NodeAttrs& attrs) {
- nnvm::ObjectPtr node = nnvm::Node::Create();
- auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
- node->attrs.op = Op::Get("_sg_onednn_selfatt_qk");
+ nnvm::ObjectPtr node = nnvm::Node::Create();
+ auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+ if constexpr (with_split) {
+ node->attrs.op = Op::Get("_sg_onednn_selfatt_qk_split");
+ } else {
+ node->attrs.op = Op::Get("_sg_onednn_selfatt_qk");
+ }
node->attrs.name = "quantized_" + attrs.name;
node->attrs.dict = attrs.dict;
node->attrs.dict["heads"] = std::to_string(param.heads);
@@ -318,26 +394,79 @@ nnvm::ObjectPtr SgDNNLSelfAttQKQuantizedOp(const
NodeAttrs& attrs) {
return node;
}
-NNVM_REGISTER_OP(_sg_onednn_selfatt_qk)
- .add_alias("_sg_mkldnn_selfatt_qk")
+#define MXNET_OPERATOR_REGISTER_SELFATT_QK(name)
\
+ NNVM_REGISTER_OP(name)
\
+ .set_num_outputs([](const NodeAttrs& attrs) {
\
+ auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
\
+ if (param.quantized && !param.enabled_float_output.has_value()) {
\
+ return 3;
\
+ } else {
\
+ return 1;
\
+ }
\
+ })
\
+ .set_attr<nnvm::FListOutputNames>(
\
+ "FListOutputNames",
\
+ [](const NodeAttrs& attrs) {
\
+ auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
\
+ std::vector<std::string> output_names{"output"};
\
+ if (param.quantized && !param.enabled_float_output.has_value()) {
\
+ output_names.emplace_back("min_output");
\
+ output_names.emplace_back("max_output");
\
+ }
\
+ return output_names;
\
+ })
\
+ .set_attr_parser(ParamParser<DNNLSelfAttParam>)
\
+ .set_attr<FInferStorageType>("FInferStorageType",
SgDNNLSelfAttStorageType) \
+ .set_attr<FCreateOpState>("FCreateOpState", CreateSgDNNLSelfAttQKState)
\
+ .set_attr<bool>("TIsDNNL", true)
\
+ .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
\
+ .set_attr<FQuantizable>("FQuantizable",
\
+ [](const NodeAttrs& attrs) { return
QuantizeType::kMust; }) \
+ .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs)
{ return true; }) \
+ .add_arguments(DNNLSelfAttParam::__FIELDS__())
+
+MXNET_OPERATOR_REGISTER_SELFATT_QK(_sg_onednn_selfatt_qk)
.describe(R"code(_sg_onednn_selfatt_qk)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
if (param.quantized) {
- return 3;
+ return 6;
} else {
- return 1;
+ return 2;
}
})
- .set_num_outputs([](const NodeAttrs& attrs) {
+ .set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ auto const& param =
+
nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+ std::vector<std::string>
input_names{"queries"};
+ input_names.emplace_back("keys");
+ if (param.quantized) {
+ input_names.emplace_back("min_q");
+ input_names.emplace_back("max_q");
+ input_names.emplace_back("min_k");
+ input_names.emplace_back("max_k");
+ }
+ return input_names;
+ })
+ .set_attr<mxnet::FInferShape>("FInferShape", SgDNNLSelfAttShape<false>)
+ .set_attr<nnvm::FInferType>("FInferType", SgDNNLSelfAttQKInferType<false>)
+ .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>",
SgDNNLSelfAttQKForward<false>)
+ .set_attr<FQuantizedOp>("FQuantizedOp", SgDNNLSelfAttQKQuantizedOp<false>)
+ .add_argument("queries", "NDArray-or-Symbol", "Interleaved queries, keys
and values")
+ .add_argument("keys", "NDArray-or-Symbol", "Interleaved queries, keys and
values");
+
+MXNET_OPERATOR_REGISTER_SELFATT_QK(_sg_onednn_selfatt_qk_split)
+ .add_alias("_sg_mkldnn_selfatt_qk")
+ .describe(R"code(_sg_onednn_selfatt_qk_split)code" ADD_FILELINE)
+ .set_num_inputs([](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
- if (param.quantized && !param.enabled_float_output.has_value()) {
+ if (param.quantized) {
return 3;
} else {
return 1;
}
})
- .set_attr_parser(ParamParser<DNNLSelfAttParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
auto const& param =
@@ -349,33 +478,11 @@ NNVM_REGISTER_OP(_sg_onednn_selfatt_qk)
}
return input_names;
})
- .set_attr<nnvm::FListOutputNames>("FListOutputNames",
- [](const NodeAttrs& attrs) {
- auto const& param =
-
nnvm::get<DNNLSelfAttParam>(attrs.parsed);
- std::vector<std::string>
output_names{"output"};
- if (param.quantized &&
-
!param.enabled_float_output.has_value()) {
-
output_names.emplace_back("min_output");
-
output_names.emplace_back("max_output");
- }
- return output_names;
- })
- .set_attr<mxnet::FInferShape>("FInferShape", SgDNNLSelfAttShape)
- .set_attr<nnvm::FInferType>("FInferType", SgDNNLSelfAttQKInferType)
- .set_attr<FInferStorageType>("FInferStorageType", SgDNNLSelfAttStorageType)
- .set_attr<FCreateOpState>("FCreateOpState", CreateSgDNNLSelfAttQKState)
- .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>",
SgDNNLSelfAttQKForward)
- .set_attr<bool>("TIsDNNL", true)
- .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
- .set_attr<FQuantizable>("FQuantizable",
- [](const NodeAttrs& attrs) { return
QuantizeType::kMust; })
- .set_attr<FQuantizedOp>("FQuantizedOp", SgDNNLSelfAttQKQuantizedOp)
- .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) {
return true; })
- .add_argument("queries_keys_values",
- "NDArray-or-Symbol",
- "Interleaved queries, keys and values")
- .add_arguments(DNNLSelfAttParam::__FIELDS__());
+ .set_attr<mxnet::FInferShape>("FInferShape", SgDNNLSelfAttShape<true>)
+ .set_attr<nnvm::FInferType>("FInferType", SgDNNLSelfAttQKInferType<true>)
+ .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>",
SgDNNLSelfAttQKForward<true>)
+ .set_attr<FQuantizedOp>("FQuantizedOp", SgDNNLSelfAttQKQuantizedOp<true>)
+ .add_argument("query_keys_values", "NDArray-or-Symbol", "Interleaved
queries, keys and values");
/**********************************_sg_onednn_selfatt_valatt**********************************/
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer_qk_common.h
b/src/operator/subgraph/dnnl/dnnl_transformer_qk_common.h
new file mode 100644
index 0000000000..902c58dfbc
--- /dev/null
+++ b/src/operator/subgraph/dnnl/dnnl_transformer_qk_common.h
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_TRANSFORMER_QK_COMMON_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_TRANSFORMER_QK_COMMON_H_
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <string>
+#include <vector>
+
+#include "operator/contrib/transformer-inl.h"
+#include "operator/numpy/np_matrix_op-inl.h"
+#include "operator/tensor/matrix_op-inl.h"
+#include "operator/subgraph/common.h"
+#include "dnnl_common.h"
+#include "dnnl_subgraph_base-inl.h"
+#include "dnnl_transformer-inl.h"
+
+namespace mxnet {
+namespace op {
+namespace qk_common {
+
+enum SelectStatusTransformerQK {
+ kFail = 0,
+ kStart,
+ kFirstSwapAx,
+ kSecondSwapAx,
+ kFirstReshape,
+ kSecondReshape,
+ kSuccess
+};
+
+// /*
+// kStart ---> kFirstSwapAx ---> kSecondSwapAx ---> kFirstReshape --->
kSecondReshape ---> kSuccess
+// OR
+// kStart ---> kFirstSwapAx ---> kSecondSwapAx ---> kFirstReshape ---> kSuccess
+// each status except kStart is connected with kFail
+// */
+
+inline bool CheckSwapAxisConditionsQK(const BiDirectedNode& input_node) {
+ if (input_node.outputs.size() != 1)
+ return false;
+ return CheckSwapAxisConditions(*input_node.node);
+}
+
+inline bool CheckReshapeConditionsQK(const BiDirectedNode& input_node, const
index_t out_index) {
+ if (input_node.outputs.size() != 1)
+ return false;
+ return CheckReshapeConditions(*input_node.node, out_index);
+}
+
+inline bool CheckSplitConditions(const std::vector<const BiDirectedNode*>&
matched_list,
+ const BiDirectedNode& node) {
+ const SplitParam& param = dmlc::get<SplitParam>(node.node->attrs.parsed);
+
+ if (param.axis != -1 || param.sections != 3 || param.squeeze_axis)
+ return false;
+
+ const auto first_reshape = (*(matched_list.end() - 2))->node;
+ const auto second_reshape = (*(matched_list.end() - 1))->node;
+ if (first_reshape->op() != Op::Get("_npx_reshape") ||
+ second_reshape->op() != Op::Get("_npx_reshape")) {
+ return false;
+ }
+ // 3 sections - ensure that every output is used only once
+ if (node.outputs.size() == 3 && node.outputs.count(first_reshape) &&
+ node.outputs.count(second_reshape)) {
+ return true;
+ }
+
+ return false;
+}
+
+inline bool Select(SelectStatusTransformerQK* status,
+ std::vector<const BiDirectedNode*>* matched_list,
+ const BiDirectedNode& seed_node,
+ const std::shared_ptr<NodeAttr>& node_attr) {
+ if (seed_node.node->op() == Op::Get("batch_dot")) {
+ *status = kStart;
+ matched_list->clear();
+ matched_list->push_back(&seed_node);
+ return true;
+ }
+ return false;
+}
+
+template <bool with_split>
+bool SelectInput(SelectStatusTransformerQK* status,
+ std::vector<const BiDirectedNode*>* matched_list,
+ const BiDirectedNode& n,
+ const BiDirectedNode& input_node) {
+ if (*status == kFail || *status == kSuccess ||
input_node.node->is_variable())
+ return false;
+ const auto& raw_input_node = *input_node.node;
+ switch (*status) {
+ case kStart:
+ if (raw_input_node.op() == Op::Get("SwapAxis")) {
+ if (CheckSwapAxisConditionsQK(input_node)) {
+ *status = kFirstSwapAx;
+ matched_list->push_back(&input_node);
+ return true;
+ }
+ }
+ break;
+ case kFirstSwapAx:
+ if (raw_input_node.op() == Op::Get("SwapAxis")) {
+ if (CheckSwapAxisConditionsQK(input_node)) {
+ *status = kSecondSwapAx;
+ matched_list->push_back(&input_node);
+ return true;
+ }
+ }
+ break;
+ case kSecondSwapAx:
+ if (raw_input_node.op() == Op::Get("_npx_reshape")) {
+ // input to reshape must be first or second output from split
+ if (CheckReshapeConditionsQK(input_node, 0) ||
CheckReshapeConditionsQK(input_node, 1)) {
+ *status = kFirstReshape;
+ matched_list->push_back(&input_node);
+ return true;
+ }
+ }
+ break;
+ case kFirstReshape:
+ if (raw_input_node.op() == Op::Get("_npx_reshape")) {
+ if (CheckReshapeConditionsQK(input_node, 0) ||
CheckReshapeConditionsQK(input_node, 1)) {
+ if constexpr (with_split) {
+ *status = kSecondReshape;
+ } else {
+ *status = kSuccess;
+ }
+ matched_list->push_back(&input_node);
+ return true;
+ }
+ }
+ break;
+ case kSecondReshape:
+ if (raw_input_node.op() == Op::Get("_split_v2") &&
+ CheckSplitConditions(*matched_list, input_node)) {
+ *status = kSuccess;
+ return true;
+ }
+ break;
+ default:
+ *status = kFail;
+ return false;
+ }
+ return false;
+}
+
+inline std::vector<BiDirectedNode*> Filter(const SelectStatusTransformerQK&
status,
+ const std::vector<const
BiDirectedNode*>& matched_list,
+ const std::vector<BiDirectedNode*>&
candidates) {
+ if (status != kSuccess) {
+ return std::vector<BiDirectedNode*>(0);
+ } else {
+ std::vector<BiDirectedNode*> ret;
+ for (auto i : matched_list) {
+ auto non_const_i = const_cast<BiDirectedNode*>(i);
+ if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
candidates.end()) {
+ ret.push_back(non_const_i);
+ }
+ }
+ return ret;
+ }
+}
+
+template <bool with_split>
+nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym, const int
subgraph_id = 0) {
+ std::string op_name;
+ if constexpr (with_split) {
+ op_name = "_sg_onednn_selfatt_qk_split";
+ } else {
+ op_name = "_sg_onednn_selfatt_qk";
+ }
+ nnvm::ObjectPtr n = nnvm::Node::Create();
+ // this op has single output, remove duplicated
+ auto last_node = sym.outputs[0].node;
+ nnvm::Symbol new_sym;
+ new_sym.outputs.emplace_back(last_node);
+ std::ostringstream node_name;
+
+ DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr& node) {
+ if ((node->op() == Op::Get("_npx_reshape"))) {
+ auto const& reshape_param =
nnvm::get<NumpyXReshapeParam>(node->attrs.parsed);
+ // set heads attribute - all necessary conditions are checked before
+ n->attrs.dict["heads"] = std::to_string(reshape_param.newshape[2]);
+ }
+ });
+
+ node_name << op_name << subgraph_id;
+ n->attrs.name = node_name.str();
+ n->attrs.op = Op::Get(op_name);
+ CHECK(n->attrs.op);
+ n->op()->attr_parser(&(n->attrs));
+ return n;
+}
+
+inline void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
+ std::vector<nnvm::NodeEntry*>*
output_entries) {
+ // connect all extern output entries to output[0]
+ for (size_t i = 0; i < output_entries->size(); ++i) {
+ auto entry_ptr = output_entries->at(i);
+ *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
+ }
+}
+
+} // namespace qk_common
+} // namespace op
+} // namespace mxnet
+
+#endif // if MXNET_USE_ONEDNN == 1
+#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_TRANSFORMER_QK_COMMON_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
index 9f4a9991a1..a34af4711f 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
@@ -22,19 +22,26 @@
#if MXNET_USE_ONEDNN == 1
-#include <map>
#include <string>
#include <vector>
-#include "operator/contrib/transformer-inl.h"
-#include "operator/numpy/np_matrix_op-inl.h"
-#include "operator/tensor/matrix_op-inl.h"
-#include "operator/subgraph/common.h"
-#include "dnnl_common.h"
-#include "dnnl_subgraph_base-inl.h"
-#include "dnnl_transformer-inl.h"
+#include "dnnl_transformer_qk_common.h"
/*
+ custom_op custom_op
+ | |
+ ______|______________|________
+ | | | |
+ | _npx_reshape _npx_reshape |
+ | | | |
+ | SwapAxis SwapAxis |
+ | \ / |
+ | batch_dot |
+ | | |
+ |______________________________|
+
+OR
+
custom_op
|
_____________|_________________
@@ -47,113 +54,55 @@
| batch_dot |
| | |
|______________________________|
+
*/
namespace mxnet {
namespace op {
class SgDNNLTransformerQKSelector : public SubgraphSelectorV2 {
- enum SelectStatusTransformerQK {
- kFail = 0,
- kStart,
- kFirstSwapAx,
- kSecondSwapAx,
- kFirstReshape,
- kSecondReshape,
- kSuccess
- };
-
- /*!
- kStart ---> kFirstSwapAx ---> kSecondSwapAx ---> kFirstReshape --->
kSecondReshape ---> kSuccess
- Each status except kStart is connected with kFail
- */
-
private:
- SelectStatusTransformerQK status_;
+ qk_common::SelectStatusTransformerQK status_;
std::vector<const BiDirectedNode*> matched_list_;
- bool CheckSplitConditions(const BiDirectedNode& node) {
- const SplitParam& param = dmlc::get<SplitParam>(node.node->attrs.parsed);
-
- if (param.axis != -1 || param.sections != 3 || param.squeeze_axis)
- return false;
+ public:
+ bool Select(const BiDirectedNode& seed_node,
+ const std::shared_ptr<NodeAttr>& node_attr) override {
+ return qk_common::Select(&status_, &matched_list_, seed_node, node_attr);
+ }
- const auto first_reshape = (*(matched_list_.end() - 2))->node;
- const auto second_reshape = (*(matched_list_.end() - 1))->node;
- if (first_reshape->op() != Op::Get("_npx_reshape") ||
- second_reshape->op() != Op::Get("_npx_reshape")) {
- return false;
- }
- // 3 sections - ensure that every output is used only once
- if (node.outputs.size() == 3 && node.outputs.count(first_reshape) &&
- node.outputs.count(second_reshape)) {
- return true;
- }
+ bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node)
override {
+ return qk_common::SelectInput<false>(&status_, &matched_list_, n,
input_node);
+ }
+ bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode&
output_node) override {
return false;
}
+ std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>&
candidates) override {
+ return qk_common::Filter(status_, matched_list_, candidates);
+ }
+
+ void Reset() override {
+ CHECK_GE(matched_list_.size(), 1);
+ auto new_selector = SgDNNLTransformerQKSelector();
+ new_selector.Select(*matched_list_[0], nullptr);
+ *this = new_selector;
+ }
+};
+
+class SgDNNLTransformerQKSplitSelector : public SubgraphSelectorV2 {
+ private:
+ qk_common::SelectStatusTransformerQK status_;
+ std::vector<const BiDirectedNode*> matched_list_;
+
public:
bool Select(const BiDirectedNode& seed_node,
const std::shared_ptr<NodeAttr>& node_attr) override {
- if (seed_node.node->op() == Op::Get("batch_dot")) {
- status_ = kStart;
- matched_list_.clear();
- matched_list_.push_back(&seed_node);
- return true;
- }
- return false;
+ return qk_common::Select(&status_, &matched_list_, seed_node, node_attr);
}
bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node)
override {
- if (status_ == kFail || status_ == kSuccess ||
input_node.node->is_variable())
- return false;
- const auto& raw_input_node = *input_node.node;
- switch (status_) {
- case kStart:
- if (raw_input_node.op() == Op::Get("SwapAxis")) {
- if (CheckSwapAxisConditions(raw_input_node)) {
- status_ = kFirstSwapAx;
- matched_list_.push_back(&input_node);
- }
- return true;
- }
- case kFirstSwapAx:
- if (raw_input_node.op() == Op::Get("SwapAxis")) {
- if (CheckSwapAxisConditions(raw_input_node)) {
- status_ = kSecondSwapAx;
- matched_list_.push_back(&input_node);
- return true;
- }
- }
- case kSecondSwapAx:
- if (raw_input_node.op() == Op::Get("_npx_reshape")) {
- // input to reshape must be first or second output from split
- if (CheckReshapeConditions(raw_input_node, 0) ||
- CheckReshapeConditions(raw_input_node, 1)) {
- status_ = kFirstReshape;
- matched_list_.push_back(&input_node);
- return true;
- }
- }
- case kFirstReshape:
- if (raw_input_node.op() == Op::Get("_npx_reshape")) {
- if (CheckReshapeConditions(raw_input_node, 0) ||
- CheckReshapeConditions(raw_input_node, 1)) {
- status_ = kSecondReshape;
- matched_list_.push_back(&input_node);
- return true;
- }
- }
- case kSecondReshape:
- if (raw_input_node.op() == Op::Get("_split_v2") &&
CheckSplitConditions(input_node)) {
- status_ = kSuccess;
- return true;
- }
- default:
- status_ = kFail;
- return false;
- }
- return false;
+ return qk_common::SelectInput<true>(&status_, &matched_list_, n,
input_node);
}
bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode&
output_node) override {
@@ -161,23 +110,12 @@ class SgDNNLTransformerQKSelector : public
SubgraphSelectorV2 {
}
std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>&
candidates) override {
- if (status_ != kSuccess) {
- return std::vector<BiDirectedNode*>(0);
- } else {
- std::vector<BiDirectedNode*> ret;
- for (auto i : matched_list_) {
- auto non_const_i = const_cast<BiDirectedNode*>(i);
- if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
candidates.end()) {
- ret.push_back(non_const_i);
- }
- }
- return ret;
- }
+ return qk_common::Filter(status_, matched_list_, candidates);
}
void Reset() override {
CHECK_GE(matched_list_.size(), 1);
- auto new_selector = SgDNNLTransformerQKSelector();
+ auto new_selector = SgDNNLTransformerQKSplitSelector();
new_selector.Select(*matched_list_[0], nullptr);
*this = new_selector;
}
@@ -200,43 +138,43 @@ class SgDNNLTransformerQKProperty : public
SubgraphProperty {
nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
const int subgraph_id = 0) const override
{
- nnvm::ObjectPtr n = nnvm::Node::Create();
- // This op has single output, remove duplicated.
- auto last_node = sym.outputs[0].node;
- nnvm::Symbol new_sym;
- new_sym.outputs.emplace_back(last_node);
- std::ostringstream node_name;
- std::string op_name;
-
- DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr& node) {
- if ((node->op() == Op::Get("_npx_reshape"))) {
- auto const& reshape_param =
nnvm::get<NumpyXReshapeParam>(node->attrs.parsed);
- // set heads attribute - all necessary conditions are checked before
- n->attrs.dict["heads"] = std::to_string(reshape_param.newshape[2]);
- }
- });
-
- node_name << "_sg_onednn_selfatt_qk_" << subgraph_id;
-
- n->attrs.name = node_name.str();
- n->attrs.op = Op::Get("_sg_onednn_selfatt_qk");
- CHECK(n->attrs.op);
- n->op()->attr_parser(&(n->attrs));
- return n;
+ return qk_common::CreateSubgraphNode<false>(sym, subgraph_id);
+ }
+
+ void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
+ std::vector<nnvm::NodeEntry*>* output_entries)
const override {
+ qk_common::ConnectSubgraphOutputs(n, output_entries);
}
SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
auto selector = std::make_shared<SgDNNLTransformerQKSelector>();
return selector;
}
+};
+
+class SgDNNLTransformerQKSplitProperty : public SubgraphProperty {
+ public:
+ SgDNNLTransformerQKSplitProperty() {}
+
+ static SubgraphPropertyPtr Create() {
+ static const std::string& name = "oneDNN Transformer optimization pass";
+ auto property =
std::make_shared<SgDNNLTransformerQKSplitProperty>();
+ property->SetAttr<std::string>("property_name", name);
+ property->SetAttr<bool>("inference_only", true);
+ if (dmlc::GetEnv("MXNET_DISABLE_ONEDNN_TRANSFORMER_OPT", 0)) {
+ property->SetAttr<bool>("disable", true);
+ }
+ return property;
+ }
+
+ nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
+ const int subgraph_id = 0) const override
{
+ return qk_common::CreateSubgraphNode<true>(sym, subgraph_id);
+ }
void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
std::vector<nnvm::NodeEntry*>* output_entries)
const override {
- // Connect all extern output entries to output[0]
- for (size_t i = 0; i < output_entries->size(); ++i) {
- auto entry_ptr = output_entries->at(i);
- *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
- }
+ qk_common::ConnectSubgraphOutputs(n, output_entries);
}
void ConnectSubgraphInputs(const nnvm::ObjectPtr subgraph_node,
@@ -247,6 +185,11 @@ class SgDNNLTransformerQKProperty : public
SubgraphProperty {
// connect subgraph input with split input
subgraph_node->inputs[0] = orig_input_entries->at(0).node->inputs[0];
}
+
+ SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
+ auto selector = std::make_shared<SgDNNLTransformerQKSplitSelector>();
+ return selector;
+ }
};
} // namespace op
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer_valatt_property.h
b/src/operator/subgraph/dnnl/dnnl_transformer_valatt_property.h
index 35f02f6203..0cdde8c964 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer_valatt_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_transformer_valatt_property.h
@@ -36,12 +36,12 @@
#include "dnnl_transformer-inl.h"
/*
- custom_op
- _________________|_________
- | Split |
- | | |
- | _npx_reshape |
- | | |
+ custom_op
+ _____________|_____________
+ | Split |
+ | / | |
+ | custom_op _npx_reshape |
+ | ... | |
| custom_op SwapAxis |
| \ / |
| batch_dot |
diff --git a/tests/python/dnnl/op_cfg.py b/tests/python/dnnl/op_cfg.py
index c4739ab807..50ec87a017 100644
--- a/tests/python/dnnl/op_cfg.py
+++ b/tests/python/dnnl/op_cfg.py
@@ -291,11 +291,17 @@ def get_all_ops_cfgs(dtype):
'_sg_onednn_batch_norm': {CFG_BASED_ON: 'BatchNorm'},
'_sg_onednn_selfatt_qk': {
CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_qk', 'ONEDNN')],
+ 'queries': [mx.nd.random.normal(0, 1, (1, 4, 3*2*8), dtype)],
+ 'keys': [mx.nd.random.normal(0, 1, (1, 8, 3*2*8), dtype)],
+ 'heads': [2]
+ },
+ '_sg_onednn_selfatt_qk_split': {
+ CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_qk_split',
'ONEDNN')],
'queries_keys_values': [mx.nd.random.normal(0, 1, (1, 4, 3*2*8),
dtype)],
'heads': [2]
},
'_sg_onednn_selfatt_valatt': {
- CFG_BASED_ON: '_sg_onednn_selfatt_qk',
+ CFG_BASED_ON: '_sg_onednn_selfatt_qk_split',
CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_valatt', 'ONEDNN')],
'attention': [CfgBasedArg(valatt_attention_tensor)]
}
diff --git a/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
b/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
index a96c8c31ba..938d49444f 100644
--- a/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
@@ -27,22 +27,28 @@ import math
class MultiHeadAttention(nn.HybridBlock):
- def __init__(self, units, num_heads, dtype='float32', negative_case=False,
**kwargs):
+ def __init__(self, units, num_heads, batch_size=-1, seq_length=-1,
dtype='float32', negative_case=False, no_split_case = False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self._units = units
self._num_heads = num_heads
self._fc = nn.Dense(in_units=self._units, units=3*self._units,
flatten=False, dtype=dtype)
self._scale = math.sqrt(self._units // self._num_heads)
self.negative_case = negative_case
+ self.no_split_case = no_split_case
+ self.batch_size = batch_size
+ self.seq_length = seq_length
def forward(self, x, mask):
out = self._fc(x)
query, key, value = mx.np.split(out, 3, axis=-1)
+ if self.no_split_case:
+ key = mx.np.concat((key, key), axis = 1)
+ value = mx.np.concat((value, value), axis = 1)
+ query = mx.np.reshape(query, (-2, -2, self._num_heads, -1))
if self.negative_case:
- key = key * 2
- query = mx.npx.reshape(query, (-2, -2, self._num_heads, -1))
- key = mx.npx.reshape(key, (-2, -2, self._num_heads, -1))
- value = mx.npx.reshape(value, (-2, -2, self._num_heads, -1))
+ query = query * 2
+ key = mx.np.reshape(key, (-2, -2, self._num_heads, -1))
+ value = mx.np.reshape(value, (-2, -2, self._num_heads, -1))
scores = mx.npx.batch_dot(mx.np.swapaxes(query, 1, 2),
mx.np.swapaxes(key, 1, 2),
transpose_b=True)
mask = mx.np.expand_dims(mask, axis=1).astype(np.bool)
@@ -58,10 +64,16 @@ class MultiHeadAttention(nn.HybridBlock):
@pytest.mark.parametrize('seq_length', [124, 384])
@pytest.mark.parametrize('units', [256, 768])
@pytest.mark.parametrize('num_heads', [4, 8])
-def test_self_attention(batch_size, seq_length, units, num_heads):
- net = MultiHeadAttention(units, num_heads)
[email protected]('split', [True, False])
+def test_self_attention(batch_size, seq_length, units, num_heads, split):
+ net = MultiHeadAttention(units, num_heads, no_split_case=not split)
in_data = mx.np.random.uniform(size=[batch_size, seq_length, units],
dtype='float32')
- mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length,
seq_length], dtype='int32')
+ if (split):
+ mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length,
seq_length], dtype='int32')
+ else:
+ # key dimension will be expanded by num_heads value to simulate gpt-2 model
+ # mask needs to be expanded as well
+ mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length,
seq_length * 2], dtype='int32')
net.initialize()
fused_net = net
@@ -74,13 +86,13 @@ def test_self_attention(batch_size, seq_length, units,
num_heads):
assert_almost_equal(out.asnumpy(), ref_out.asnumpy())
- calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data,
mask), batch_size=1)
+ calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data,
mask), batch_size=batch_size)
qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto',
exclude_layers=None,
exclude_layers_match=None,
calib_data=calib_data,
calib_mode='naive',
- num_calib_batches=1,
+ num_calib_batches=batch_size,
ctx=mx.cpu())
qout = qnet(in_data, mask)
@@ -97,7 +109,7 @@ def test_self_attention(batch_size, seq_length, units,
num_heads):
@pytest.mark.parametrize('units', [256, 768])
@pytest.mark.parametrize('num_heads', [4, 8])
def test_self_attention_negative(batch_size, seq_length, units, num_heads):
- net = MultiHeadAttention(units, num_heads, negative_case=True)
+ net = MultiHeadAttention(units, num_heads, batch_size, seq_length,
negative_case=True)
in_data = mx.np.random.uniform(size=[batch_size, seq_length, units],
dtype='float32')
mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length,
seq_length], dtype='int32')
@@ -112,13 +124,13 @@ def test_self_attention_negative(batch_size, seq_length,
units, num_heads):
assert_almost_equal(out.asnumpy(), ref_out.asnumpy())
- calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data,
mask), batch_size=1)
+ calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data,
mask), batch_size=batch_size)
qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto',
exclude_layers=None,
exclude_layers_match=None,
calib_data=calib_data,
calib_mode='naive',
- num_calib_batches=1,
+ num_calib_batches=batch_size,
ctx=mx.cpu())
qout = qnet(in_data, mask)