zhiics commented on a change in pull request #5272: [BYOC] Add example of
Composite + Annotate for DNNL fused op
URL: https://github.com/apache/incubator-tvm/pull/5272#discussion_r405813179
##########
File path: src/relay/backend/contrib/dnnl/codegen.cc
##########
@@ -133,83 +209,100 @@ class CodegenDNNL : public ExprVisitor, public
CodegenCBase {
}
private:
- std::vector<std::string> Conv2d(const CallNode* call) {
- std::vector<std::string> args;
- const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();
- CHECK(conv2d_attr);
-
- auto ishape = GetShape(call->args[0]->checked_type());
- auto wshape = GetShape(call->args[1]->checked_type());
+ struct GenerateBodyOutput {
+ std::string decl, buf;
+ int out_size = 1;
+ std::string out;
+ };
- // Args: N, C, H, W
- for (auto s : ishape) {
- args.push_back(std::to_string(s));
+ std::vector<std::string> GetArgumentNames(const CallNode* call) {
+ std::vector<std::string> arg_names;
+ for (size_t i = 0; i < call->args.size(); ++i) {
+ VisitExpr(call->args[i]);
+ for (auto out : out_) {
+ arg_names.push_back(out.name);
+ }
}
-
- // Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw
- args.push_back(std::to_string(wshape[0]));
- args.push_back(std::to_string(conv2d_attr->groups));
-
args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImmNode>()->value));
-
args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImmNode>()->value));
- args.push_back(std::to_string(wshape[2]));
- args.push_back(std::to_string(wshape[3]));
-
args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImmNode>()->value));
-
args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value));
-
- return args;
+ return arg_names;
}
- std::vector<std::string> Dense(const CallNode* call) {
- std::vector<std::string> args;
- auto ishape = GetShape(call->args[0]->checked_type());
- auto wshape = GetShape(call->args[1]->checked_type());
-
- // Args: N, C, O
- args.push_back(std::to_string(ishape[0]));
- args.push_back(std::to_string(ishape[1]));
- args.push_back(std::to_string(wshape[0]));
+ GenerateBodyOutput GenerateOpCall(const CallNode* call) {
+ const auto* op_node = call->op.as<OpNode>();
+ CHECK(op_node) << "OpNode expected, got something else";
Review comment:
Let's also print out what the op_node is
```
CHECK(op_node) << "Expect OpNode, but got " << op_node->GetTypeKey();
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services