masahi commented on a change in pull request #4258: [WIP][TVM] Bring Your Own Codegen to TVM URL: https://github.com/apache/incubator-tvm/pull/4258#discussion_r352304021
########## File path: src/relay/pass/partition_graph.cc ########## @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file src/relay/pass/partition_graph.cc + * + * \brief Partition an input function into multiple Functions according based + * on the inserted annotation nodes (i.e. begin and end). These nodes are used + * as boundaries to partition the Relay function into multiple regions that can + * be offloaded to different accelerators. + * + * Each of these paritioned functions, a.k.a subgraphs, will be viewed as + * external functions, and they will use external tools for codegen. + */ + +#include <tvm/relay/analysis.h> +#include <tvm/relay/attrs/annotation.h> +#include <tvm/relay/expr.h> +#include <tvm/relay/expr_functor.h> +#include <tvm/relay/transform.h> + +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +namespace tvm { +namespace relay { +namespace graph_partitioning { + +/*! + * \brief The subgraph properties for partition. + */ +struct Subgraph { + /*! \brief The subgraph ID. */ + int id; + + /*! \brief The input arguments of this subgraph. */ + std::vector<std::pair<Var, Expr>> args; + + /*! \brief Nodes in this subgraph. */ + std::unordered_set<Expr, ExprHash, ExprEqual> nodes; +}; + +/*! + * \brief The checker that verifies if a Relay program is annotated correctly + * for graph partitioning. + */ +class AnnotationChecker : public ExprVisitor { + public: + bool Check() { + if (!this->found_start && !this->found_end) { + LOG(WARNING) << "No subgraph annotation found"; + } else if (!this->found_start) { + LOG(ERROR) << "Subgraph start annotation is missing"; + return false; + } else if (!this->found_end) { + LOG(ERROR) << "Subgraph end annotation is missing"; + return false; + } + return true; + } + + void VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as<OpNode>(); + if (op_node == nullptr || call->attrs.as<SubgraphAttrs>() == nullptr) { + return; + } else if (GetRef<Op>(op_node) == Op::Get("annotation.subgraph_begin")) { + this->found_start = true; + } else if (GetRef<Op>(op_node) == Op::Get("annotation.subgraph_end")) { + this->found_end = true; + } + } + + private: + bool found_start = false; + bool found_end = false; +}; + +/*! \brief This class partitions the graph labeled with begin and end annoations + * into function containing multiple subgraphs. Each subgraph is labeled as + * external. + * + * TODO(@zhiics) This following algorithm is not adequate to handle all cases, + * i.e. multiple `end` nodes. + */ +class Partitioner : public ExprMutator { + public: + Subgraph* GetSubgraph(const Expr node) { + for (auto candidate : this->subgraphs_) { + if (candidate->nodes.find(node) != candidate->nodes.end()) { + return candidate; + } + } + return nullptr; + } + + void MergeSubgraph(Subgraph* subgraph1, Subgraph* subgraph2) { + if (subgraph1 == subgraph2) { + return; + } + + // Merge subgraph 2 to subgraph 1 and erase subgraph 2. + subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end()); + for (auto arg : subgraph2->args) { + subgraph1->args.push_back(arg); + } + this->subgraphs_.erase(subgraph2); + } + + void AddToSubgraph(Subgraph* subgraph, const Expr expr) { + auto subgraph2 = GetSubgraph(expr); + if (subgraph2) { + MergeSubgraph(subgraph, subgraph2); + } else { + subgraph->nodes.insert(expr); + } + } + + Expr VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as<OpNode>(); + + if (op_node == nullptr || call->attrs.as<SubgraphAttrs>() == nullptr) { + // Propogate subgraph to arguments + auto subgraph = GetSubgraph(GetRef<Call>(call)); + if (subgraph) { + for (auto arg : call->args) { + AddToSubgraph(subgraph, arg); + } + } + return ExprMutator::VisitExpr_(call); + } else if (GetRef<Op>(op_node) == Op::Get("annotation.subgraph_begin")) { + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + // Traverse the rest graph. + auto input_expr = VisitExpr(call->args[0]); + + // Replace the begin annotation with an external call input variable. + auto subgraph_attrs = call->attrs.as<SubgraphAttrs>(); + auto var = VarNode::make(subgraph_attrs->compiler + "_input" + std::to_string(var_id_++), + input_expr->checked_type_); + + // Find the corresponding subgraph and add the argument. + auto subgraph = GetSubgraph(GetRef<Call>(call)); + if (!subgraph) { + throw Error(RELAY_ERROR("Cannot find the corresponding subgraph for start annotation:\n" + << AsText(GetRef<Call>(call), false))); + } + subgraph->args.push_back({var, input_expr}); + return std::move(var); + } else { + CHECK(GetRef<Op>(op_node) == Op::Get("annotation.subgraph_end")); + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + auto subgraph_attrs = call->attrs.as<SubgraphAttrs>(); + + // Check if the argument is already belonged to an exist subgraph + auto subgraph = GetSubgraph(call->args[0]); + if (!subgraph) { + auto ret = this->subgraphs_.emplace(new Subgraph()); + subgraph = *ret.first; + subgraph->nodes.insert(call->args[0]); + subgraph->id = this->subgraph_id_++; + } + subgraph->nodes.insert(GetRef<Call>(call)); + + // Traverse towarding to subgraph inputs. + auto input = VisitExpr(call->args[0]); + Array<Var> params; + Array<Expr> args; + + // The subgraph may be merged so we need to update it again. + subgraph = GetSubgraph(GetRef<Call>(call)); + for (auto pair : subgraph->args) { + params.push_back(pair.first); + args.push_back(pair.second); + } + + auto subgraph_func = + FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs()); + + Expr arg0 = call->args[0]; + std::string name = subgraph_attrs->compiler + "_" + std::to_string(subgraph->id); + subgraph_func = + FunctionSetAttr(subgraph_func, attr::kFuncName, tvm::ir::StringImm::make(name)); + subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1)); + subgraph_func = FunctionSetAttr(subgraph_func, attr::kExternal, + tvm::ir::StringImm::make(subgraph_attrs->compiler)); + return CallNode::make(subgraph_func, args); + } + } + + Expr VisitExpr_(const TupleNode* op) final { + auto subgraph = GetSubgraph(GetRef<Tuple>(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + for (auto field : op->fields) { + AddToSubgraph(subgraph, field); + } + Array<Expr> fields; + for (auto field : op->fields) { + fields.push_back(VisitExpr(field)); + } + return TupleNode::make(fields); + } + } + + Expr VisitExpr_(const TupleGetItemNode* g) final { + auto subgraph = GetSubgraph(GetRef<TupleGetItem>(g)); + if (!subgraph) { + return ExprMutator::VisitExpr_(g); + } else { + AddToSubgraph(subgraph, g->tuple); + auto t = VisitExpr(g->tuple); + return TupleGetItemNode::make(t, g->index); + } + } + + Expr VisitExpr_(const FunctionNode* op) final { + auto subgraph = GetSubgraph(GetRef<Function>(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + Array<Var> params; + for (auto param : op->params) { + AddToSubgraph(subgraph, param); + } + for (auto param : op->params) { + Var new_param = Downcast<Var>(VisitExpr(param)); + params.push_back(new_param); + } + auto body = VisitExpr(op->body); + return FunctionNode::make(params, body, op->ret_type, op->type_params, op->attrs); + } + } + + Expr VisitExpr_(const LetNode* op) final { + auto subgraph = GetSubgraph(GetRef<Let>(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->var); + AddToSubgraph(subgraph, op->value); + AddToSubgraph(subgraph, op->body); + Var var = Downcast<Var>(VisitExpr(op->var)); + auto value = VisitExpr(op->value); + auto body = VisitExpr(op->body); + + return LetNode::make(var, value, body); + } + } + + Expr VisitExpr_(const IfNode* op) final { + auto subgraph = GetSubgraph(GetRef<If>(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->cond); + AddToSubgraph(subgraph, op->true_branch); + AddToSubgraph(subgraph, op->false_branch); + auto guard = VisitExpr(op->cond); + auto true_b = VisitExpr(op->true_branch); + auto false_b = VisitExpr(op->false_branch); + return IfNode::make(guard, true_b, false_b); + } + } + + Expr VisitExpr_(const RefCreateNode* op) final { + auto subgraph = GetSubgraph(GetRef<RefCreate>(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->value); + Expr value = VisitExpr(op->value); + return RefCreateNode::make(value); + } + } + + Expr VisitExpr_(const RefReadNode* op) final { + auto subgraph = GetSubgraph(GetRef<RefRead>(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->ref); + Expr ref = VisitExpr(op->ref); + return RefReadNode::make(ref); + } + } + + Expr VisitExpr_(const RefWriteNode* op) final { + auto subgraph = GetSubgraph(GetRef<RefWrite>(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->ref); + Expr ref = VisitExpr(op->ref); + Expr value = VisitExpr(op->value); + return RefWriteNode::make(ref, value); + } + } + + private: + int var_id_{0}; + int subgraph_id_{0}; + std::unordered_set<Subgraph*> subgraphs_; +}; + +/*! + * \brief TODO(@zhiics, @comaniac) Combine parallel subgraphs that belong to + * the same codegen backend. This reduces rounds trips between TVM and external + * backends. + * + * For example, sg1 and sg2 should be combined if they belong to the same + * codegen tool in the following case. + * + * op1 + * / \ + * sg1 sg2 + * + * | + * \|/ + * + * op1 + * | + * sg1_sg2 + * + * where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two + * inputs that obtained from the tuple. + */ Review comment: Maybe we can reuse some code from op fusion pass for this ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
