tqchen commented on a change in pull request #4886: [WIP][POC]First pass a
defining at non-recursive Graph Vistor and Rewriter
URL: https://github.com/apache/incubator-tvm/pull/4886#discussion_r401215336
##########
File path: include/tvm/relay/expr_functor.h
##########
@@ -232,6 +232,181 @@ class ExprMutator
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_;
};
+/*!
+ * \brief A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
+ *
+ * DataflowVisitor treats Expr as dataflow graph, and visits in post-DFS order
+ *
+ * DataflowVisitor provides the same recursive API as ExprVisitor, and uses
+ * recursion to traverse most forms of the IR, but under the hood it expands
nested dataflow regions
+ * of the graph and processes them iteratatively to prevent stack overflows
+ */
+class DataflowVisitor : public ::tvm::relay::ExprVisitor {
+ public:
+ DataflowVisitor(int visit_limit = 1);
+
+ /*!
+ * \brief VisitExpr is finalized to preserve call expansion of dataflow
regions
+ */
+ void VisitExpr(const Expr& expr) final;
+ void VisitExpr_(const CallNode* op) override;
+ void VisitExpr_(const TupleNode* op) override;
+ void VisitExpr_(const TupleGetItemNode* op) override;
+
+
+ protected:
+ /*!
+ * \brief A function to apply when reaching a leaf of the graph
non-recursively
+ */
+ virtual void VisitLeaf(const Expr& expr);
+ /*!
+ * \brief A function to determine if an expression has already been visited
or needs to be
+ * re-visited
+ */
+ virtual bool CheckVisited(const Expr& expr);
+ /*!
+ * \brief The max number of times to visit a node
+ */
+ size_t visit_limit_;
+};
+
+/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
+ *
+ * ScopeMutator treats Expr as dataflow graph, and only Rewrites each Expr
once.
+ * The mutated results are memoized in a map and reused so that
+ * local transformation on the dataflow preserves the graph structure.
+ *
+ * ScopeMutator provides the same recursive API as ExprMutator, and uses
+ * recursion to traverse most forms of the IR, but under the hood it expands
nested dataflow regions
+ * of the graph and processes them iteratatively to prevent stack overflows
+ *
+ * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and
non-recursive behavior.
+ */
+class ScopeMutator : public ::tvm::relay::ExprMutator {
+ public:
+ Expr Mutate(const Expr& expr) final;
+ Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
+ Expr VisitExpr_(const CallNode* call_node) final { return
Rewrite(call_node); };
+ Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
+ /*!
+ * Users should override Rewrite_ methods to implement their pass. Rewrite_
functions will be
+ * able to rewrite the op only with data about the original node `pre` and
the same node with
Review comment:
document all the arguments
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services