This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 7de8a53 [RELAY] Non-recursive Graph Vistor and Rewriter (#4886)
7de8a53 is described below
commit 7de8a539b1e73627309308b49c6c69625efc4d5a
Author: Matthew Brookhart <[email protected]>
AuthorDate: Fri Apr 3 14:35:55 2020 -0700
[RELAY] Non-recursive Graph Vistor and Rewriter (#4886)
* First pass a defining a non-recursive Graph Vistor and Rewriter
autoformat
remove a currently empty test until testing is solidfied
* Make CalcDep from Dead Code Elimination non-recursive
* Partially working, not passing all tests yet
passes tests when disabling GetExprRefCount, I think I have a bug in visit
counting
fix GetExprRefCount
Fix a subtle bug with nested recursive/non-recursive scopes
* Refactor
* improve comments
* respond to review comments on comments
* Fix a problem with default recursion for dataflow nodes
mark DataflowVisitor methods as override
* implement ScopeMutator
* convert forward_rewrite to ScopeMutator, remove DataflowMutator
* rewrite ExprRewriter and convert fast_math to use it
* switch BiasAddSimplifier to ExprRewriter
fix a clang warning
fix cpp lint
fix doc param error
* respond to review comments
* fix a typo in the iterative looping
* add a regression test for GetExprRefCount issue
* Normalize naming
* fix lint
* First pass a defining a non-recursive Graph Vistor and Rewriter
autoformat
remove a currently empty test until testing is solidfied
* Make CalcDep from Dead Code Elimination non-recursive
* Partially working, not passing all tests yet
passes tests when disabling GetExprRefCount, I think I have a bug in visit
counting
fix GetExprRefCount
Fix a subtle bug with nested recursive/non-recursive scopes
* Refactor
* improve comments
* respond to review comments on comments
* Fix a problem with default recursion for dataflow nodes
mark DataflowVisitor methods as override
* implement ScopeMutator
* convert forward_rewrite to ScopeMutator, remove DataflowMutator
* rewrite ExprRewriter and convert fast_math to use it
* switch BiasAddSimplifier to ExprRewriter
fix a clang warning
fix cpp lint
fix doc param error
* respond to review comments
* fix a typo in the iterative looping
* add a regression test for GetExprRefCount issue
* Normalize naming
* fix lint
* respond to review comments
---
include/tvm/relay/analysis.h | 11 ++
include/tvm/relay/expr_functor.h | 183 +++++++++++++++++++++++++++++++
src/relay/analysis/util.cc | 2 +-
src/relay/ir/expr_functor.cc | 158 +++++++++++++++++++++++++-
src/relay/transforms/canonicalize_ops.cc | 9 +-
src/relay/transforms/dead_code.cc | 9 +-
src/relay/transforms/fast_math.cc | 18 +--
src/relay/transforms/forward_rewrite.cc | 66 +++++------
src/relay/transforms/pass_util.h | 8 --
tests/cpp/relay_build_module_test.cc | 17 +++
10 files changed, 416 insertions(+), 65 deletions(-)
diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h
index e04b4e6..a2c0c75 100644
--- a/include/tvm/relay/analysis.h
+++ b/include/tvm/relay/analysis.h
@@ -30,6 +30,7 @@
#include <tvm/ir/module.h>
#include <tvm/relay/type.h>
#include <string>
+#include <unordered_map>
namespace tvm {
namespace relay {
@@ -225,6 +226,16 @@ TVM_DLL Map<Expr, Integer>
CollectDeviceAnnotationOps(const Expr& expr);
*/
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);
+/*!
+ * \brief Get reference counter of each internal ExprNode in body.
+ *
+ * \param body The body expression.
+ *
+ * \return The reference count mapping.
+ */
+TVM_DLL std::unordered_map<const Object*, size_t>
+GetExprRefCount(const Expr& body);
+
} // namespace relay
} // namespace tvm
diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h
index d1c5ca1..6f8ac69 100644
--- a/include/tvm/relay/expr_functor.h
+++ b/include/tvm/relay/expr_functor.h
@@ -233,6 +233,189 @@ class ExprMutator
};
/*!
+ * \brief A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
+ *
+ * MixedModeVisitor treats Expr as dataflow graph, and visits in post-DFS order
+ *
+ * MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
+ * recursion to traverse most forms of the IR, but under the hood it expands
nested dataflow regions
+ * of the graph and processes them iteratatively to prevent stack overflows
+ */
+class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
+ public:
+ /*! \brief The constructor of MixedModeVisitor
+ * \param visit_limit The number of times to allow visitation to a node.
Usually 1, ocassionally
+ * higher (i.e., 2 for dead code elimiation), limited to 10 as a sanity
check.
+ */
+ explicit MixedModeVisitor(int visit_limit = 1);
+
+ /*!
+ * \brief VisitExpr is finalized to preserve call expansion of dataflow
regions
+ */
+ void VisitExpr(const Expr& expr) final;
+ void VisitExpr_(const CallNode* op) override;
+ void VisitExpr_(const TupleNode* op) override;
+ void VisitExpr_(const TupleGetItemNode* op) override;
+
+ protected:
+ /*!
+ * \brief A function to apply when reaching a leaf of the graph
non-recursively
+ */
+ virtual void VisitLeaf(const Expr& expr);
+ /*!
+ * \brief A function to determine if an expression has already been visited
or needs to be
+ * re-visited
+ */
+ virtual bool CheckVisited(const Expr& expr);
+ /*!
+ * \brief The max number of times to visit a node
+ */
+ size_t visit_limit_;
+};
+
+/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
+ *
+ * MixedModeMutator treats Expr as dataflow graph, and only Rewrites each Expr
once.
+ * The mutated results are memoized in a map and reused so that
+ * local transformation on the dataflow preserves the graph structure.
+ *
+ * MixedModeMutator provides the same recursive API as ExprMutator, and uses
+ * recursion to traverse most forms of the IR, but under the hood it expands
nested dataflow regions
+ * of the graph and processes them iteratatively to prevent stack overflows
+ *
+ * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and
non-recursive behavior.
+ */
+class MixedModeMutator : public ::tvm::relay::ExprMutator {
+ public:
+ Expr VisitExpr(const Expr& expr) final;
+ virtual Expr DispatchVisitExpr(const Expr& expr);
+ Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
+ Expr VisitExpr_(const CallNode* call_node) final { return
Rewrite(call_node); };
+ Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
+ /*!
+ * \brief Users should override Rewrite_ methods to implement their pass.
Rewrite_ functions will be
+ * able to rewrite the op only with data about the original node `pre` and
the same node with
+ * modified inputs `post` and should not recurse.
+ *
+ * \param pre The expression node before rewriting.
+ * \param post The expression with rewritten inputs.
+ */
+ virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;}
+ virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; }
+ virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) {
return post; }
+
+ protected:
+ /*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to
get a `post` node with
+ * changed inputs.
+ */
+ template <typename T>
+ Expr Rewrite(const T* op) {
+ Expr post = ExprMutator::VisitExpr_(op);
+ return Rewrite_(op, post);
+ }
+
+ virtual void VisitLeaf(const Expr& expr);
+ virtual bool CheckVisited(const Expr& expr);
+};
+
+#define RELAY_EXPR_REWRITER_DISPATCH(OP)
\
+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const
Expr& post) { \
+ return self->Rewrite_(static_cast<const OP*>(n.get()), post);
\
+ });
+
+#define EXPR_REWRITER_REWRITE_DEFAULT \
+ { return post; }
+
+/*! \brief A non-iterating Expression Rewriter
+ *
+ * ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS
order.
+ *
+ * The expectation is that ExprRewriter objects will be passed to
PostOrderRewrite, which will
+ * non-recursively unroll the graph and call Rewriting on inputs. It will then
pass the original
+ * node, called `pre`, and a node recreated with any alterned inputs, called
`post`, to the
+ * ExprRewriter. The ExprRewriter can then use the information in those two
nodes to do more complex
+ * graph rewriting.
+ */
+class ExprRewriter {
+ private:
+ using TSelf = ExprRewriter;
+ using FType = tvm::NodeFunctor<Expr(const ObjectRef& n, TSelf* self, const
Expr& post)>;
+
+ public:
+ /*! \brief virtual destructor */
+ virtual ~ExprRewriter() {}
+ /*!
+ * \brief Same as call.
+ * \param pre The expression node before rewriting.
+ * \param post The expression node with rewritten inputs.
+ * \return The result of the call
+ */
+ Expr operator()(const Expr& pre, const Expr& post) {
+ return Rewrite(pre, post);
+ }
+ /*!
+ * \brief The functor call.
+ * \param pre The expression node before rewriting.
+ * \param post The expression node with rewritten inputs.
+ * \return The result of the call
+ */
+ virtual Expr Rewrite(const Expr& pre, const Expr& post) {
+ CHECK(pre.defined());
+ static FType vtable = InitVTable();
+ return vtable(pre, this, post);
+ }
+ // Functions that can be overriden by subclass, should not recurse
+ virtual Expr Rewrite_(const VarNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const GlobalVarNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const ConstantNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const TupleNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const FunctionNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const CallNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const LetNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const IfNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const OpNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const TupleGetItemNode* pre,
+ const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const RefCreateNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const RefReadNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const RefWriteNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const ConstructorNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+ virtual Expr Rewrite_(const MatchNode* pre, const Expr& post)
EXPR_REWRITER_REWRITE_DEFAULT;
+
+ private:
+ // initialize the vtable.
+ static FType InitVTable() {
+ FType vtable;
+ // Set dispatch
+ RELAY_EXPR_REWRITER_DISPATCH(ConstantNode);
+ RELAY_EXPR_REWRITER_DISPATCH(TupleNode);
+ RELAY_EXPR_REWRITER_DISPATCH(VarNode);
+ RELAY_EXPR_REWRITER_DISPATCH(GlobalVarNode);
+ RELAY_EXPR_REWRITER_DISPATCH(FunctionNode);
+ RELAY_EXPR_REWRITER_DISPATCH(CallNode);
+ RELAY_EXPR_REWRITER_DISPATCH(LetNode);
+ RELAY_EXPR_REWRITER_DISPATCH(IfNode);
+ RELAY_EXPR_REWRITER_DISPATCH(OpNode);
+ RELAY_EXPR_REWRITER_DISPATCH(TupleGetItemNode);
+ RELAY_EXPR_REWRITER_DISPATCH(RefCreateNode);
+ RELAY_EXPR_REWRITER_DISPATCH(RefReadNode);
+ RELAY_EXPR_REWRITER_DISPATCH(RefWriteNode);
+ RELAY_EXPR_REWRITER_DISPATCH(ConstructorNode);
+ RELAY_EXPR_REWRITER_DISPATCH(MatchNode);
+ return vtable;
+ }
+};
+
+/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
+ *
+ * PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS
order and calls the
+ * ExprRewriter's Rewrite functions on nodes once their inputs are rewritten.
At each rewrite call,
+ * PostOrderRewrite provides the original node and the node with altered
inputs for use by the
+ * ExprRewriter.
+ */
+Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
+
+/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc
index 6132532..a86faeb 100644
--- a/src/relay/analysis/util.cc
+++ b/src/relay/analysis/util.cc
@@ -330,7 +330,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars")
*/
std::unordered_map<const Object*, size_t>
GetExprRefCount(const Expr& body) {
- class ExprRefCounter : private ExprVisitor {
+ class ExprRefCounter : private MixedModeVisitor {
public:
std::unordered_map<const Object*, size_t>
Get(const Expr& body) {
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index 11e85d5..cb5d06f 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -29,8 +29,162 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
+#include <stack>
+
namespace tvm {
namespace relay {
+/*!
+ * \brief A function to iteratively traverse dataflow regions of a graph
+ *
+ * ExpandDataflow manually manages a stack and performs DFS to determine the
processing
+ * order of nodes in an input graph.
+ *
+ * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the
arguments to that node
+ * need to be processed via fcheck_visited. If so, the function pushes those
arguments to the stack
+ * and continues iteratively to process the top of the stack. When it finds a
node that doesn't
+ * match the dataflow types, or a node who's inputs have all been processed,
it visits the current
+ * leaf via fvisit_leaf.
+ *
+ * This function should be used internally to other classes to implement
mixed-mode traversals. The
+ * expectation is that fvisit_leaf will perform recursive analysis within
mixed-mode traversal if it
+ * hits a non-dataflow node.
+ *
+ * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
+ */
+template <typename FCheckVisited, typename FVisitLeaf>
+void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf
fvisit_leaf) {
+ std::stack<std::pair<Expr, bool>> stack;
+ auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
+ // The second state of the stack indicate whether the child has been
+ // expanded in the pre-order.
+ // NOTE: function will be inlined.
+ if (!fcheck_visited(expr)) {
+ stack.push({expr, false});
+ }
+ };
+ fpush_to_stack(expr);
+ while (stack.size() > 0) {
+ auto node = stack.top().first;
+ if (fcheck_visited(node)) {
+ // if this node was visited through another path
+ // after being added to the stack ignore it.
+ stack.pop();
+ } else if (stack.top().second) {
+ // all the children have already been expanded.
+ // we can just run post order visit on it.
+ fvisit_leaf(node);
+ stack.pop();
+ } else if (const CallNode* op = node.as<CallNode>()) {
+ // mark expanded = true
+ stack.top().second = true;
+ // push the children to the stack in reverse order
+ // to match recursive processing order
+ for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
+ fpush_to_stack(*it);
+ }
+ fpush_to_stack(op->op);
+ } else if (const TupleNode* op = node.as<TupleNode>()) {
+ stack.top().second = true;
+ // push the children to the stack in reverse order
+ // to match recursive processing order
+ for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
+ fpush_to_stack(*it);
+ }
+ } else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
+ stack.top().second = true;
+ fpush_to_stack(op->tuple);
+ } else {
+ // No need to expand the children directly run visit.
+ fvisit_leaf(node);
+ stack.pop();
+ }
+ }
+}
+
+MixedModeVisitor::MixedModeVisitor(int visit_limit) {
+ CHECK(visit_limit > 0) << "Dataflow visit limit must be greater than 0";
+ CHECK(visit_limit < 10) << "Dataflow visit limit must be less than 10";
+ visit_limit_ = visit_limit;
+}
+
+void MixedModeVisitor::VisitLeaf(const Expr& expr) {
+ if (visit_counter_[expr.get()] < visit_limit_) {
+ ExprFunctor::VisitExpr(expr);
+ }
+ visit_counter_[expr.get()]++;
+}
+
+bool MixedModeVisitor::CheckVisited(const Expr& expr) {
+ if (visit_counter_[expr.get()] < visit_limit_) {
+ return false;
+ } else {
+ visit_counter_[expr.get()]++;
+ return true;
+ }
+}
+
+void MixedModeVisitor::VisitExpr(const Expr& expr) {
+ auto fcheck_visited = [this](const Expr& expr) { return
this->CheckVisited(expr); };
+ auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr);
};
+ if (visit_counter_[expr.get()] < visit_limit_) {
+ ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
+ }
+}
+
+// Overwrite the VisitExpr so we don't recurse for dataflow nodes
+void MixedModeVisitor::VisitExpr_(const CallNode* op) {}
+
+// Overwrite the VisitExpr so we don't recurse for dataflow nodes
+void MixedModeVisitor::VisitExpr_(const TupleNode* op) {}
+
+// Overwrite the VisitExpr so we don't recurse for dataflow nodes
+void MixedModeVisitor::VisitExpr_(const TupleGetItemNode* op) {}
+
+void MixedModeMutator::VisitLeaf(const Expr& expr) {
+ if (!memo_.count(expr)) {
+ this->DispatchVisitExpr(expr);
+ }
+}
+
+bool MixedModeMutator::CheckVisited(const Expr& expr) {
+ if (memo_.count(expr)) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) {
+ return ExprMutator::VisitExpr(expr);
+}
+
+Expr MixedModeMutator::VisitExpr(const Expr& expr) {
+ auto fcheck_visited = [this](const Expr& expr) { return
this->CheckVisited(expr); };
+ auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr);
};
+ if (memo_.count(expr)) {
+ return memo_[expr];
+ } else {
+ ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
+ Expr ret = this->DispatchVisitExpr(expr);
+ memo_[expr] = ret;
+ return ret;
+ }
+}
+
+class PostOrderRewriter : public MixedModeMutator {
+ public:
+ explicit PostOrderRewriter(ExprRewriter* rewriter) : rewriter_(rewriter) {}
+ Expr DispatchVisitExpr(const Expr& expr) final {
+ auto post = ExprFunctor::VisitExpr(expr);
+ return rewriter_->Rewrite(expr, post);
+ }
+ protected:
+ ExprRewriter* rewriter_;
+};
+
+Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter) {
+ return PostOrderRewriter(rewriter).VisitExpr(expr);
+}
Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr);
@@ -211,12 +365,12 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
for (const Clause& p : m->clauses) {
clauses.push_back(VisitClause(p));
}
- return Match(VisitExpr(m->data), clauses, m->complete);
+ return Match(Mutate(m->data), clauses, m->complete);
}
Clause ExprMutator::VisitClause(const Clause& c) {
Pattern p = VisitPattern(c->lhs);
- return Clause(p, VisitExpr(c->rhs));
+ return Clause(p, Mutate(c->rhs));
}
Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
diff --git a/src/relay/transforms/canonicalize_ops.cc
b/src/relay/transforms/canonicalize_ops.cc
index bcb7f9d..97a128d 100644
--- a/src/relay/transforms/canonicalize_ops.cc
+++ b/src/relay/transforms/canonicalize_ops.cc
@@ -32,12 +32,12 @@
namespace tvm {
namespace relay {
-class BiasAddSimplifier : public ExprMutator {
+class BiasAddSimplifier : public ExprRewriter {
public:
BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {}
- Expr VisitExpr_(const CallNode* n) {
- auto new_n = ExprMutator::VisitExpr_(n);
+ Expr Rewrite_(const CallNode* n, const Expr& post) override {
+ auto new_n = post;
if (n->op == bias_add_op_) {
Call call = Downcast<Call>(new_n);
CHECK_EQ(call->args.size(), 2);
@@ -63,7 +63,8 @@ class BiasAddSimplifier : public ExprMutator {
};
Expr CanonicalizeOps(const Expr& e) {
- return BiasAddSimplifier().Mutate(e);
+ auto rewriter = BiasAddSimplifier();
+ return PostOrderRewrite(e, &rewriter);
}
namespace transform {
diff --git a/src/relay/transforms/dead_code.cc
b/src/relay/transforms/dead_code.cc
index f4058b2..a0d093f 100644
--- a/src/relay/transforms/dead_code.cc
+++ b/src/relay/transforms/dead_code.cc
@@ -92,7 +92,7 @@ class Eliminator : private ExprMutator {
};
// calculate the dependency graph from expression
-class CalcDep : private ExprVisitor {
+class CalcDep : protected MixedModeVisitor {
public:
static Expr Eliminate(const Expr& e, bool inline_once) {
FindDef fd;
@@ -104,11 +104,14 @@ class CalcDep : private ExprVisitor {
}
private:
- explicit CalcDep(const VarMap<Expr>& expr_map) : expr_map_(expr_map) { }
+ explicit CalcDep(const VarMap<Expr>& expr_map)
+ : MixedModeVisitor(2), expr_map_(expr_map) {}
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
- void VisitExpr(const Expr& e) final {
+ using MixedModeVisitor::VisitExpr_;
+
+ void VisitLeaf(const Expr& e) final {
visit_counter_[e.get()]++;
// The dce code seprate variable into three parts:
// used 0 times (remove)
diff --git a/src/relay/transforms/fast_math.cc
b/src/relay/transforms/fast_math.cc
index 898f760..861566f 100644
--- a/src/relay/transforms/fast_math.cc
+++ b/src/relay/transforms/fast_math.cc
@@ -31,20 +31,19 @@
namespace tvm {
namespace relay {
-class FastMathMutator : public ExprMutator {
+class FastMathMutator : public ExprRewriter {
public:
FastMathMutator()
: exp_op_(Op::Get("exp")),
tanh_op_(Op::Get("tanh")) {}
- Expr VisitExpr_(const CallNode* n) {
- auto new_n = ExprMutator::VisitExpr_(n);
- if (n->op == exp_op_) {
- return FastExp(new_n.as<CallNode>()->args[0]);
- } else if (n->op == tanh_op_) {
- return FastTanh(new_n.as<CallNode>()->args[0]);
+ Expr Rewrite_(const CallNode* pre, const Expr& post) override {
+ if (pre->op == exp_op_) {
+ return FastExp(post.as<CallNode>()->args[0]);
+ } else if (pre->op == tanh_op_) {
+ return FastTanh(post.as<CallNode>()->args[0]);
}
- return new_n;
+ return post;
}
private:
@@ -56,7 +55,8 @@ class FastMathMutator : public ExprMutator {
};
Expr FastMath(const Expr& e) {
- return FastMathMutator().Mutate(e);
+ auto rewriter = FastMathMutator();
+ return PostOrderRewrite(e, &rewriter);
}
namespace transform {
diff --git a/src/relay/transforms/forward_rewrite.cc
b/src/relay/transforms/forward_rewrite.cc
index 1d9d2b6..f01c4fa 100644
--- a/src/relay/transforms/forward_rewrite.cc
+++ b/src/relay/transforms/forward_rewrite.cc
@@ -22,6 +22,7 @@
* \file forward_rewrite.cc
* \brief Apply rewriting rules in a forward fashion.
*/
+#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
@@ -33,32 +34,25 @@ namespace relay {
// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
-class TempRealizer : private ExprMutator {
+class TempRealizer : private MixedModeMutator {
public:
Expr Realize(Expr expr) {
- return VisitExpr(expr);
+ return Mutate(expr);
}
private:
- Expr VisitExpr(const Expr& expr) final {
- auto it = memo_.find(expr);
- if (it != memo_.end()) {
- return it->second;
+ Expr DispatchVisitExpr(const Expr& expr) final {
+ Expr res;
+ if (const auto* temp = expr.as<TempExprNode>()) {
+ res = temp->Realize();
} else {
- Expr res;
- if (const auto* temp = expr.as<TempExprNode>()) {
- res = temp->Realize();
-
- } else {
- res = ExprFunctor::VisitExpr(expr);
- }
- memo_[res] = res;
- return res;
+ res = MixedModeMutator::DispatchVisitExpr(expr);
}
+ return res;
}
};
-class ForwardRewriter : private ExprMutator {
+class ForwardRewriter : private MixedModeMutator {
public:
ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
std::function<ObjectRef(const Call&)> fcontext,
@@ -76,11 +70,11 @@ class ForwardRewriter : private ExprMutator {
// Transform expression.
- Expr Rewrite(Expr expr) {
+ Expr Rewrite(const Expr& expr) {
if (fmulti_ref_trigger_ != nullptr) {
ref_counter_ = GetExprRefCount(expr);
}
- return this->VisitExpr(expr);
+ return realizer_.Realize(this->VisitExpr(expr));
}
private:
@@ -96,15 +90,10 @@ class ForwardRewriter : private ExprMutator {
// internal realizer
TempRealizer realizer_;
- Expr VisitExpr(const Expr& expr) final {
- // by default always realize.
- return realizer_.Realize(ExprMutator::VisitExpr(expr));
- }
-
// Visit and allow non-realized version.
- Expr GetTempExpr(const Expr& expr) {
+ Expr GetTempExpr(const Expr& expr, const Expr& post) {
if (fmulti_ref_trigger_ != nullptr) {
- Expr ret = ExprMutator::VisitExpr(expr);
+ Expr ret = post;
auto it = ref_counter_.find(expr.get());
CHECK(it != ref_counter_.end());
if (it->second > 1) {
@@ -112,13 +101,13 @@ class ForwardRewriter : private ExprMutator {
}
return ret;
} else {
- return ExprMutator::VisitExpr(expr);
+ return post;
}
}
// Automatic fold TupleGetItem.
- Expr VisitExpr_(const TupleGetItemNode* op) final {
- Expr tuple = this->GetTempExpr(op->tuple);
+ Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
+ Expr tuple = this->GetTempExpr(op->tuple,
post.as<TupleGetItemNode>()->tuple);
if (const auto* ptuple = tuple.as<TupleNode>()) {
return ptuple->fields[op->index];
} else {
@@ -130,13 +119,14 @@ class ForwardRewriter : private ExprMutator {
}
}
- Expr VisitExpr_(const TupleNode* op) final {
+ Expr Rewrite_(const TupleNode* op, const Expr& post) final {
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
- for (auto field : op->fields) {
- auto new_field = this->GetTempExpr(field);
+ const auto* post_node = post.as<TupleNode>();
+ for (size_t i = 0; i < op->fields.size(); ++i) {
+ auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]);
fields.push_back(new_field);
- all_fields_unchanged &= new_field.same_as(field);
+ all_fields_unchanged &= new_field.same_as(op->fields[i]);
}
if (all_fields_unchanged) {
@@ -146,7 +136,7 @@ class ForwardRewriter : private ExprMutator {
}
}
- Expr VisitExpr_(const CallNode* call_node) final {
+ Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
const Call& ref_call = GetRef<Call>(call_node);
PackedFunc frewrite;
if (rewrite_func_) {
@@ -155,17 +145,17 @@ class ForwardRewriter : private ExprMutator {
CHECK(rewrite_map_);
frewrite = rewrite_map_->get(call_node->op, nullptr);
}
-
- auto new_op = this->Mutate(call_node->op);
+ const auto* post_node = post.as<CallNode>();
+ auto new_op = post_node->op;
bool unchanged = call_node->op.same_as(new_op);
Array<Expr> call_args;
- for (auto arg : call_node->args) {
- Expr new_arg = this->GetTempExpr(arg);
+ for (size_t i = 0; i < call_node->args.size(); ++i) {
+ Expr new_arg = this->GetTempExpr(call_node->args[i], post_node->args[i]);
if (frewrite == nullptr) {
new_arg = realizer_.Realize(new_arg);
}
- unchanged &= new_arg.same_as(arg);
+ unchanged &= new_arg.same_as(call_node->args[i]);
call_args.push_back(new_arg);
}
// try to rewrite.
diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h
index 6a69cf9..56b0645 100644
--- a/src/relay/transforms/pass_util.h
+++ b/src/relay/transforms/pass_util.h
@@ -35,14 +35,6 @@ namespace tvm {
namespace relay {
/*!
- * \brief Get reference counter of each internal ExprNode in body.
- * \param body The body expression.
- * \return The reference count mapping.
- */
-std::unordered_map<const Object*, size_t>
-GetExprRefCount(const Expr& body);
-
-/*!
* \brief Check if expr is positive constant.
* \param expr The expression to be checked.
* \return Whether all elements of expr is positive constant.
diff --git a/tests/cpp/relay_build_module_test.cc
b/tests/cpp/relay_build_module_test.cc
index fa94271..f5658fb 100644
--- a/tests/cpp/relay_build_module_test.cc
+++ b/tests/cpp/relay_build_module_test.cc
@@ -161,6 +161,23 @@ TEST(Relay, BuildModule) {
}
}
+TEST(Relay, GetExprRefCount) {
+ auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32));
+ auto a = relay::Var("a", tensor_type);
+ auto add_op = relay::Op::Get("add");
+ auto relu_op = relay::Op::Get("nn.relu");
+ auto x = relay::Call(relu_op, {a}, tvm::Attrs(), {});
+ auto y = relay::Call(relu_op, {x}, tvm::Attrs(), {});
+ auto z = relay::Call(add_op, {y, x}, tvm::Attrs(), {});
+ auto ref_count = GetExprRefCount(z);
+ CHECK(ref_count[a.get()] == 1);
+ CHECK(ref_count[relu_op.get()] == 2);
+ CHECK(ref_count[add_op.get()] == 1);
+ CHECK(ref_count[x.get()] == 2);
+ CHECK(ref_count[y.get()] == 1);
+ CHECK(ref_count[z.get()] == 1);
+}
+
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";