This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c1e74c  [REFACTOR][BOYC] Non recursive partitioning (#5493)
9c1e74c is described below

commit 9c1e74ce0727ac7aacd012b35ac068a25cbc9a42
Author: Zhi <5145158+zhi...@users.noreply.github.com>
AuthorDate: Fri May 1 13:27:44 2020 -0700

    [REFACTOR][BOYC] Non recursive partitioning (#5493)
    
    * non recursive partitioning
    
    * refactor maps
    
    * rebase upstream
    
    * refactor shared output
    
    * address comments
    
    Co-authored-by: Cody Yu <comaniac0...@gmail.com>
---
 src/relay/transforms/partition_graph.cc | 393 ++++++++++----------------------
 1 file changed, 115 insertions(+), 278 deletions(-)

diff --git a/src/relay/transforms/partition_graph.cc 
b/src/relay/transforms/partition_graph.cc
index 3b0d6bc..634434d 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -54,39 +54,30 @@ namespace partitioning {
 static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
 static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
 
-/*!
- * \brief The checker that verifies if a Relay program is annotated correctly
- * for partitioning.
+/*! \brief This struct maintains the required metadata for a region to 
generate a corresponding
+ * global function and function call. Global function will be passed to the 
target specific codegen
+ * and function call will be used in the transform Relay graph to invoke the 
function in runtime.
  */
-class AnnotationChecker : public ExprVisitor {
- public:
-  bool Check() {
-    if (!found_start_ && !found_end_) {
-      LOG(WARNING) << "No compiler annotation found";
-    } else if (!found_start_) {
-      LOG(ERROR) << "compiler_begin annotation is missing";
-      return false;
-    } else if (!found_end_) {
-      LOG(ERROR) << "compiler_end annotation is missing";
-      return false;
-    }
-    return true;
-  }
+struct RegionFuncMetadata {
+  /*! \brief The call node of the generated global function for this region. */
+  Call func_call;
 
-  void VisitExpr_(const CallNode* call) final {
-    auto op_node = call->op.as<OpNode>();
-    if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
-      return;
-    } else if (call->op == compiler_begin_op) {
-      found_start_ = true;
-    } else if (call->op == compiler_end_op) {
-      found_end_ = true;
-    }
-  }
+  /*! \brief A list of argument pairs. Each pair includes (var, expr). var is 
used
+   * as a function node argument; input expression is used as a function call 
parameter.
+   */
+  std::vector<std::pair<Var, Expr>> args;
 
- private:
-  bool found_start_{false};
-  bool found_end_{false};
+  /*! \brief Map from each region output expr (compiler end) node to
+   * the corresponding function output expr.
+   */
+  std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> region_func_out;
+
+  /*! \brief Map from each region input expression (compiler begin) to
+   * the corresponding function input variable. This cache is used to make sure
+   * a region function will not have duplicated inputs even if it refers to
+   * the same expr multiple times.
+   */
+  std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> region_func_in;
 };
 
 /*! \brief This class partitions the expr labeled with begin and end 
annotations
@@ -124,37 +115,35 @@ class AnnotationChecker : public ExprVisitor {
  *         the compiler name.
  */
 
-class Partitioner : public ExprMutator {
+class Partitioner : public MixedModeMutator {
  public:
   explicit Partitioner(const IRModule& module) : module_(module) {
     for (auto f : module->functions) {
       GlobalVar f_var = f.first;
       BaseFunc f_func = f.second;
 
-      // Creating regionset per function in the module
+      // Creating regionset per function in the module.
       auto region_set = AnnotatedRegionSet::Create(f_func, 
partitioning::compiler_begin_op,
                                                    
partitioning::compiler_end_op);
       regions_sets_[region_set] = f_func;
     }
   }
 
-  Expr VisitExpr_(const CallNode* call) final {
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
     auto op_node = call->op.as<OpNode>();
     if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
-      return ExprMutator::VisitExpr_(call);
+      return post;
     } else if (call->op == compiler_begin_op) {
-      // The annotation node is inserted on edge so it must have only one
-      // argument.
+      // The annotation node is inserted on edge so it must have only one 
argument.
       CHECK_EQ(call->args.size(), 1U);
 
       // Traverse the rest graph.
       Expr parent = call->args[0];
-      auto input_expr = VisitExpr(parent);
+      auto input_expr = Downcast<Call>(post)->args[0];
 
       // Backtrace the parent to find the first ancestor node that is not a 
begin or end op
       while (const auto* parent_call = parent.as<CallNode>()) {
-        if (parent_call->op == compiler_begin_op ||
-            parent_call->op == compiler_end_op) {
+        if (parent_call->op == compiler_begin_op || parent_call->op == 
compiler_end_op) {
           parent = parent_call->args[0];
         } else {
           break;
@@ -165,8 +154,8 @@ class Partitioner : public ExprMutator {
       int index = GetArgIdx(sg, GetRef<Call>(call));
       CHECK_NE(index, -1);
 
-      if (shared_output_.count(parent) && shared_output_[parent].count(sg)) {
-        return shared_output_[parent][sg];
+      if (region_func_meta_[sg].region_func_in.count(parent)) {
+        return region_func_meta_[sg].region_func_in[parent];
       } else {
         // The type of the created variable is the same as the compiler_begin
         // node.
@@ -177,11 +166,11 @@ class Partitioner : public ExprMutator {
 
         std::pair<Var, Expr> cand = std::make_pair(var, input_expr);
 
-        if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
-            region_args[sg].end()) {
-          region_args[sg].push_back(cand);
+        if (std::find(region_func_meta_[sg].args.begin(), 
region_func_meta_[sg].args.end(), cand) ==
+            region_func_meta_[sg].args.end()) {
+          region_func_meta_[sg].args.push_back(cand);
         }
-        shared_output_[parent][sg] = var;
+        region_func_meta_[sg].region_func_in[parent] = var;
         return std::move(var);
       }
     } else {
@@ -197,114 +186,21 @@ class Partitioner : public ExprMutator {
       BaseFunc f = GetFunc(GetRef<Call>(call));
 
       // Traverse subgraph inputs.
-      auto input = VisitExpr(call->args[0]);
+      auto input = Downcast<Call>(post)->args[0];
       CHECK(region.defined()) << "Region not defined for " << 
GetRef<Call>(call);
       // functions are created for each annotated regions,
       // when their first output is encountered.
       // If multiple outputs are there, a tuple node is inserted at the end.
-      // region_function_calls is map that maintains
-      // (each annotated regions) --> created function
 
-      if (region_function_calls.find(region) == region_function_calls.end()) {
-        // First time this region is encountered in the traversal.
-        // Creating the function.
+      if (!region_func_meta_[region].func_call.defined()) {
+        // First time this region is encountered in the traversal. Creating 
the function.
         CreateFunction(region, call);
       }
-      // Retrieve this particular output of function.
-      return GetFunctionOutput(region, GetRef<Call>(call));
-    }
-  }
-
-  Expr VisitExpr_(const TupleNode* op) final {
-    auto region = GetRegion(GetRef<Tuple>(op));
-    if (!region.defined()) {
-      return ExprMutator::VisitExpr_(op);
-    } else {
-      Array<Expr> fields;
-      for (auto field : op->fields) {
-        fields.push_back(VisitExpr(field));
-      }
-      return Tuple(fields);
-    }
-  }
-
-  Expr VisitExpr_(const TupleGetItemNode* g) final {
-    auto region = GetRegion(GetRef<TupleGetItem>(g));
-    if (!region.defined()) {
-      return ExprMutator::VisitExpr_(g);
-    } else {
-      auto t = VisitExpr(g->tuple);
-      return TupleGetItem(t, g->index);
-    }
-  }
-
-  Expr VisitExpr_(const FunctionNode* op) final {
-    auto region = GetRegion(GetRef<Function>(op));
-    if (!region.defined()) {
-      return ExprMutator::VisitExpr_(op);
-    } else {
-      Array<Var> params;
-      for (auto param : op->params) {
-        Var new_param = Downcast<Var>(VisitExpr(param));
-        params.push_back(new_param);
-      }
-      auto body = VisitExpr(op->body);
-      return Function(params, body, op->ret_type, op->type_params, op->attrs);
-    }
-  }
-
-  Expr VisitExpr_(const LetNode* op) final {
-    auto region = GetRegion(GetRef<Let>(op));
-    if (!region.defined()) {
-      return ExprMutator::VisitExpr_(op);
-    } else {
-      Var var = Downcast<Var>(VisitExpr(op->var));
-      auto value = VisitExpr(op->value);
-      auto body = VisitExpr(op->body);
-      return Let(var, value, body);
-    }
-  }
-
-  Expr VisitExpr_(const IfNode* op) final {
-    auto region = GetRegion(GetRef<If>(op));
-    if (!region.defined()) {
-      return ExprMutator::VisitExpr_(op);
-    } else {
-      auto guard = VisitExpr(op->cond);
-      auto true_b = VisitExpr(op->true_branch);
-      auto false_b = VisitExpr(op->false_branch);
-      return If(guard, true_b, false_b);
-    }
-  }
-
-  Expr VisitExpr_(const RefCreateNode* op) final {
-    auto region = GetRegion(GetRef<RefCreate>(op));
-    if (!region.defined()) {
-      return ExprMutator::VisitExpr_(op);
-    } else {
-      Expr value = VisitExpr(op->value);
-      return RefCreate(value);
-    }
-  }
 
-  Expr VisitExpr_(const RefReadNode* op) final {
-    auto region = GetRegion(GetRef<RefRead>(op));
-    if (!region.defined()) {
-      return ExprMutator::VisitExpr_(op);
-    } else {
-      Expr ref = VisitExpr(op->ref);
-      return RefRead(ref);
-    }
-  }
-
-  Expr VisitExpr_(const RefWriteNode* op) final {
-    auto region = GetRegion(GetRef<RefWrite>(op));
-    if (!region.defined()) {
-      return ExprMutator::VisitExpr_(op);
-    } else {
-      Expr ref = VisitExpr(op->ref);
-      Expr value = VisitExpr(op->value);
-      return RefWrite(ref, value);
+      // Retrieve this particular output of function.
+      Expr region_out_expr = Downcast<Call>(GetRef<Call>(call))->args[0];
+      CHECK(region_func_meta_[region].region_func_out.count(region_out_expr));
+      return region_func_meta_[region].region_func_out[region_out_expr];
     }
   }
 
@@ -370,24 +266,22 @@ class Partitioner : public ExprMutator {
   }
 
   /*!
-   * \brief This function is called first time that we encounter a compiler_end
-   * node to create the function for the subgraph.
+   * \brief Create a function and its function call for the given region. If 
the function has
+   * multiple outputs, a Tuple will be formed to aggregate all outputs, and 
TupleGetItem nodes
+   * will be created to serve output consumers.
    */
-  void CreateFunction(AnnotatedRegion region, const CallNode* call) {
-    // Create fields which is a unique list of outputs. Also populate
-    // region_return_indices_ map which maps parent of compiler_end node to
-    // corresponding index in fields.
+  void CreateFunction(AnnotatedRegion region, const CallNode* end_node) {
+    // Create fields which is a unique list of outputs.
     Array<Expr> fields;
-    int i = 0;
-    for (auto ret : region->GetOutputs()) {
-      auto ret_node = Downcast<Call>(ret)->args[0];
+    std::unordered_map<Expr, int, ObjectHash, ObjectEqual> out_expr_to_idx;
+    int out_idx = 0;
+    for (auto region_end_node : region->GetOutputs()) {
+      auto ret_node = Downcast<Call>(region_end_node)->args[0];
       // Don't duplicate outputs.
-      if (!region_return_indices_.count(region) ||
-          !region_return_indices_[region].count(ret_node)) {
-        auto ret_expr = VisitExpr(ret_node);
+      if (!out_expr_to_idx.count(ret_node)) {
+        auto ret_expr = MixedModeMutator::VisitExpr(ret_node);
         fields.push_back(ret_expr);
-        region_return_indices_[region][ret_node] = i;
-        i++;
+        out_expr_to_idx[ret_node] = out_idx++;
       }
     }
 
@@ -396,20 +290,14 @@ class Partitioner : public ExprMutator {
     Map<Var, Expr> params_bind;
 
     auto IsConstant = [](const Expr& expr) {
-      if (expr->IsInstance<ConstantNode>())
-        return true;
-      if (expr->IsInstance<TupleNode>()) {
-        auto tuple = expr.as<TupleNode>();
-        for (const auto& field : tuple->fields) {
-          if (!field->IsInstance<ConstantNode>())
-            return false;
-        }
-        return true;
-      }
-      return false;
+      if (expr->IsInstance<ConstantNode>()) return true;
+      if (!expr->IsInstance<TupleNode>()) return false;
+      const auto* tn = expr.as<TupleNode>();
+      return std::all_of(tn->fields.begin(), tn->fields.end(),
+                         [](const Expr& e) { return 
e->IsInstance<ConstantNode>(); });
     };
 
-    for (auto pair : region_args[region]) {
+    for (auto pair : region_func_meta_[region].args) {
       params.push_back(pair.first);
       if (IsConstant(pair.second)) {
         params_bind.Set(pair.first, pair.second);
@@ -422,23 +310,21 @@ class Partitioner : public ExprMutator {
     if (fields.size() == 1) {
       // If there are only a single output; no need to add a tuple
       global_region_func =
-          Function(params, fields[0], call->args[0]->checked_type_, {}, 
DictAttrs());
+          Function(params, fields[0], end_node->args[0]->checked_type_, {}, 
DictAttrs());
     } else {
       auto tuple = Tuple(fields);
       global_region_func = Function(params, tuple, tuple->checked_type_, {}, 
DictAttrs());
     }
 
-    std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+    std::string target = end_node->attrs.as<CompilerAttrs>()->compiler;
     std::string name = target + "_" + std::to_string(region->GetID());
 
-    global_region_func = WithAttr(std::move(global_region_func), 
tvm::attr::kGlobalSymbol,
-                                  runtime::String(name));
     global_region_func =
-        WithAttr(std::move(global_region_func), attr::kPrimitive, 
tvm::Integer(1));
-    global_region_func = WithAttr(std::move(global_region_func), 
attr::kCompiler,
-                                  tvm::runtime::String(target));
+        WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, 
runtime::String(name));
+    global_region_func = WithAttr(std::move(global_region_func), 
attr::kPrimitive, tvm::Integer(1));
     global_region_func =
-        WithAttr(std::move(global_region_func), attr::kInline, 
tvm::Integer(1));
+        WithAttr(std::move(global_region_func), attr::kCompiler, 
tvm::runtime::String(target));
+    global_region_func = WithAttr(std::move(global_region_func), 
attr::kInline, tvm::Integer(1));
 
     // Constant propagation
     if (!params_bind.empty()) {
@@ -446,8 +332,7 @@ class Partitioner : public ExprMutator {
     }
 
     std::string fname = name;
-    CHECK(!module_->ContainGlobalVar(fname))
-        << "Global function " << fname << " already exists";
+    CHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << 
" already exists";
     // Create a global function and add it to the IRModule for the region.
     // This way we lift the functions that should be handled by external
     // codegen to the module scope and rely on the pass manager to prevent
@@ -456,129 +341,81 @@ class Partitioner : public ExprMutator {
     GlobalVar glob_func(fname);
     module_->Add(glob_func, global_region_func);
 
-    // The return type of callnode is the same as the type of the
-    // compiler_end node.
-    auto ret = Call(glob_func, param_expr);
-    region_function_calls[region] = ret;
-  }
+    // Create a call node for the function.
+    auto call = Call(glob_func, param_expr);
+    region_func_meta_[region].func_call = call;
 
-  /*!
-   * \brief Get the return(output) of the function for compiler end node 
"end_arg".
-   * This will return either a Call (for a function with a single output) or a
-   * TupleGetItem (for a function with multiple outputs).
-   */
-  Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
-    Expr arg = Downcast<Call>(end_arg)->args[0];
-    // Function has one output.
-    if (region_return_indices_[region].size() == 1) {
-      return region_function_calls[region];
-    }
-    // Function has multiple outputs.
-    // Use already made TupleGetItem.
-    if (region_return_tuplegetitem_.count(region) &&
-        region_return_tuplegetitem_[region].count(arg)) {
-      return region_return_tuplegetitem_[region][arg];
+    // Create output expr(s) for the function call.
+    if (out_expr_to_idx.size() == 1) {
+      // Single output direcly uses the call node as the output expr.
+      
region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = 
call;
+    } else {
+      // Multiple outptus need to create TupleGetItem nodes as output exprs.
+      for (auto pair : out_expr_to_idx) {
+        Expr region_out_expr = pair.first;  // The arg of a compiler end node 
of this region.
+        int idx = pair.second;              // Corresponding function output 
tuple index.
+        auto tuple_get_item = TupleGetItem(call, idx);
+        tuple_get_item->checked_type_ = region_out_expr->checked_type_;
+        region_func_meta_[region].region_func_out[region_out_expr] = 
tuple_get_item;
+      }
     }
-    // Create new TupleGetItem.
-    CHECK(region_return_indices_.count(region) &&
-          region_return_indices_[region].count(arg));
-    int index = region_return_indices_[region][arg];
-
-    auto func_call = region_function_calls[region];
-    auto tuple_get_item_ = TupleGetItem(func_call, index);
-    tuple_get_item_->checked_type_ = arg->checked_type_;
-    region_return_tuplegetitem_[region][arg] = tuple_get_item_;
-    return std::move(tuple_get_item_);
   }
 
-  /*!
-   * \brief This map maintains the already created function calls.
-   * This is required in the multi-output scenario, to link rest of the outputs
-   * to call
-   */
-  std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> 
region_function_calls;
-
-  /*!
-   * \brief This map maintains arguments (of region) visits through visitor
-   * patterns. Those arguement var and expression will be used to when creating
-   * the function.
-   */
-  std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, 
ObjectHash, ObjectEqual>
-      region_args;
-
-  /*!
-   * \brief This map maintains the index of an output in the subgraph function
-   * for a given region. If there are multiple entries for a region, then the
-   * function has a tuple of multiple outputs for its return.
-   */
-  using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash, 
ObjectEqual>;
-  std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, 
ObjectEqual>
-      region_return_indices_;
+  /*! \brief Map from each region to its metadata of the generated function. */
+  std::unordered_map<AnnotatedRegion, RegionFuncMetadata, ObjectHash, 
ObjectEqual>
+      region_func_meta_;
 
-  /*!
-   * \brief This map holds already created TupleGetItem nodes for accessing
-   * outputs of a function.
-   */
-  using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, 
ObjectHash, ObjectEqual>;
-  std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, 
ObjectEqual>
-      region_return_tuplegetitem_;
-
-  /*!
-   * \brief Each region set is associated with a function in the module.
+  /*! \brief Each region set is associated with a function in the module.
    * This map maintains the mapping between regionsets and the function it
    * belongs to
    */
   std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> 
regions_sets_;
 
-  /*!\brief Cache the output that is shared by different nodes. */
-  using RegionOutputMap = std::unordered_map<AnnotatedRegion, Var, ObjectHash, 
ObjectEqual>;
-  std::unordered_map<Expr, RegionOutputMap, ObjectHash, ObjectEqual> 
shared_output_;
-
   /*!\brief The IRModule used for partitioning. */
   IRModule module_;
 };
 
-class DefaultRemover : public ExprMutator {
- public:
-  explicit DefaultRemover(const IRModule& module) : module_(module) {}
+IRModule RemoveDefaultAnnotations(IRModule module) {
+  class DefaultRemover : public ExprRewriter {
+   public:
+    DefaultRemover() = default;
 
-  IRModule Remove() {
-    auto glob_funcs = module_->functions;
-    for (const auto& pair : glob_funcs) {
-      if (auto* fn = pair.second.as<FunctionNode>()) {
-        auto func = GetRef<Function>(fn);
-        func = Function(func->params, VisitExpr(func->body), func->ret_type, 
func->type_params,
-                        func->attrs);
-        module_->Update(pair.first, func);
+    Expr Rewrite_(const CallNode* call, const Expr& post) final {
+      auto attrs = call->attrs.as<CompilerAttrs>();
+      if (attrs != nullptr && attrs->compiler == "default") {
+        return Downcast<Call>(post)->args[0];
       }
+      return post;
     }
-    return module_;
-  }
+  };
 
-  Expr VisitExpr_(const CallNode* call) final {
-    auto attrs = call->attrs.as<CompilerAttrs>();
-    if (attrs != nullptr && attrs->compiler == "default") {
-      return VisitExpr(call->args[0]);
+  auto glob_funcs = module->functions;
+  // module is mutable, hence, we make a copy of it.
+  module.CopyOnWrite();
+  for (const auto& pair : glob_funcs) {
+    if (auto* fn = pair.second.as<FunctionNode>()) {
+      auto func = GetRef<Function>(fn);
+      DefaultRemover remover;
+      auto removed = PostOrderRewrite(func->body, &remover);
+      func = Function(func->params, removed, func->ret_type, 
func->type_params, func->attrs);
+      module->Update(pair.first, func);
     }
-    return ExprMutator::VisitExpr_(call);
   }
-
- private:
-  IRModule module_;
-};
+  return module;
+}
 
 }  // namespace partitioning
 
 namespace transform {
 
 Pass PartitionGraph() {
-  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
-      [=](IRModule m, PassContext pc) {
-        // TODO(@comaniac, @zhiics): We should also handle the annotation with 
"default" attribute
-        // by treating them as un-annotated, but we don't have it yet. This 
workaround pass removes
-        // all "default" annotations and should be deleted in the future.
-        auto new_m = partitioning::DefaultRemover(m).Remove();
-        return partitioning::Partitioner(new_m).Partition();
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = 
[=](IRModule m,
+                                                                            
PassContext pc) {
+    // TODO(@comaniac, @zhiics): We should also handle the annotation with 
"default" attribute
+    // by treating them as un-annotated, but we don't have it yet. This 
workaround pass removes
+    // all "default" annotations and should be deleted in the future.
+    auto new_m = partitioning::RemoveDefaultAnnotations(m);
+    return partitioning::Partitioner(new_m).Partition();
   };
   auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
   return Sequential({partitioned, InferType()});

Reply via email to