masahi commented on a change in pull request #4741: [External codegen] Add test cases for fused ops with manual annotation URL: https://github.com/apache/incubator-tvm/pull/4741#discussion_r368707692
########## File path: src/relay/backend/contrib/dnnl/codegen.cc ########## @@ -50,82 +51,109 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { out_.push_back({node->name_hint(), 0}); } - void VisitExpr_(const TupleGetItemNode* op) final { - // Do nothing - } - void VisitExpr_(const CallNode* call) final { - std::ostringstream decl_stream; - std::ostringstream buf_stream; - // Args: ID - std::vector<std::string> args; + struct Output { + std::string decl, buf; + int out_size = 1; + std::string out; + }; + + auto generate_body = [=](const CallNode* root_call, const std::string& func_name, + const std::vector<std::string>& args, + const std::vector<std::string>& fused_func_args) { + // Make function call with input buffers when visiting arguments + bool first = true; + std::ostringstream arg_stream; + arg_stream << "("; + for (size_t i = 0; i < root_call->args.size(); ++i) { + VisitExpr(root_call->args[i]); + for (auto out : out_) { + if (!first) { + arg_stream << ", "; + } + first = false; + arg_stream << out.first; + } + } + + for (auto arg_name : fused_func_args) { + arg_stream << ", " << arg_name; + } + + // Analyze the output buffer + auto type_node = root_call->checked_type().as<TensorTypeNode>(); + CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32)) + << "Only support single output tensor with float type"; + + auto out_shape = GetShape(root_call->checked_type()); + + Output ret; + ret.out = "buf_" + std::to_string(buf_idx_++); + ret.out_size = std::accumulate(out_shape.begin(), out_shape.end(), 1, std::multiplies<int>()); + + this->PrintIndents(); + + std::ostringstream buf_stream; + buf_stream << "float* " << ret.out << " = (float*)std::malloc(4 * " << ret.out_size << ");"; + ret.buf = buf_stream.str(); - // Get the arguments for various DNNL kernels. - if (IsOp(call, "nn.conv2d")) { - decl_stream << "dnnl_conv2d"; - args = Conv2d(call); + arg_stream << ", " << ret.out; + // Attach attribute arguments + for (size_t i = 0; i < args.size(); ++i) { + arg_stream << ", " << args[i]; + } + arg_stream << ");"; + ret.decl = func_name + arg_stream.str(); + + return ret; + }; + + Output ret; + if (auto conv_call = DetectFusedConv2DBiasReLU(call)) { Review comment: I can also leave the current dumb implementation as it is, with the understanding that * This is a temporary solution * It will serve as a concrete motivation and test case for validating a more general mechanism to be introduced Trying to be a bit more clever and duplicating an entire state machine logic here do not seem worth it to me anymore. Either way I'm fine. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services