slyubomirsky commented on code in PR #16129:
URL: https://github.com/apache/tvm/pull/16129#discussion_r1451648612


##########
src/relax/transform/dataflow_inplace.cc:
##########
@@ -0,0 +1,1017 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/dataflow_inplace.cc
+ * \brief Pass that converts eligible operator calls in dataflow blocks
+ *   into in-place versions.
+ */
+
+#include <tvm/ir/transform.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/op.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/utils.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+// Perform liveness analysis on a dataflow block, returning a map of vars to
+// pairs of indices (the liveness interval, from the starting index to the end 
index).
+// A starting index of -1 means the var is defined before the block starts and 
an end index
+// of block->bindings.size() (one past the last index) means it is live after 
the block ends.
+std::unordered_map<Var, std::pair<int, int>, ObjectPtrHash, ObjectPtrEqual> 
AnalyzeLiveness(
+    const DataflowBlock& block) {
+  std::unordered_map<Var, std::pair<int, int>, ObjectPtrHash, ObjectPtrEqual> 
ret;
+  for (int i = block->bindings.size() - 1; i >= 0; i--) {
+    Binding b = block->bindings[i];
+    Var defined_var = b->var;
+    Expr value = GetBoundValue(b);
+    Array<Var> used_vars;
+    // for a function literal, we consider only the free vars
+    // (those captured from the outer scope)
+    if (value.as<FunctionNode>()) {
+      used_vars = FreeVars(value);
+    } else if (value.as<TupleGetItemNode>()) {
+      // Special case: we do not consider a tuple index to be a "use."
+      // This is a bit of a hack but allows us to do operations that
+      // create tuples to be done in-place (otherwise, any index of the tuple
+      // would be considered a use and so the tuple would be live later).
+      // Hence we keep the array empty.
+    } else {
+      used_vars = AllVars(value);
+    }
+
+    for (auto var : used_vars) {
+      if (!ret.count(var)) {
+        ret[var] = {-1, i};
+      }
+    }
+
+    if (!ret.count(defined_var)) {
+      // if it's an output, then it lives past the end of the block
+      if (!defined_var.as<DataflowVarNode>()) {
+        ret[defined_var] = {i, block->bindings.size()};
+      } else {
+        // otherwise, it's live only here
+        ret[defined_var] = {i, i};
+      }
+    } else {
+      // this means the var is used later but we encountered its definition now
+      auto last_range = ret[defined_var];
+      CHECK_EQ(last_range.first, -1);
+      std::pair<int, int> new_range = {i, last_range.second};
+      ret[defined_var] = new_range;
+    }
+  }
+  return ret;
+}
+
+class AliasAnalyzer {
+ public:
+  AliasAnalyzer() : alias_map_(), tuple_map_(), mem_idx_(0) {}
+
+  // The analysis returns a map of vars to memory locations that it *could* 
map to
+  // (any unique allocation = one memory location), plus a map of memory 
locations
+  // that correspond to tuples (this maps to sets of memory locations for each 
tuple element).
+  // Note: inputs are values that should be assumed not to be aliased and are 
therefore
+  // (in the case of in-place ops) safe to overwrite. This may not be true of 
function args.
+  std::pair<std::unordered_map<Var, std::unordered_set<int>, ObjectPtrHash, 
ObjectPtrEqual>,
+            std::unordered_map<int, std::vector<std::unordered_set<int>>>>
+  Analyze(const DataflowBlock& block, const Array<Var>& inputs) {
+    for (auto input : inputs) {
+      int curr_idx = get_fresh_idx();
+      alias_map_[input] = {curr_idx};
+      if (auto* tup_info = GetStructInfoAs<TupleStructInfoNode>(input)) {
+        InsertFreshTuple(curr_idx, tup_info);
+      }
+    }
+
+    for (const Binding& binding : block->bindings) {
+      Var current_var = binding->var;
+      Expr value = GetBoundValue(binding);
+      alias_map_[current_var] = GetAliasSet(value, current_var);
+    }
+
+    return {alias_map_, tuple_map_};
+  }
+
+ private:
+  int get_fresh_idx() {
+    int ret = mem_idx_;
+    mem_idx_++;
+    return ret;
+  }
+
+  // Fresh tuple = each element is assumed to be a unique allocation
+  void InsertFreshTuple(int tup_idx, const TupleStructInfoNode* tup_info) {
+    std::vector<std::unordered_set<int>> tuple_set;
+    for (int i = 0; i < static_cast<int>(tup_info->fields.size()); i++) {
+      int curr_field = get_fresh_idx();
+      tuple_set.push_back({curr_field});
+      if (auto* nested_tup_info = 
tup_info->fields[i].as<TupleStructInfoNode>()) {
+        InsertFreshTuple(curr_field, nested_tup_info);
+      }
+    }
+    tuple_map_[tup_idx] = tuple_set;
+  }
+
+  // given a tuple index, add the given memory location indices to each 
component's
+  // alias set
+  void UpdateTupleComponents(int tup_idx, const std::unordered_set<int>& 
insert_idxs) {
+    if (tuple_map_.count(tup_idx)) {
+      auto tuple_comps = tuple_map_[tup_idx];
+      for (size_t i = 0; i < tuple_comps.size(); i++) {
+        auto comp_set = tuple_comps[i];
+
+        // if a member is a tuple, update its components as well
+        for (int member : comp_set) {
+          if (tuple_map_.count(member)) {
+            UpdateTupleComponents(member, insert_idxs);
+          }
+        }
+
+        // update after iterating to avoid iterating over the inserted elements
+        tuple_map_[tup_idx][i].insert(insert_idxs.begin(), insert_idxs.end());
+      }
+    }
+  }
+
+  // capture the given index and also its tuple components (including 
recursively)
+  // if they exist
+  void AddCapturedIndices(std::unordered_set<int>* captured_set, int idx) {
+    captured_set->insert(idx);
+    if (tuple_map_.count(idx)) {
+      for (auto comp_set : tuple_map_[idx]) {
+        for (auto tup_comp_idx : comp_set) {
+          AddCapturedIndices(captured_set, tup_comp_idx);
+        }
+      }
+    }
+  }
+
+  // Conservative extremely pessimistic assumption:
+  // assume that the result of a non-op call can be aliased to any argument
+  // or that it could be a newly allocated value.
+  // For tuples, assume all members are aliased. Yeah, it's bad.
+  // (Skip first arg is for handling call_pure_packed, where the first arg is 
an ExternFunc that we
+  // should ignore)
+  std::unordered_set<int> HandleMysteryCall(const CallNode* call_node, const 
Var& bound_var,
+                                            bool skip_first_arg = false) {
+    // the result may or may not be newly allocated
+    std::unordered_set<int> ret;
+    int res_idx = get_fresh_idx();
+    // the result may be a tuple
+    if (auto* tup_info_node = GetStructInfoAs<TupleStructInfoNode>(bound_var)) 
{
+      InsertFreshTuple(res_idx, tup_info_node);
+    }
+    AddCapturedIndices(&ret, res_idx);
+
+    for (size_t i = (skip_first_arg) ? 1 : 0; i < call_node->args.size(); i++) 
{
+      auto arg = call_node->args[i];
+      auto arg_alias_set = GetAliasSet(arg, bound_var);
+      for (int alias_idx : arg_alias_set) {
+        AddCapturedIndices(&ret, alias_idx);
+      }
+    }
+    // if the result is a tuple, the components can also potentially be 
aliased to any arg
+    // or, in fact, to each other
+    UpdateTupleComponents(res_idx, ret);
+    return ret;
+  }
+
+  // given the expression value, return the set of memory locations 
corresponding to it
+  // (the var the expression is being bound to is needed for struct info)
+  std::unordered_set<int> GetAliasSet(const Expr& value, const Var& bound_var) 
{
+    std::unordered_set<int> ret;
+
+    // cases for value:
+    // constant: it's a fresh index
+    // var: look up in alias map (-1 if not present)
+    // op call: assume it's fresh (may need to make list of exceptions)
+    // tuple: fresh entry in tuple index, recurse to determine indices for 
values
+    // function/packed call: chaos reigns, alias with any other argument
+    //   (if tuple is passed, assume also aliased with all members of the 
tuple)
+    // tuple index: -1 if tuple is not in tuple map, otherwise look up 
corresponding entry
+    // function constant: give them a fresh index (TODO: we can handle in more 
detail if this is a
+    // case we need to support) prim value: fresh index if node: should not 
happen inside dataflow
+    // block
+    if (value.as<ConstantNode>() || value.as<PrimValueNode>() || 
value.as<FunctionNode>()) {
+      // TODO(@slyubomirsky): We will probably want special handling for 
closures
+      ret.insert(get_fresh_idx());
+    } else if (auto* target_var_node = value.as<VarNode>()) {

Review Comment:
   I think this is doing a lot more than what `CanonicalizeBindings` is doing; 
I don't see an obvious correspondence between them, plus this analysis only 
checks dataflow blocks. I'm all for deduplicating if it makes sense, though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to