altanh commented on a change in pull request #10026: URL: https://github.com/apache/tvm/pull/10026#discussion_r797142059
########## File path: src/relay/backend/vm/manifest_lifetimes.cc ########## @@ -0,0 +1,611 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/vm/manifest_lifetimes.cc + * \brief Analysis and explicit manifestation of variable lifetimes. + */ + +#include <tvm/relay/transform.h> + +#include "../../op/memory/device_copy.h" +#include "../../transforms/device_aware_visitors.h" +#include "../../transforms/let_list.h" + +namespace tvm { +namespace relay { +namespace transform { + +using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>; + +// TODO(@altanh, @mbs, @mbrookhart): we should do a survey of all "*-flow graphs" in the codebase +// to see what can be deduplicated. + +// TODO(@altanh): support Relay Refs once/if they are supported by the VM. + +/*! + * \brief A representation of an input expression (typically a Function) as a directed graph of + * basic blocks, with edges between basic blocks corresponding to control flow branching. + */ +class ControlFlowGraph { + public: + struct Node; + struct BasicBlock; + + using NodePtr = std::shared_ptr<Node>; + using BasicBlockPtr = std::shared_ptr<BasicBlock>; + + /*! + * \brief A chunk of IR that does not have any control flow branching. At this stage in the IR, + * basic blocks correspond to: + * (1) a sequence of nested Let expressions, where each node in the block corresponds to a + * binding and the last node is either the (non-Let) body or a binding that branches + * (e.g. "let %x = if (%c) { true_block } else { false_block }"). + * (2) an atomic expression representing the target expression of a control flow branch, e.g. + * %v and %u in "let %x = if (%c) { %v } else { %u }". + */ + struct BasicBlock { + // The nodes of the basic block. + std::vector<NodePtr> nodes; + // The predecessor basic blocks. + std::vector<BasicBlockPtr> pred; + // The successor basic blocks. + std::vector<BasicBlockPtr> succ; + + static BasicBlockPtr Make() { return std::make_shared<BasicBlock>(); } + }; + + /*! + * \brief Roughly corresponds to a "statement" in the IR, such as an individual binding in a + * basic block or the "return value" of a block. Each node maps to a single corresponding expr in + * the IR, but the converse is not true (e.g. in the case of variables). + */ + struct Node { + /*! \brief The basic block this node belongs to. */ + BasicBlockPtr parent; + /*! \brief The index into the parent basic block where this node is. */ + size_t index; + /*! \brief The expr this node corresponds to. */ + Expr expr; + + /*! \brief Returns whether or not this node is the first one in the parent basic block. */ + bool IsFirst() const { return index == 0; } + + /*! \brief Returns whether or not this node is the last one in the parent basic block. */ + bool IsLast() const { return index == parent->nodes.size() - 1; } + + /*! \brief Returns the predecessor nodes of this node. */ + std::vector<NodePtr> GetPred() const { + std::vector<NodePtr> pred; + if (IsFirst()) { + for (const BasicBlockPtr& pred_block : parent->pred) { + pred.push_back(pred_block->nodes.back()); + } + } else { + pred.push_back(parent->nodes[index - 1]); + } + return pred; + } + + /*! \brief Returns the successor nodes of this node. */ + std::vector<NodePtr> GetSucc() const { + std::vector<NodePtr> succ; + if (IsLast()) { + for (const BasicBlockPtr& succ_block : parent->succ) { + succ.push_back(succ_block->nodes.front()); + } + } else { + succ.push_back(parent->nodes[index + 1]); + } + return succ; + } + + /*! \brief Creates a node with the given expr and appends it to the parent basic block. */ + static NodePtr Make(BasicBlockPtr parent, Expr expr) { + NodePtr n = std::make_shared<Node>(); + n->parent = parent; + n->expr = expr; + n->index = parent->nodes.size(); + parent->nodes.push_back(n); + return n; + } + }; + + /*! \brief The basic block where control flow begins. */ + BasicBlockPtr entry; + + /*! + * \brief Mapping from Let expressions to their corresponding nodes. Note that Let expressions + * are never shared in ANF (unlike vars), so this is an injection. + */ + std::unordered_map<Expr, NodePtr, ObjectPtrHash, ObjectPtrEqual> let_map; + + /*! \brief The nodes of the CFG in reverse post order. */ + std::vector<NodePtr> reverse_post_order; + + /*! \brief Creates and returns the CFG of the given expression. */ + static ControlFlowGraph Create(const Expr& body); + + private: + class Creator; +}; + +/*! \brief Helper class for building CFGs. */ +class ControlFlowGraph::Creator : private ExprFunctor<void(const Expr&, BasicBlockPtr)> { + public: + Creator() {} + + ControlFlowGraph Create(const Expr& body) { + cfg_.entry = BasicBlock::Make(); + VisitExpr(body, cfg_.entry); + return std::move(cfg_); + } + + private: + /*! \brief The CFG being built. */ + ControlFlowGraph cfg_; + /*! + * \brief Whether or not we are in a function. CFGs do not support nested functions so this is + * used to error out in such a case. + */ + bool in_func_ = false; + + /*! + * \brief Link \p to as a successor block to \p from. + */ + void Succ(BasicBlockPtr from, BasicBlockPtr to) { + from->succ.push_back(to); + to->pred.push_back(from); + } + +#define DEFAULT_CFG(OP) \ + void VisitExpr_(const OP* op, BasicBlockPtr parent) final { \ + NodePtr n = Node::Make(parent, GetRef<Expr>(op)); \ + cfg_.reverse_post_order.push_back(n); \ + } + + void VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) final { + ICHECK(!in_func_) << "nested functions not supported by CFG analysis"; + in_func_ = true; + + // Unwrap the nested function and proceed normally. + if (f->HasNonzeroAttr(attr::kClosure)) { + ICHECK(f->body.as<FunctionNode>()); + return VisitExpr(Downcast<Function>(f->body)->body, parent); + } + + return VisitExpr(f->body, parent); + } + + void VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) final { + Expr expr = GetRef<Expr>(let_node); + + while (const LetNode* inner_let_node = expr.as<LetNode>()) { + NodePtr curr_node = Node::Make(parent, expr); + + ICHECK(!cfg_.let_map.count(expr)); + cfg_.let_map[expr] = curr_node; + cfg_.reverse_post_order.push_back(curr_node); + + if (const IfNode* ite = AsIgnoringOnDevice<IfNode>(inner_let_node->value)) { Review comment: let me know if the comment I added helps at all -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
