Lunderberg commented on code in PR #15689: URL: https://github.com/apache/tvm/pull/15689#discussion_r1343017657
########## include/tvm/relax/dataflow_analysis.h: ########## @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/object.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) Review Comment: Does this imply that the analysis can only be applied to normalized relax expressions? If non-normalized, the body of a `SeqExpr` is bound to a variable in the containing `BindingBlock` or `DataflowBlock`. ########## include/tvm/relax/dataflow_analysis.h: ########## @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/object.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) + */ +enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; Review Comment: Nit: Can we make this be `enum class`, to avoid accidental conversions to/from integer? ########## python/tvm/relax/analysis/analysis.py: ########## @@ -389,6 +389,29 @@ def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]: return _ffi_api.udchain(dfb) # type: ignore +def liveness_analysis(func: Function) -> List[Set[Var]]: + """ + Perform a liveness analysis on the given function, returning a set of + the variables live in the given program location. + + Parameters + ---------- + func: Function + The function to be analyzed + + Returns + ------- + ret: List[Set[Var]] + The set of live variables for each binding in the function. + The indexing is determined by the control flow graph, so + use `extract_cfg` and `get_binding_index` to find the index + for a given program location in the list. + """ + live_lists = _ffi_api.LivenessAnalysis(func) + # convert the lists to sets + return [set(live_list) for live_list in live_lists] Review Comment: What is the advantage of converting the list to a set? If it is required for de-duplication, we probably should do that on the C++ side so that C++ callees also get de-duplicated outputs. ########## src/relax/analysis/dataflow_analysis.cc: ########## @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis/dataflow_analysis.cc + * \brief Implementation of functionality in dataflow_analysis.h + */ +#include <tvm/relax/dataflow_analysis.h> +#include <tvm/runtime/memory.h> + +#include <queue> + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(GraphBindingNode); + +GraphBinding GraphBinding::Create(const SeqExpr& seq, const Array<Var>& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind) { + ObjectPtr<GraphBindingNode> n = make_object<GraphBindingNode>(); + n->seq = seq; + n->args = args; + n->block_idx = block_idx; + n->binding_idx = binding_idx; + n->kind = kind; + return GraphBinding(n); +} + +TVM_REGISTER_NODE_TYPE(ControlFlowGraphNode); + +ControlFlowGraph ControlFlowGraph::Create(const Array<GraphBinding>& bindings, + const Array<Array<Integer>>& preds, + const Array<Array<Integer>>& succs) { + ObjectPtr<ControlFlowGraphNode> n = make_object<ControlFlowGraphNode>(); + n->bindings = bindings; + n->preds = preds; + n->succs = succs; + return ControlFlowGraph(n); +} + +// Extracts a basic block and updates the running lists bindings, preds, and succs. +// The return value is the index of the final binding processed in the seq expression +// (useful for processing branches). +size_t ExtractCFGHelper(const SeqExpr& seq, const Array<Var>& args, size_t block_idx, + size_t binding_idx, std::vector<size_t> current_preds, + std::vector<GraphBinding>* bindings, + std::vector<std::vector<size_t>>* preds, + std::vector<std::vector<size_t>>* succs) { + // case 1: We're past the end -> this is the block body (base case) + if (block_idx == seq->blocks.size()) { + bindings->push_back(GraphBinding::Create(seq, args, block_idx, 0U, BindingNodeKind::kSeqBody)); + preds->push_back(current_preds); + // the final binding has no successors + succs->push_back({}); + return bindings->size() - 1; + } + + Binding binding = seq->blocks[block_idx]->bindings[binding_idx]; + Expr binding_value = GetBoundValue(binding); + + // case 2: Ordinary binding + if (!binding_value.as<IfNode>()) { + bindings->push_back( + GraphBinding::Create(seq, args, block_idx, binding_idx, BindingNodeKind::kBinding)); + size_t idx = bindings->size() - 1; + preds->push_back(current_preds); + // successor: the next binding (there will always be at least one binding after this, + // even if it's the seq body) + succs->push_back({idx + 1}); + } else { + // case 3: dealing with a branch + auto if_node = Downcast<If>(binding_value); + // start with the cond node + bindings->push_back( + GraphBinding::Create(seq, args, block_idx, binding_idx, BindingNodeKind::kIfCond)); + size_t idx = bindings->size() - 1; + preds->push_back(current_preds); + // there will be another successor, which we will add after recursing down the branches + succs->push_back({idx + 1}); + size_t final_true_idx = ExtractCFGHelper(Downcast<SeqExpr>(if_node->true_branch), {}, 0U, 0U, + {idx}, bindings, preds, succs); + succs->at(idx).push_back(final_true_idx + 1); + size_t final_false_idx = ExtractCFGHelper(Downcast<SeqExpr>(if_node->false_branch), {}, 0U, 0U, + {idx}, bindings, preds, succs); + // now create the merge + bindings->push_back( + GraphBinding::Create(seq, {}, block_idx, binding_idx, BindingNodeKind::kIfMerge)); + size_t merge_idx = bindings->size() - 1; + preds->push_back({final_true_idx, final_false_idx}); + succs->push_back({merge_idx + 1}); + // update the successors of the final true and false indices as well + succs->at(final_true_idx).push_back(merge_idx); + succs->at(final_false_idx).push_back(merge_idx); + } + // move on to next binding + size_t next_block_idx = block_idx; + size_t next_binding_idx = binding_idx + 1; + if (next_binding_idx >= seq->blocks[block_idx]->bindings.size()) { + next_block_idx = block_idx + 1; + next_binding_idx = 0U; + } + return ExtractCFGHelper(seq, {}, next_block_idx, next_binding_idx, {bindings->size() - 1}, + bindings, preds, succs); +} + +ControlFlowGraph ExtractCFG(const Function& func) { + std::vector<GraphBinding> bindings; + std::vector<std::vector<size_t>> preds; + std::vector<std::vector<size_t>> succs; + ExtractCFGHelper(Downcast<SeqExpr>(func->body), func->params, 0U, 0U, {}, &bindings, &preds, + &succs); + + Array<Array<Integer>> pred_arr; + for (auto pred_vec : preds) { + Array<Integer> pred_ints; + for (auto idx : pred_vec) { + pred_ints.push_back(Integer(idx)); + } + pred_arr.push_back(pred_ints); + } + Array<Array<Integer>> succ_arr; + for (auto succ_vec : succs) { + Array<Integer> succ_ints; + for (auto idx : succ_vec) { + succ_ints.push_back(Integer(idx)); + } + succ_arr.push_back(succ_ints); + } + return ControlFlowGraph::Create(Array<GraphBinding>(bindings), pred_arr, succ_arr); +} + +std::pair<Array<ObjectRef>, Array<ObjectRef>> DataflowAnalysis( + const ControlFlowGraph& cfg, const ObjectRef& init, + std::function<ObjectRef(const GraphBinding&, const ObjectRef&)> transfer_func, + std::function<ObjectRef(const ObjectRef&, const ObjectRef&)> merge_func, bool forward) { + std::vector<ObjectRef> in_map; + std::vector<ObjectRef> out_map; + for (size_t i = 0; i < cfg->bindings.size(); i++) { + in_map.push_back(init); + out_map.push_back(init); + } + + // Modification from Adrian Sampson's version: + // Since there are no loops in our AST, one traversal through the CFG suffices. + // We will do BFS + std::queue<size_t> worklist; + worklist.push((forward) ? 0 : cfg->bindings.size() - 1); + while (!worklist.empty()) { + size_t idx = worklist.front(); + worklist.pop(); + Array<Integer> prev = (forward) ? cfg->preds[idx] : cfg->succs[idx]; + Array<Integer> next = (forward) ? cfg->succs[idx] : cfg->preds[idx]; + std::vector<ObjectRef>* results = (forward) ? &out_map : &in_map; + std::vector<ObjectRef>* inputs = (forward) ? &in_map : &out_map; + + // Cases (for forward analysis): + // 0 predecessors: The first block in the function + // 1 predecessor: A branch in an If node (no merge needed) + // 2 predecessors: The merge block after an If node (merge needed) + // (Analogous for successors in backward analysis) + inputs->operator[](idx) = (prev.size() == 0) ? init + : (prev.size() == 1) ? results->at(prev[0].IntValue()) + : merge_func(results->at(prev[0].IntValue()), + results->at(prev[1].IntValue())); + results->operator[](idx) = transfer_func(cfg->bindings[idx], inputs->at(idx)); + + for (Integer next_idx : next) { + worklist.push(next_idx.IntValue()); + } + } + + return {Array<ObjectRef>(in_map), Array<ObjectRef>(out_map)}; +} + +size_t GetBindingIndex(const ControlFlowGraph& cfg, const SeqExpr& seq, size_t block_idx, + size_t binding_idx, bool match_cond) { + bool is_body = (block_idx == seq->blocks.size()); + bool is_if = + (!is_body && (GetBoundValue(seq->blocks[block_idx]->bindings[binding_idx]).as<IfNode>())); + + // This is an inefficient linear scan; it could be improved by keeping a map of + // SeqExprs to indices in the CFG data structure. + // That should be considered if this function poses performance issues (unlikely). + for (size_t i = 0; i < cfg->bindings.size(); i++) { Review Comment: Another advantage of using a `relax::Var` to specify the location at which that variable is being bound, we would avoid the possible future problem of the linear scan being inefficient. ########## src/relax/analysis/liveness.cc: ########## @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis/liveness.cc + * \brief Implementation of liveness analysis + */ +#include <tvm/relax/analysis.h> +#include <tvm/relax/dataflow_analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/runtime/object.h> + +namespace tvm { +namespace relax { + +// just sets of vars. the bool value is unnecessary +using Domain = Map<Var, Bool>; Review Comment: Instead of using the dummy `Bool` object, we could instead use a `std::unordered_set<const VarNode*>` or a `std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>`. It would require changing the signature of `DataflowAnalysis`, but nothing in it depends on the `ObjectRef` base class. ########## include/tvm/relax/dataflow_analysis.h: ########## @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/object.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) + */ +enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; + +class GraphBindingNode : public Object { + public: + /*! \brief The SeqExpr the binding resides in. */ + SeqExpr seq; + + /*! \brief The arguments to the binding. Only the first binding in the graph has arguments + * (i.e., the function arguments). */ + Array<Var> args; + + /*! \brief Index of the binding block in the SeqExpr where the binding is found. + * Convention: We put the SeqExpr body at one block past the final block. */ + size_t block_idx; + + /*! \brief Index of the binding within the binding block corresponding to this binding. + * Convention: Both the If condition and merge are mapped to the same index. + * We use the kind to distinguish. */ + size_t binding_idx; + + /*! \brief The kind of binding this is. */ + BindingNodeKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("seq", &seq); + v->Visit("args", &args); + v->Visit("block_idx", &block_idx); + v->Visit("binding_idx", &binding_idx); + v->Visit("kind", &kind); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.GraphBinding"; + TVM_DECLARE_BASE_OBJECT_INFO(GraphBindingNode, Object); +}; + +/*! \brief Representation of a binding in the control flow graph */ +class GraphBinding : public ObjectRef { + public: + /*! + * \brief Create a GraphBinding. See the docs on GraphBindingNode for further details. + * + * \param seq: The SeqExpr in which the binding resides. + * \param args: The arguments to the binding (only nonempty for the first binding: + * these will be the function arguments) + * \param block_idx: The index of the BindingBlock in the SeqExpr + * where the binding resides (for the return expression, use one past the final block). + * \param binding_idx: The index of the binding in the BindingBlock corresponding to the binding. + * \param kind: The kind of binding this is. (Used especially to distinguish If node conditions + * from the merge after the If) + */ + TVM_DLL static GraphBinding Create(const SeqExpr& seq, const Array<Var>& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(GraphBinding, ObjectRef, GraphBindingNode); +}; + +/* A control flow graph corresponding to a function. + */ +class ControlFlowGraphNode : public Object { + public: + /*! \brief The bindings in the graph. 0 is the entry point. */ + Array<GraphBinding> bindings; + /*! \brief The ith member is the list of predecessors (indices) to binding i in bindings. */ + Array<Array<Integer>> preds; + /*! \brief The ith member is the list of successors (indices) to binding i in bindings. */ + Array<Array<Integer>> succs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("bindings", &bindings); + v->Visit("preds", &preds); + v->Visit("succs", &succs); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.ControlFlowGraph"; + TVM_DECLARE_BASE_OBJECT_INFO(ControlFlowGraphNode, Object); +}; + +class ControlFlowGraph : public ObjectRef { + public: + /*! + * \brief Create a ControlFlowGraph. + * + * \param bindings: The bindings in the graph + * \param preds: List of lists of predecessors to each binding. + * \param succs: List of lists of successors to each binding. + */ + TVM_DLL static ControlFlowGraph Create(const Array<GraphBinding>& bindings, + const Array<Array<Integer>>& preds, + const Array<Array<Integer>>& succs); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ControlFlowGraph, ObjectRef, ControlFlowGraphNode); +}; + +/*! + * \brief Extracts the control flow graph for a Relax function. + * \param func The function. This conversion expects it to be normalized. + * \return The control flow graph corresponding to the function. + */ +ControlFlowGraph ExtractCFG(const Function& func); Review Comment: Can we accept a `relax::Expr` instead of a `Function`? That would allow it to be used in more cases, such as a `SeqExpr` generated inside a function visitor. ########## src/relax/analysis/liveness.cc: ########## @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis/liveness.cc + * \brief Implementation of liveness analysis + */ +#include <tvm/relax/analysis.h> +#include <tvm/relax/dataflow_analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/runtime/object.h> + +namespace tvm { +namespace relax { + +// just sets of vars. the bool value is unnecessary +using Domain = Map<Var, Bool>; + +Domain transfer_func(const GraphBinding& binding, const ObjectRef& input) { + Domain in_domain = Downcast<Domain>(input); + Domain new_domain(in_domain); + + // 1. If a var that appears in the RHS of the binding, add it (it's live) + // 2. Remove the bound var (it is not live prior to being bound) + Array<Var> vars_used; + Optional<Var> var_bound; + if (binding->kind == BindingNodeKind::kSeqBody) { + vars_used = AllVars(binding->seq->body); + } else if (binding->kind == BindingNodeKind::kIfCond) { + Binding b = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]; + Expr cond = Downcast<If>(GetBoundValue(b))->cond; + vars_used = AllVars(cond); + } else if (binding->kind == BindingNodeKind::kIfMerge) { + // no vars are used in the merge + vars_used = {}; + // define the merge var + var_bound = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]->var; + } else { + // the ordinary binding case + Binding b = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]; + Expr bound_value = GetBoundValue(b); + // special case: if the RHS is a function literal, we only care about the free vars + // (those captured by the closure) + if (bound_value.as<FunctionNode>()) { Review Comment: Couldn't we use `FreeVars` in both cases? Assuming the input `relax::Function` is normalized, only `FunctionNode` could contain internal bindings, so for any other node type `FreeVars` would produce the same output as `AllVars`. ########## src/relax/analysis/liveness.cc: ########## @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis/liveness.cc + * \brief Implementation of liveness analysis + */ +#include <tvm/relax/analysis.h> +#include <tvm/relax/dataflow_analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/runtime/object.h> + +namespace tvm { +namespace relax { + +// just sets of vars. the bool value is unnecessary +using Domain = Map<Var, Bool>; + +Domain transfer_func(const GraphBinding& binding, const ObjectRef& input) { + Domain in_domain = Downcast<Domain>(input); + Domain new_domain(in_domain); + + // 1. If a var that appears in the RHS of the binding, add it (it's live) + // 2. Remove the bound var (it is not live prior to being bound) + Array<Var> vars_used; + Optional<Var> var_bound; + if (binding->kind == BindingNodeKind::kSeqBody) { + vars_used = AllVars(binding->seq->body); + } else if (binding->kind == BindingNodeKind::kIfCond) { + Binding b = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]; + Expr cond = Downcast<If>(GetBoundValue(b))->cond; + vars_used = AllVars(cond); + } else if (binding->kind == BindingNodeKind::kIfMerge) { + // no vars are used in the merge + vars_used = {}; + // define the merge var + var_bound = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]->var; + } else { + // the ordinary binding case + Binding b = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]; + Expr bound_value = GetBoundValue(b); + // special case: if the RHS is a function literal, we only care about the free vars + // (those captured by the closure) + if (bound_value.as<FunctionNode>()) { + vars_used = FreeVars(bound_value); + } else { + vars_used = AllVars(bound_value); + } + var_bound = b->var; + } + + for (auto var : vars_used) { + if (!new_domain.count(var)) { + new_domain.Set(var, Bool(true)); + } + } + + // the var bound is killed + if (var_bound.defined()) { + new_domain.erase(var_bound.value()); + } + + // technically, we could kill the args too, + // but they are not actually *bound* at the first binding + + return new_domain; +} + +// simply combine sets of live vars to merge +Domain merge_func(const ObjectRef& domain1, const ObjectRef& domain2) { + Domain merged; + for (auto kv : Downcast<Domain>(domain1)) { + merged.Set(kv.first, kv.second); + } + for (auto kv : Downcast<Domain>(domain2)) { + merged.Set(kv.first, kv.second); + } + return merged; +} + +Array<Array<Var>> LivenessAnalysis(const Function& func) { + // initial domain is empty + Domain init_domain; + ControlFlowGraph cfg = ExtractCFG(func); + std::pair<ObjectRef, ObjectRef> results = + DataflowAnalysis(cfg, init_domain, transfer_func, merge_func, false); + + // we will return the input map but convert the maps into arrays for simplicity + Array<Domain> in_map = Downcast<Array<Domain>>(results.first); Review Comment: Nit: Instead of `Downcast<Array<Domain>>(results.first)`, can we use `results.first.Map([](const ObjectRef& obj) { return Downcast<Domain>(obj); })`? All `Array<T>` use the same untyped backing container `ArrayNode`, so the usual type-checking that `Downcast` performs isn't useful in this case, and an array with an element other than `Domain` would result in undefined behavior. The `Array::Map` method uses the same copy-on-write backing array when all elements are downcast to the same value, but does so in a type-safe manner. (In the current PR, `results.first` can only contain `Domain` by construction, but this would be to guard against future changes.) ########## include/tvm/relax/utils.h: ########## @@ -110,6 +110,11 @@ TVM_DLL Function CopyWithNewVars(Function func); */ Expr ToNonDataflow(const Expr& e); +/*! + * \brief Get the value bound in the binding. + */ +Expr GetBoundValue(const Binding& b); Review Comment: Thank you for this function, as I keep finding myself re-implementing it in many locations. What would be your opinion on either (a) making it a member of the `relax::Binding` parent class or (b) hoisting the `Expr value` field from child classes `VarBindingNode` and `MatchCastNode` into the parent class `BindingNode`? ########## include/tvm/relax/analysis.h: ########## @@ -404,6 +404,18 @@ TVM_DLL Map<Var, Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb); */ std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Function& fn); +/*! + * \brief Perform a liveness analysis on the function, indicating which variables + * are live at which location in the function. + * + * \param fn The function to be analyzed. + * \return An array of arrays of live variables per binding in the function. + * The array is indexed based on the corresponding control flow graph, + * so use `ExtractCFG` and `GetBindingIndex` to match locations in `fn` + * to indices in the result. + */ +Array<Array<Var>> LivenessAnalysis(const Function& fn); Review Comment: Instead of `Array<Array<Var>>`, could we return `Map<Var, Array<Var>>`? That is, a map from the variable being bound to the list of variables that are live while the value of the variable is being computed. That would avoid requiring a user of the function to know the internal indexing scheme, and most of the APIs have easy access to the Var (e.g. In a mutator that implements `ExprMutator::VisitBinding`). ########## include/tvm/relax/analysis.h: ########## @@ -404,6 +404,18 @@ TVM_DLL Map<Var, Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb); */ std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Function& fn); +/*! + * \brief Perform a liveness analysis on the function, indicating which variables + * are live at which location in the function. + * + * \param fn The function to be analyzed. Review Comment: Can `LivenessAnalysis` be a member function of `ControlFlowGraph`? That way (1) the `ControlFlowGraph` would only need to be collected once if it is required by more than one analysis. ########## include/tvm/relax/dataflow_analysis.h: ########## @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/object.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) + */ +enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; + +class GraphBindingNode : public Object { + public: + /*! \brief The SeqExpr the binding resides in. */ + SeqExpr seq; + + /*! \brief The arguments to the binding. Only the first binding in the graph has arguments + * (i.e., the function arguments). */ + Array<Var> args; + + /*! \brief Index of the binding block in the SeqExpr where the binding is found. + * Convention: We put the SeqExpr body at one block past the final block. */ + size_t block_idx; + + /*! \brief Index of the binding within the binding block corresponding to this binding. + * Convention: Both the If condition and merge are mapped to the same index. + * We use the kind to distinguish. */ + size_t binding_idx; + + /*! \brief The kind of binding this is. */ + BindingNodeKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("seq", &seq); + v->Visit("args", &args); + v->Visit("block_idx", &block_idx); + v->Visit("binding_idx", &binding_idx); + v->Visit("kind", &kind); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.GraphBinding"; + TVM_DECLARE_BASE_OBJECT_INFO(GraphBindingNode, Object); +}; + +/*! \brief Representation of a binding in the control flow graph */ +class GraphBinding : public ObjectRef { + public: + /*! + * \brief Create a GraphBinding. See the docs on GraphBindingNode for further details. + * + * \param seq: The SeqExpr in which the binding resides. + * \param args: The arguments to the binding (only nonempty for the first binding: + * these will be the function arguments) + * \param block_idx: The index of the BindingBlock in the SeqExpr + * where the binding resides (for the return expression, use one past the final block). + * \param binding_idx: The index of the binding in the BindingBlock corresponding to the binding. + * \param kind: The kind of binding this is. (Used especially to distinguish If node conditions + * from the merge after the If) + */ + TVM_DLL static GraphBinding Create(const SeqExpr& seq, const Array<Var>& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(GraphBinding, ObjectRef, GraphBindingNode); +}; + +/* A control flow graph corresponding to a function. + */ +class ControlFlowGraphNode : public Object { + public: + /*! \brief The bindings in the graph. 0 is the entry point. */ + Array<GraphBinding> bindings; + /*! \brief The ith member is the list of predecessors (indices) to binding i in bindings. */ + Array<Array<Integer>> preds; + /*! \brief The ith member is the list of successors (indices) to binding i in bindings. */ + Array<Array<Integer>> succs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("bindings", &bindings); + v->Visit("preds", &preds); + v->Visit("succs", &succs); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.ControlFlowGraph"; + TVM_DECLARE_BASE_OBJECT_INFO(ControlFlowGraphNode, Object); +}; + +class ControlFlowGraph : public ObjectRef { + public: + /*! + * \brief Create a ControlFlowGraph. + * + * \param bindings: The bindings in the graph + * \param preds: List of lists of predecessors to each binding. + * \param succs: List of lists of successors to each binding. + */ + TVM_DLL static ControlFlowGraph Create(const Array<GraphBinding>& bindings, + const Array<Array<Integer>>& preds, + const Array<Array<Integer>>& succs); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ControlFlowGraph, ObjectRef, ControlFlowGraphNode); +}; + +/*! + * \brief Extracts the control flow graph for a Relax function. + * \param func The function. This conversion expects it to be normalized. + * \return The control flow graph corresponding to the function. + */ +ControlFlowGraph ExtractCFG(const Function& func); + +/*! + * \brief Generic implementation of dataflow analysis, based on + * Adrian Sampson's course material, except binding by binding + * instead of basic block by basic block: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * + * The analysis creates input and output maps (mapping binding indices to a domain), + * sets the initial input and output for each binding to the init value, and then + * performs a traversal of the CFG (BFS in this implementation, since unlike the general case, + * we do not have loops) and uses the transfer and merge function to update the inputs and + * outputs. The analysis can proceed forwards (from binding 0 onwards) or backwards (from the + * last binding back), flipping the roles of the input and output maps in the cases. + * + * \param forward Whether to perform a forward or backward analysis + * \param cfg The input control flow graph + * \param init The value corresponding to an initial domain + * \param transfer_func Given an input domain and a binding, determine the resulting domain + * \param merge_func Given a set of domains, combine them to form a single new domain + * (note: in Relax, a binding can never have more than two predecessors/successors) + * + * \return Two arrays, the first being the "input map" (domain being passed *into* + * each binding in the CFG) and the second being the "output map" (the domain + * being passed *out of* the corresponding binding) + */ +std::pair<Array<ObjectRef>, Array<ObjectRef>> DataflowAnalysis( Review Comment: Should this function be externally exposed? From the changes in this PR on its own, it looks like an implementation detail for `LivenessAnalysis`, but the function signature suggests that it is intended for more general use. ########## include/tvm/relax/dataflow_analysis.h: ########## @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/object.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) + */ +enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; + +class GraphBindingNode : public Object { + public: + /*! \brief The SeqExpr the binding resides in. */ + SeqExpr seq; + + /*! \brief The arguments to the binding. Only the first binding in the graph has arguments Review Comment: Having a parameter in all `GraphBinding` nodes that is only non-empty for one of them seems a bit odd. Since this list is unique across the entire graph, can we instead move this into the `ControlFlowGraphNode`? ########## include/tvm/relax/dataflow_analysis.h: ########## @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/object.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) + */ +enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; + +class GraphBindingNode : public Object { + public: + /*! \brief The SeqExpr the binding resides in. */ + SeqExpr seq; + + /*! \brief The arguments to the binding. Only the first binding in the graph has arguments + * (i.e., the function arguments). */ + Array<Var> args; + + /*! \brief Index of the binding block in the SeqExpr where the binding is found. + * Convention: We put the SeqExpr body at one block past the final block. */ + size_t block_idx; + + /*! \brief Index of the binding within the binding block corresponding to this binding. + * Convention: Both the If condition and merge are mapped to the same index. + * We use the kind to distinguish. */ + size_t binding_idx; + + /*! \brief The kind of binding this is. */ + BindingNodeKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("seq", &seq); + v->Visit("args", &args); + v->Visit("block_idx", &block_idx); + v->Visit("binding_idx", &binding_idx); + v->Visit("kind", &kind); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.GraphBinding"; + TVM_DECLARE_BASE_OBJECT_INFO(GraphBindingNode, Object); +}; + +/*! \brief Representation of a binding in the control flow graph */ +class GraphBinding : public ObjectRef { + public: + /*! + * \brief Create a GraphBinding. See the docs on GraphBindingNode for further details. + * + * \param seq: The SeqExpr in which the binding resides. + * \param args: The arguments to the binding (only nonempty for the first binding: + * these will be the function arguments) + * \param block_idx: The index of the BindingBlock in the SeqExpr + * where the binding resides (for the return expression, use one past the final block). + * \param binding_idx: The index of the binding in the BindingBlock corresponding to the binding. + * \param kind: The kind of binding this is. (Used especially to distinguish If node conditions + * from the merge after the If) + */ + TVM_DLL static GraphBinding Create(const SeqExpr& seq, const Array<Var>& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(GraphBinding, ObjectRef, GraphBindingNode); +}; + +/* A control flow graph corresponding to a function. + */ +class ControlFlowGraphNode : public Object { + public: + /*! \brief The bindings in the graph. 0 is the entry point. */ + Array<GraphBinding> bindings; + /*! \brief The ith member is the list of predecessors (indices) to binding i in bindings. */ + Array<Array<Integer>> preds; Review Comment: Can `preds` and `succs` be moved to the `GraphBindingNode` instead? That way, we make it impossible for these three lists to erroneously have mismatched sizes, and also make it immediately clear to readers which predecessors are associated with which nodes. ########## include/tvm/relax/dataflow_analysis.h: ########## @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/object.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) + */ +enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; + +class GraphBindingNode : public Object { + public: + /*! \brief The SeqExpr the binding resides in. */ + SeqExpr seq; + + /*! \brief The arguments to the binding. Only the first binding in the graph has arguments + * (i.e., the function arguments). */ + Array<Var> args; + + /*! \brief Index of the binding block in the SeqExpr where the binding is found. + * Convention: We put the SeqExpr body at one block past the final block. */ + size_t block_idx; + + /*! \brief Index of the binding within the binding block corresponding to this binding. + * Convention: Both the If condition and merge are mapped to the same index. + * We use the kind to distinguish. */ + size_t binding_idx; + + /*! \brief The kind of binding this is. */ + BindingNodeKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("seq", &seq); + v->Visit("args", &args); + v->Visit("block_idx", &block_idx); + v->Visit("binding_idx", &binding_idx); + v->Visit("kind", &kind); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.GraphBinding"; + TVM_DECLARE_BASE_OBJECT_INFO(GraphBindingNode, Object); +}; + +/*! \brief Representation of a binding in the control flow graph */ +class GraphBinding : public ObjectRef { + public: + /*! + * \brief Create a GraphBinding. See the docs on GraphBindingNode for further details. + * + * \param seq: The SeqExpr in which the binding resides. + * \param args: The arguments to the binding (only nonempty for the first binding: + * these will be the function arguments) + * \param block_idx: The index of the BindingBlock in the SeqExpr + * where the binding resides (for the return expression, use one past the final block). + * \param binding_idx: The index of the binding in the BindingBlock corresponding to the binding. + * \param kind: The kind of binding this is. (Used especially to distinguish If node conditions + * from the merge after the If) + */ + TVM_DLL static GraphBinding Create(const SeqExpr& seq, const Array<Var>& args, size_t block_idx, Review Comment: Nit: From the signature, it looks like this should be a constructor rather than a static method. ########## include/tvm/relax/dataflow_analysis.h: ########## @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/object.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) + */ +enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; + +class GraphBindingNode : public Object { + public: + /*! \brief The SeqExpr the binding resides in. */ + SeqExpr seq; + + /*! \brief The arguments to the binding. Only the first binding in the graph has arguments + * (i.e., the function arguments). */ + Array<Var> args; + + /*! \brief Index of the binding block in the SeqExpr where the binding is found. + * Convention: We put the SeqExpr body at one block past the final block. */ + size_t block_idx; + + /*! \brief Index of the binding within the binding block corresponding to this binding. + * Convention: Both the If condition and merge are mapped to the same index. + * We use the kind to distinguish. */ + size_t binding_idx; + + /*! \brief The kind of binding this is. */ + BindingNodeKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("seq", &seq); + v->Visit("args", &args); + v->Visit("block_idx", &block_idx); + v->Visit("binding_idx", &binding_idx); + v->Visit("kind", &kind); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.GraphBinding"; + TVM_DECLARE_BASE_OBJECT_INFO(GraphBindingNode, Object); +}; + +/*! \brief Representation of a binding in the control flow graph */ +class GraphBinding : public ObjectRef { + public: + /*! + * \brief Create a GraphBinding. See the docs on GraphBindingNode for further details. + * + * \param seq: The SeqExpr in which the binding resides. + * \param args: The arguments to the binding (only nonempty for the first binding: + * these will be the function arguments) + * \param block_idx: The index of the BindingBlock in the SeqExpr + * where the binding resides (for the return expression, use one past the final block). + * \param binding_idx: The index of the binding in the BindingBlock corresponding to the binding. + * \param kind: The kind of binding this is. (Used especially to distinguish If node conditions + * from the merge after the If) + */ + TVM_DLL static GraphBinding Create(const SeqExpr& seq, const Array<Var>& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(GraphBinding, ObjectRef, GraphBindingNode); +}; + +/* A control flow graph corresponding to a function. + */ +class ControlFlowGraphNode : public Object { + public: + /*! \brief The bindings in the graph. 0 is the entry point. */ + Array<GraphBinding> bindings; + /*! \brief The ith member is the list of predecessors (indices) to binding i in bindings. */ + Array<Array<Integer>> preds; + /*! \brief The ith member is the list of successors (indices) to binding i in bindings. */ + Array<Array<Integer>> succs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("bindings", &bindings); + v->Visit("preds", &preds); + v->Visit("succs", &succs); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.ControlFlowGraph"; + TVM_DECLARE_BASE_OBJECT_INFO(ControlFlowGraphNode, Object); +}; + +class ControlFlowGraph : public ObjectRef { + public: + /*! + * \brief Create a ControlFlowGraph. + * + * \param bindings: The bindings in the graph + * \param preds: List of lists of predecessors to each binding. + * \param succs: List of lists of successors to each binding. + */ + TVM_DLL static ControlFlowGraph Create(const Array<GraphBinding>& bindings, Review Comment: Nit: From the signature, it looks like this should be a constructor rather than a static method. ########## include/tvm/relax/dataflow_analysis.h: ########## @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/object.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) + */ +enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; + +class GraphBindingNode : public Object { + public: + /*! \brief The SeqExpr the binding resides in. */ + SeqExpr seq; + + /*! \brief The arguments to the binding. Only the first binding in the graph has arguments + * (i.e., the function arguments). */ + Array<Var> args; + + /*! \brief Index of the binding block in the SeqExpr where the binding is found. Review Comment: Rather than identifying a particular binding by `block_idx` and `binding_idx`, could we instead identify it by the variable being bound? The variables are already required to be unique, and it would avoid needing to keep track of which `size_t` is associated with which array. ########## include/tvm/relax/dataflow_analysis.h: ########## @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/object.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) + */ +enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; + +class GraphBindingNode : public Object { + public: + /*! \brief The SeqExpr the binding resides in. */ + SeqExpr seq; + + /*! \brief The arguments to the binding. Only the first binding in the graph has arguments + * (i.e., the function arguments). */ + Array<Var> args; + + /*! \brief Index of the binding block in the SeqExpr where the binding is found. + * Convention: We put the SeqExpr body at one block past the final block. */ + size_t block_idx; + + /*! \brief Index of the binding within the binding block corresponding to this binding. + * Convention: Both the If condition and merge are mapped to the same index. + * We use the kind to distinguish. */ + size_t binding_idx; + + /*! \brief The kind of binding this is. */ + BindingNodeKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("seq", &seq); + v->Visit("args", &args); + v->Visit("block_idx", &block_idx); + v->Visit("binding_idx", &binding_idx); + v->Visit("kind", &kind); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.GraphBinding"; + TVM_DECLARE_BASE_OBJECT_INFO(GraphBindingNode, Object); +}; + +/*! \brief Representation of a binding in the control flow graph */ +class GraphBinding : public ObjectRef { + public: + /*! + * \brief Create a GraphBinding. See the docs on GraphBindingNode for further details. + * + * \param seq: The SeqExpr in which the binding resides. + * \param args: The arguments to the binding (only nonempty for the first binding: + * these will be the function arguments) + * \param block_idx: The index of the BindingBlock in the SeqExpr + * where the binding resides (for the return expression, use one past the final block). + * \param binding_idx: The index of the binding in the BindingBlock corresponding to the binding. + * \param kind: The kind of binding this is. (Used especially to distinguish If node conditions + * from the merge after the If) + */ + TVM_DLL static GraphBinding Create(const SeqExpr& seq, const Array<Var>& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(GraphBinding, ObjectRef, GraphBindingNode); +}; + +/* A control flow graph corresponding to a function. + */ +class ControlFlowGraphNode : public Object { + public: + /*! \brief The bindings in the graph. 0 is the entry point. */ + Array<GraphBinding> bindings; + /*! \brief The ith member is the list of predecessors (indices) to binding i in bindings. */ + Array<Array<Integer>> preds; + /*! \brief The ith member is the list of successors (indices) to binding i in bindings. */ + Array<Array<Integer>> succs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("bindings", &bindings); + v->Visit("preds", &preds); + v->Visit("succs", &succs); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.ControlFlowGraph"; + TVM_DECLARE_BASE_OBJECT_INFO(ControlFlowGraphNode, Object); +}; + +class ControlFlowGraph : public ObjectRef { Review Comment: Looking at this structure, I think it could be generalized to also represent a data-dependency graph, and most of the functionality would also carry over. * Both the predecessors in a control-flow graph are analogous to the inputs for a data-dependency graph, and both could be represented by `Array<Var>`. * Both the successors in a control-flow graph are analogous to the outputs of a data-dependency graph, and both could be represented by `Array<Var>`. * The `DataflowAnalysis` function would operate identically in both cases, either flowing things that are known at a specific time for a control-flow graph, or flowing things that are known about a specific value for a data-dependency graph. What are your thoughts on generalizing the utility? I think the main drawback would be if there's a fundamental assumption made about the graph structure that only holds for one of the two, but they look like they might be similar enough to have lots of overlap. ########## src/relax/analysis/dataflow_analysis.cc: ########## @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis/dataflow_analysis.cc + * \brief Implementation of functionality in dataflow_analysis.h + */ +#include <tvm/relax/dataflow_analysis.h> +#include <tvm/runtime/memory.h> + +#include <queue> + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(GraphBindingNode); + +GraphBinding GraphBinding::Create(const SeqExpr& seq, const Array<Var>& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind) { + ObjectPtr<GraphBindingNode> n = make_object<GraphBindingNode>(); + n->seq = seq; + n->args = args; + n->block_idx = block_idx; + n->binding_idx = binding_idx; + n->kind = kind; + return GraphBinding(n); +} + +TVM_REGISTER_NODE_TYPE(ControlFlowGraphNode); + +ControlFlowGraph ControlFlowGraph::Create(const Array<GraphBinding>& bindings, + const Array<Array<Integer>>& preds, + const Array<Array<Integer>>& succs) { + ObjectPtr<ControlFlowGraphNode> n = make_object<ControlFlowGraphNode>(); + n->bindings = bindings; + n->preds = preds; + n->succs = succs; + return ControlFlowGraph(n); +} + +// Extracts a basic block and updates the running lists bindings, preds, and succs. +// The return value is the index of the final binding processed in the seq expression +// (useful for processing branches). +size_t ExtractCFGHelper(const SeqExpr& seq, const Array<Var>& args, size_t block_idx, + size_t binding_idx, std::vector<size_t> current_preds, + std::vector<GraphBinding>* bindings, + std::vector<std::vector<size_t>>* preds, + std::vector<std::vector<size_t>>* succs) { + // case 1: We're past the end -> this is the block body (base case) + if (block_idx == seq->blocks.size()) { + bindings->push_back(GraphBinding::Create(seq, args, block_idx, 0U, BindingNodeKind::kSeqBody)); + preds->push_back(current_preds); + // the final binding has no successors + succs->push_back({}); + return bindings->size() - 1; + } + + Binding binding = seq->blocks[block_idx]->bindings[binding_idx]; + Expr binding_value = GetBoundValue(binding); + + // case 2: Ordinary binding + if (!binding_value.as<IfNode>()) { + bindings->push_back( + GraphBinding::Create(seq, args, block_idx, binding_idx, BindingNodeKind::kBinding)); + size_t idx = bindings->size() - 1; + preds->push_back(current_preds); + // successor: the next binding (there will always be at least one binding after this, + // even if it's the seq body) + succs->push_back({idx + 1}); + } else { + // case 3: dealing with a branch + auto if_node = Downcast<If>(binding_value); + // start with the cond node + bindings->push_back( + GraphBinding::Create(seq, args, block_idx, binding_idx, BindingNodeKind::kIfCond)); + size_t idx = bindings->size() - 1; + preds->push_back(current_preds); + // there will be another successor, which we will add after recursing down the branches + succs->push_back({idx + 1}); + size_t final_true_idx = ExtractCFGHelper(Downcast<SeqExpr>(if_node->true_branch), {}, 0U, 0U, + {idx}, bindings, preds, succs); + succs->at(idx).push_back(final_true_idx + 1); + size_t final_false_idx = ExtractCFGHelper(Downcast<SeqExpr>(if_node->false_branch), {}, 0U, 0U, + {idx}, bindings, preds, succs); + // now create the merge + bindings->push_back( + GraphBinding::Create(seq, {}, block_idx, binding_idx, BindingNodeKind::kIfMerge)); + size_t merge_idx = bindings->size() - 1; + preds->push_back({final_true_idx, final_false_idx}); + succs->push_back({merge_idx + 1}); + // update the successors of the final true and false indices as well + succs->at(final_true_idx).push_back(merge_idx); + succs->at(final_false_idx).push_back(merge_idx); + } + // move on to next binding + size_t next_block_idx = block_idx; + size_t next_binding_idx = binding_idx + 1; + if (next_binding_idx >= seq->blocks[block_idx]->bindings.size()) { + next_block_idx = block_idx + 1; + next_binding_idx = 0U; + } + return ExtractCFGHelper(seq, {}, next_block_idx, next_binding_idx, {bindings->size() - 1}, + bindings, preds, succs); +} + +ControlFlowGraph ExtractCFG(const Function& func) { + std::vector<GraphBinding> bindings; + std::vector<std::vector<size_t>> preds; + std::vector<std::vector<size_t>> succs; + ExtractCFGHelper(Downcast<SeqExpr>(func->body), func->params, 0U, 0U, {}, &bindings, &preds, + &succs); + + Array<Array<Integer>> pred_arr; + for (auto pred_vec : preds) { + Array<Integer> pred_ints; + for (auto idx : pred_vec) { + pred_ints.push_back(Integer(idx)); + } + pred_arr.push_back(pred_ints); + } + Array<Array<Integer>> succ_arr; + for (auto succ_vec : succs) { + Array<Integer> succ_ints; + for (auto idx : succ_vec) { + succ_ints.push_back(Integer(idx)); + } + succ_arr.push_back(succ_ints); + } + return ControlFlowGraph::Create(Array<GraphBinding>(bindings), pred_arr, succ_arr); +} + +std::pair<Array<ObjectRef>, Array<ObjectRef>> DataflowAnalysis( + const ControlFlowGraph& cfg, const ObjectRef& init, + std::function<ObjectRef(const GraphBinding&, const ObjectRef&)> transfer_func, + std::function<ObjectRef(const ObjectRef&, const ObjectRef&)> merge_func, bool forward) { + std::vector<ObjectRef> in_map; + std::vector<ObjectRef> out_map; + for (size_t i = 0; i < cfg->bindings.size(); i++) { + in_map.push_back(init); + out_map.push_back(init); + } + + // Modification from Adrian Sampson's version: + // Since there are no loops in our AST, one traversal through the CFG suffices. + // We will do BFS + std::queue<size_t> worklist; + worklist.push((forward) ? 0 : cfg->bindings.size() - 1); + while (!worklist.empty()) { + size_t idx = worklist.front(); + worklist.pop(); + Array<Integer> prev = (forward) ? cfg->preds[idx] : cfg->succs[idx]; + Array<Integer> next = (forward) ? cfg->succs[idx] : cfg->preds[idx]; + std::vector<ObjectRef>* results = (forward) ? &out_map : &in_map; + std::vector<ObjectRef>* inputs = (forward) ? &in_map : &out_map; + + // Cases (for forward analysis): + // 0 predecessors: The first block in the function + // 1 predecessor: A branch in an If node (no merge needed) + // 2 predecessors: The merge block after an If node (merge needed) + // (Analogous for successors in backward analysis) + inputs->operator[](idx) = (prev.size() == 0) ? init Review Comment: Nit: If this is declared as `std::vector<ObjectRef>& inputs = (forward)? out_map : in_map;`, then the LHS of the assignment becomes `inputs[idx]` instead of `inputs->operator[](idx)`. ########## src/relax/analysis/liveness.cc: ########## @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis/liveness.cc + * \brief Implementation of liveness analysis + */ +#include <tvm/relax/analysis.h> +#include <tvm/relax/dataflow_analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/runtime/object.h> + +namespace tvm { +namespace relax { + +// just sets of vars. the bool value is unnecessary +using Domain = Map<Var, Bool>; + +Domain transfer_func(const GraphBinding& binding, const ObjectRef& input) { + Domain in_domain = Downcast<Domain>(input); + Domain new_domain(in_domain); + + // 1. If a var that appears in the RHS of the binding, add it (it's live) + // 2. Remove the bound var (it is not live prior to being bound) + Array<Var> vars_used; + Optional<Var> var_bound; + if (binding->kind == BindingNodeKind::kSeqBody) { + vars_used = AllVars(binding->seq->body); + } else if (binding->kind == BindingNodeKind::kIfCond) { + Binding b = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]; + Expr cond = Downcast<If>(GetBoundValue(b))->cond; + vars_used = AllVars(cond); + } else if (binding->kind == BindingNodeKind::kIfMerge) { + // no vars are used in the merge + vars_used = {}; + // define the merge var + var_bound = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]->var; + } else { + // the ordinary binding case + Binding b = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]; + Expr bound_value = GetBoundValue(b); + // special case: if the RHS is a function literal, we only care about the free vars + // (those captured by the closure) + if (bound_value.as<FunctionNode>()) { + vars_used = FreeVars(bound_value); + } else { + vars_used = AllVars(bound_value); + } + var_bound = b->var; + } + + for (auto var : vars_used) { + if (!new_domain.count(var)) { Review Comment: Nit: Both `.count` and `.Set` need to hash the key and find the index in the map (or if it gets changed to `unordered_set`, the set). Since the contents are already being de-duplicated, we could just call `new_domain.Set(var, Bool(true))` without the `if(!new_domain.count(var)`. If `var` is already present, it's a no-op, and if it isn't already present, we've saved a lookup. ########## include/tvm/relax/dataflow_analysis.h: ########## @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/object.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) + */ +enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; + +class GraphBindingNode : public Object { + public: + /*! \brief The SeqExpr the binding resides in. */ + SeqExpr seq; + + /*! \brief The arguments to the binding. Only the first binding in the graph has arguments + * (i.e., the function arguments). */ + Array<Var> args; + + /*! \brief Index of the binding block in the SeqExpr where the binding is found. + * Convention: We put the SeqExpr body at one block past the final block. */ + size_t block_idx; + + /*! \brief Index of the binding within the binding block corresponding to this binding. + * Convention: Both the If condition and merge are mapped to the same index. + * We use the kind to distinguish. */ + size_t binding_idx; + + /*! \brief The kind of binding this is. */ + BindingNodeKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("seq", &seq); + v->Visit("args", &args); + v->Visit("block_idx", &block_idx); + v->Visit("binding_idx", &binding_idx); + v->Visit("kind", &kind); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.GraphBinding"; + TVM_DECLARE_BASE_OBJECT_INFO(GraphBindingNode, Object); +}; + +/*! \brief Representation of a binding in the control flow graph */ +class GraphBinding : public ObjectRef { + public: + /*! + * \brief Create a GraphBinding. See the docs on GraphBindingNode for further details. + * + * \param seq: The SeqExpr in which the binding resides. + * \param args: The arguments to the binding (only nonempty for the first binding: + * these will be the function arguments) + * \param block_idx: The index of the BindingBlock in the SeqExpr + * where the binding resides (for the return expression, use one past the final block). + * \param binding_idx: The index of the binding in the BindingBlock corresponding to the binding. + * \param kind: The kind of binding this is. (Used especially to distinguish If node conditions + * from the merge after the If) + */ + TVM_DLL static GraphBinding Create(const SeqExpr& seq, const Array<Var>& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(GraphBinding, ObjectRef, GraphBindingNode); +}; + +/* A control flow graph corresponding to a function. + */ +class ControlFlowGraphNode : public Object { + public: + /*! \brief The bindings in the graph. 0 is the entry point. */ + Array<GraphBinding> bindings; + /*! \brief The ith member is the list of predecessors (indices) to binding i in bindings. */ + Array<Array<Integer>> preds; Review Comment: As I'm thinking on it, that would also benefit from using `Var` to track locations within an expression, rather than `size_t` indices. Each `GraphBindingNode` would hold `Array<Var> predecessors` and `Array<Var> successors`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
