This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a418e4e1c [FEATURE] Add query_keys transformer version without split 
(#21115)
1a418e4e1c is described below

commit 1a418e4e1cf89a87838eb02abf74ccfa4bfe6c37
Author: AdamGrabowski <[email protected]>
AuthorDate: Tue Aug 23 15:33:02 2022 +0200

    [FEATURE] Add query_keys transformer version without split (#21115)
    
    * Add query_keys transformer version without split
    
    * Add op to amp, implement different shaped inputs cases
    
    * Fix sanity
    
    * Remove semicolon
    
    * Fix windows build
    
    * Fix windows build sanity
    
    * Change to bool template, fix comments, shorten test case
    
    * Change comment message
    
    * Fix SelectInput(), input names and Forward()
---
 python/mxnet/amp/lists/symbol_bf16.py              |   1 +
 python/mxnet/amp/lists/symbol_fp16.py              |   1 +
 .../subgraph/dnnl/dnnl_post_amp_property.h         |   1 +
 .../subgraph/dnnl/dnnl_post_quantize_property.h    |   1 +
 .../subgraph/dnnl/dnnl_subgraph_property.cc        |   2 +
 src/operator/subgraph/dnnl/dnnl_transformer.cc     | 295 ++++++++++++++-------
 .../subgraph/dnnl/dnnl_transformer_qk_common.h     | 230 ++++++++++++++++
 .../subgraph/dnnl/dnnl_transformer_qk_property.h   | 217 ++++++---------
 .../dnnl/dnnl_transformer_valatt_property.h        |  12 +-
 tests/python/dnnl/op_cfg.py                        |   8 +-
 .../python/dnnl/subgraphs/test_matmul_subgraph.py  |  38 ++-
 11 files changed, 555 insertions(+), 251 deletions(-)

diff --git a/python/mxnet/amp/lists/symbol_bf16.py 
b/python/mxnet/amp/lists/symbol_bf16.py
index 5b9df27497..3f9bfbc707 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -31,6 +31,7 @@ if Features.instance.is_enabled('ONEDNN'):
         '_sg_onednn_conv',
         '_sg_onednn_fully_connected',
         '_sg_onednn_selfatt_qk',
+        '_sg_onednn_selfatt_qk_split',
         '_sg_onednn_selfatt_valatt'
     ])
 
diff --git a/python/mxnet/amp/lists/symbol_fp16.py 
b/python/mxnet/amp/lists/symbol_fp16.py
index 76e9488f69..62e87dfc59 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -634,6 +634,7 @@ if Features().is_enabled('ONEDNN'):
         '_sg_onednn_conv',
         '_sg_onednn_fully_connected',
         '_sg_onednn_selfatt_qk',
+        '_sg_onednn_selfatt_qk_split',
         '_sg_onednn_selfatt_valatt',
         '_sg_onednn_batch_dot',
         '_sg_onednn_batch_norm',
diff --git a/src/operator/subgraph/dnnl/dnnl_post_amp_property.h 
b/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
index 6ec7c54e38..8e14c9b361 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
@@ -35,6 +35,7 @@ inline bool IsSupportedAMPFuseOp(const nnvm::Node& node) {
   static const std::set<const Op*> supported_ops = {Op::Get("_sg_onednn_conv"),
                                                     
Op::Get("_sg_onednn_fully_connected"),
                                                     
Op::Get("_sg_onednn_selfatt_qk"),
+                                                    
Op::Get("_sg_onednn_selfatt_qk_split"),
                                                     
Op::Get("_sg_onednn_selfatt_valatt")};
   return supported_ops.count(node.op()) > 0;
 }
diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h 
b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index 94c7e63085..6e905f59da 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -45,6 +45,7 @@ bool SupportsRequantizeFusion(const Op* op) {
       Op::Get("_sg_onednn_conv"),
       Op::Get("_sg_onednn_fully_connected"),
       Op::Get("_sg_onednn_selfatt_qk"),
+      Op::Get("_sg_onednn_selfatt_qk_split"),
       Op::Get("_sg_onednn_selfatt_valatt"),
       Op::Get("_sg_onednn_batch_dot")};
 
diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc 
b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
index 86e08020ee..bc190a73ae 100644
--- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
+++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
@@ -45,6 +45,7 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLConvProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLFCProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBNReLUProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLRemoveCastsProperty);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerQKSplitProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerQKProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerValAttProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBatchDotProperty);
@@ -56,6 +57,7 @@ 
MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN_QUANTIZE).set_attr("context", Context::CP
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLIdentityProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
SgDNNLConvProperty).set_attr("quantize", true);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
SgDNNLFCProperty).set_attr("quantize", true);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
SgDNNLTransformerQKSplitProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLTransformerQKProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
SgDNNLTransformerValAttProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLBatchDotProperty)
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer.cc 
b/src/operator/subgraph/dnnl/dnnl_transformer.cc
index 67fee5f98f..dfafbb038d 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer.cc
+++ b/src/operator/subgraph/dnnl/dnnl_transformer.cc
@@ -29,7 +29,7 @@
 #include "operator/subgraph/common.h"
 #include "dnnl_transformer-inl.h"
 
-// 3 tensors within one (queries key values) =
+// 3 tensors within one (queries keys values)
 #define QKV_NUM 3
 
 namespace mxnet {
@@ -37,46 +37,70 @@ namespace op {
 
 DMLC_REGISTER_PARAMETER(DNNLSelfAttParam);
 
+template <bool with_split>
 static bool SgDNNLSelfAttShape(const NodeAttrs& attrs,
                                mxnet::ShapeVector* in_shape,
                                mxnet::ShapeVector* out_shape) {
-  const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
-  auto qkv_shape     = in_shape->at(0);
-  CHECK_EQ(qkv_shape.ndim(), 3U)
+  const auto& params        = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+  unsigned int in_shape_num = 1;
+  auto in_shape_0           = in_shape->at(0);
+  auto in_shape_1           = in_shape_0;  // with split there is only one 
input
+  CHECK_EQ(in_shape_0.ndim(), 3U)
       << "Input queries_keys_values should be 3D in batch-seq_length-proj_dim, 
"
-      << "but the given tensor is " << qkv_shape.ndim() << "D";
+      << "but the given tensor is " << in_shape_0.ndim() << "D";
 
-  if (params.quantized) {
-    CHECK_EQ(in_shape->size(), 3U) << "Input: [queries_keys_values, min_qkv, 
max_qkv] "
-                                   << "- currently have " << in_shape->size() 
<< " inputs";
+  if constexpr (!with_split) {
+    in_shape_1 = in_shape->at(1);  // without split we need to consider 2nd 
input
+    CHECK_EQ(in_shape_1.ndim(), 3U)
+        << "Input queries_keys_values should be 3D in 
batch-seq_length-proj_dim, "
+        << "but the given tensor is " << in_shape_1.ndim() << "D";
+    CHECK_EQ(in_shape_0[0], in_shape_1[0]);
+    CHECK_EQ(in_shape_0[2], in_shape_1[2]);
+    in_shape_num = 2;
+  }
 
-    SHAPE_ASSIGN_CHECK(*in_shape, 1, mxnet::TShape({1}));
-    SHAPE_ASSIGN_CHECK(*in_shape, 2, mxnet::TShape({1}));
+  if (params.quantized) {
+    CHECK_EQ(in_shape->size(), 3 * in_shape_num)
+        << "Input: [queries_keys_values, min_qkv, max_qkv] "
+        << "- currently have " << in_shape->size() << " inputs";
+    if constexpr (with_split) {
+      SHAPE_ASSIGN_CHECK(*in_shape, 1, mxnet::TShape({1}));
+      SHAPE_ASSIGN_CHECK(*in_shape, 2, mxnet::TShape({1}));
+    } else {
+      SHAPE_ASSIGN_CHECK(*in_shape, 2, mxnet::TShape({1}));
+      SHAPE_ASSIGN_CHECK(*in_shape, 3, mxnet::TShape({1}));
+      SHAPE_ASSIGN_CHECK(*in_shape, 4, mxnet::TShape({1}));
+      SHAPE_ASSIGN_CHECK(*in_shape, 5, mxnet::TShape({1}));
+    }
 
     out_shape->resize(3);
-    SHAPE_ASSIGN_CHECK(
-        *out_shape, 0, mxnet::TShape({qkv_shape[0], params.heads, 
qkv_shape[1], qkv_shape[1]}));
     if (!params.enabled_float_output.has_value()) {
       SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1}));  // min output
       SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1}));  // max output
     }
   } else {
-    CHECK_EQ(in_shape->size(), 1U)
+    CHECK_EQ(in_shape->size(), in_shape_num)
         << "Input:[queries_keys_values] - currently have " << in_shape->size() 
<< " inputs";
     out_shape->resize(1);
-    SHAPE_ASSIGN_CHECK(
-        *out_shape, 0, mxnet::TShape({qkv_shape[0], params.heads, 
qkv_shape[1], qkv_shape[1]}));
   }
 
+  SHAPE_ASSIGN_CHECK(
+      *out_shape, 0, mxnet::TShape({in_shape_0[0], params.heads, 
in_shape_0[1], in_shape_1[1]}));
   return true;
 }
 
+template <bool with_split>
 static bool SgDNNLSelfAttQKInferType(const nnvm::NodeAttrs& attrs,
                                      std::vector<int>* in_types,
                                      std::vector<int>* out_types) {
-  const auto& params = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+  const auto& params        = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+  unsigned int in_shape_num = 1;
+  if constexpr (!with_split) {
+    CHECK_EQ(in_types->at(0), in_types->at(1));
+    in_shape_num = 2;
+  }
   if (params.quantized) {
-    CHECK_EQ(in_types->size(), 3U);
+    CHECK_EQ(in_types->size(), 3 * in_shape_num);
 
     if (in_types->at(0) == mshadow::kBfloat16) {
       return false;
@@ -86,8 +110,15 @@ static bool SgDNNLSelfAttQKInferType(const nnvm::NodeAttrs& 
attrs,
         << "QuantizedSelfAttentionQK only supports int8 input, while " << 
in_types->at(0)
         << " is given.";
 
-    TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32);
-    TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32);
+    if constexpr (with_split) {
+      TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32);
+      TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32);
+    } else {
+      TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32);
+      TYPE_ASSIGN_CHECK(*in_types, 3, mshadow::kFloat32);
+      TYPE_ASSIGN_CHECK(*in_types, 4, mshadow::kFloat32);
+      TYPE_ASSIGN_CHECK(*in_types, 5, mshadow::kFloat32);
+    }
 
     if (params.enabled_float_output.has_value()) {
       CHECK_EQ(out_types->size(), 1U);
@@ -103,12 +134,18 @@ static bool SgDNNLSelfAttQKInferType(const 
nnvm::NodeAttrs& attrs,
       TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);
     }
   } else {
-    CHECK_EQ(in_types->size(), 1U);
+    CHECK_EQ(in_types->size(), in_shape_num);
     CHECK_EQ(out_types->size(), 1U);
     if (in_types->at(0) == mshadow::kFloat32) {
       TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kFloat32);
+      if constexpr (!with_split) {
+        TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32);
+      }
       TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
     } else if (in_types->at(0) == mshadow::kBfloat16) {
+      if constexpr (!with_split) {
+        TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kBfloat16);
+      }
       if (params.enabled_float_output.has_value()) {
         TYPE_ASSIGN_CHECK(*out_types, 0, params.enabled_float_output.value());
       } else {
@@ -128,6 +165,7 @@ class SgDNNLSelfAttQKOp {
   explicit SgDNNLSelfAttQKOp(const nnvm::NodeAttrs& attrs)
       : param_(nnvm::get<DNNLSelfAttParam>(attrs.parsed)) {}
 
+  template <bool with_split>
   void Forward(const OpContext& ctx,
                const std::vector<NDArray>& inputs,
                const std::vector<OpReqType>& req,
@@ -142,6 +180,7 @@ class SgDNNLSelfAttQKOp {
                   "inference computation.";
   }
 
+  template <bool with_split>
   void Initialize(const OpContext& ctx,
                   const std::vector<NDArray>& inputs,
                   const std::vector<OpReqType>& req,
@@ -159,11 +198,14 @@ class SgDNNLSelfAttQKOp {
   std::shared_ptr<dnnl::memory> cached_query_mem_;
   std::shared_ptr<dnnl::memory> cached_key_mem_;
   std::shared_ptr<dnnl::memory> cached_out_mem_;
-  float min_data_;
-  float max_data_;
+  float min_data_0_;
+  float max_data_0_;
+  float min_data_1_;
+  float max_data_1_;
   float min_output_;
   float max_output_;
-  float data_scale_{0.0f};
+  float data_scale_0_{0.0f};
+  float data_scale_1_{0.0f};
 };
 
 static OpStatePtr CreateSgDNNLSelfAttQKState(const nnvm::NodeAttrs& attrs,
@@ -173,18 +215,19 @@ static OpStatePtr CreateSgDNNLSelfAttQKState(const 
nnvm::NodeAttrs& attrs,
   return OpStatePtr::Create<SgDNNLSelfAttQKOp>(attrs);
 }
 
-static void SgDNNLSelfAttQKForward(const OpStatePtr& state_pointer,
-                                   const OpContext& ctx,
-                                   const std::vector<NDArray>& inputs,
-                                   const std::vector<OpReqType>& req,
-                                   const std::vector<NDArray>& outputs) {
+template <bool with_split>
+void SgDNNLSelfAttQKForward(const OpStatePtr& state_pointer,
+                            const OpContext& ctx,
+                            const std::vector<NDArray>& inputs,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<NDArray>& outputs) {
   SgDNNLSelfAttQKOp& op = state_pointer.get_state<SgDNNLSelfAttQKOp>();
   bool already_prepared = false;
   if (!op.IsInitialized()) {
-    op.Initialize(ctx, inputs, req, outputs);
+    op.Initialize<with_split>(ctx, inputs, req, outputs);
     already_prepared = true;
   }
-  op.Forward(ctx, inputs, req, outputs, already_prepared);
+  op.Forward<with_split>(ctx, inputs, req, outputs, already_prepared);
 }
 
 static bool SgDNNLSelfAttStorageType(const nnvm::NodeAttrs& attrs,
@@ -195,52 +238,70 @@ static bool SgDNNLSelfAttStorageType(const 
nnvm::NodeAttrs& attrs,
   return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, 
out_attrs);
 }
 
+template <bool with_split>
 void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
                                    const std::vector<NDArray>& inputs,
                                    const std::vector<OpReqType>& req,
                                    const std::vector<NDArray>& outputs) {
   using namespace dnnl;
 
-  const auto qkv_tensor = inputs[0];
-  const auto out_tensor = outputs[0];
+  const auto in_tensor_0 = inputs[0];
+  auto in_tensor_1       = in_tensor_0;  // with split there is only one input
+  const auto out_tensor  = outputs[0];
 
-  const auto qkv_dtype = get_dnnl_type(qkv_tensor.dtype());
+  const auto in_dtype = get_dnnl_type(in_tensor_0.dtype());
 
   const memory::dim heads          = param_.heads;
-  const memory::dim sequences      = inputs[0].shape()[0];
-  const memory::dim qkv_seq_len    = inputs[0].shape()[1];
-  const memory::dim output_lin_dim = inputs[0].shape()[2];
-  const memory::dim embed_dim      = output_lin_dim / QKV_NUM;
+  const memory::dim sequences      = in_tensor_0.shape()[0];
+  const memory::dim qkv_seq_len_0  = in_tensor_0.shape()[1];
+  const memory::dim output_lin_dim = in_tensor_0.shape()[2];
+  memory::dim embed_dim            = output_lin_dim;
+  if constexpr (with_split) {
+    embed_dim /= QKV_NUM;
+  } else {
+    in_tensor_1 = inputs[1];  // without split we need to consider 2nd input
+  }
+  const memory::dim qkv_seq_len_1  = in_tensor_1.shape()[1];
   const memory::dim head_dim       = embed_dim / heads;
-  const memory::dim batch_stride   = output_lin_dim * qkv_seq_len;
+  const memory::dim batch_stride_0 = output_lin_dim * qkv_seq_len_0;
+  const memory::dim batch_stride_1 = output_lin_dim * qkv_seq_len_1;
 
   float min_data = 0.0f;
   float max_data = 0.0f;
 
   const auto engine = CpuEngine::Get()->get_engine();
 
-  memory::dims query_dims = {sequences, heads, qkv_seq_len, head_dim};
-  memory::dims key_dims   = {sequences, heads, head_dim, qkv_seq_len};
-
-  memory::dims query_strides = {batch_stride, head_dim, output_lin_dim, 1};
-  memory::dims key_strides   = {batch_stride, head_dim, 1, output_lin_dim};
+  memory::dims query_dims    = {sequences, heads, qkv_seq_len_0, head_dim};
+  memory::dims key_dims      = {sequences, heads, head_dim, qkv_seq_len_1};
+  memory::dims query_strides = {batch_stride_0, head_dim, output_lin_dim, 1};
+  memory::dims key_strides   = {batch_stride_1, head_dim, 1, output_lin_dim};
 
-  auto query_md = memory::desc(query_dims, qkv_dtype, query_strides);
-  auto key_md   = memory::desc(key_dims, qkv_dtype, key_strides);
+  auto query_md = memory::desc(query_dims, in_dtype, query_strides);
+  auto key_md   = memory::desc(key_dims, in_dtype, key_strides);
 
   float oscale = 1.0f;
   if (param_.quantized) {
-    min_data_   = inputs[1].data().dptr<float>()[0];
-    max_data_   = inputs[2].data().dptr<float>()[0];
-    data_scale_ = GetQuantizeScale(qkv_tensor.dtype(), min_data_, max_data_);
+    if constexpr (with_split) {
+      min_data_0_   = inputs[1].data().dptr<float>()[0];
+      max_data_0_   = inputs[2].data().dptr<float>()[0];
+      data_scale_0_ = data_scale_1_ =
+          GetQuantizeScale(in_tensor_0.dtype(), min_data_0_, max_data_0_);
+    } else {
+      min_data_0_   = inputs[2].data().dptr<float>()[0];
+      max_data_0_   = inputs[3].data().dptr<float>()[0];
+      min_data_1_   = inputs[4].data().dptr<float>()[0];
+      max_data_1_   = inputs[5].data().dptr<float>()[0];
+      data_scale_0_ = GetQuantizeScale(in_tensor_0.dtype(), min_data_0_, 
max_data_0_);
+      data_scale_1_ = GetQuantizeScale(in_tensor_1.dtype(), min_data_1_, 
max_data_1_);
+    }
 
     if (param_.min_calib_range.has_value() && 
param_.max_calib_range.has_value()) {
       min_output_ = param_.min_calib_range.value();
       max_output_ = param_.max_calib_range.value();
       oscale      = GetQuantizeScale(out_tensor.dtype(), min_output_, 
max_output_) /
-               (data_scale_ * data_scale_);
+               (data_scale_0_ * data_scale_1_);
     } else if (param_.enabled_float_output.has_value()) {
-      oscale = 1.0f / (data_scale_ * data_scale_);
+      oscale = 1.0f / (data_scale_0_ * data_scale_1_);
     } else {
       mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
       mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, 
cpu>::Launch(
@@ -256,9 +317,14 @@ void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
 
   MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
     DType* query_mem_ptr = inputs[0].data().dptr<DType>();
-    DType* key_mem_ptr   = query_mem_ptr + embed_dim;
-    cached_query_mem_    = std::make_shared<memory>(query_md, engine, 
query_mem_ptr);
-    cached_key_mem_      = std::make_shared<memory>(key_md, engine, 
key_mem_ptr);
+    DType* key_mem_ptr;
+    if constexpr (with_split) {
+      key_mem_ptr = query_mem_ptr + embed_dim;
+    } else {
+      key_mem_ptr = inputs[1].data().dptr<DType>();
+    }
+    cached_query_mem_ = std::make_shared<memory>(query_md, engine, 
query_mem_ptr);
+    cached_key_mem_   = std::make_shared<memory>(key_md, engine, key_mem_ptr);
   });
 
   MSHADOW_TYPE_SWITCH(out_tensor.dtype(), DType, {
@@ -272,6 +338,7 @@ void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
   initialized_            = true;
 }
 
+template <bool with_split>
 void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx,
                                 const std::vector<NDArray>& inputs,
                                 const std::vector<OpReqType>& req,
@@ -283,7 +350,12 @@ void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx,
 
     MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
       DType* query_mem_ptr = inputs[0].data().dptr<DType>();
-      DType* key_mem_ptr   = query_mem_ptr + embed_dim;
+      DType* key_mem_ptr;
+      if constexpr (with_split) {
+        key_mem_ptr = query_mem_ptr + embed_dim;
+      } else {
+        key_mem_ptr = inputs[1].data().dptr<DType>();
+      }
       cached_query_mem_->set_data_handle(query_mem_ptr);
       cached_key_mem_->set_data_handle(key_mem_ptr);
     });
@@ -298,16 +370,20 @@ void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx,
   if (param_.quantized && !param_.enabled_float_output.has_value()) {
     float* output_min = outputs[1].data().dptr<float>();
     float* output_max = outputs[2].data().dptr<float>();
-
-    *output_min = min_output_;
-    *output_max = max_output_;
+    *output_min       = min_output_;
+    *output_max       = max_output_;
   }
 }
 
+template <bool with_split>
 nnvm::ObjectPtr SgDNNLSelfAttQKQuantizedOp(const NodeAttrs& attrs) {
-  nnvm::ObjectPtr node          = nnvm::Node::Create();
-  auto const& param             = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
-  node->attrs.op                = Op::Get("_sg_onednn_selfatt_qk");
+  nnvm::ObjectPtr node = nnvm::Node::Create();
+  auto const& param    = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+  if constexpr (with_split) {
+    node->attrs.op = Op::Get("_sg_onednn_selfatt_qk_split");
+  } else {
+    node->attrs.op = Op::Get("_sg_onednn_selfatt_qk");
+  }
   node->attrs.name              = "quantized_" + attrs.name;
   node->attrs.dict              = attrs.dict;
   node->attrs.dict["heads"]     = std::to_string(param.heads);
@@ -318,26 +394,79 @@ nnvm::ObjectPtr SgDNNLSelfAttQKQuantizedOp(const 
NodeAttrs& attrs) {
   return node;
 }
 
-NNVM_REGISTER_OP(_sg_onednn_selfatt_qk)
-    .add_alias("_sg_mkldnn_selfatt_qk")
+#define MXNET_OPERATOR_REGISTER_SELFATT_QK(name)                               
                  \
+  NNVM_REGISTER_OP(name)                                                       
                  \
+      .set_num_outputs([](const NodeAttrs& attrs) {                            
                  \
+        auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);         
                  \
+        if (param.quantized && !param.enabled_float_output.has_value()) {      
                  \
+          return 3;                                                            
                  \
+        } else {                                                               
                  \
+          return 1;                                                            
                  \
+        }                                                                      
                  \
+      })                                                                       
                  \
+      .set_attr<nnvm::FListOutputNames>(                                       
                  \
+          "FListOutputNames",                                                  
                  \
+          [](const NodeAttrs& attrs) {                                         
                  \
+            auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);     
                  \
+            std::vector<std::string> output_names{"output"};                   
                  \
+            if (param.quantized && !param.enabled_float_output.has_value()) {  
                  \
+              output_names.emplace_back("min_output");                         
                  \
+              output_names.emplace_back("max_output");                         
                  \
+            }                                                                  
                  \
+            return output_names;                                               
                  \
+          })                                                                   
                  \
+      .set_attr_parser(ParamParser<DNNLSelfAttParam>)                          
                  \
+      .set_attr<FInferStorageType>("FInferStorageType", 
SgDNNLSelfAttStorageType)                \
+      .set_attr<FCreateOpState>("FCreateOpState", CreateSgDNNLSelfAttQKState)  
                  \
+      .set_attr<bool>("TIsDNNL", true)                                         
                  \
+      .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)               
                  \
+      .set_attr<FQuantizable>("FQuantizable",                                  
                  \
+                              [](const NodeAttrs& attrs) { return 
QuantizeType::kMust; })        \
+      .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) 
{ return true; }) \
+      .add_arguments(DNNLSelfAttParam::__FIELDS__())
+
+MXNET_OPERATOR_REGISTER_SELFATT_QK(_sg_onednn_selfatt_qk)
     .describe(R"code(_sg_onednn_selfatt_qk)code" ADD_FILELINE)
     .set_num_inputs([](const NodeAttrs& attrs) {
       auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
       if (param.quantized) {
-        return 3;
+        return 6;
       } else {
-        return 1;
+        return 2;
       }
     })
-    .set_num_outputs([](const NodeAttrs& attrs) {
+    .set_attr<nnvm::FListInputNames>("FListInputNames",
+                                     [](const NodeAttrs& attrs) {
+                                       auto const& param =
+                                           
nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+                                       std::vector<std::string> 
input_names{"queries"};
+                                       input_names.emplace_back("keys");
+                                       if (param.quantized) {
+                                         input_names.emplace_back("min_q");
+                                         input_names.emplace_back("max_q");
+                                         input_names.emplace_back("min_k");
+                                         input_names.emplace_back("max_k");
+                                       }
+                                       return input_names;
+                                     })
+    .set_attr<mxnet::FInferShape>("FInferShape", SgDNNLSelfAttShape<false>)
+    .set_attr<nnvm::FInferType>("FInferType", SgDNNLSelfAttQKInferType<false>)
+    .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", 
SgDNNLSelfAttQKForward<false>)
+    .set_attr<FQuantizedOp>("FQuantizedOp", SgDNNLSelfAttQKQuantizedOp<false>)
+    .add_argument("queries", "NDArray-or-Symbol", "Interleaved queries, keys 
and values")
+    .add_argument("keys", "NDArray-or-Symbol", "Interleaved queries, keys and 
values");
+
+MXNET_OPERATOR_REGISTER_SELFATT_QK(_sg_onednn_selfatt_qk_split)
+    .add_alias("_sg_mkldnn_selfatt_qk")
+    .describe(R"code(_sg_onednn_selfatt_qk_split)code" ADD_FILELINE)
+    .set_num_inputs([](const NodeAttrs& attrs) {
       auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
-      if (param.quantized && !param.enabled_float_output.has_value()) {
+      if (param.quantized) {
         return 3;
       } else {
         return 1;
       }
     })
-    .set_attr_parser(ParamParser<DNNLSelfAttParam>)
     .set_attr<nnvm::FListInputNames>("FListInputNames",
                                      [](const NodeAttrs& attrs) {
                                        auto const& param =
@@ -349,33 +478,11 @@ NNVM_REGISTER_OP(_sg_onednn_selfatt_qk)
                                        }
                                        return input_names;
                                      })
-    .set_attr<nnvm::FListOutputNames>("FListOutputNames",
-                                      [](const NodeAttrs& attrs) {
-                                        auto const& param =
-                                            
nnvm::get<DNNLSelfAttParam>(attrs.parsed);
-                                        std::vector<std::string> 
output_names{"output"};
-                                        if (param.quantized &&
-                                            
!param.enabled_float_output.has_value()) {
-                                          
output_names.emplace_back("min_output");
-                                          
output_names.emplace_back("max_output");
-                                        }
-                                        return output_names;
-                                      })
-    .set_attr<mxnet::FInferShape>("FInferShape", SgDNNLSelfAttShape)
-    .set_attr<nnvm::FInferType>("FInferType", SgDNNLSelfAttQKInferType)
-    .set_attr<FInferStorageType>("FInferStorageType", SgDNNLSelfAttStorageType)
-    .set_attr<FCreateOpState>("FCreateOpState", CreateSgDNNLSelfAttQKState)
-    .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", 
SgDNNLSelfAttQKForward)
-    .set_attr<bool>("TIsDNNL", true)
-    .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
-    .set_attr<FQuantizable>("FQuantizable",
-                            [](const NodeAttrs& attrs) { return 
QuantizeType::kMust; })
-    .set_attr<FQuantizedOp>("FQuantizedOp", SgDNNLSelfAttQKQuantizedOp)
-    .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { 
return true; })
-    .add_argument("queries_keys_values",
-                  "NDArray-or-Symbol",
-                  "Interleaved queries, keys and values")
-    .add_arguments(DNNLSelfAttParam::__FIELDS__());
+    .set_attr<mxnet::FInferShape>("FInferShape", SgDNNLSelfAttShape<true>)
+    .set_attr<nnvm::FInferType>("FInferType", SgDNNLSelfAttQKInferType<true>)
+    .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", 
SgDNNLSelfAttQKForward<true>)
+    .set_attr<FQuantizedOp>("FQuantizedOp", SgDNNLSelfAttQKQuantizedOp<true>)
+    .add_argument("query_keys_values", "NDArray-or-Symbol", "Interleaved 
queries, keys and values");
 
 
/**********************************_sg_onednn_selfatt_valatt**********************************/
 
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer_qk_common.h 
b/src/operator/subgraph/dnnl/dnnl_transformer_qk_common.h
new file mode 100644
index 0000000000..902c58dfbc
--- /dev/null
+++ b/src/operator/subgraph/dnnl/dnnl_transformer_qk_common.h
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_TRANSFORMER_QK_COMMON_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_TRANSFORMER_QK_COMMON_H_
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <string>
+#include <vector>
+
+#include "operator/contrib/transformer-inl.h"
+#include "operator/numpy/np_matrix_op-inl.h"
+#include "operator/tensor/matrix_op-inl.h"
+#include "operator/subgraph/common.h"
+#include "dnnl_common.h"
+#include "dnnl_subgraph_base-inl.h"
+#include "dnnl_transformer-inl.h"
+
+namespace mxnet {
+namespace op {
+namespace qk_common {
+
+enum SelectStatusTransformerQK {
+  kFail = 0,
+  kStart,
+  kFirstSwapAx,
+  kSecondSwapAx,
+  kFirstReshape,
+  kSecondReshape,
+  kSuccess
+};
+
+// /*
+// kStart ---> kFirstSwapAx ---> kSecondSwapAx ---> kFirstReshape ---> 
kSecondReshape ---> kSuccess
+// OR
+// kStart ---> kFirstSwapAx ---> kSecondSwapAx ---> kFirstReshape ---> kSuccess
+// each status except kStart is connected with kFail
+// */
+
+inline bool CheckSwapAxisConditionsQK(const BiDirectedNode& input_node) {
+  if (input_node.outputs.size() != 1)
+    return false;
+  return CheckSwapAxisConditions(*input_node.node);
+}
+
+inline bool CheckReshapeConditionsQK(const BiDirectedNode& input_node, const 
index_t out_index) {
+  if (input_node.outputs.size() != 1)
+    return false;
+  return CheckReshapeConditions(*input_node.node, out_index);
+}
+
+inline bool CheckSplitConditions(const std::vector<const BiDirectedNode*>& 
matched_list,
+                                 const BiDirectedNode& node) {
+  const SplitParam& param = dmlc::get<SplitParam>(node.node->attrs.parsed);
+
+  if (param.axis != -1 || param.sections != 3 || param.squeeze_axis)
+    return false;
+
+  const auto first_reshape  = (*(matched_list.end() - 2))->node;
+  const auto second_reshape = (*(matched_list.end() - 1))->node;
+  if (first_reshape->op() != Op::Get("_npx_reshape") ||
+      second_reshape->op() != Op::Get("_npx_reshape")) {
+    return false;
+  }
+  // 3 sections - ensure that every output is used only once
+  if (node.outputs.size() == 3 && node.outputs.count(first_reshape) &&
+      node.outputs.count(second_reshape)) {
+    return true;
+  }
+
+  return false;
+}
+
+inline bool Select(SelectStatusTransformerQK* status,
+                   std::vector<const BiDirectedNode*>* matched_list,
+                   const BiDirectedNode& seed_node,
+                   const std::shared_ptr<NodeAttr>& node_attr) {
+  if (seed_node.node->op() == Op::Get("batch_dot")) {
+    *status = kStart;
+    matched_list->clear();
+    matched_list->push_back(&seed_node);
+    return true;
+  }
+  return false;
+}
+
+template <bool with_split>
+bool SelectInput(SelectStatusTransformerQK* status,
+                 std::vector<const BiDirectedNode*>* matched_list,
+                 const BiDirectedNode& n,
+                 const BiDirectedNode& input_node) {
+  if (*status == kFail || *status == kSuccess || 
input_node.node->is_variable())
+    return false;
+  const auto& raw_input_node = *input_node.node;
+  switch (*status) {
+    case kStart:
+      if (raw_input_node.op() == Op::Get("SwapAxis")) {
+        if (CheckSwapAxisConditionsQK(input_node)) {
+          *status = kFirstSwapAx;
+          matched_list->push_back(&input_node);
+          return true;
+        }
+      }
+      break;
+    case kFirstSwapAx:
+      if (raw_input_node.op() == Op::Get("SwapAxis")) {
+        if (CheckSwapAxisConditionsQK(input_node)) {
+          *status = kSecondSwapAx;
+          matched_list->push_back(&input_node);
+          return true;
+        }
+      }
+      break;
+    case kSecondSwapAx:
+      if (raw_input_node.op() == Op::Get("_npx_reshape")) {
+        // input to reshape must be first or second output from split
+        if (CheckReshapeConditionsQK(input_node, 0) || 
CheckReshapeConditionsQK(input_node, 1)) {
+          *status = kFirstReshape;
+          matched_list->push_back(&input_node);
+          return true;
+        }
+      }
+      break;
+    case kFirstReshape:
+      if (raw_input_node.op() == Op::Get("_npx_reshape")) {
+        if (CheckReshapeConditionsQK(input_node, 0) || 
CheckReshapeConditionsQK(input_node, 1)) {
+          if constexpr (with_split) {
+            *status = kSecondReshape;
+          } else {
+            *status = kSuccess;
+          }
+          matched_list->push_back(&input_node);
+          return true;
+        }
+      }
+      break;
+    case kSecondReshape:
+      if (raw_input_node.op() == Op::Get("_split_v2") &&
+          CheckSplitConditions(*matched_list, input_node)) {
+        *status = kSuccess;
+        return true;
+      }
+      break;
+    default:
+      *status = kFail;
+      return false;
+  }
+  return false;
+}
+
+inline std::vector<BiDirectedNode*> Filter(const SelectStatusTransformerQK& 
status,
+                                           const std::vector<const 
BiDirectedNode*>& matched_list,
+                                           const std::vector<BiDirectedNode*>& 
candidates) {
+  if (status != kSuccess) {
+    return std::vector<BiDirectedNode*>(0);
+  } else {
+    std::vector<BiDirectedNode*> ret;
+    for (auto i : matched_list) {
+      auto non_const_i = const_cast<BiDirectedNode*>(i);
+      if (std::find(candidates.begin(), candidates.end(), non_const_i) != 
candidates.end()) {
+        ret.push_back(non_const_i);
+      }
+    }
+    return ret;
+  }
+}
+
+template <bool with_split>
+nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym, const int 
subgraph_id = 0) {
+  std::string op_name;
+  if constexpr (with_split) {
+    op_name = "_sg_onednn_selfatt_qk_split";
+  } else {
+    op_name = "_sg_onednn_selfatt_qk";
+  }
+  nnvm::ObjectPtr n = nnvm::Node::Create();
+  // this op has single output, remove duplicated
+  auto last_node = sym.outputs[0].node;
+  nnvm::Symbol new_sym;
+  new_sym.outputs.emplace_back(last_node);
+  std::ostringstream node_name;
+
+  DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr& node) {
+    if ((node->op() == Op::Get("_npx_reshape"))) {
+      auto const& reshape_param = 
nnvm::get<NumpyXReshapeParam>(node->attrs.parsed);
+      // set heads attribute - all necessary conditions are checked before
+      n->attrs.dict["heads"] = std::to_string(reshape_param.newshape[2]);
+    }
+  });
+
+  node_name << op_name << subgraph_id;
+  n->attrs.name = node_name.str();
+  n->attrs.op   = Op::Get(op_name);
+  CHECK(n->attrs.op);
+  n->op()->attr_parser(&(n->attrs));
+  return n;
+}
+
+inline void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
+                                   std::vector<nnvm::NodeEntry*>* 
output_entries) {
+  // connect all extern output entries to output[0]
+  for (size_t i = 0; i < output_entries->size(); ++i) {
+    auto entry_ptr = output_entries->at(i);
+    *entry_ptr     = nnvm::NodeEntry{n, entry_ptr->index, 0};
+  }
+}
+
+}  // namespace qk_common
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // if MXNET_USE_ONEDNN == 1
+#endif  // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_TRANSFORMER_QK_COMMON_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h 
b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
index 9f4a9991a1..a34af4711f 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
@@ -22,19 +22,26 @@
 
 #if MXNET_USE_ONEDNN == 1
 
-#include <map>
 #include <string>
 #include <vector>
 
-#include "operator/contrib/transformer-inl.h"
-#include "operator/numpy/np_matrix_op-inl.h"
-#include "operator/tensor/matrix_op-inl.h"
-#include "operator/subgraph/common.h"
-#include "dnnl_common.h"
-#include "dnnl_subgraph_base-inl.h"
-#include "dnnl_transformer-inl.h"
+#include "dnnl_transformer_qk_common.h"
 
 /*
+       custom_op      custom_op
+          |              |
+    ______|______________|________
+   |      |              |        |
+   |  _npx_reshape  _npx_reshape  |
+   |      |              |        |
+   |  SwapAxis        SwapAxis    |
+   |       \          /           |
+   |        batch_dot             |
+   |            |                 |
+   |______________________________|
+
+OR
+
               custom_op
                  |
     _____________|_________________
@@ -47,113 +54,55 @@
    |        batch_dot             |
    |            |                 |
    |______________________________|
+
 */
 namespace mxnet {
 namespace op {
 
 class SgDNNLTransformerQKSelector : public SubgraphSelectorV2 {
-  enum SelectStatusTransformerQK {
-    kFail = 0,
-    kStart,
-    kFirstSwapAx,
-    kSecondSwapAx,
-    kFirstReshape,
-    kSecondReshape,
-    kSuccess
-  };
-
-  /*!
-    kStart ---> kFirstSwapAx ---> kSecondSwapAx ---> kFirstReshape ---> 
kSecondReshape ---> kSuccess
-    Each status except kStart is connected with kFail
-  */
-
  private:
-  SelectStatusTransformerQK status_;
+  qk_common::SelectStatusTransformerQK status_;
   std::vector<const BiDirectedNode*> matched_list_;
 
-  bool CheckSplitConditions(const BiDirectedNode& node) {
-    const SplitParam& param = dmlc::get<SplitParam>(node.node->attrs.parsed);
-
-    if (param.axis != -1 || param.sections != 3 || param.squeeze_axis)
-      return false;
+ public:
+  bool Select(const BiDirectedNode& seed_node,
+              const std::shared_ptr<NodeAttr>& node_attr) override {
+    return qk_common::Select(&status_, &matched_list_, seed_node, node_attr);
+  }
 
-    const auto first_reshape  = (*(matched_list_.end() - 2))->node;
-    const auto second_reshape = (*(matched_list_.end() - 1))->node;
-    if (first_reshape->op() != Op::Get("_npx_reshape") ||
-        second_reshape->op() != Op::Get("_npx_reshape")) {
-      return false;
-    }
-    // 3 sections - ensure that every output is used only once
-    if (node.outputs.size() == 3 && node.outputs.count(first_reshape) &&
-        node.outputs.count(second_reshape)) {
-      return true;
-    }
+  bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) 
override {
+    return qk_common::SelectInput<false>(&status_, &matched_list_, n, 
input_node);
+  }
 
+  bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& 
output_node) override {
     return false;
   }
 
+  std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& 
candidates) override {
+    return qk_common::Filter(status_, matched_list_, candidates);
+  }
+
+  void Reset() override {
+    CHECK_GE(matched_list_.size(), 1);
+    auto new_selector = SgDNNLTransformerQKSelector();
+    new_selector.Select(*matched_list_[0], nullptr);
+    *this = new_selector;
+  }
+};
+
+class SgDNNLTransformerQKSplitSelector : public SubgraphSelectorV2 {
+ private:
+  qk_common::SelectStatusTransformerQK status_;
+  std::vector<const BiDirectedNode*> matched_list_;
+
  public:
   bool Select(const BiDirectedNode& seed_node,
               const std::shared_ptr<NodeAttr>& node_attr) override {
-    if (seed_node.node->op() == Op::Get("batch_dot")) {
-      status_ = kStart;
-      matched_list_.clear();
-      matched_list_.push_back(&seed_node);
-      return true;
-    }
-    return false;
+    return qk_common::Select(&status_, &matched_list_, seed_node, node_attr);
   }
 
   bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) 
override {
-    if (status_ == kFail || status_ == kSuccess || 
input_node.node->is_variable())
-      return false;
-    const auto& raw_input_node = *input_node.node;
-    switch (status_) {
-      case kStart:
-        if (raw_input_node.op() == Op::Get("SwapAxis")) {
-          if (CheckSwapAxisConditions(raw_input_node)) {
-            status_ = kFirstSwapAx;
-            matched_list_.push_back(&input_node);
-          }
-          return true;
-        }
-      case kFirstSwapAx:
-        if (raw_input_node.op() == Op::Get("SwapAxis")) {
-          if (CheckSwapAxisConditions(raw_input_node)) {
-            status_ = kSecondSwapAx;
-            matched_list_.push_back(&input_node);
-            return true;
-          }
-        }
-      case kSecondSwapAx:
-        if (raw_input_node.op() == Op::Get("_npx_reshape")) {
-          // input to reshape must be first or second output from split
-          if (CheckReshapeConditions(raw_input_node, 0) ||
-              CheckReshapeConditions(raw_input_node, 1)) {
-            status_ = kFirstReshape;
-            matched_list_.push_back(&input_node);
-            return true;
-          }
-        }
-      case kFirstReshape:
-        if (raw_input_node.op() == Op::Get("_npx_reshape")) {
-          if (CheckReshapeConditions(raw_input_node, 0) ||
-              CheckReshapeConditions(raw_input_node, 1)) {
-            status_ = kSecondReshape;
-            matched_list_.push_back(&input_node);
-            return true;
-          }
-        }
-      case kSecondReshape:
-        if (raw_input_node.op() == Op::Get("_split_v2") && 
CheckSplitConditions(input_node)) {
-          status_ = kSuccess;
-          return true;
-        }
-      default:
-        status_ = kFail;
-        return false;
-    }
-    return false;
+    return qk_common::SelectInput<true>(&status_, &matched_list_, n, 
input_node);
   }
 
   bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& 
output_node) override {
@@ -161,23 +110,12 @@ class SgDNNLTransformerQKSelector : public 
SubgraphSelectorV2 {
   }
 
   std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& 
candidates) override {
-    if (status_ != kSuccess) {
-      return std::vector<BiDirectedNode*>(0);
-    } else {
-      std::vector<BiDirectedNode*> ret;
-      for (auto i : matched_list_) {
-        auto non_const_i = const_cast<BiDirectedNode*>(i);
-        if (std::find(candidates.begin(), candidates.end(), non_const_i) != 
candidates.end()) {
-          ret.push_back(non_const_i);
-        }
-      }
-      return ret;
-    }
+    return qk_common::Filter(status_, matched_list_, candidates);
   }
 
   void Reset() override {
     CHECK_GE(matched_list_.size(), 1);
-    auto new_selector = SgDNNLTransformerQKSelector();
+    auto new_selector = SgDNNLTransformerQKSplitSelector();
     new_selector.Select(*matched_list_[0], nullptr);
     *this = new_selector;
   }
@@ -200,43 +138,43 @@ class SgDNNLTransformerQKProperty : public 
SubgraphProperty {
 
   nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
                                      const int subgraph_id = 0) const override 
{
-    nnvm::ObjectPtr n = nnvm::Node::Create();
-    // This op has single output, remove duplicated.
-    auto last_node = sym.outputs[0].node;
-    nnvm::Symbol new_sym;
-    new_sym.outputs.emplace_back(last_node);
-    std::ostringstream node_name;
-    std::string op_name;
-
-    DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr& node) {
-      if ((node->op() == Op::Get("_npx_reshape"))) {
-        auto const& reshape_param = 
nnvm::get<NumpyXReshapeParam>(node->attrs.parsed);
-        // set heads attribute - all necessary conditions are checked before
-        n->attrs.dict["heads"] = std::to_string(reshape_param.newshape[2]);
-      }
-    });
-
-    node_name << "_sg_onednn_selfatt_qk_" << subgraph_id;
-
-    n->attrs.name = node_name.str();
-    n->attrs.op   = Op::Get("_sg_onednn_selfatt_qk");
-    CHECK(n->attrs.op);
-    n->op()->attr_parser(&(n->attrs));
-    return n;
+    return qk_common::CreateSubgraphNode<false>(sym, subgraph_id);
+  }
+
+  void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
+                              std::vector<nnvm::NodeEntry*>* output_entries) 
const override {
+    qk_common::ConnectSubgraphOutputs(n, output_entries);
   }
 
   SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
     auto selector = std::make_shared<SgDNNLTransformerQKSelector>();
     return selector;
   }
+};
+
+class SgDNNLTransformerQKSplitProperty : public SubgraphProperty {
+ public:
+  SgDNNLTransformerQKSplitProperty() {}
+
+  static SubgraphPropertyPtr Create() {
+    static const std::string& name = "oneDNN Transformer optimization pass";
+    auto property                  = 
std::make_shared<SgDNNLTransformerQKSplitProperty>();
+    property->SetAttr<std::string>("property_name", name);
+    property->SetAttr<bool>("inference_only", true);
+    if (dmlc::GetEnv("MXNET_DISABLE_ONEDNN_TRANSFORMER_OPT", 0)) {
+      property->SetAttr<bool>("disable", true);
+    }
+    return property;
+  }
+
+  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
+                                     const int subgraph_id = 0) const override 
{
+    return qk_common::CreateSubgraphNode<true>(sym, subgraph_id);
+  }
 
   void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
                               std::vector<nnvm::NodeEntry*>* output_entries) 
const override {
-    // Connect all extern output entries to output[0]
-    for (size_t i = 0; i < output_entries->size(); ++i) {
-      auto entry_ptr = output_entries->at(i);
-      *entry_ptr     = nnvm::NodeEntry{n, entry_ptr->index, 0};
-    }
+    qk_common::ConnectSubgraphOutputs(n, output_entries);
   }
 
   void ConnectSubgraphInputs(const nnvm::ObjectPtr subgraph_node,
@@ -247,6 +185,11 @@ class SgDNNLTransformerQKProperty : public 
SubgraphProperty {
     // connect subgraph input with split input
     subgraph_node->inputs[0] = orig_input_entries->at(0).node->inputs[0];
   }
+
+  SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
+    auto selector = std::make_shared<SgDNNLTransformerQKSplitSelector>();
+    return selector;
+  }
 };
 
 }  // namespace op
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer_valatt_property.h 
b/src/operator/subgraph/dnnl/dnnl_transformer_valatt_property.h
index 35f02f6203..0cdde8c964 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer_valatt_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_transformer_valatt_property.h
@@ -36,12 +36,12 @@
 #include "dnnl_transformer-inl.h"
 
 /*
-                 custom_op
-   _________________|_________
-  |               Split      |
-  |                 |        |
-  |             _npx_reshape |
-  |                 |        |
+            custom_op
+   _____________|_____________
+  |           Split          |
+  |          /      |        |
+  | custom_op   _npx_reshape |
+  |    ...          |        |
   | custom_op    SwapAxis    |
   |      \        /          |
   |       batch_dot          |
diff --git a/tests/python/dnnl/op_cfg.py b/tests/python/dnnl/op_cfg.py
index c4739ab807..50ec87a017 100644
--- a/tests/python/dnnl/op_cfg.py
+++ b/tests/python/dnnl/op_cfg.py
@@ -291,11 +291,17 @@ def get_all_ops_cfgs(dtype):
         '_sg_onednn_batch_norm': {CFG_BASED_ON: 'BatchNorm'},
         '_sg_onednn_selfatt_qk': {
             CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_qk', 'ONEDNN')],
+            'queries': [mx.nd.random.normal(0, 1, (1, 4, 3*2*8), dtype)],
+            'keys': [mx.nd.random.normal(0, 1, (1, 8, 3*2*8), dtype)],
+            'heads': [2]
+        },
+        '_sg_onednn_selfatt_qk_split': {
+            CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_qk_split', 
'ONEDNN')],
             'queries_keys_values': [mx.nd.random.normal(0, 1, (1, 4, 3*2*8), 
dtype)],
             'heads': [2]
         },
         '_sg_onednn_selfatt_valatt': {
-            CFG_BASED_ON: '_sg_onednn_selfatt_qk',
+            CFG_BASED_ON: '_sg_onednn_selfatt_qk_split',
             CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_valatt', 'ONEDNN')],
             'attention': [CfgBasedArg(valatt_attention_tensor)]
         }
diff --git a/tests/python/dnnl/subgraphs/test_matmul_subgraph.py 
b/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
index a96c8c31ba..938d49444f 100644
--- a/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
@@ -27,22 +27,28 @@ import math
 
 
 class MultiHeadAttention(nn.HybridBlock):
-  def __init__(self, units, num_heads, dtype='float32', negative_case=False, 
**kwargs):
+  def __init__(self, units, num_heads, batch_size=-1, seq_length=-1, 
dtype='float32', negative_case=False, no_split_case = False, **kwargs):
       super(MultiHeadAttention, self).__init__(**kwargs)
       self._units = units
       self._num_heads = num_heads
       self._fc = nn.Dense(in_units=self._units, units=3*self._units, 
flatten=False, dtype=dtype)
       self._scale = math.sqrt(self._units // self._num_heads)
       self.negative_case = negative_case
+      self.no_split_case = no_split_case
+      self.batch_size = batch_size
+      self.seq_length = seq_length
 
   def forward(self, x, mask):
       out = self._fc(x)
       query, key, value = mx.np.split(out, 3, axis=-1)
+      if self.no_split_case:
+        key = mx.np.concat((key, key), axis = 1)
+        value = mx.np.concat((value, value), axis = 1)
+      query = mx.np.reshape(query, (-2, -2, self._num_heads, -1))
       if self.negative_case:
-        key = key * 2
-      query = mx.npx.reshape(query, (-2, -2, self._num_heads, -1))
-      key = mx.npx.reshape(key, (-2, -2, self._num_heads, -1))
-      value = mx.npx.reshape(value, (-2, -2, self._num_heads, -1))
+        query = query * 2
+      key = mx.np.reshape(key, (-2, -2, self._num_heads, -1))
+      value = mx.np.reshape(value, (-2, -2, self._num_heads, -1))
       scores = mx.npx.batch_dot(mx.np.swapaxes(query, 1, 2), 
mx.np.swapaxes(key, 1, 2),
                                 transpose_b=True)
       mask = mx.np.expand_dims(mask, axis=1).astype(np.bool)
@@ -58,10 +64,16 @@ class MultiHeadAttention(nn.HybridBlock):
 @pytest.mark.parametrize('seq_length', [124, 384])
 @pytest.mark.parametrize('units', [256, 768])
 @pytest.mark.parametrize('num_heads', [4, 8])
-def test_self_attention(batch_size, seq_length, units, num_heads):
-  net = MultiHeadAttention(units, num_heads)
[email protected]('split', [True, False])
+def test_self_attention(batch_size, seq_length, units, num_heads, split):
+  net = MultiHeadAttention(units, num_heads, no_split_case=not split)
   in_data = mx.np.random.uniform(size=[batch_size, seq_length, units], 
dtype='float32')
-  mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length, 
seq_length], dtype='int32')
+  if (split):
+    mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length, 
seq_length], dtype='int32')
+  else:
+    # key dimension will be expanded by num_heads value to simulate gpt-2 model
+    # mask needs to be expanded as well
+    mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length, 
seq_length * 2], dtype='int32')
 
   net.initialize()
   fused_net = net
@@ -74,13 +86,13 @@ def test_self_attention(batch_size, seq_length, units, 
num_heads):
 
   assert_almost_equal(out.asnumpy(), ref_out.asnumpy())
 
-  calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, 
mask), batch_size=1)
+  calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, 
mask), batch_size=batch_size)
   qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto',
                                             exclude_layers=None,
                                             exclude_layers_match=None,
                                             calib_data=calib_data,
                                             calib_mode='naive',
-                                            num_calib_batches=1,
+                                            num_calib_batches=batch_size,
                                             ctx=mx.cpu())
 
   qout = qnet(in_data, mask)
@@ -97,7 +109,7 @@ def test_self_attention(batch_size, seq_length, units, 
num_heads):
 @pytest.mark.parametrize('units', [256, 768])
 @pytest.mark.parametrize('num_heads', [4, 8])
 def test_self_attention_negative(batch_size, seq_length, units, num_heads):
-  net = MultiHeadAttention(units, num_heads, negative_case=True)
+  net = MultiHeadAttention(units, num_heads, batch_size, seq_length, 
negative_case=True)
   in_data = mx.np.random.uniform(size=[batch_size, seq_length, units], 
dtype='float32')
   mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length, 
seq_length], dtype='int32')
 
@@ -112,13 +124,13 @@ def test_self_attention_negative(batch_size, seq_length, 
units, num_heads):
 
   assert_almost_equal(out.asnumpy(), ref_out.asnumpy())
 
-  calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, 
mask), batch_size=1)
+  calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, 
mask), batch_size=batch_size)
   qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto',
                                             exclude_layers=None,
                                             exclude_layers_match=None,
                                             calib_data=calib_data,
                                             calib_mode='naive',
-                                            num_calib_batches=1,
+                                            num_calib_batches=batch_size,
                                             ctx=mx.cpu())
 
   qout = qnet(in_data, mask)

Reply via email to