agrabows commented on code in PR #21115:
URL: https://github.com/apache/incubator-mxnet/pull/21115#discussion_r951288927


##########
src/operator/subgraph/dnnl/dnnl_transformer.cc:
##########
@@ -318,26 +405,81 @@ nnvm::ObjectPtr SgDNNLSelfAttQKQuantizedOp(const 
NodeAttrs& attrs) {
   return node;
 }
 
-NNVM_REGISTER_OP(_sg_onednn_selfatt_qk)
-    .add_alias("_sg_mkldnn_selfatt_qk")
+#define MXNET_OPERATOR_REGISTER_SELFATT_QK(name)                               
                  \
+  NNVM_REGISTER_OP(name)                                                       
                  \
+      .set_num_outputs([](const NodeAttrs& attrs) {                            
                  \
+        auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);         
                  \
+        if (param.quantized && !param.enabled_float_output.has_value()) {      
                  \
+          return 3;                                                            
                  \
+        } else {                                                               
                  \
+          return 1;                                                            
                  \
+        }                                                                      
                  \
+      })                                                                       
                  \
+      .set_attr<nnvm::FListOutputNames>(                                       
                  \
+          "FListOutputNames",                                                  
                  \
+          [](const NodeAttrs& attrs) {                                         
                  \
+            auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);     
                  \
+            std::vector<std::string> output_names{"output"};                   
                  \
+            if (param.quantized && !param.enabled_float_output.has_value()) {  
                  \
+              output_names.emplace_back("min_output");                         
                  \
+              output_names.emplace_back("max_output");                         
                  \
+            }                                                                  
                  \
+            return output_names;                                               
                  \
+          })                                                                   
                  \
+      .set_attr_parser(ParamParser<DNNLSelfAttParam>)                          
                  \
+      .set_attr<FInferStorageType>("FInferStorageType", 
SgDNNLSelfAttStorageType)                \
+      .set_attr<FCreateOpState>("FCreateOpState", CreateSgDNNLSelfAttQKState)  
                  \
+      .set_attr<bool>("TIsDNNL", true)                                         
                  \
+      .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)               
                  \
+      .set_attr<FQuantizable>("FQuantizable",                                  
                  \
+                              [](const NodeAttrs& attrs) { return 
QuantizeType::kMust; })        \
+      .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) 
{ return true; }) \
+      .add_arguments(DNNLSelfAttParam::__FIELDS__())
+
+MXNET_OPERATOR_REGISTER_SELFATT_QK(_sg_onednn_selfatt_qk)
     .describe(R"code(_sg_onednn_selfatt_qk)code" ADD_FILELINE)
     .set_num_inputs([](const NodeAttrs& attrs) {
       auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
       if (param.quantized) {
-        return 3;
+        return 6;
       } else {
-        return 1;
+        return 2;
       }
     })
-    .set_num_outputs([](const NodeAttrs& attrs) {
+    .set_attr<nnvm::FListInputNames>("FListInputNames",
+                                     [](const NodeAttrs& attrs) {
+                                       auto const& param =
+                                           
nnvm::get<DNNLSelfAttParam>(attrs.parsed);
+                                       std::vector<std::string> 
input_names{"queries"};
+                                       if (param.quantized) {
+                                         input_names.emplace_back("min_q");
+                                         input_names.emplace_back("max_q");
+                                       }
+                                       input_names.emplace_back("keys");

Review Comment:
   done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to