masahi commented on a change in pull request #4570: [relay] Relay annotation and partitioning for external compilers URL: https://github.com/apache/incubator-tvm/pull/4570#discussion_r361615989
########## File path: src/relay/pass/partition_graph.cc ########## @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file src/relay/pass/partition.cc + * + * \brief Partition an input function into multiple functions according based + * on the inserted annotation nodes (i.e. compiler_begin and compiler_end). + * These nodes are used as boundaries to partition the Relay function into + * multiple regions that can be offloaded to different accelerators/backends. + * + * Each of these paritioned functions, a.k.a subgraphs, will be viewed as + * external functions, and they will use the provided compiler for codegen. + */ + +#include <tvm/relay/analysis.h> +#include <tvm/relay/attrs/annotation.h> +#include <tvm/relay/expr.h> +#include <tvm/relay/expr_functor.h> +#include <tvm/relay/transform.h> + +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +namespace tvm { +namespace relay { +namespace partitioning { + +/*! + * \brief The subgraph properties for partitioning. + */ +struct Subgraph { + /*! \brief The subgraph ID. */ + int id; + + /*! \brief The input arguments of this subgraph. */ + std::vector<std::pair<Var, Expr>> args; + + /*! \brief Nodes in this subgraph. */ + std::unordered_set<Expr, ExprHash, ExprEqual> nodes; +}; + +/*! + * \brief The checker that verifies if a Relay program is annotated correctly + * for partitioning. + */ +class AnnotationChecker : public ExprVisitor { + public: + bool Check() { + if (!this->found_start && !this->found_end) { + LOG(WARNING) << "No compiler annotation found"; + } else if (!this->found_start) { + LOG(ERROR) << "compiler_begin annotation is missing"; + return false; + } else if (!this->found_end) { + LOG(ERROR) << "compiler_end annotation is missing"; + return false; + } + return true; + } + + void VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as<OpNode>(); + if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { + return; + } else if (GetRef<Op>(op_node) == Op::Get("annotation.compiler_begin")) { + this->found_start = true; + } else if (GetRef<Op>(op_node) == Op::Get("annotation.compiler_end")) { + this->found_end = true; + } + } + + private: + bool found_start = false; + bool found_end = false; +}; + +/*! \brief This class partitions the expr labeled with begin and end annoations + * into function containing multiple regions. Each region is labeled with + * a compiler attribute so that it will be handled by any compilers that are not + * in the TVM stack. + * + * TODO(@zhiics) This following algorithm is not adequate to handle all cases, + * i.e. multiple `compiler_end` nodes. + */ +class Partitioner : public ExprMutator { + public: + Subgraph* GetSubgraph(const Expr node) { + for (auto candidate : this->subgraphs_) { + if (candidate->nodes.find(node) != candidate->nodes.end()) { + return candidate; + } + } + return nullptr; + } + + void MergeSubgraph(Subgraph* subgraph1, Subgraph* subgraph2) { + if (subgraph1 == subgraph2) { + return; + } + + // Merge subgraph 2 to subgraph 1 and erase subgraph 2. + subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end()); + for (auto arg : subgraph2->args) { + subgraph1->args.push_back(arg); + } + this->subgraphs_.erase(subgraph2); + } + + void AddToSubgraph(Subgraph* subgraph, const Expr expr) { + auto subgraph2 = GetSubgraph(expr); + if (subgraph2) { + MergeSubgraph(subgraph, subgraph2); + } else { + subgraph->nodes.insert(expr); + } + } + + Expr VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as<OpNode>(); + + if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { + // Propogate subgraph to arguments + auto subgraph = GetSubgraph(GetRef<Call>(call)); + if (subgraph) { + for (auto arg : call->args) { + AddToSubgraph(subgraph, arg); + } + } + return ExprMutator::VisitExpr_(call); + } else if (GetRef<Op>(op_node) == Op::Get("annotation.compiler_begin")) { + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + // Traverse the rest graph. + auto input_expr = VisitExpr(call->args[0]); + + // Replace the begin annotation with an external call input variable. + auto compiler_attrs = call->attrs.as<CompilerAttrs>(); + auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++), + input_expr->checked_type_); + + // Find the corresponding subgraph and add the argument. + auto subgraph = GetSubgraph(GetRef<Call>(call)); + if (!subgraph) { + throw Error(RELAY_ERROR("Cannot find the corresponding subgraph for start annotation:\n" + << AsText(GetRef<Call>(call), false))); + } + subgraph->args.push_back({var, input_expr}); + return std::move(var); + } else { + CHECK(GetRef<Op>(op_node) == Op::Get("annotation.compiler_end")); + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + auto compiler_attrs = call->attrs.as<CompilerAttrs>(); + + // Check if the argument is already belonged to an exist subgraph + auto subgraph = GetSubgraph(call->args[0]); + if (!subgraph) { + auto ret = this->subgraphs_.emplace(new Subgraph()); + subgraph = *ret.first; + subgraph->nodes.insert(call->args[0]); + subgraph->id = this->subgraph_id_++; + } + subgraph->nodes.insert(GetRef<Call>(call)); + + // Traverse towarding to subgraph inputs. + auto input = VisitExpr(call->args[0]); + Array<Var> params; + Array<Expr> args; + + // The subgraph may be merged so we need to update it again. + subgraph = GetSubgraph(GetRef<Call>(call)); Review comment: may be better to add CHECK(subgraph) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
