This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e7748aa  [Relay] Extract dataflow matcher data structure into header 
(#8774)
e7748aa is described below

commit e7748aac40bd4c263882323393ea8896837614a9
Author: Haichen Shen <[email protected]>
AuthorDate: Tue Aug 17 23:16:54 2021 -0700

    [Relay] Extract dataflow matcher data structure into header (#8774)
    
    * extract dataflow matcher data structure into a header file
    
    * lint
    
    * lint
---
 src/relay/ir/dataflow_matcher.cc     | 615 ++++++++++++++---------------------
 src/relay/ir/dataflow_matcher_impl.h | 164 ++++++++++
 2 files changed, 417 insertions(+), 362 deletions(-)

diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 5ce06d9..d7f130f 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -29,50 +29,12 @@
 
 #include <stack>
 
-#include "indexed_graph.h"
+#include "dataflow_matcher_impl.h"
 
 namespace tvm {
 namespace relay {
 
 // Pattern Matcher
-
-class DominatorMatcher;
-
-class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const 
Expr&)> {
- public:
-  explicit DFPatternMatcher(const Expr& root_expr) : 
expr_graph_(CreateIndexedGraph(root_expr)) {}
-  bool Match(const DFPattern& pattern, const Expr& expr);
-  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, 
Array<Expr>>(memo_); }
-  const IndexedGraph<Expr> expr_graph_;
-
- protected:
-  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
-  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) 
override;
-  bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) 
override;
-  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) 
override;
-  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) 
override;
-  bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) 
override;
-  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) 
override;
-
-  void ClearMap(size_t watermark);
-  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
-  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
-
-  std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual> 
memo_;
-  std::vector<DFPattern> matched_nodes_;
-  bool memoize_ = true;
-};
-
 bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
   memo_.clear();
   matched_nodes_.clear();
@@ -542,304 +504,251 @@ bool MatchPattern(DFPattern pattern, Expr expr) {
 
 
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern);
 
-/*!
- * \brief PatternGrouper does pre-rewriting pattern matching and analysis
- *
- * This class creates a number of groups of matched expressions, ensures they 
don't overlap, and
- * returns them to the caller for post-analysis rewriting.
- *
- * This is primarily needed to support the post-dominator analysis required 
for dominator pattern
- * matching.
- */
-class PatternGrouper {
+/*! \brief Creates a new set of nodes based on Group inputs, used to create 
functions and perform
+ * group overlap analysis */
+class MatchExtractor : public ExprMutator {
  public:
-  /*! \brief Internal Group class for storing analysis */
-  struct Group {
-    Expr root_node;
-    int gid;
-    Map<DFPattern, Array<Expr>> matched_nodes;
-    std::string name;
-    Function function;
-    Array<Expr> args;
-  };
-
-  /*! \brief Return the group assignments of expressions */
-  const std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>& 
GetGIDAssignments() {
-    return gid_assignments_;
+  explicit MatchExtractor(
+      const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual>& 
inputs)
+      : inputs_(inputs) {}
+  const std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>& 
GetMemo() {
+    return this->memo_;
   }
-  /*! \brief Group expressions that match the pattern */
-  const std::unordered_map<int, Group>& GroupMatches(const DFPattern& pattern, 
const Expr& pre) {
-    groups_.clear();
-    gid_assignments_.clear();
+  const std::string& GetName() { return name_; }
 
-    pattern_ = pattern;
-    pattern_graph_ = CreateIndexedGraph(pattern_);
-    auto matcher = DFPatternMatcher(pre);
-    matcher_ = &matcher;
-    this->VisitExprs();
-    return this->groups_;
+ protected:
+  Expr VisitExpr(const Expr& pre) override {
+    if (inputs_.count(pre)) {
+      return inputs_.at(pre);
+    }
+    return ExprMutator::VisitExpr(pre);
   }
+  Expr VisitExpr_(const TupleNode* op) override {
+    auto out = ExprMutator::VisitExpr_(op);
+    name_ += "Tuple_";
+    return out;
+  };
+  Expr VisitExpr_(const FunctionNode* op) override {
+    auto out = ExprMutator::VisitExpr_(op);
+    name_ += "Function";
+    return out;
+  };
+  Expr VisitExpr_(const CallNode* call_node) override {
+    auto out = ExprMutator::VisitExpr_(call_node);
+    if (auto operation = call_node->op.as<OpNode>()) {
+      name_ += operation->name + "_";
+    } else {
+      name_ += "Call_";
+    }
+    return out;
+  };
+  Expr VisitExpr_(const LetNode* op) override {
+    auto out = ExprMutator::VisitExpr_(op);
+    name_ += "Let_";
+    return out;
+  };
+  Expr VisitExpr_(const IfNode* op) override {
+    auto out = ExprMutator::VisitExpr_(op);
+    name_ += "If_";
+    return out;
+  };
+  Expr VisitExpr_(const TupleGetItemNode* op) override {
+    auto out = ExprMutator::VisitExpr_(op);
+    name_ += "TupleGetItem" + std::to_string(op->index) + "_";
+    return out;
+  };
+  Expr VisitExpr_(const MatchNode* op) override {
+    auto out = ExprMutator::VisitExpr_(op);
+    name_ += "Match_";
+    return out;
+  };
+  std::string name_;
+  const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs_;
+};
 
- protected:
-  /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs
-   *
-   * If we traverse the graph in post-order, we can run into situtations where 
a small subgraph will
-   * match the pattern. Due to options like AltPattern, a larger subgraph with 
more nodes later in
-   * the graph may also match the pattern. With post-order traversal, we mark 
the smaller subgraph
-   * as matched and fail to catch the larger subgraph. This problem is fixed 
by using pre-order
-   * traversal.
-   */
-  void VisitExprs() {
-    std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> pre_partitioned;
-    for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; 
--i) {
-      size_t index = i - 1;
-      Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_;
-      if (gid_assignments_.count(current) == 0) {  // Don't visit nodes we've 
already grouped
-        if (auto op = current.as<FunctionNode>()) {
-          if (op->attrs.defined() && 
op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
-            pre_partitioned.insert(current);
-            PostOrderVisit(op->body,
-                           [&pre_partitioned](const Expr& expr) { 
pre_partitioned.insert(expr); });
-          }
-        }
-        if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, 
current)) {
-          CreateGroup(current);
+/*! \brief Group expressions that match the pattern */
+const std::unordered_map<int, PatternGrouper::Group>& 
PatternGrouper::GroupMatches(
+    const DFPattern& pattern, const Expr& pre) {
+  groups_.clear();
+  gid_assignments_.clear();
+
+  pattern_ = pattern;
+  pattern_graph_ = CreateIndexedGraph(pattern_);
+  auto matcher = DFPatternMatcher(pre);
+  matcher_ = &matcher;
+  this->VisitExprs();
+  return this->groups_;
+}
+
+void PatternGrouper::VisitExprs() {
+  std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> pre_partitioned;
+  for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; 
--i) {
+    size_t index = i - 1;
+    Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_;
+    if (gid_assignments_.count(current) == 0) {  // Don't visit nodes we've 
already grouped
+      if (auto op = current.as<FunctionNode>()) {
+        if (op->attrs.defined() && 
op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
+          pre_partitioned.insert(current);
+          PostOrderVisit(op->body,
+                         [&pre_partitioned](const Expr& expr) { 
pre_partitioned.insert(expr); });
         }
       }
+      if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, 
current)) {
+        CreateGroup(current);
+      }
     }
   }
-  /*! \brief Creates a new set of nodes based on Group inputs, used to create 
functions and perform
-   * group overlap analysis */
-  class MatchExtractor : public ExprMutator {
-   public:
-    explicit MatchExtractor(
-        const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual>& 
inputs)
-        : inputs_(inputs) {}
-    const std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>& 
GetMemo() {
-      return this->memo_;
-    }
-    const std::string& GetName() { return name_; }
+}
 
-   protected:
-    Expr VisitExpr(const Expr& pre) override {
-      if (inputs_.count(pre)) {
-        return inputs_.at(pre);
+void PatternGrouper::CreateGroup(const Expr& expr) {
+  int var_number = 0;
+
+  auto node_map = matcher_->GetMemo();
+  // Get fuzzy patterns
+  std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
+  for (auto node : pattern_graph_.topological_order_) {
+    // Don't treat fuzzy Dominator patterns input variables for partition
+    if (auto op = node->ref_.as<DominatorPatternNode>()) {
+      for (auto fuzzy_op : {op->parent, op->path}) {
+        for (auto match : node_map[fuzzy_op]) {
+          fuzzy_matches.insert(match);
+        }
       }
-      return ExprMutator::VisitExpr(pre);
     }
-    Expr VisitExpr_(const TupleNode* op) override {
-      auto out = ExprMutator::VisitExpr_(op);
-      name_ += "Tuple_";
-      return out;
-    };
-    Expr VisitExpr_(const FunctionNode* op) override {
-      auto out = ExprMutator::VisitExpr_(op);
-      name_ += "Function";
-      return out;
-    };
-    Expr VisitExpr_(const CallNode* call_node) override {
-      auto out = ExprMutator::VisitExpr_(call_node);
-      if (auto operation = call_node->op.as<OpNode>()) {
-        name_ += operation->name + "_";
-      } else {
-        name_ += "Call_";
+    // Don't treat Function params or body as input variables for partition
+    if (node->ref_.as<FunctionPatternNode>()) {
+      auto matches = node_map[node->ref_];
+      for (auto match : matches) {
+        auto graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
+        for (auto node : graph.topological_order_) {
+          fuzzy_matches.insert(node->ref_);
+        }
       }
-      return out;
-    };
-    Expr VisitExpr_(const LetNode* op) override {
-      auto out = ExprMutator::VisitExpr_(op);
-      name_ += "Let_";
-      return out;
-    };
-    Expr VisitExpr_(const IfNode* op) override {
-      auto out = ExprMutator::VisitExpr_(op);
-      name_ += "If_";
-      return out;
-    };
-    Expr VisitExpr_(const TupleGetItemNode* op) override {
-      auto out = ExprMutator::VisitExpr_(op);
-      name_ += "TupleGetItem" + std::to_string(op->index) + "_";
-      return out;
-    };
-    Expr VisitExpr_(const MatchNode* op) override {
-      auto out = ExprMutator::VisitExpr_(op);
-      name_ += "Match_";
-      return out;
-    };
-    std::string name_;
-    const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs_;
-  };
+    }
+  }
 
-  /*! \brief Create a group based on a matched expression */
-  void CreateGroup(const Expr& expr) {
-    int var_number = 0;
-
-    auto node_map = matcher_->GetMemo();
-    // Get fuzzy patterns
-    std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
-    for (auto node : pattern_graph_.topological_order_) {
-      // Don't treat fuzzy Dominator patterns input variables for partition
-      if (auto op = node->ref_.as<DominatorPatternNode>()) {
-        for (auto fuzzy_op : {op->parent, op->path}) {
-          for (auto match : node_map[fuzzy_op]) {
-            fuzzy_matches.insert(match);
-          }
-        }
+  // Create input variables
+  Group group;
+  group.root_node = expr;
+  group.matched_nodes = node_map;
+
+  std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
+  Array<Var> params;
+
+  for (auto node : pattern_graph_.topological_order_) {
+    auto make_input = [&](const Expr& input) {
+      if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
+          input.as<FunctionNode>() == nullptr && !EmbedConst(input, 
node->ref_)) {
+        inputs[input] =
+            Var("FunctionVar_" + std::to_string(graph_number_) + "_" + 
std::to_string(var_number),
+                NullValue<Type>());
+        group.args.push_back(input);
+        params.push_back(inputs[input]);
+        var_number++;
       }
-      // Don't treat Function params or body as input variables for partition
-      if (node->ref_.as<FunctionPatternNode>()) {
+    };
+    auto tuple = node->ref_.as<TuplePatternNode>();
+    auto call = node->ref_.as<CallPatternNode>();
+    if (tuple && !tuple->fields.defined()) {
+      if (node_map.count(node->ref_)) {
         auto matches = node_map[node->ref_];
         for (auto match : matches) {
-          auto graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
-          for (auto node : graph.topological_order_) {
-            fuzzy_matches.insert(node->ref_);
+          for (auto input : match.as<TupleNode>()->fields) {
+            make_input(input);
           }
         }
       }
-    }
-
-    // Create input variables
-    Group group;
-    group.root_node = expr;
-    group.matched_nodes = node_map;
-
-    std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
-    Array<Var> params;
-
-    for (auto node : pattern_graph_.topological_order_) {
-      auto make_input = [&](const Expr& input) {
-        if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
-            input.as<FunctionNode>() == nullptr && !EmbedConst(input, 
node->ref_)) {
-          inputs[input] =
-              Var("FunctionVar_" + std::to_string(graph_number_) + "_" + 
std::to_string(var_number),
-                  NullValue<Type>());
-          group.args.push_back(input);
-          params.push_back(inputs[input]);
-          var_number++;
-        }
-      };
-      auto tuple = node->ref_.as<TuplePatternNode>();
-      auto call = node->ref_.as<CallPatternNode>();
-      if (tuple && !tuple->fields.defined()) {
-        if (node_map.count(node->ref_)) {
-          auto matches = node_map[node->ref_];
-          for (auto match : matches) {
-            for (auto input : match.as<TupleNode>()->fields) {
-              make_input(input);
-            }
-          }
-        }
-      } else if (call && !call->args.defined()) {
-        if (node_map.count(node->ref_)) {
-          auto matches = node_map[node->ref_];
-          for (auto match : matches) {
-            for (auto input : match.as<CallNode>()->args) {
-              make_input(input);
-            }
+    } else if (call && !call->args.defined()) {
+      if (node_map.count(node->ref_)) {
+        auto matches = node_map[node->ref_];
+        for (auto match : matches) {
+          for (auto input : match.as<CallNode>()->args) {
+            make_input(input);
           }
         }
-      } else if (node->inputs_.size() == 0) {
-        if (node_map.count(node->ref_)) {
-          auto matches = node_map[node->ref_];
-          for (auto match : matches) {
-            make_input(match);
-          }
+      }
+    } else if (node->inputs_.size() == 0) {
+      if (node_map.count(node->ref_)) {
+        auto matches = node_map[node->ref_];
+        for (auto match : matches) {
+          make_input(match);
         }
       }
     }
+  }
 
-    graph_number_++;
-
-    // Extract a Function. Used in Partition directly,
-    // used to determine Group overlap in other passes
-    auto extractor = MatchExtractor(inputs);
-    auto body = extractor.Mutate(expr);
-
-    group.function = Function(params, body, NullValue<Type>(), 
Array<TypeVar>());
-    group.name = extractor.GetName();
-    // Check to make sure we aren't overlapping with another group or creating 
an invalid fusion
-    // The MatchExtractor will create a new graph by replacing nodes that 
match the inputs of the
-    // pattern with the input FunctionVar* Variables. The resulting 
memoization map will only
-    // contain nodes in the expression that matched the pattern. If a 
non-input node of the pattern
-    // (i.e., some piece of computation) overlaps with the nodes in a previous 
group, we'll have a
-    // situation where we try to rewrite the same node twice in the second 
rewriting or parition
-    // pass. This isn't valid, so we check for it here. We ignore Ops, 
functions, and constants
-    // because they exist more globally outside of the fusion.
-    // Similiarly, if interior nodes in a group are used outside of the group 
fusing to a single
-    // output would create an invalid graph tranformation, so we block the 
creation of such groups.
-    auto memo = extractor.GetMemo();
-    for (auto kv : memo) {
-      // Check to ensure that this node isn't an input or a global
-      if (inputs.count(kv.first) == 0 && kv.first.as<OpNode>() == nullptr &&
-          kv.first.as<FunctionNode>() == nullptr && 
kv.first.as<ConstantNode>() == nullptr) {
-        if (gid_assignments_.count(kv.first) != 0) {
-          // check to see if the node is use in other groups
-          // Exit due to overlapping partitions
-          return;
-        } else if (kv.second != body) {
-          // if the node isn't the output of the group
-          auto node = matcher_->expr_graph_.node_map_.at(kv.first);
-          for (auto* output : node->outputs_) {
-            // and the node is used by nodes outside of the group
-            if (memo.count(output->ref_) == 0 &&
-                !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) {
-              // Exit because nodes in this pattern's body are used outside 
the pattern
-              // fusing it would be invalid
-              return;
-            }
+  graph_number_++;
+
+  // Extract a Function. Used in Partition directly,
+  // used to determine Group overlap in other passes
+  auto extractor = MatchExtractor(inputs);
+  auto body = extractor.Mutate(expr);
+
+  group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
+  group.name = extractor.GetName();
+  // Check to make sure we aren't overlapping with another group or creating 
an invalid fusion
+  // The MatchExtractor will create a new graph by replacing nodes that match 
the inputs of the
+  // pattern with the input FunctionVar* Variables. The resulting memoization 
map will only
+  // contain nodes in the expression that matched the pattern. If a non-input 
node of the pattern
+  // (i.e., some piece of computation) overlaps with the nodes in a previous 
group, we'll have a
+  // situation where we try to rewrite the same node twice in the second 
rewriting or parition
+  // pass. This isn't valid, so we check for it here. We ignore Ops, 
functions, and constants
+  // because they exist more globally outside of the fusion.
+  // Similiarly, if interior nodes in a group are used outside of the group 
fusing to a single
+  // output would create an invalid graph tranformation, so we block the 
creation of such groups.
+  auto memo = extractor.GetMemo();
+  for (auto kv : memo) {
+    // Check to ensure that this node isn't an input or a global
+    if (inputs.count(kv.first) == 0 && kv.first.as<OpNode>() == nullptr &&
+        kv.first.as<FunctionNode>() == nullptr && kv.first.as<ConstantNode>() 
== nullptr) {
+      if (gid_assignments_.count(kv.first) != 0) {
+        // check to see if the node is use in other groups
+        // Exit due to overlapping partitions
+        return;
+      } else if (kv.second != body) {
+        // if the node isn't the output of the group
+        auto node = matcher_->expr_graph_.node_map_.at(kv.first);
+        for (auto* output : node->outputs_) {
+          // and the node is used by nodes outside of the group
+          if (memo.count(output->ref_) == 0 &&
+              !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) {
+            // Exit because nodes in this pattern's body are used outside the 
pattern
+            // fusing it would be invalid
+            return;
           }
         }
       }
     }
-    // Assign Group Ids
-    group.gid = ++gid_;
-    for (auto kv : extractor.GetMemo()) {
-      gid_assignments_[kv.first] = gid_;
-    }
+  }
+  // Assign Group Ids
+  group.gid = ++gid_;
+  for (auto kv : extractor.GetMemo()) {
+    gid_assignments_[kv.first] = gid_;
+  }
+
+  // Save Group
+  groups_[group.gid] = std::move(group);
+}
 
-    // Save Group
-    groups_[group.gid] = std::move(group);
-  }
-
-  /*! \brief EmbedConst implements rules for embedding constants into 
partitioned functions or
-   * lifting them into the function arguments.
-   *
-   * The rules depend on what pattern the ConstantNode matched.
-   *
-   * The basic rules are:
-   *  If the constant matches ExprPattern(relay.const(*)) or a 
ConstantPattern(), embed the constant
-   * in the partitioned function. If the constant matched an AltPattern, 
recursively check the
-   * matched side of the pattern. For any other matching pattern (i.e, 
wildcard, VarPattern, etc),
-   * lift the constant into the arguments of the partitioned function.
-   */
-  bool EmbedConst(const Expr& expr, const DFPattern pattern) {
-    bool embed = false;
-    if (expr.as<ConstantNode>()) {
-      if (pattern.as<ConstantPatternNode>() != nullptr) {
+bool PatternGrouper::EmbedConst(const Expr& expr, const DFPattern pattern) {
+  bool embed = false;
+  if (expr.as<ConstantNode>()) {
+    if (pattern.as<ConstantPatternNode>() != nullptr) {
+      embed = true;
+    } else if (auto expr_pat = pattern.as<ExprPatternNode>()) {
+      if (expr_pat->expr.as<ConstantNode>()) {
         embed = true;
-      } else if (auto expr_pat = pattern.as<ExprPatternNode>()) {
-        if (expr_pat->expr.as<ConstantNode>()) {
-          embed = true;
-        }
-      } else if (auto alt_pat = pattern.as<AltPatternNode>()) {
-        if (matcher_->Match(alt_pat->left, expr)) {
-          embed = EmbedConst(expr, alt_pat->left);
-        } else {
-          embed = EmbedConst(expr, alt_pat->right);
-        }
+      }
+    } else if (auto alt_pat = pattern.as<AltPatternNode>()) {
+      if (matcher_->Match(alt_pat->left, expr)) {
+        embed = EmbedConst(expr, alt_pat->left);
+      } else {
+        embed = EmbedConst(expr, alt_pat->right);
       }
     }
-    return embed;
   }
-  // Internal State
-  DFPattern pattern_;
-  std::unordered_map<int, Group> groups_;
-  std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> 
gid_assignments_;
-  DFPatternMatcher* matcher_ = nullptr;
-  IndexedGraph<DFPattern> pattern_graph_;
-  int gid_ = 0;
-  int graph_number_ = 0;
-};
+  return embed;
+}
 
 // Rewrite
 
@@ -858,72 +767,54 @@ 
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback")
       return DFPatternCallback(pattern, function, require_type);
     });
 
-/*!
- * \brief PatternRewriter rewrites the expression by finding matches and 
allowing user callback
- * function to rewrite those matches
- *
- * The class uses PatternGrouper to support the dominator pattern.
- */
-class PatternRewriter : protected MixedModeMutator {
- public:
-  PatternRewriter(IRModule mod) : mod_(mod) {}
-  /*! \brief Rewrite can take a number of callbacks and will repeatedly 
rewrite the graph with the
-   * callbacks until it stops changing */
-  Expr Rewrite(const Array<DFPatternCallback>& callbacks, const Expr& pre) {
-    auto post = pre;
-    auto last = post;
-    // rewrite the graph until it stops changing to make sure all rewrites are 
complete
-    int count = 0;
-    bool equal = true;
-    static auto* structural_equal = 
runtime::Registry::Get("node.StructuralEqual");
-    ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
-    do {
-      last = post;
-      for (auto callback : callbacks) {
-        callback_ = callback;
-        if (callback_->require_type) {
-          post = InferTypeWithModule(post, mod_);
-        }
-        auto grouper = PatternGrouper();
-        groups_ = grouper.GroupMatches(callback_->pattern, post);
-        gid_assignments_ = grouper.GetGIDAssignments();
-        memo_.clear();
-        post = this->VisitExpr(post);
-        count++;
-      }
-      equal = (*structural_equal)(last, post, false, true);
-    } while (!equal && count < 100);
-    if (count >= 100) {
-      LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting 
passes?";
+Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const 
Expr& pre) {
+  auto post = pre;
+  auto last = post;
+  // rewrite the graph until it stops changing to make sure all rewrites are 
complete
+  int count = 0;
+  bool equal = true;
+  static auto* structural_equal = 
runtime::Registry::Get("node.StructuralEqual");
+  ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
+  do {
+    last = post;
+    for (auto callback : callbacks) {
+      callback_ = callback;
+      if (callback_->require_type) {
+        post = InferTypeWithModule(post, mod_);
+      }
+      auto grouper = PatternGrouper();
+      groups_ = grouper.GroupMatches(callback_->pattern, post);
+      gid_assignments_ = grouper.GetGIDAssignments();
+      memo_.clear();
+      post = this->VisitExpr(post);
+      count++;
     }
-    return post;
+    equal = (*structural_equal)(last, post, false, true);
+  } while (!equal && count < 100);
+  if (count >= 100) {
+    LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?";
   }
+  return post;
+}
 
- protected:
-  Expr DispatchVisitExpr(const Expr& pre) override {
-    auto post = MixedModeMutator::DispatchVisitExpr(pre);
-    if (gid_assignments_.count(pre) && pre == 
groups_[gid_assignments_[pre]].root_node) {
-      // Convert the pre-rewrite node map to a post-rewrite node map
-      auto group = groups_[gid_assignments_[pre]];
-      std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, 
ObjectPtrEqual> node_map;
-      for (auto kv : group.matched_nodes) {
-        Array<Expr> tmp;
-        for (size_t i = 0; i < kv.second.size(); ++i) {
-          tmp.push_back(this->memo_[kv.second[i]]);
-        }
-        node_map.insert({kv.first, tmp});
+Expr PatternRewriter::DispatchVisitExpr(const Expr& pre) {
+  auto post = MixedModeMutator::DispatchVisitExpr(pre);
+  if (gid_assignments_.count(pre) && pre == 
groups_[gid_assignments_[pre]].root_node) {
+    // Convert the pre-rewrite node map to a post-rewrite node map
+    auto group = groups_[gid_assignments_[pre]];
+    std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual> 
node_map;
+    for (auto kv : group.matched_nodes) {
+      Array<Expr> tmp;
+      for (size_t i = 0; i < kv.second.size(); ++i) {
+        tmp.push_back(this->memo_[kv.second[i]]);
       }
-      // run the user callback function
-      return callback_->function(pre, post, Map<DFPattern, 
Array<Expr>>(node_map));
+      node_map.insert({kv.first, tmp});
     }
-    return post;
+    // run the user callback function
+    return callback_->function(pre, post, Map<DFPattern, 
Array<Expr>>(node_map));
   }
-
-  IRModule mod_;
-  DFPatternCallback callback_;
-  std::unordered_map<int, PatternGrouper::Group> groups_;
-  std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> 
gid_assignments_;
-};
+  return post;
+}
 
 Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule 
mod) {
   return PatternRewriter(mod).Rewrite(callbacks, expr);
diff --git a/src/relay/ir/dataflow_matcher_impl.h 
b/src/relay/ir/dataflow_matcher_impl.h
new file mode 100644
index 0000000..d993d47
--- /dev/null
+++ b/src/relay/ir/dataflow_matcher_impl.h
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher_impl.h
+ * \brief The auxiliary data structure for dataflow matcher.
+ */
+#ifndef TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_
+#define TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_
+
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/dataflow_pattern.h>
+#include <tvm/relay/dataflow_pattern_functor.h>
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "indexed_graph.h"
+
+namespace tvm {
+namespace relay {
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const 
Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : 
expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, 
Array<Expr>>(memo_); }
+  const IndexedGraph<Expr> expr_graph_;
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) 
override;
+  bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) 
override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) 
override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) 
override;
+  bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) 
override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) 
override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual> 
memo_;
+  std::vector<DFPattern> matched_nodes_;
+  bool memoize_ = true;
+};
+
+/*!
+ * \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they 
don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarily needed to support the post-dominator analysis required 
for dominator pattern
+ * matching.
+ */
+class PatternGrouper {
+ public:
+  /*! \brief Internal Group class for storing analysis */
+  struct Group {
+    Expr root_node;
+    int gid;
+    Map<DFPattern, Array<Expr>> matched_nodes;
+    std::string name;
+    Function function;
+    Array<Expr> args;
+  };
+
+  /*! \brief Return the group assignments of expressions */
+  inline const std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>& 
GetGIDAssignments() {
+    return gid_assignments_;
+  }
+  /*! \brief Group expressions that match the pattern */
+  const std::unordered_map<int, Group>& GroupMatches(const DFPattern& pattern, 
const Expr& pre);
+
+ protected:
+  /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs
+   *
+   * If we traverse the graph in post-order, we can run into situtations where 
a small subgraph will
+   * match the pattern. Due to options like AltPattern, a larger subgraph with 
more nodes later in
+   * the graph may also match the pattern. With post-order traversal, we mark 
the smaller subgraph
+   * as matched and fail to catch the larger subgraph. This problem is fixed 
by using pre-order
+   * traversal.
+   */
+  void VisitExprs();
+
+  /*! \brief Create a group based on a matched expression */
+  void CreateGroup(const Expr& expr);
+
+  /*! \brief EmbedConst implements rules for embedding constants into 
partitioned functions or
+   * lifting them into the function arguments.
+   *
+   * The rules depend on what pattern the ConstantNode matched.
+   *
+   * The basic rules are:
+   *  If the constant matches ExprPattern(relay.const(*)) or a 
ConstantPattern(), embed the constant
+   * in the partitioned function. If the constant matched an AltPattern, 
recursively check the
+   * matched side of the pattern. For any other matching pattern (i.e, 
wildcard, VarPattern, etc),
+   * lift the constant into the arguments of the partitioned function.
+   */
+  bool EmbedConst(const Expr& expr, const DFPattern pattern);
+  // Internal State
+  DFPattern pattern_;
+  std::unordered_map<int, Group> groups_;
+  std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> 
gid_assignments_;
+  DFPatternMatcher* matcher_ = nullptr;
+  IndexedGraph<DFPattern> pattern_graph_;
+  int gid_ = 0;
+  int graph_number_ = 0;
+};
+
+/*!
+ * \brief PatternRewriter rewrites the expression by finding matches and 
allowing user callback
+ * function to rewrite those matches
+ *
+ * The class uses PatternGrouper to support the dominator pattern.
+ */
+class PatternRewriter : protected MixedModeMutator {
+ public:
+  explicit PatternRewriter(IRModule mod) : mod_(mod) {}
+  /*! \brief Rewrite can take a number of callbacks and will repeatedly 
rewrite the graph with the
+   * callbacks until it stops changing */
+  virtual Expr Rewrite(const Array<DFPatternCallback>& callbacks, const Expr& 
pre);
+
+ protected:
+  virtual Expr DispatchVisitExpr(const Expr& pre);
+
+  IRModule mod_;
+  DFPatternCallback callback_;
+  std::unordered_map<int, PatternGrouper::Group> groups_;
+  std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> 
gid_assignments_;
+};
+
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_

Reply via email to