This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e997185 [Relay] Change some passes to mix mode (#6695)
e997185 is described below
commit e997185795480d24075a2e7d3fa42ccec425b5f6
Author: lixiaoquan <[email protected]>
AuthorDate: Fri Oct 16 23:47:27 2020 +0800
[Relay] Change some passes to mix mode (#6695)
---
src/relay/analysis/util.cc | 8 ++++++--
src/relay/analysis/well_formed.cc | 16 +++++++---------
src/relay/ir/expr_functor.cc | 4 +++-
src/relay/transforms/de_duplicate.cc | 6 ++++--
src/relay/transforms/fold_constant.cc | 32 ++++++++++++++++----------------
5 files changed, 36 insertions(+), 30 deletions(-)
diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc
index 59ce01c..edf8fb6 100644
--- a/src/relay/analysis/util.cc
+++ b/src/relay/analysis/util.cc
@@ -71,7 +71,7 @@ class TypeVarTVisitor : public TypeVisitor {
InsertionSet<TypeVar>* bound_type_vars_;
};
-class TypeVarEVisitor : private ExprVisitor {
+class TypeVarEVisitor : private MixedModeVisitor {
public:
explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {}
@@ -131,6 +131,8 @@ class TypeVarEVisitor : private ExprVisitor {
return CollectAll();
}
+ using MixedModeVisitor::VisitExpr_;
+
void VisitExpr_(const FunctionNode* f) final {
for (const auto& tp : f->type_params) {
type_vars_.Insert(tp);
@@ -159,7 +161,7 @@ class TypeVarEVisitor : private ExprVisitor {
const IRModule& mod_;
};
-class VarVisitor : protected ExprVisitor, protected PatternVisitor {
+class VarVisitor : protected MixedModeVisitor, protected PatternVisitor {
public:
Array<Var> Free(const Expr& expr) {
this->VisitExpr(expr);
@@ -204,6 +206,8 @@ class VarVisitor : protected ExprVisitor, protected
PatternVisitor {
vars_.Insert(v);
}
+ using MixedModeVisitor::VisitExpr_;
+
void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
void VisitExpr_(const FunctionNode* op) final {
diff --git a/src/relay/analysis/well_formed.cc
b/src/relay/analysis/well_formed.cc
index 3e409d1..5abbbc9 100644
--- a/src/relay/analysis/well_formed.cc
+++ b/src/relay/analysis/well_formed.cc
@@ -32,7 +32,7 @@ namespace tvm {
namespace relay {
//! brief make sure each Var is bound at most once in a scope.
-class WellFormedChecker : private ExprVisitor, PatternVisitor {
+class WellFormedChecker : private MixedModeVisitor, PatternVisitor {
public:
Optional<DiagnosticContext> diag_ctx;
Span occurs_in;
@@ -79,6 +79,8 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor
{
total_bound.insert(v);
}
+ using MixedModeVisitor::VisitExpr_;
+
void VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
if (current_bound.count(v) == 0) {
@@ -126,7 +128,7 @@ class WellFormedChecker : private ExprVisitor,
PatternVisitor {
// CHECK(call->attrs.defined());
CHECK(call->type_args.defined());
- ExprVisitor::VisitExpr_(call);
+ MixedModeVisitor::VisitExpr_(call);
}
void VisitClause(const Clause& c) final {
@@ -139,18 +141,14 @@ class WellFormedChecker : private ExprVisitor,
PatternVisitor {
void VisitVar(const Var& v) final { Bound(v); }
- void VisitExpr(const Expr& e) final {
+ public:
+ bool CheckWellFormed(const Expr& e) {
if (auto v = e.as<VarNode>()) {
VisitExpr_(v);
} else {
// this->occurs_in = e->span;
- ExprVisitor::VisitExpr(e);
+ VisitExpr(e);
}
- }
-
- public:
- bool CheckWellFormed(const Expr& e) {
- this->VisitExpr(e);
return well_formed;
}
};
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index cbc41d2..a09179b 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -517,10 +517,12 @@
TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit").set_body_typed([](Expr ex
});
// Implement bind.
-class ExprBinder : public ExprMutator, PatternMutator {
+class ExprBinder : public MixedModeMutator, PatternMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) :
args_map_(args_map) {}
+ using MixedModeMutator::VisitExpr_;
+
Expr VisitExpr_(const LetNode* op) final {
CHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in
let";
return ExprMutator::VisitExpr_(op);
diff --git a/src/relay/transforms/de_duplicate.cc
b/src/relay/transforms/de_duplicate.cc
index d90e5c5..8c62fe6 100644
--- a/src/relay/transforms/de_duplicate.cc
+++ b/src/relay/transforms/de_duplicate.cc
@@ -31,7 +31,7 @@ namespace tvm {
namespace relay {
Expr DeDup(const Expr& e) {
- class DeDupMutator : public TypeMutator, public ExprMutator, public
PatternMutator {
+ class DeDupMutator : public TypeMutator, public MixedModeMutator, public
PatternMutator {
public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVar(tv->name_hint, tv->kind);
@@ -47,12 +47,14 @@ Expr DeDup(const Expr& e) {
return ret;
}
- Expr VisitExpr(const Expr& e) final {
+ Expr DispatchVisitExpr(const Expr& e) final {
auto ret = ExprMutator::VisitExpr(e);
ret->checked_type_ = e->checked_type_;
return ret;
}
+ using MixedModeMutator::VisitExpr_;
+
Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return rename_.count(v) != 0 ? rename_.at(v) : v;
diff --git a/src/relay/transforms/fold_constant.cc
b/src/relay/transforms/fold_constant.cc
index 660aff2..8d2cba0 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -75,7 +75,7 @@
TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec
// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
-class ConstantFolder : public ExprMutator {
+class ConstantFolder : public MixedModeMutator {
public:
explicit ConstantFolder(IRModule module)
: module_(module),
@@ -89,6 +89,8 @@ class ConstantFolder : public ExprMutator {
cast_op_(Op::Get("cast")),
ndarray_size_op_(Op::Get("ndarray_size")) {}
+ using MixedModeMutator::VisitExpr_;
+
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
@@ -118,7 +120,7 @@ class ConstantFolder : public ExprMutator {
}
}
- Expr VisitExpr_(const CallNode* call) final {
+ Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (inside_primitive) {
return GetRef<Expr>(call);
}
@@ -127,26 +129,25 @@ class ConstantFolder : public ExprMutator {
std::unordered_set<std::string> skip_list{"zeros_like", "ones_like",
"full_like", "full"};
auto origin_args = call->args;
- Expr res = ExprMutator::VisitExpr_(call);
- call = res.as<CallNode>();
+ call = post.as<CallNode>();
// We don't constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
- if (call->args.size() == 0) return res;
+ if (call->args.size() == 0) return post;
const OpNode* op = call->op.as<OpNode>();
- if (op == nullptr) return res;
+ if (op == nullptr) return post;
if (skip_list.count(op->name)) {
- return res;
+ return post;
}
// skip stateful ops.
- if (op_stateful.get(GetRef<Op>(op), false)) return res;
+ if (op_stateful.get(GetRef<Op>(op), false)) return post;
// Try to evaluate shape_of op
if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) {
- return EvaluateShapeOf(res, origin_args, call->attrs);
+ return EvaluateShapeOf(post, origin_args, call->attrs);
}
if (call->op == ndarray_size_op_) {
- return EvaluateNdarraySize(res, origin_args, call->attrs);
+ return EvaluateNdarraySize(post, origin_args, call->attrs);
}
// We should think about potentially constant evaluation over these ops
too.
@@ -162,19 +163,18 @@ class ConstantFolder : public ExprMutator {
}
}
if (all_const_args) {
- return ConstEvaluate(res);
+ return ConstEvaluate(post);
} else {
- return res;
+ return post;
}
}
- Expr VisitExpr_(const TupleGetItemNode* op) final {
- Expr res = ExprMutator::VisitExpr_(op);
- op = res.as<TupleGetItemNode>();
+ Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
+ op = post.as<TupleGetItemNode>();
if (const auto* tuple = op->tuple.as<TupleNode>()) {
return tuple->fields[op->index];
} else {
- return res;
+ return post;
}
}