This is an automated email from the ASF dual-hosted git repository.

marisa pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new b6db7e3  [RELAY] Basic block normal form (#6152)
b6db7e3 is described below

commit b6db7e33f7a589fdf5dce062d8488ce2f83a3727
Author: Haibin Lin <linhaibin.e...@gmail.com>
AuthorDate: Mon Aug 3 21:38:51 2020 -0700

    [RELAY] Basic block normal form (#6152)
    
    * initial commit
    
    * refactor utils
    
    * add util
    
    * revert anf test
    
    * update test
    
    * fix logging
    
    * fix scope bug
    
    * complete tests
    
    * remove logging
    
    * revert refactoring
    
    * add one more test case
    
    * fix missing var binding
    
    * fix test
    
    * fix lint
    
    * fix lint
    
    * fix clang-format
    
    * fix lint
    
    * fix lint
    
    * commit missing code
    
    * add analysis api
    
    * fix lint
    
    * fix lint
    
    * lint
    
    * add test for func
    
    * address CR
    
    * fix typo
    
    * fix return type
    
    * fix lint
    
    * refactor classes
    
    * fix lint
    
    * remove prints
    
    * address comments
    
    Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-138.ec2.internal>
---
 include/tvm/relay/analysis.h                       |   9 +
 include/tvm/relay/transform.h                      |  15 +
 python/tvm/relay/analysis/analysis.py              |  15 +
 python/tvm/relay/transform/transform.py            |  15 +
 src/relay/analysis/dependency_graph.cc             |   4 +
 src/relay/backend/build_module.cc                  |   1 +
 src/relay/backend/vm/compiler.cc                   |   1 +
 src/relay/transforms/let_list.h                    |   6 +
 src/relay/transforms/pass_util.h                   |  88 ++++
 src/relay/transforms/to_a_normal_form.cc           | 299 ++++++-------
 src/relay/transforms/to_basic_block_normal_form.cc | 104 +++++
 .../relay/test_analysis_basic_block_normal_form.py | 206 +++++++++
 .../relay/test_pass_to_basic_block_normal_form.py  | 482 +++++++++++++++++++++
 13 files changed, 1098 insertions(+), 147 deletions(-)

diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h
index 8eda7dd..c65bb41 100644
--- a/include/tvm/relay/analysis.h
+++ b/include/tvm/relay/analysis.h
@@ -67,6 +67,15 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);
 TVM_DLL bool ConstantCheck(const Expr& e);
 
 /*!
+ * \brief Check whether an expression is in the basic block normal form.
+ *
+ * \param e the expression.
+ *
+ * \return whether the expression is in the basic block normal form.
+ */
+TVM_DLL bool BasicBlockNormalFormCheck(const Expr& e);
+
+/*!
  * \brief Check that each Var is only bound once.
  *
  * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index d995301..cf14feb 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -117,6 +117,21 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
 TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
 
 /*!
+ * \brief Turn an expression to Basic Block Normal Form.
+ *
+ * We define a block as a group of expressions implied by the scope structure.
+ *
+ * Each graph node can only belong to a single block.
+ *
+ * For any value that is being used in multiple blocks, it has to be referred
+ * by a Var which is defined in a block, whose scope is the least common 
ancestor
+ * of blocks this value is used.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass ToBasicBlockNormalForm();
+
+/*!
  * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal 
Form (ANF).
  *
  * It will turn an expression that is in a graph form (with sharing implicit),
diff --git a/python/tvm/relay/analysis/analysis.py 
b/python/tvm/relay/analysis/analysis.py
index 632af46..165e39a 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -106,6 +106,21 @@ def check_constant(expr):
     """
     return _ffi_api.check_constant(expr)
 
+def check_basic_block_normal_form(expr):
+    """Check whether an expression is in the basic block form
+
+    Parameters
+    ----------
+    expr : tvm.relay.Expr
+        The input expression
+
+    Returns
+    -------
+    result : bool
+        Whether the expression is in the basic block form.
+    """
+    return _ffi_api.check_basic_block_normal_form(expr)
+
 
 def free_vars(expr):
     """Get free Vars from expression expr in Post DFS order.
diff --git a/python/tvm/relay/transform/transform.py 
b/python/tvm/relay/transform/transform.py
index 7db0687..3abc382 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -488,6 +488,21 @@ def ToANormalForm():
     """
     return _ffi_api.ToANormalForm()
 
+def ToBasicBlockNormalForm():
+    """Turn an expression to Basic Block Normal Form.
+    We define a block as a group of expressions implied by the scope structure.
+    Each graph node can only belong to a single block.
+    For any value that is being used in multiple blocks, it has to be referred
+    by a Var which is defined in a block, whose scope is the least common 
ancestor
+    of blocks this value is used.
+
+    Returns
+    -------
+    ret: tvm.transform.Pass
+        The registered pass that transforms an expression into Basic Block 
Normal Form.
+    """
+    return _ffi_api.ToBasicBlockNormalForm()
+
 
 def ToCPS(expr, mod=None):
     """
diff --git a/src/relay/analysis/dependency_graph.cc 
b/src/relay/analysis/dependency_graph.cc
index 5db8338..de61800 100644
--- a/src/relay/analysis/dependency_graph.cc
+++ b/src/relay/analysis/dependency_graph.cc
@@ -137,6 +137,9 @@ class DependencyGraph::Creator : private 
ExprFunctor<void(const Expr& e)> {
     DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(f)];
     DependencyGraph::Node* b = NewNode(true);
     Depend(n, b);
+    for (const auto& p : f->params) {
+      Depend(b, p);
+    }
     Depend(b, f->body);
     graph_.post_dfs_order.push_back(b);
   }
@@ -145,6 +148,7 @@ class DependencyGraph::Creator : private 
ExprFunctor<void(const Expr& e)> {
     DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(l)];
     DependencyGraph::Node* b = NewNode(true);
     Depend(n, b);
+    Depend(b, l->var);
     Depend(b, l->value);
     Depend(b, l->body);
     graph_.post_dfs_order.push_back(b);
diff --git a/src/relay/backend/build_module.cc 
b/src/relay/backend/build_module.cc
index 4d84c48..533619e 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -253,6 +253,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     Array<Pass> pass_seqs;
     Array<runtime::String> entry_functions{"main"};
     pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
+    pass_seqs.push_back(transform::ToBasicBlockNormalForm());
 
     // Run all dialect legalization passes.
     pass_seqs.push_back(relay::qnn::transform::Legalize());
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index b811911..a98f1ef 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -927,6 +927,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, 
const TargetsMap& targe
   Array<Pass> pass_seqs;
   Array<runtime::String> entry_functions{"main"};
   pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
+  pass_seqs.push_back(transform::ToBasicBlockNormalForm());
   // Run all dialect legalization passes.
   pass_seqs.push_back(relay::qnn::transform::Legalize());
 
diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h
index c0e0b3a..c925dc0 100644
--- a/src/relay/transforms/let_list.h
+++ b/src/relay/transforms/let_list.h
@@ -107,6 +107,12 @@ class LetList {
     return ret;
   }
 
+  /*! \brief get the number of let bindings in the let list.
+   *
+   *  \return the let list size.
+   */
+  size_t size() const { return lets_.size(); }
+
   /*! \brief generate an LetList and wrap the result automatically.
    *
    *  \param f a function that generate the unwrapped Expr.
diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h
index 5f58762..50d0fbb 100644
--- a/src/relay/transforms/pass_util.h
+++ b/src/relay/transforms/pass_util.h
@@ -31,6 +31,11 @@
 
 #include <memory>
 #include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "../analysis/dependency_graph.h"
+#include "let_list.h"
 
 namespace tvm {
 namespace relay {
@@ -184,6 +189,89 @@ struct TreeBranchNode : TreeNode<ConditionObjectPtr> {
   ~TreeBranchNode() {}
 };
 
+struct ScopeNode;
+using Scope = std::shared_ptr<ScopeNode>;
+using NodeScopeMap = std::unordered_map<DependencyGraph::Node*, Scope>;
+using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
+
+/* Invariant: when parent is null level is 0
+ * Invariant: when parent is not null level is 1 + parent->level
+ */
+struct ScopeNode {
+  // the level of the scope
+  size_t level;
+  // the parent scope
+  Scope parent;
+  // the corresponding let list which holds all let bindings in the scope
+  std::shared_ptr<LetList> let_list = std::make_shared<LetList>();
+  explicit ScopeNode(const Scope& parent) : level(1 + parent->level), 
parent(parent) {}
+  ScopeNode() : level(0) {}
+};
+
+/*! \brief Calculate the scope of nodes in the dependency graph by least 
common ancestor.
+ *
+ *  \param dg the input dependency graph
+ *  \param expr_scope the output node -> scope mapping for all nodes.
+ *  \param lifted_exprs the output set of expressions whose scope is lifted 
due to dependency
+ */
+std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg);
+
+/*! \brief find the least common ancestor of lhs scope and rhs scope.
+ */
+Scope LCA(Scope lhs, Scope rhs);
+
+/* Special care is needed to handle local recursion.
+ * Fill additionally take a (possibly null) Var argument,
+ * If it is not null, Fill is required to bind the transformed result to that 
var.
+ */
+class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
+ public:
+  static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, 
NodeScopeMap* node_scope);
+
+  // For basic block normal form, bind expressions only if the original 
expression's
+  // scope should be lifted
+  static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg,
+                                     NodeScopeMap* node_scope, ExprSet* 
lifted);
+
+ private:
+  const DependencyGraph& dg_;
+  NodeScopeMap* node_scope_ = nullptr;
+  std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo;
+  // a set of Expressions to include for let bindings. If set to nullptr
+  // all Exprs will be pushed to the let list.
+  ExprSet* include_set_ = nullptr;
+
+  Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* 
include_set)
+      : dg_(dg), node_scope_(node_scope), include_set_(include_set) {}
+
+  Scope GetScope(const Expr& e);
+  Scope GetSubScope(const Expr& e, size_t i);
+
+  Expr VisitExpr(const Expr& e, const Var& v) final;
+  Expr VisitExpr(const Expr& e);
+
+  Expr Atomic(const Expr& e, const Var& v);
+  // Bind expression `now` to var `v` if the original expression is in the 
include set, or if
+  // v is already defined (e.g. coming from a Let expression). Otherwise 
return `now` directly.
+  Expr Compound(const Expr& orig, const Expr& now, const Var& v);
+
+  Expr VisitExpr_(const CallNode* c, const Var& v) final;
+  Expr VisitExpr_(const TupleNode* t, const Var& v) final;
+  Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final;
+  Expr VisitExpr_(const RefCreateNode* r, const Var& v) final;
+  Expr VisitExpr_(const RefReadNode* r, const Var& v) final;
+  Expr VisitExpr_(const RefWriteNode* r, const Var& v) final;
+  Expr VisitExpr_(const IfNode* i, const Var& v) final;
+  Expr VisitExpr_(const FunctionNode* f, const Var& v) final;
+  Expr VisitExpr_(const LetNode* l, const Var& v) final;
+  Expr VisitExpr_(const ConstantNode* c, const Var& v) final;
+  Expr VisitExpr_(const VarNode* vn, const Var& v) final;
+  Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final;
+  Expr VisitExpr_(const OpNode* op, const Var& v) final;
+  Expr VisitExpr_(const ConstructorNode* c, const Var& v) final;
+  Expr VisitExpr_(const MatchNode* m, const Var& v) final;
+};
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_TRANSFORMS_PASS_UTIL_H_
diff --git a/src/relay/transforms/to_a_normal_form.cc 
b/src/relay/transforms/to_a_normal_form.cc
index 8d10242..06e0d56 100644
--- a/src/relay/transforms/to_a_normal_form.cc
+++ b/src/relay/transforms/to_a_normal_form.cc
@@ -36,23 +36,6 @@
 namespace tvm {
 namespace relay {
 
-struct ScopeNode;
-using Scope = std::shared_ptr<ScopeNode>;
-
-/* Invariant: when parent is null level is 0
- *
- * Invariant: when parent is not null level is 1 + parent->level
- */
-struct ScopeNode {
-  size_t level;
-  Scope parent;
-  std::shared_ptr<LetList> ll = std::make_shared<LetList>();
-  explicit ScopeNode(const Scope& parent) : level(1 + parent->level), 
parent(parent) {}
-  ScopeNode() : level(0) {}
-};
-
-Scope ChildScope(const Scope& s) { return std::make_shared<ScopeNode>(s); }
-
 Scope LCA(Scope lhs, Scope rhs) {
   while (lhs != rhs) {
     if (lhs->level > rhs->level) {
@@ -67,10 +50,16 @@ Scope LCA(Scope lhs, Scope rhs) {
   return lhs;
 }
 
-std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const 
DependencyGraph& dg) {
-  std::unordered_map<DependencyGraph::Node*, Scope> expr_scope;
+std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg) {
+  NodeScopeMap expr_scope;
+  ExprSet lifted_exprs;
+  std::unordered_map<DependencyGraph::Node*, Expr> node_to_expr;
+  for (auto expr_node : dg.expr_node) {
+    node_to_expr[expr_node.second] = expr_node.first;
+  }
   bool global_scope_used = false;
   Scope global_scope = std::make_shared<ScopeNode>();
+
   for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); 
++it) {
     DependencyGraph::Node* n = *it;
     auto iit = n->parents.head;
@@ -81,171 +70,187 @@ std::unordered_map<DependencyGraph::Node*, Scope> 
CalcScope(const DependencyGrap
       global_scope_used = true;
     } else {
       s = expr_scope.at(iit->value);
+      const auto original_s = s;
       iit = iit->next;
       for (; iit != nullptr; iit = iit->next) {
         s = LCA(s, expr_scope.at(iit->value));
       }
+      if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) {
+        // filter out exprs whose scope do not matter
+        Expr expr = node_to_expr[n];
+        if (!expr.as<OpNode>()) {
+          lifted_exprs.insert(expr);
+        }
+      }
+    }
+    if (n->new_scope) {
+      auto child_scope = std::make_shared<ScopeNode>(s);
+      expr_scope.insert({n, child_scope});
+    } else {
+      expr_scope.insert({n, s});
     }
-    expr_scope.insert({n, n->new_scope ? ChildScope(s) : s});
   }
   CHECK(global_scope_used);
-  return expr_scope;
+  return std::make_pair(expr_scope, lifted_exprs);
 }
 
-/* Special care is needed to handle local recursion.
- * Fill additionally take a (possibly null) Var argument,
- * If it is not null, Fill is required to bind the transformed result to that 
var.
- */
-class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
- public:
-  static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg,
-                            std::unordered_map<DependencyGraph::Node*, Scope>* 
node_scope) {
-    Fill fi(dg, node_scope);
-    return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
-  }
-
- private:
-  const DependencyGraph& dg_;
-  std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
-  std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo;
+Expr Fill::ToANormalForm(const Expr& e, const DependencyGraph& dg, 
NodeScopeMap* node_scope) {
+  Fill fi(dg, node_scope, nullptr);
+  return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e));
+}
 
-  Fill(const DependencyGraph& dg, std::unordered_map<DependencyGraph::Node*, 
Scope>* node_scope)
-      : dg_(dg), node_scope_(node_scope) {}
+// For basic block normal form, bind expressions only if the original 
expression's scope
+// should be lifted
+Expr Fill::ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg,
+                                  NodeScopeMap* node_scope, ExprSet* lifted) {
+  Fill fi(dg, node_scope, lifted);
+  auto var = fi.VisitExpr(e);
+  return fi.GetScope(e)->let_list->Get(var);
+}
 
-  Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); 
}
+Scope Fill::GetScope(const Expr& e) { return 
node_scope_->at(dg_.expr_node.at(e)); }
 
-  Scope GetSubScope(const Expr& e, size_t i) {
-    DependencyGraph::Node* n = dg_.expr_node.at(e);
-    auto h = n->children.head;
-    while (i != 0) {
-      CHECK(h);
-      --i;
-      h = h->next;
-    }
+Scope Fill::GetSubScope(const Expr& e, size_t i) {
+  DependencyGraph::Node* n = dg_.expr_node.at(e);
+  auto h = n->children.head;
+  while (i != 0) {
     CHECK(h);
-    return node_scope_->at(h->value);
+    --i;
+    h = h->next;
   }
+  CHECK(h);
+  return node_scope_->at(h->value);
+}
 
-  Expr VisitExpr(const Expr& e, const Var& v) final {
-    if (memo.count(e) == 0) {
-      memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, 
v)});
-    } else if (v.defined()) {
-      GetScope(e)->ll->Push(v, memo.at(e));
-    }
-    auto ret = memo.at(e);
-    CHECK(IsAtomic(ret));
-    return ret;
+Expr Fill::VisitExpr(const Expr& e, const Var& v) {
+  if (memo.count(e) == 0) {
+    memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, 
v)});
+  } else if (v.defined()) {
+    GetScope(e)->let_list->Push(v, memo.at(e));
   }
+  auto ret = memo.at(e);
+  // if no include_set is specified, every expression should be atomic.
+  if (include_set_ == nullptr) CHECK(IsAtomic(ret));
+  return ret;
+}
 
-  Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); }
+Expr Fill::VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); }
 
-  Expr Atomic(const Expr& e, const Var& v) { return v.defined() ? 
GetScope(e)->ll->Push(v, e) : e; }
+Expr Fill::Atomic(const Expr& e, const Var& v) {
+  return v.defined() ? GetScope(e)->let_list->Push(v, e) : e;
+}
 
-  Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
-    Var var = v.defined() ? v : Var(String("x"), Type());
-    return GetScope(orig)->ll->Push(var, now);
+// Bind expression `now` to var `v` if the original expression is in the 
include set, or if
+// v is already defined (e.g. coming from a Let expression). Otherwise return 
`now` directly
+Expr Fill::Compound(const Expr& orig, const Expr& now, const Var& v) {
+  Var var = v.defined() ? v : Var(String("x"), Type());
+  bool not_included = include_set_ && include_set_->find(orig) == 
include_set_->end();
+  if (!v.defined() && not_included) {
+    return now;
+  } else {
+    return GetScope(orig)->let_list->Push(var, now);
   }
+}
 
-  Expr VisitExpr_(const CallNode* c, const Var& v) final {
-    Expr e = GetRef<Expr>(c);
-    std::vector<Expr> args;
-    for (const auto& a : c->args) {
-      args.push_back(VisitExpr(a));
-    }
-    return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), 
v);
+Expr Fill::VisitExpr_(const CallNode* c, const Var& v) {
+  Expr e = GetRef<Expr>(c);
+  std::vector<Expr> args;
+  for (const auto& a : c->args) {
+    args.push_back(VisitExpr(a));
   }
+  return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v);
+}
 
-  Expr VisitExpr_(const TupleNode* t, const Var& v) final {
-    Expr e = GetRef<Expr>(t);
-    std::vector<Expr> fields;
-    for (const auto& a : t->fields) {
-      fields.push_back(VisitExpr(a));
-    }
-    return Compound(e, Tuple(fields), v);
+Expr Fill::VisitExpr_(const TupleNode* t, const Var& v) {
+  Expr e = GetRef<Expr>(t);
+  std::vector<Expr> fields;
+  for (const auto& a : t->fields) {
+    fields.push_back(VisitExpr(a));
   }
+  return Compound(e, Tuple(fields), v);
+}
 
-  Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
-    Expr e = GetRef<Expr>(t);
-    return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v);
-  }
+Expr Fill::VisitExpr_(const TupleGetItemNode* t, const Var& v) {
+  Expr e = GetRef<Expr>(t);
+  return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v);
+}
 
-  Expr VisitExpr_(const RefCreateNode* r, const Var& v) final {
-    Expr e = GetRef<Expr>(r);
-    return Compound(e, RefCreate(VisitExpr(r->value)), v);
-  }
+Expr Fill::VisitExpr_(const RefCreateNode* r, const Var& v) {
+  Expr e = GetRef<Expr>(r);
+  return Compound(e, RefCreate(VisitExpr(r->value)), v);
+}
 
-  Expr VisitExpr_(const RefReadNode* r, const Var& v) final {
-    Expr e = GetRef<Expr>(r);
-    return Compound(e, RefRead(VisitExpr(r->ref)), v);
-  }
+Expr Fill::VisitExpr_(const RefReadNode* r, const Var& v) {
+  Expr e = GetRef<Expr>(r);
+  return Compound(e, RefRead(VisitExpr(r->ref)), v);
+}
 
-  Expr VisitExpr_(const RefWriteNode* r, const Var& v) final {
-    Expr e = GetRef<Expr>(r);
-    return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v);
-  }
+Expr Fill::VisitExpr_(const RefWriteNode* r, const Var& v) {
+  Expr e = GetRef<Expr>(r);
+  return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v);
+}
 
-  Expr VisitExpr_(const IfNode* i, const Var& v) final {
-    Expr e = GetRef<Expr>(i);
-    Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 
1)->ll->Get(VisitExpr(i->true_branch)),
-                  GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch)));
-    return Compound(e, ret, v);
-  }
+Expr Fill::VisitExpr_(const IfNode* i, const Var& v) {
+  Expr e = GetRef<Expr>(i);
+  Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 
1)->let_list->Get(VisitExpr(i->true_branch)),
+                GetSubScope(e, 2)->let_list->Get(VisitExpr(i->false_branch)));
+  return Compound(e, ret, v);
+}
 
-  Expr VisitExpr_(const FunctionNode* f, const Var& v) final {
-    Expr e = GetRef<Expr>(f);
-    Expr ret;
-    if (f->HasNonzeroAttr(attr::kPrimitive)) {
-      ret = e;
-    } else {
-      ret = Function(f->params, GetSubScope(e, 
0)->ll->Get(VisitExpr(f->body)), f->ret_type,
-                     f->type_params, f->attrs);
-    }
-    return Compound(e, ret, v);
+Expr Fill::VisitExpr_(const FunctionNode* f, const Var& v) {
+  Expr e = GetRef<Expr>(f);
+  Expr ret;
+  if (f->HasNonzeroAttr(attr::kPrimitive)) {
+    ret = e;
+  } else {
+    ret = Function(f->params, GetSubScope(e, 
0)->let_list->Get(VisitExpr(f->body)), f->ret_type,
+                   f->type_params, f->attrs);
   }
+  return Compound(e, ret, v);
+}
 
-  Expr VisitExpr_(const LetNode* l, const Var& v) final {
-    Expr e = GetRef<Expr>(l);
-    VisitExpr(l->value, l->var);
-    Expr ret = GetSubScope(e, 0)->ll->Get(VisitExpr(l->body));
-    return Compound(e, ret, v);
-  }
+Expr Fill::VisitExpr_(const LetNode* l, const Var& v) {
+  Expr e = GetRef<Expr>(l);
+  VisitExpr(l->value, l->var);
+  Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body));
+  return Compound(e, ret, v);
+}
 
-  Expr VisitExpr_(const ConstantNode* c, const Var& v) final {
-    Expr e = GetRef<Expr>(c);
-    return Compound(e, e, v);
-  }
+Expr Fill::VisitExpr_(const ConstantNode* c, const Var& v) {
+  Expr e = GetRef<Expr>(c);
+  return Compound(e, e, v);
+}
 
-  Expr VisitExpr_(const VarNode* vn, const Var& v) final {
-    Expr e = GetRef<Expr>(vn);
-    return Atomic(e, v);
-  }
+Expr Fill::VisitExpr_(const VarNode* vn, const Var& v) {
+  Expr e = GetRef<Expr>(vn);
+  return Atomic(e, v);
+}
 
-  Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
-    GlobalVar gv = GetRef<GlobalVar>(gvn);
-    return Atomic(gv, v);
-  }
+Expr Fill::VisitExpr_(const GlobalVarNode* gvn, const Var& v) {
+  GlobalVar gv = GetRef<GlobalVar>(gvn);
+  return Atomic(gv, v);
+}
 
-  Expr VisitExpr_(const OpNode* op, const Var& v) final {
-    Expr e = GetRef<Expr>(op);
-    return Atomic(e, v);
-  }
+Expr Fill::VisitExpr_(const OpNode* op, const Var& v) {
+  Expr e = GetRef<Expr>(op);
+  return Atomic(e, v);
+}
 
-  Expr VisitExpr_(const ConstructorNode* c, const Var& v) final {
-    Expr e = GetRef<Expr>(c);
-    return Atomic(e, v);
-  }
+Expr Fill::VisitExpr_(const ConstructorNode* c, const Var& v) {
+  Expr e = GetRef<Expr>(c);
+  return Atomic(e, v);
+}
 
-  Expr VisitExpr_(const MatchNode* m, const Var& v) final {
-    Expr e = GetRef<Expr>(m);
-    Expr data = VisitExpr(m->data);
-    std::vector<Clause> clauses;
-    for (const Clause& c : m->clauses) {
-      clauses.push_back(
-          Clause(c->lhs, GetSubScope(e, 1 + 
clauses.size())->ll->Get(VisitExpr(c->rhs))));
-    }
-    return Compound(e, Match(data, clauses, m->complete), v);
+Expr Fill::VisitExpr_(const MatchNode* m, const Var& v) {
+  Expr e = GetRef<Expr>(m);
+  Expr data = VisitExpr(m->data);
+  std::vector<Clause> clauses;
+  for (const Clause& c : m->clauses) {
+    clauses.push_back(
+        Clause(c->lhs, GetSubScope(e, 1 + 
clauses.size())->let_list->Get(VisitExpr(c->rhs))));
   }
-};
+  return Compound(e, Match(data, clauses, m->complete), v);
+}
 
 Expr ToANormalFormAux(const Expr& e) {
   /* When you lift a lambda, what is inside is also being lift.
@@ -269,8 +274,8 @@ Expr ToANormalFormAux(const Expr& e) {
    * Every scope additionally contain a LetList which collect all value of 
that scope.
    * We do an additional pass to fill all the LetList and we are done.
    */
-  std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg);
-  return Fill::ToANormalForm(e, dg, &node_scope);
+  std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
+  return Fill::ToANormalForm(e, dg, &scopes.first);
 }
 
 IRModule ToANormalForm(const IRModule& m) {
diff --git a/src/relay/transforms/to_basic_block_normal_form.cc 
b/src/relay/transforms/to_basic_block_normal_form.cc
new file mode 100644
index 0000000..5fc01e1
--- /dev/null
+++ b/src/relay/transforms/to_basic_block_normal_form.cc
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_basic_block_normal_form.cc
+ *
+ * \brief Turn an expression to the basic normal form.
+ */
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/support/logging.h>
+
+#include "../../support/arena.h"
+#include "../analysis/dependency_graph.h"
+#include "let_list.h"
+#include "pass_util.h"
+
+namespace tvm {
+namespace relay {
+
+Expr ToBasicBlockNormalFormAux(const Expr& e) {
+  // calculate all the dependency between nodes.
+  support::Arena arena;
+  DependencyGraph dg = DependencyGraph::Create(&arena, e);
+  /* The scope of the whole expr is global.
+   * The scope of any subexpr, is the lowest common ancestor of all incoming 
edge.
+   * We also record the set of expressions whose scope is lifted.
+   */
+  std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
+  return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);
+}
+
+IRModule ToBasicBlockNormalForm(const IRModule& mod) {
+  DLOG(INFO) << "ToBBlock:" << std::endl << mod;
+
+  tvm::Map<GlobalVar, Function> updates;
+  auto funcs = mod->functions;
+  for (const auto& it : funcs) {
+    CHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables";
+    if (const auto* n = it.second.as<FunctionNode>()) {
+      if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
+    }
+    Expr ret = TransformF([&](const Expr& e) { return 
ToBasicBlockNormalFormAux(e); }, it.second);
+    updates.Set(it.first, Downcast<Function>(ret));
+  }
+
+  for (auto pair : updates) {
+    mod->Add(pair.first, pair.second, true);
+  }
+
+  DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod;
+
+  return mod;
+}
+
+bool BasicBlockNormalFormCheck(const Expr& e) {
+  // calculate all the dependency between nodes.
+  support::Arena arena;
+  DependencyGraph dg = DependencyGraph::Create(&arena, e);
+  std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
+  for (auto expr : scopes.second) {
+    LOG(FATAL) << "The expression below violates the basic block normal form 
in that "
+               << "its scope should be lifted:\n"
+               << expr;
+  }
+  return scopes.second.size() == 0;
+}
+
+TVM_REGISTER_GLOBAL("relay.analysis.check_basic_block_normal_form")
+    .set_body_typed(BasicBlockNormalFormCheck);
+
+namespace transform {
+
+Pass ToBasicBlockNormalForm() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule m, PassContext pc) { return 
relay::ToBasicBlockNormalForm(m); };
+  return CreateModulePass(pass_func, 1, "ToBasicBlockNormalForm", {});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.ToBasicBlockNormalForm")
+    .set_body_typed(ToBasicBlockNormalForm);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_analysis_basic_block_normal_form.py 
b/tests/python/relay/test_analysis_basic_block_normal_form.py
new file mode 100644
index 0000000..dfd7dd1
--- /dev/null
+++ b/tests/python/relay/test_analysis_basic_block_normal_form.py
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+from tvm import relay
+from tvm.relay.analysis import check_basic_block_normal_form
+
+def test_one_block():
+    x = relay.var('x')
+    y = relay.add(x, x)
+    z = relay.add(x, y)
+    check_basic_block_normal_form(z)
+
+def test_let():
+    x = relay.var('x')
+    y = relay.var('y')
+    body = relay.Let(y, x, y)
+    check_basic_block_normal_form(body)
+
+@pytest.mark.xfail(raises=tvm.error.TVMError)
+def test_invalid_if():
+    cond = relay.var('cond', dtype='bool', shape=())
+    shared = relay.var('shared')
+    true_branch = shared
+    false_branch = relay.add(shared, shared)
+    body = relay.If(cond, true_branch, false_branch)
+    """
+    The program below violates basic block normal form, as the scope of %shared
+    is ambiguous and should not be in that of true branch.
+
+    free_var %cond: bool
+    if (%cond) {
+      free_var %shared
+      %shared
+    } else {
+      add(%shared, %shared)
+    }
+    """
+    check_basic_block_normal_form(body)
+
+def test_valid_if():
+    cond = relay.var('cond', dtype='bool', shape=())
+    shared = relay.var('shared')
+    true_branch = shared
+    false_branch = relay.add(shared, shared)
+    body = relay.If(cond, true_branch, false_branch)
+    shared_bound = relay.var('shared_bound', shape=(1,), dtype='float32')
+    body = relay.Let(shared, shared_bound, body)
+    """
+    The program below uses let binding to control the scope of %shared, which
+    follows the basic block normal form.
+
+    free_var %shared_bound: Tensor[(1), float32]
+    let %shared = %shared_bound;
+    free_var %cond: bool
+    if (%cond) {
+      %shared
+    } else {
+      add(%shared, %shared)
+    }
+    """
+    check_basic_block_normal_form(body)
+
+@pytest.mark.xfail(raises=tvm.error.TVMError)
+def test_invalid_if2():
+    """
+    fn (%x: float32) {
+      %0 = equal(%x, 2f);
+      if (%0) {
+        %1 = add(%x, 1f);
+        multiply(%1, 2f)
+      } else {
+        multiply(%1, 1f)
+      }
+    }
+    """
+    x = relay.var('x', shape=(), dtype='float32')
+    one = relay.const(1, dtype='float32')
+    two = relay.const(2, dtype='float32')
+    v1 = relay.add(x, one)
+    v2 = relay.equal(x, two)
+    true_branch = relay.multiply(v1, two)
+    false_branch = relay.multiply(v1, one)
+    body = relay.If(v2, true_branch, false_branch)
+    func = relay.Function([x], body)
+    check_basic_block_normal_form(func)
+
+def test_valid_if2():
+    """
+    fn (%x: float32) {
+      let %v1 = add(%x, 1f);
+      %0 = equal(%x, 2f);
+      if (%0) {
+        multiply(%v1, 2f)
+      } else {
+        multiply(%v1, 1f)
+      }
+    }
+    """
+    x = relay.var('x', shape=(), dtype='float32')
+    one = relay.const(1, dtype='float32')
+    two = relay.const(2, dtype='float32')
+    v1 = relay.var('v1')
+    v2 = relay.equal(x, two)
+    true_branch = relay.multiply(v1, two)
+    false_branch = relay.multiply(v1, one)
+    body = relay.If(v2, true_branch, false_branch)
+    body = relay.Let(v1, relay.add(x, one), body)
+    func = relay.Function([x], body)
+    check_basic_block_normal_form(func)
+
+@pytest.mark.xfail(raises=tvm.error.TVMError)
+def test_func():
+    x = relay.var('x', shape=(1,), dtype='float32')#, a)
+    y = relay.var('y', shape=(1,), dtype='float32')#, a)
+    z = relay.var('z', shape=(1,), dtype='float32')#, a)
+    x2 = relay.add(x, x)
+    func_a = relay.Function([y], relay.add(x2, y)) #, a, [a])
+    func_b = relay.Function([z], relay.add(x2, z)) #, a, [a])
+    body = relay.Tuple([func_a, func_b])
+    body = relay.Function([x], body)
+    """
+    fn (%x: Tensor[(1), float32]) {
+      %1 = fn (%y: Tensor[(1), float32]) {
+        %0 = add(%x, %x);
+        add(%0, %y)
+      };
+      %2 = fn (%z: Tensor[(1), float32]) {
+        add(%0, %z)
+      };
+      (%1, %2)
+    }
+    """
+    check_basic_block_normal_form(body)
+
+@pytest.mark.xfail(raises=tvm.error.TVMError)
+def test_higher_order_return():
+    x = relay.var('x', shape=(1,), dtype='float32')#, a)
+    y = relay.var('y', shape=(1,), dtype='float32')#, a)
+    z = relay.var('z', shape=(1,), dtype='float32')#, a)
+    x2 = relay.add(x, x)
+    func_a = relay.Function([y], relay.add(x2, y)) #, a, [a])
+    func_b = relay.Function([z], relay.add(x2, z)) #, a, [a])
+    body = relay.Tuple([func_a, func_b])
+    body = relay.Function([x], body)
+    """
+    fn (%x: Tensor[(1), float32]) {
+      %1 = fn (%y: Tensor[(1), float32]) {
+        %0 = add(%x, %x);
+        add(%0, %y)
+      };
+      %2 = fn (%z: Tensor[(1), float32]) {
+        add(%0, %z)
+      };
+      (%1, %2)
+    }
+    """
+    check_basic_block_normal_form(body)
+
+
+@pytest.mark.xfail(raises=tvm.error.TVMError)
+def test_higher_order_nested():
+    x = relay.var('x', dtype='float32', shape=(1,))
+    s = relay.var('s', dtype='float32', shape=(1,))
+    shared = relay.add(s, s)
+    func_true = relay.Function([x], relay.add(x, shared))
+    choice_t = relay.FuncType([], relay.scalar_type('bool'))
+    f = relay.Var('f', choice_t)
+    z = relay.Var('z')
+    body = relay.If(f(), func_true, relay.Function([z], relay.add(z, shared)))
+    top = relay.Function([f, s], body)
+    """
+    fn (%f: fn () -> bool, %s: Tensor[(1), float32]) {
+      %0 = %f();
+      if (%0) {
+        fn (%x: Tensor[(1), float32]) {
+          %1 = add(%s, %s);
+          add(%x, %1)
+        }
+      } else {
+        fn (%z) {
+          add(%z, %1)
+        }
+      }
+    }
+    """
+    check_basic_block_normal_form(top)
+
+
+if __name__ == '__main__':
+    pytest.main([__file__])
diff --git a/tests/python/relay/test_pass_to_basic_block_normal_form.py 
b/tests/python/relay/test_pass_to_basic_block_normal_form.py
new file mode 100644
index 0000000..05c6544
--- /dev/null
+++ b/tests/python/relay/test_pass_to_basic_block_normal_form.py
@@ -0,0 +1,482 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import numpy as np
+import tvm
+from tvm import te
+from tvm import relay
+from tvm.relay.analysis import detect_feature
+from tvm.relay import op, create_executor, transform
+from tvm.relay.prelude import Prelude
+from tvm.relay.testing import add_nat_definitions, count
+from tvm.relay.analysis import Feature
+from tvm.relay.analysis import check_basic_block_normal_form
+
+
+def run_opt_pass(expr, passes):
+    passes = passes if isinstance(passes, list) else [passes]
+    mod = tvm.IRModule.from_expr(expr)
+    seq = tvm.transform.Sequential(passes)
+    with tvm.transform.PassContext(opt_level=3):
+       mod = seq(mod)
+    entry = mod["main"]
+    return entry if isinstance(expr, relay.Function) else entry.body
+
+
+def check_eval(expr, expected_result, mod=None, rtol=1e-07):
+    ctx = tvm.context("llvm", 0)
+    intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
+
+    result = intrp.evaluate(expr)
+    np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
+
+
+def test_no_explicit_bind():
+    x = relay.const(1)
+    y = op.add(x, x)
+    z = op.add(y, y)
+    f = relay.Function([], op.add(z, z))
+    """
+    fn () {
+      %0 = add(1, 1);
+      %1 = add(%0, %0);
+      add(%1, %1)
+    }
+    """
+    assert not Feature.fLet in detect_feature(f)
+    bblock = run_opt_pass(f, transform.ToBasicBlockNormalForm())
+    assert Feature.fLet not in detect_feature(bblock)
+    check_eval(f(), 8.0)
+    check_eval(bblock(), 8.0)
+    check_basic_block_normal_form(bblock)
+
+def test_top_level_nested_if():
+    x = relay.var('x', shape=(), dtype='bool')
+    y = relay.var('y', shape=(), dtype='float32')
+    z = relay.var('z', shape=(), dtype='float32')
+    cond_t = relay.const(True)
+    cond_f = relay.const(False)
+    one = relay.const(1, dtype='float32')
+    three = relay.const(3, dtype='float32')
+    y2 = relay.add(y, y)
+    z2 = relay.add(z, z)
+    true_branch = relay.If(cond_t, relay.add(z2, y2), relay.add(three, y2))
+    false_branch = relay.If(cond_f, z2, one)
+    body = relay.If(x, true_branch, false_branch)
+    """
+    free_var %x: bool
+    if (%x) {
+      if (True) {
+        free_var %z: float32
+        %0 = add(%z, %z);
+        free_var %y: float32
+        %1 = add(%y, %y);
+        add(%0, %1)
+      } else {
+        add(3f, %1)
+      }
+    } else {
+      if (False) {
+        %0
+      } else {
+        1f
+      }
+    }
+    """
+    def expected():
+        x = relay.var('x', shape=(), dtype='bool')
+        y = relay.var('y', shape=(), dtype='float32')
+        z = relay.var('z', shape=(), dtype='float32')
+        cond_t = relay.const(True)
+        cond_f = relay.const(False)
+        one = relay.const(1, dtype='float32')
+        three = relay.const(3, dtype='float32')
+        y2 = relay.var('y2')
+        z2 = relay.var('z2')
+        true_branch = relay.If(cond_t, relay.add(z2, y2), relay.add(three, y2))
+        true_branch = relay.Let(y2, relay.add(y, y), true_branch)
+        false_branch = relay.If(cond_f, z2, one)
+        body = relay.If(x, true_branch, false_branch)
+        body = relay.Let(z2, relay.add(z, z), body)
+        return body
+
+    bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm()])
+    """
+    free_var %z: float32
+    let %x: float32 = add(%z, %z) /* ty=float32 */;
+    free_var %x1: bool
+    if (%x1) {
+      free_var %y: float32
+      let %x2: float32 = add(%y, %y) /* ty=float32 */;
+      if (True /* ty=bool */) {
+        add(%x, %x2) /* ty=float32 */
+      } else {
+        add(3f /* ty=float32 */, %x2) /* ty=float32 */
+      }
+    } else {
+      if (False /* ty=bool */) {
+        %x
+      } else {
+        1f /* ty=float32 */
+      }
+    }
+    """
+    expected_output = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True)
+
+def test_nested_if():
+    x = relay.var('x', shape=(), dtype='bool')
+    y = relay.var('y', shape=(), dtype='float32')
+    cond_t = relay.const(True)
+    cond_f = relay.const(False)
+    one = relay.const(1, dtype='float32')
+    two = relay.const(2, dtype='float32')
+    three = relay.const(3, dtype='float32')
+    y2 = relay.add(y, y)
+    true_branch = relay.If(cond_t, y2, relay.add(three, y2))
+    false_branch = relay.If(cond_f, two, one)
+    body = relay.If(x, true_branch, false_branch)
+    """
+    free_var %x: bool
+    if (%x) {
+      if (True) {
+        free_var %y: float32
+        %0 = add(%y, %y);
+        %0
+      } else {
+        add(3f, %0)
+      }
+    } else {
+      if (False) {
+        2f
+      } else {
+        1f
+      }
+    }
+    """
+    def expected():
+        x = relay.var('x', shape=(), dtype='bool')
+        y = relay.var('y', shape=(), dtype='float32')
+        cond_t = relay.const(True)
+        cond_f = relay.const(False)
+        one = relay.const(1, dtype='float32')
+        two = relay.const(2, dtype='float32')
+        three = relay.const(3, dtype='float32')
+        y2 = relay.var('y2')
+        true_branch = relay.If(cond_t, y2, relay.add(three, y2))
+        true_branch = relay.Let(y2, relay.add(y, y), true_branch)
+        false_branch = relay.If(cond_f, two, one)
+        body = relay.If(x, true_branch, false_branch)
+        return body
+
+    bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm()])
+    """
+    free_var %x: bool
+    if (%x) {
+      free_var %y: float32
+      let %x1: float32 = add(%y, %y) /* ty=float32 */;
+      if (True /* ty=bool */) {
+        %x1
+      } else {
+        add(3f /* ty=float32 */, %x1) /* ty=float32 */
+      }
+    } else {
+      if (False /* ty=bool */) {
+        2f /* ty=float32 */
+      } else {
+        1f /* ty=float32 */
+      }
+    }
+    """
+    expected_output = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True)
+    check_basic_block_normal_form(bblock)
+
+
+# make sure we do not infinite loop.
+# it is too large so we won't check for the exact program.
+def test_recursion():
+    """
+    Program:
+       let f(n: i32) -> i32 = {
+          m = (n * 2)
+          if (n == 0) {
+              return m;
+          } else {
+              return m + f(n - 1);
+          }
+       }
+       f(5);
+    """
+    mod = tvm.IRModule()
+    i64 = relay.TensorType((), 'int64')
+    f = relay.GlobalVar("f")
+    n = relay.Var("n", i64)
+    m = n * relay.const(2, 'int64')
+    cond = relay.equal(n, relay.const(0, 'int64'))
+    false_branch = m + f(n - relay.const(1, 'int64'))
+    funcbody = relay.If(cond, m, false_branch)
+    value = relay.Function([n], funcbody, i64, [])
+    mod[f] = value
+    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
+    old_f = mod[f]
+    mod = transform.ToBasicBlockNormalForm()(mod)
+    f = mod[f]
+    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
+    check_basic_block_normal_form(f)
+
+def test_ref():
+    i = relay.Var('i')
+    iv = relay.Var('iv')
+    u = relay.Var('u')
+    uv = relay.Var('uv')
+    body = relay.add(iv, uv)
+    body = relay.Let(uv, relay.RefRead(i), body)
+    body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
+    body = relay.Let(iv, relay.RefRead(i), body)
+    body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
+    check_eval(body, 3)
+    opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+    check_eval(opt_body, 3)
+    check_basic_block_normal_form(opt_body)
+
+
+def test_nat_add():
+    mod = tvm.IRModule()
+    p = Prelude(mod)
+    add_nat_definitions(p)
+    nat = p.nat
+    add = p.add
+    s = p.s
+    z = p.z
+    ctx = tvm.context("llvm", 0)
+    intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
+    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
+    assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
+    expr = add(s(z()), s(z()))
+    f = relay.GlobalVar("f")
+    mod[f] = relay.Function([], expr)
+    mod = transform.ToBasicBlockNormalForm()(mod)
+    opt_expr = mod["f"]
+    assert count(p, intrp.evaluate(opt_expr.body)) == 2
+    assert not Feature.fLet in detect_feature(mod[add])
+    check_basic_block_normal_form(opt_expr)
+
+def test_let():
+    def test_let1():
+        x = relay.Var("x")
+        c = relay.const(4.0, 'float32')
+        body = relay.Let(x, c, x)
+        body = run_opt_pass(body, transform.InferType())
+        """
+        let %x: float32 = 4f /* ty=float32 */;
+        %x
+        """
+        opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+        assert tvm.ir.structural_equal(body, opt_body)
+        check_basic_block_normal_form(opt_body)
+        
+    def test_let1_1():
+        x = relay.Var("y")
+        d = relay.const(4.0, 'float32')
+        body = relay.Let(x, d, relay.add(x,x))
+        body = run_opt_pass(body, transform.InferType())
+        opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+        assert tvm.ir.structural_equal(body, opt_body)
+        check_basic_block_normal_form(opt_body)
+    
+    def test_let2():
+        x = relay.Var("x")
+        y = relay.Var("y")
+        d = relay.const(4.0, 'float32')
+        body = relay.Let(y, x, x)
+        body = relay.Let(x, d, body)
+        body = run_opt_pass(body, transform.InferType())
+        check_eval(body, 4)
+
+        def expected():
+            x = relay.Var("x")
+            y = relay.Var("y")
+            d = relay.const(4.0, 'float32')
+            body = relay.Let(y, x, y)
+            body = relay.Let(x, d, body)
+            return body
+
+        opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+        expected_body = run_opt_pass(expected(), transform.InferType())
+        assert tvm.ir.structural_equal(opt_body, expected_body)
+        check_basic_block_normal_form(opt_body)
+
+    def test_let3():
+        x = relay.Var("x")
+        y = relay.Var("y")
+        z = relay.Var("z")
+        c = relay.const(3.0, 'float32')
+        d = relay.const(4.0, 'float32')
+        body = relay.Let(z, x + y, x + z)
+        body = relay.Let(x, d, body)
+        body = relay.Let(y, c, body)
+        body = run_opt_pass(body, transform.InferType())
+        opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+        assert tvm.ir.structural_equal(body, opt_body)
+        check_basic_block_normal_form(opt_body)
+
+    test_let1()
+    test_let1_1()
+    test_let2()
+    test_let3()
+
+def test_function():
+    t = relay.TensorType((), 'float32')
+    x = relay.Var("x", t)
+    f = relay.Function([x], x + x)
+    d = relay.const(4.0, 'float32')
+    bblock = run_opt_pass(f, transform.ToBasicBlockNormalForm())
+    assert isinstance(bblock, relay.Function)
+    check_eval(f(d), 8)
+    check_eval(bblock(d), 8)
+    check_basic_block_normal_form(bblock)
+
+def test_gradient_if():
+    x = relay.var("a", shape=(1, 16))
+    y = relay.var("y", shape=(1, 16))
+    cond = relay.var("cond", shape=(), dtype='uint1')
+    net = relay.If(cond, x, x)
+    net = relay.add(x, net)
+    net = relay.Function([cond,x,y], net)
+    mod = tvm.IRModule.from_expr(net)
+    mod = relay.transform.ToBasicBlockNormalForm()(mod)
+    net_grad = relay.transform.gradient(mod["main"], mode='higher_order')
+    mod["main"] = net_grad
+    mod_grad = relay.transform.ToBasicBlockNormalForm()(mod)
+    check_basic_block_normal_form(mod_grad['main'])
+    check_basic_block_normal_form(mod['main'])
+
+def test_if():
+    def if_expr(x):
+        """
+        free_var %x: float32
+        %0 = equal(%x, 2f);
+        if (%0) {
+          %1 = add(%x, 1f);
+          multiply(%1, 2f)
+        } else {
+          multiply(%1, 1f)
+        }
+        """
+        one = relay.const(1, dtype='float32')
+        two = relay.const(2, dtype='float32')
+        v1 = relay.add(x, one)
+        v2 = relay.equal(x, two)
+        true_branch = relay.multiply(v1, two)
+        false_branch = relay.multiply(v1, one)
+        body = relay.If(v2, true_branch, false_branch)
+        return body
+
+    def expected_if_expr(x):
+        """
+        free_var %x: float32
+        let %v1: float32 = add(%x, 1f /* ty=float32 */) /* ty=float32 */;
+        %0 = equal(%x, 2f /* ty=float32 */) /* ty=bool */;
+        if (%0) {
+          multiply(%v1, 2f /* ty=float32 */) /* ty=float32 */
+        } else {
+          multiply(%v1, 1f /* ty=float32 */) /* ty=float32 */
+        }
+        """
+        one = relay.const(1, dtype='float32')
+        two = relay.const(2, dtype='float32')
+        v1 = relay.var('v1')
+        v2 = relay.equal(x, two)
+        true_branch = relay.multiply(v1, two)
+        false_branch = relay.multiply(v1, one)
+        body = relay.If(v2, true_branch, false_branch)
+        body = relay.Let(v1, relay.add(x, one), body)
+        return body
+
+    x = relay.var('x', shape=(), dtype='float32')
+    body = if_expr(x)
+    expected_body = expected_if_expr(x)
+    bblock = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+    expected_bblock = run_opt_pass(expected_body, transform.InferType())
+    assert tvm.ir.structural_equal(bblock, expected_bblock, map_free_vars=True)
+    check_basic_block_normal_form(bblock)
+
+    func = relay.Function([x], body)
+    expected_func = relay.Function([x], expected_body)
+    bblock = run_opt_pass(func, transform.ToBasicBlockNormalForm())
+    expected_bblock = run_opt_pass(expected_func, transform.InferType())
+    assert tvm.ir.structural_equal(bblock, expected_bblock)
+    check_basic_block_normal_form(bblock)
+
+def test_higher_order_return():
+    x = relay.var('x', shape=(1,), dtype='float32')#, a)
+    y = relay.var('y', shape=(1,), dtype='float32')#, a)
+    z = relay.var('z', shape=(1,), dtype='float32')#, a)
+    x2 = relay.add(x, x)
+    func_a = relay.Function([y], relay.add(x2, y)) #, a, [a])
+    func_b = relay.Function([z], relay.add(x2, z)) #, a, [a])
+    body = relay.Tuple([func_a, func_b])
+    body = relay.Function([x], body)
+    """
+    fn (%x: Tensor[(1), float32]) {
+      %1 = fn (%y: Tensor[(1), float32]) {
+        %0 = add(%x, %x);
+        add(%0, %y)
+      };
+      %2 = fn (%z: Tensor[(1), float32]) {
+        add(%0, %z)
+      };
+      (%1, %2)
+    }
+    """
+
+    bblock = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+    check_basic_block_normal_form(bblock)
+
+
+def test_higher_order_nested():
+    x = relay.var('x', dtype='float32', shape=(1,))
+    s = relay.var('s', dtype='float32', shape=(1,))
+    shared = relay.add(s, s)
+    func_true = relay.Function([x], relay.add(x, shared))
+    choice_t = relay.FuncType([], relay.scalar_type('bool'))
+    f = relay.Var('f', choice_t)
+    z = relay.Var('z')
+    body = relay.If(f(), func_true, relay.Function([z], relay.add(z, shared)))
+    top = relay.Function([f, s], body)
+    """
+    fn (%f: fn () -> bool, %s: Tensor[(1), float32]) {
+      %0 = %f();
+      if (%0) {
+        fn (%x: Tensor[(1), float32]) {
+          %1 = add(%s, %s);
+          add(%x, %1)
+        }
+      } else {
+        fn (%z) {
+          add(%z, %1)
+        }
+      }
+    }
+    """
+
+    bblock = run_opt_pass(top, transform.ToBasicBlockNormalForm())
+    check_basic_block_normal_form(bblock)
+
+if __name__ == '__main__':
+    pytest.main([__file__])

Reply via email to