altanh commented on code in PR #11208: URL: https://github.com/apache/tvm/pull/11208#discussion_r877477353
########## src/relay/backend/aot/annotate_used_memory.cc: ########## @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/aot/annotate_used_memory.cc + * \brief Analyzes the memory pressure at the callsite of primitive functions. + */ + +#include <tvm/ir/module.h> +#include <tvm/relay/transform.h> + +#include "../../transforms/device_aware_visitors.h" +#include "../manifest_lifetimes.h" + +namespace tvm { +namespace relay { +namespace backend { +namespace aot { + +/*! + * \brief Annotates the memory usage of each primitive function by analysing the liveness + * of the input/output tensors at the function callsite and calculating the total amount of + * memory these tensors require. + */ +class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { + public: + AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg, + const transform::LivenessAnalysis& lva) + : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {} + + /*! + * \brief Get the memory required for a primitive Relay function by calculating the total + * bytes of the live tensors at the callsite of the function. + * + * \param live_tensors The tensors that are live when the function is called. + * \return int The total number of bytes a function requires. + */ + int GetMemoryUsage(const transform::VarSet& live_tensors) { + Array<Type> types_stack = {}; + int memory_usage = 0; Review Comment: can we widen the type for `memory_usage`? this would overflow at ~2GB which is pretty realistic these days. maybe `uint64_t`? ########## src/relay/backend/aot/annotate_used_memory.cc: ########## @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/aot/annotate_used_memory.cc + * \brief Analyzes the memory pressure at the callsite of primitive functions. + */ + +#include <tvm/ir/module.h> +#include <tvm/relay/transform.h> + +#include "../../transforms/device_aware_visitors.h" +#include "../manifest_lifetimes.h" + +namespace tvm { +namespace relay { +namespace backend { +namespace aot { + +/*! + * \brief Annotates the memory usage of each primitive function by analysing the liveness + * of the input/output tensors at the function callsite and calculating the total amount of + * memory these tensors require. + */ +class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { + public: + AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg, + const transform::LivenessAnalysis& lva) + : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {} + + /*! + * \brief Get the memory required for a primitive Relay function by calculating the total + * bytes of the live tensors at the callsite of the function. + * + * \param live_tensors The tensors that are live when the function is called. + * \return int The total number of bytes a function requires. + */ + int GetMemoryUsage(const transform::VarSet& live_tensors) { + Array<Type> types_stack = {}; + int memory_usage = 0; + + for (const Var& var : live_tensors) { + Type var_type = var->checked_type(); + ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass."; + types_stack.push_back(var_type); + } + + while (!types_stack.empty()) { + Type current_type = types_stack.back(); + types_stack.pop_back(); + + if (const auto* tt_node = current_type.as<TupleTypeNode>()) { + for (const Type& type : tt_node->fields) { + types_stack.push_back(type); + } + continue; + } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) { + types_stack.push_back(ft_node->ret_type); + continue; + } + + const auto* tt_node = current_type.as<TensorTypeNode>(); + ICHECK(tt_node) << "Expected TensorTypeNode but was " << current_type->GetTypeKey(); + int total_tensor_bytes = GetTensorBytes(tt_node); + memory_usage += total_tensor_bytes; + } + return memory_usage; + } + + /*! + * \brief Get the number of bytes a tensor requires. + * + * \param tensor_type_node The checked type of the tensor. + * \return int The number of bytes required. + */ + int GetTensorBytes(const TensorTypeNode* tensor_type_node) { + PrimExpr size = tensor_type_node->Size(); + const auto* size_int_imm = size.as<IntImmNode>(); + ICHECK(size_int_imm) << "Expected tensor size to be an IntImmNode but was " + << size->GetTypeKey(); + + int total_size = size_int_imm->value; + int dtype_bytes = tensor_type_node->dtype.bytes(); + return total_size * dtype_bytes; + } + + Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override { + if (const auto* func_node = pre_let_node->value.as<FunctionNode>()) { + const auto let_bound_values = control_flow_graph_.let_map; Review Comment: probably worth making this a reference so the map doesn't get copied unnecessarily ########## src/relay/backend/aot/annotate_used_memory.cc: ########## @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/aot/annotate_used_memory.cc + * \brief Analyzes the memory pressure at the callsite of primitive functions. + */ + +#include <tvm/ir/module.h> +#include <tvm/relay/transform.h> + +#include "../../transforms/device_aware_visitors.h" +#include "../manifest_lifetimes.h" + +namespace tvm { +namespace relay { +namespace backend { +namespace aot { + +/*! + * \brief Annotates the memory usage of each primitive function by analysing the liveness + * of the input/output tensors at the function callsite and calculating the total amount of + * memory these tensors require. + */ +class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { + public: + AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg, + const transform::LivenessAnalysis& lva) + : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {} + + /*! + * \brief Get the memory required for a primitive Relay function by calculating the total + * bytes of the live tensors at the callsite of the function. + * + * \param live_tensors The tensors that are live when the function is called. + * \return int The total number of bytes a function requires. + */ + int GetMemoryUsage(const transform::VarSet& live_tensors) { + Array<Type> types_stack = {}; + int memory_usage = 0; + + for (const Var& var : live_tensors) { + Type var_type = var->checked_type(); + ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass."; + types_stack.push_back(var_type); + } + + while (!types_stack.empty()) { + Type current_type = types_stack.back(); + types_stack.pop_back(); + + if (const auto* tt_node = current_type.as<TupleTypeNode>()) { + for (const Type& type : tt_node->fields) { + types_stack.push_back(type); + } + continue; + } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) { + types_stack.push_back(ft_node->ret_type); Review Comment: why would a function show up here? is it a function call? seems like we should just ignore functions bound to variables since it's the actual result (e.g.`x: Tensor[...] = f_var(...)`) that uses space, not the function variable itself ########## src/relay/backend/aot/annotate_used_memory.cc: ########## @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/aot/annotate_used_memory.cc + * \brief Analyzes the memory pressure at the callsite of primitive functions. + */ + +#include <tvm/ir/module.h> +#include <tvm/relay/transform.h> + +#include "../../transforms/device_aware_visitors.h" +#include "../manifest_lifetimes.h" + +namespace tvm { +namespace relay { +namespace backend { +namespace aot { + +/*! + * \brief Annotates the memory usage of each primitive function by analysing the liveness + * of the input/output tensors at the function callsite and calculating the total amount of + * memory these tensors require. + */ +class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { + public: + AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg, + const transform::LivenessAnalysis& lva) + : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {} + + /*! + * \brief Get the memory required for a primitive Relay function by calculating the total + * bytes of the live tensors at the callsite of the function. + * + * \param live_tensors The tensors that are live when the function is called. + * \return int The total number of bytes a function requires. + */ + int GetMemoryUsage(const transform::VarSet& live_tensors) { + Array<Type> types_stack = {}; + int memory_usage = 0; + + for (const Var& var : live_tensors) { + Type var_type = var->checked_type(); + ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass."; + types_stack.push_back(var_type); + } + + while (!types_stack.empty()) { + Type current_type = types_stack.back(); + types_stack.pop_back(); + + if (const auto* tt_node = current_type.as<TupleTypeNode>()) { + for (const Type& type : tt_node->fields) { + types_stack.push_back(type); + } + continue; + } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) { + types_stack.push_back(ft_node->ret_type); + continue; + } + + const auto* tt_node = current_type.as<TensorTypeNode>(); + ICHECK(tt_node) << "Expected TensorTypeNode but was " << current_type->GetTypeKey(); + int total_tensor_bytes = GetTensorBytes(tt_node); + memory_usage += total_tensor_bytes; + } + return memory_usage; + } + + /*! + * \brief Get the number of bytes a tensor requires. + * + * \param tensor_type_node The checked type of the tensor. + * \return int The number of bytes required. + */ + int GetTensorBytes(const TensorTypeNode* tensor_type_node) { + PrimExpr size = tensor_type_node->Size(); + const auto* size_int_imm = size.as<IntImmNode>(); + ICHECK(size_int_imm) << "Expected tensor size to be an IntImmNode but was " + << size->GetTypeKey(); + + int total_size = size_int_imm->value; + int dtype_bytes = tensor_type_node->dtype.bytes(); + return total_size * dtype_bytes; + } + + Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override { + if (const auto* func_node = pre_let_node->value.as<FunctionNode>()) { Review Comment: yeah I am also confused, since (at least from IR perspective) this binding is not a call to a (primitive?) function but just binding the function to a var. I would have expected `pre_let_node->value.as<CallNode>()` with some check that the op in the call is a func. Maybe I am missing some context for the IR form at this point? definitely would like a comment ########## src/relay/backend/manifest_lifetimes.cc: ########## @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/manifest_lifetimes.cc + * \brief Analysis and explicit manifestation of variable lifetimes. NOTE: the input IR should be in + * ANF and post-memory-lowering (explicit manifestation of allocations). + */ + +#include "manifest_lifetimes.h" + +#include <list> +#include <unordered_set> +#include <utility> +#include <vector> + +namespace tvm { +namespace relay { +namespace transform { + +using support::Arena; +using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>; + +ControlFlowGraph ControlFlowGraph::Create(Arena* arena, const Expr& body) { + return Creator().Create(arena, body); +} + +ControlFlowGraph ControlFlowGraph::Creator::Create(Arena* arena, const Expr& body) { + arena_ = arena; + cfg_.entry = BasicBlock::Make(arena); + VisitExpr(body, cfg_.entry); + return std::move(cfg_); +} + +void ControlFlowGraph::Creator::Succ(BasicBlockPtr from, BasicBlockPtr to) { + from->succ.push_back(to); + to->pred.push_back(from); +} + +void ControlFlowGraph::Creator::VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) { + ICHECK(!in_func_) << "nested functions not supported by CFG analysis"; + in_func_ = true; + + // Unwrap the nested function and proceed normally. + if (f->HasNonzeroAttr(attr::kClosure)) { + ICHECK(f->body.as<FunctionNode>()); + return VisitExpr(Downcast<Function>(f->body)->body, parent); + } + + return VisitExpr(f->body, parent); +} + +void ControlFlowGraph::Creator::VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) { + Expr expr = GetRef<Expr>(let_node); + + while (const LetNode* inner_let_node = expr.as<LetNode>()) { + NodePtr curr_node = Node::Make(arena_, parent, expr); + + ICHECK(!cfg_.let_map.count(expr)); + cfg_.let_map[expr] = curr_node; + cfg_.reverse_post_order.push_back(curr_node); + + // The basic block ends upon reaching control flow, with successor blocks corresponding to the + // control flow branch exprs (true/false in If, and one for each clause in Match). + if (const IfNode* ite = AsIgnoringOnDevice<IfNode>(inner_let_node->value)) { + // Create the basic blocks for each branch and mark them as successors to the current block. + BasicBlockPtr t_block = BasicBlock::Make(arena_); + BasicBlockPtr f_block = BasicBlock::Make(arena_); + Succ(parent, t_block); + Succ(parent, f_block); + + VisitExpr(ite->true_branch, t_block); + VisitExpr(ite->false_branch, f_block); + + // All subsequent bindings (and/or the body expr) will be in a new basic block. + BasicBlockPtr next = BasicBlock::Make(arena_); + Succ(t_block, next); + Succ(f_block, next); + parent = next; + } else if (const MatchNode* match = AsIgnoringOnDevice<MatchNode>(inner_let_node->value)) { + // Same as above but one for each pattern. + std::vector<BasicBlockPtr> clause_blocks; + BasicBlockPtr next = BasicBlock::Make(arena_); + for (const Clause& clause : match->clauses) { + BasicBlockPtr clause_block = BasicBlock::Make(arena_); + Succ(parent, clause_block); + Succ(clause_block, next); + VisitExpr(clause->rhs, clause_block); + } + parent = next; + } + + expr = inner_let_node->body; + } + + VisitExpr(expr, parent); +} + +void ControlFlowGraph::Creator::VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) { + // TODO(@altanh): is there a way of making this work? + LOG(FATAL) << "If expressions should be bound to variables."; +} + +void ControlFlowGraph::Creator::VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) { + // TODO(@altanh): same as If + LOG(FATAL) << "Match expressions should be bound to variables."; +} + +VarSet VarUseCollector::VisitExpr_(const VarNode* var_node) { return {GetRef<Var>(var_node)}; } + +VarSet VarUseCollector::VisitExpr_(const CallNode* call_node) { + VarSet use = VisitExpr(call_node->op); + for (const Expr& arg : call_node->args) { + VarSet arg_use = VisitExpr(arg); + use.insert(arg_use.begin(), arg_use.end()); + } + return use; +} + +VarSet VarUseCollector::VisitExpr_(const TupleNode* tuple_node) { + VarSet use; + for (const Expr& field : tuple_node->fields) { + VarSet field_use = VisitExpr(field); + use.insert(field_use.begin(), field_use.end()); + } + return use; +} + +VarSet VarUseCollector::VisitExpr_(const TupleGetItemNode* get_node) { + return VisitExpr(get_node->tuple); +} + +VarSet VarUseCollector::VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); } + +VarSet VarUseCollector::VisitExpr_(const MatchNode* match_node) { + return VisitExpr(match_node->data); +} + +UseDefAnalysis UseDefAnalysis::Analyze(const CFG& cfg) { + UseDefAnalysis a; + + // One pass is sufficient. + for (auto it = cfg.reverse_post_order.begin(); it != cfg.reverse_post_order.end(); ++it) { + const CFG::NodePtr& node = *it; + if (const LetNode* let_node = AsIgnoringOnDevice<LetNode>(node->expr)) { + a.use[node] = a.use_collector.VisitExpr(let_node->value); + a.def[node] = let_node->var; + } else { + a.use[node] = a.use_collector.VisitExpr(node->expr); + a.def[node] = Var(); + } + } + + return a; +} + +bool SetEqual(const VarSet& a, const VarSet& b) { + if (a.size() != b.size()) { + return false; + } + for (auto& xa : a) { + if (!b.count(xa)) { + return false; + } + } + return true; +} + +LivenessAnalysis LivenessAnalysis::Analyze(const ControlFlowGraph& cfg, + const UseDefAnalysis& use_def) { + LivenessAnalysis a; + std::list<CFG::NodePtr> worklist; + + // Initialize worklist to post-order traversal for quick convergence. + worklist.insert(worklist.end(), cfg.reverse_post_order.rbegin(), cfg.reverse_post_order.rend()); + + // See https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm. + auto visitor = [&](const CFG::NodePtr n) { + VarSet old_in_n = a.live_in[n]; + VarSet old_out_n = a.live_out[n]; + + a.live_in[n] = use_def.use.at(n); + for (const Var& v : a.live_out[n]) { + if (!v.same_as(use_def.def.at(n))) { + a.live_in[n].insert(v); + } + } + + a.live_out[n] = VarSet(); + for (const CFG::NodePtr& s : n->GetSucc()) { + a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end()); + } + + if (SetEqual(old_in_n, a.live_in[n]) && SetEqual(old_out_n, a.live_out[n])) { + // No need to update the worklist. + } else { + // Add predecessor nodes back to worklist (no need to add successors, since each node's + // in/out sets are not dependent on its predecessors). + for (const CFG::NodePtr& p : n->GetPred()) { + worklist.push_back(p); + } + } + }; + + while (!worklist.empty()) { + const CFG::NodePtr n = worklist.front(); + worklist.pop_front(); + visitor(n); + } + + return a; +} + +Expr KillInserter::VisitExpr_(const LetNode* let_node) { + Expr expr = GetRef<Expr>(let_node); + LetList ll; + + while (const LetNode* inner_let_node = expr.as<LetNode>()) { + ll.Push(inner_let_node->var, VisitExpr(inner_let_node->value)); + + ICHECK(!inner_let_node->value.as<VarNode>()) << "aliasing should have been eliminated."; + ICHECK(cfg_->let_map.count(expr)) << "all Let exprs should be mapped in the CFG"; + + const ControlFlowGraph::NodePtr n = cfg_->let_map.at(expr); + + const VarSet& li = lva_->live_in.at(n); + const VarSet& lo = lva_->live_out.at(n); + + // Killed vars = live in - live out. + VarSet kills; + for (const Var& v : li) { + if (!lo.count(v)) { + kills.insert(v); + } + } + + for (const Var& v : kills) { + ll.Push(Call(Op::Get("memory.kill"), {v})); + } + + expr = inner_let_node->body; + } + + return ll.Get(VisitExpr(expr)); +} + +Expr AliasEliminator::VisitExpr_(const LetNode* let_node) { + Expr expr = GetRef<Expr>(let_node); + LetList ll; + std::vector<Var> aliased_vars; + + while (const LetNode* inner_let_node = expr.as<LetNode>()) { + const Var& var = inner_let_node->var; + const Expr& val = inner_let_node->value; + bool aliased = false; + ICHECK(!alias_.count(var)); + + if (const VarNode* alias_of_n = AsIgnoringOnDevice<VarNode>(val)) { + alias_[var] = Downcast<Var>(VisitExpr_(alias_of_n)); + aliased = true; + } else if (AsIgnoringOnDevice<CallNode>(val)) { + // Copying to the same device is aliasing. + // WARNING: this must be kept in sync with the VM compiler logic in + // src/relay/backend/vm/compiler.cc, line 541, in DeviceAwareVisitExpr_(const CallNode*). Review Comment: I feel a little weird about this move, since there's still some VM-specific logic in the stuff for manifesting lifetimes. The other stuff can probably be factored out though (CFG, and the basic analyses). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
