This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 24c2f5c make simplify inference iterative (#8246)
24c2f5c is described below
commit 24c2f5c1a893d7b1a42301a7ad671fbe6788fc94
Author: Matthew Brookhart <[email protected]>
AuthorDate: Mon Jun 14 17:28:48 2021 -0600
make simplify inference iterative (#8246)
---
src/relay/transforms/simplify_inference.cc | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/src/relay/transforms/simplify_inference.cc
b/src/relay/transforms/simplify_inference.cc
index 7e58766..846bc08 100644
--- a/src/relay/transforms/simplify_inference.cc
+++ b/src/relay/transforms/simplify_inference.cc
@@ -178,7 +178,7 @@ Expr L2NormToInferUnpack(const Attrs attrs, Expr data) {
return Divide(data, sqrt);
}
-class InferenceSimplifier : public ExprMutator {
+class InferenceSimplifier : public MixedModeMutator {
public:
InferenceSimplifier()
: batch_norm_op_(Op::Get("nn.batch_norm")),
@@ -188,8 +188,7 @@ class InferenceSimplifier : public ExprMutator {
group_norm_op_(Op::Get("nn.group_norm")),
l2_norm_op_(Op::Get("nn.l2_normalize")) {}
- Expr VisitExpr_(const TupleGetItemNode* n) final {
- Expr new_e = ExprMutator::VisitExpr_(n);
+ Expr Rewrite_(const TupleGetItemNode* n, const Expr& new_e) final {
const auto* new_n = new_e.as<TupleGetItemNode>();
if (new_n->index != 0) {
return new_e;
@@ -205,8 +204,7 @@ class InferenceSimplifier : public ExprMutator {
return new_e;
}
- Expr VisitExpr_(const CallNode* n) {
- auto new_n = ExprMutator::VisitExpr_(n);
+ Expr Rewrite_(const CallNode* n, const Expr& new_n) {
if (n->op == batch_norm_op_) {
ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
} else if (n->op == layer_norm_op_) {