zhiics commented on a change in pull request #5272: [BYOC] Add example of 
Composite + Annotate for DNNL fused op
URL: https://github.com/apache/incubator-tvm/pull/5272#discussion_r405827170
 
 

 ##########
 File path: src/relay/backend/contrib/dnnl/codegen.cc
 ##########
 @@ -133,83 +209,100 @@ class CodegenDNNL : public ExprVisitor, public 
CodegenCBase {
   }
 
  private:
-  std::vector<std::string> Conv2d(const CallNode* call) {
-    std::vector<std::string> args;
-    const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();
-    CHECK(conv2d_attr);
-
-    auto ishape = GetShape(call->args[0]->checked_type());
-    auto wshape = GetShape(call->args[1]->checked_type());
+  struct GenerateBodyOutput {
+    std::string decl, buf;
+    int out_size = 1;
+    std::string out;
+  };
 
-    // Args: N, C, H, W
-    for (auto s : ishape) {
-      args.push_back(std::to_string(s));
+  std::vector<std::string> GetArgumentNames(const CallNode* call) {
+    std::vector<std::string> arg_names;
+    for (size_t i = 0; i < call->args.size(); ++i) {
+      VisitExpr(call->args[i]);
+      for (auto out : out_) {
+        arg_names.push_back(out.name);
+      }
     }
-
-    // Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw
-    args.push_back(std::to_string(wshape[0]));
-    args.push_back(std::to_string(conv2d_attr->groups));
-    
args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImmNode>()->value));
-    
args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImmNode>()->value));
-    args.push_back(std::to_string(wshape[2]));
-    args.push_back(std::to_string(wshape[3]));
-    
args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImmNode>()->value));
-    
args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value));
-
-    return args;
+    return arg_names;
   }
 
-  std::vector<std::string> Dense(const CallNode* call) {
-    std::vector<std::string> args;
-    auto ishape = GetShape(call->args[0]->checked_type());
-    auto wshape = GetShape(call->args[1]->checked_type());
-
-    // Args: N, C, O
-    args.push_back(std::to_string(ishape[0]));
-    args.push_back(std::to_string(ishape[1]));
-    args.push_back(std::to_string(wshape[0]));
+  GenerateBodyOutput GenerateOpCall(const CallNode* call) {
+    const auto* op_node = call->op.as<OpNode>();
+    CHECK(op_node) << "OpNode expected, got something else";
 
 Review comment:
   yes, my bad.
   
   Update: I meant, call->op->GetTypekey()

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to