This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new dfc77eb129 [Unity] Extend RemoveAllUnused to support relax::Expr 
(#15807)
dfc77eb129 is described below

commit dfc77eb1299b2120467f5fcb8f5ee06d370d3ff5
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Sep 26 18:09:33 2023 -0500

    [Unity] Extend RemoveAllUnused to support relax::Expr (#15807)
    
    * [Unity] Extend RemoveAllUnused to support relax::Expr
    
    Prior to this commit, the `RemoveAllUnused` utility could only
    apply to `relax::Function` instances.  This commit extends the allowed
    usage to apply to any `relax::Expr` that contains variable bindings,
    such as `relax::SeqExpr`.
    
    * Update unit tests for new behavior
    
    Because `RemoveAllUnused` now handles unused variables in a
    non-dataflow binding block, passes that use it as a utility function
    needed a few tests updated.
---
 include/tvm/relax/analysis.h                       |  19 +++-
 src/relax/analysis/udchain.cc                      |   7 +-
 src/relax/ir/binding_rewrite.cc                    |  51 ++++-----
 src/relax/ir/dataflow_matcher.cc                   |   2 +-
 src/relax/transform/dead_code_elimination.cc       |   2 +-
 src/relax/transform/fold_constant.cc               |   2 +-
 src/relax/transform/gradient.cc                    |   2 +-
 tests/python/relax/test_analysis.py                | 126 ++++++++++++++++++---
 tests/python/relax/test_transform_fold_constant.py |  15 +--
 tests/python/relax/test_tuning_api.py              |   1 -
 10 files changed, 157 insertions(+), 70 deletions(-)

diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index f515ba6201..82fb73b1bd 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -399,18 +399,25 @@ TVM_DLL Map<Var, Array<Var>> DataflowBlockUseDef(const 
DataflowBlock& dfb);
 /*!
  * \brief Get the use-def chain of variables inside a function.
  *
- * \param fn The function to be analyzed.
- * \return A map from variable definitions to a set of uses and variables 
needed by return value.
+ * \param expr The expression to be analyzed.
+ *
+ * \return A tuple of variable usage and variable outputs.  The first
+ * element is a map from variable definitions to the set of downstream
+ * users of that definition.  The second element is a list of
+ * variables whose usage occurs outside of any variable binding,
+ * typically the output body of a relax::Function or a relax::SeqExpr.
  */
-std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Function& fn);
+std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Expr& expr);
 
 /*!
  * \brief Remove unused statements inside DataflowBlocks.
  *
- * \param fn The function to remove unused statements.
- * \return The function that contains no unused statements in DataflowBlock.
+ * \param expr The expression (typically a relax::Function) from which
+ * to remove unused statements.
+ *
+ * \return The updated function with no unused statements in DataflowBlock.
  */
-TVM_DLL Function RemoveAllUnused(const Function fn);
+TVM_DLL Expr RemoveAllUnused(Expr expr);
 
 /*!
  * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax 
FuseOps.
diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc
index 54c7307b96..fc0ffd3292 100644
--- a/src/relax/analysis/udchain.cc
+++ b/src/relax/analysis/udchain.cc
@@ -40,14 +40,15 @@ class UDChain : public relax::ExprVisitor {
   // nullptr users means it is the output of the function.
   std::map<const VarNode*, std::set<const VarNode*>> to_users;
 
-  const VarNode* cur_user_;
+  const VarNode* cur_user_{nullptr};
 
   void VisitBinding_(const VarBindingNode* binding) override {
     // init
+    auto cache = cur_user_;
     cur_user_ = binding->var.get();
     this->VisitVarDef(binding->var);
     this->VisitExpr(binding->value);
-    cur_user_ = nullptr;
+    cur_user_ = cache;
   }
 
   void VisitExpr_(const VarNode* op) override { 
to_users[op].insert(cur_user_); }
@@ -59,7 +60,7 @@ class UDChain : public relax::ExprVisitor {
 };
 
 std::pair<runtime::Map<Var, runtime::Array<Var>>, runtime::Array<Var>> 
FunctionUseDef(
-    const Function& fn) {
+    const Expr& fn) {
   UDChain udchain;
   udchain.VisitExpr(fn);
 
diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc
index ae48b6bd69..f5820eca80 100644
--- a/src/relax/ir/binding_rewrite.cc
+++ b/src/relax/ir/binding_rewrite.cc
@@ -240,38 +240,31 @@ class RemoveUnusedVars : public ExprMutator {
   RemoveUnusedVars(Map<Var, Array<Var>> users, Array<Var> fn_outputs)
       : RemoveUnusedVars(GetUnusedVars(users, fn_outputs)) {}
 
-  BindingBlock VisitBindingBlock_(const BindingBlockNode* block) override {
-    builder_->BeginBindingBlock();
-    for (Binding binding : block->bindings) {
-      bool can_remove = [&]() -> bool {
-        if (!unused_vars.count(binding->var)) {
-          return false;
-        }
-        auto var_binding = binding.as<VarBindingNode>();
-        if (!var_binding) {
-          return false;
-        }
-        return var_binding->value->IsInstance<FunctionNode>();
-      }();
-      if (!can_remove) {
-        VisitBinding(binding);
-      }
+  void VisitBinding_(const VarBindingNode* binding) override {
+    bool can_remove = unused_vars.count(binding->var) &&
+                      (in_dataflow_block_ || 
!ContainsImpureCall(binding->value));
+    if (!can_remove) {
+      ExprMutator::VisitBinding_(binding);
     }
-    return builder_->EndBlock();
   }
 
   BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override {
-    auto prev_dfb = GetRef<DataflowBlock>(block);
-    builder_->BeginDataflowBlock();
-    for (Binding binding : block->bindings) {
-      if (!unused_vars.count(binding->var) || binding.as<MatchCastNode>()) {
-        VisitBinding(binding);
-      }
+    bool capture_output = (block == caught_rewrite.get());
+
+    bool cache = in_dataflow_block_;
+    in_dataflow_block_ = true;
+    BindingBlock output = ExprMutator::VisitBindingBlock_(block);
+    in_dataflow_block_ = cache;
+
+    if (capture_output) {
+      caught_rewrite = Downcast<DataflowBlock>(output);
     }
-    auto new_dfb = builder_->EndBlock();
-    if (caught_rewrite == prev_dfb) caught_rewrite = 
Downcast<DataflowBlock>(new_dfb);
-    return std::move(new_dfb);
+
+    return std::move(output);
   }
+
+ private:
+  bool in_dataflow_block_{false};
 };
 
 void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) {
@@ -322,10 +315,10 @@ void DataflowBlockRewriteNode::RemoveAllUnused() {
 TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused")
     .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); });
 
-Function RemoveAllUnused(Function fn) {
-  auto [users, outputs] = FunctionUseDef(fn);
+Expr RemoveAllUnused(Expr expr) {
+  auto [users, outputs] = FunctionUseDef(expr);
   RemoveUnusedVars remover(users, outputs);
-  return Downcast<Function>(remover.VisitExpr_(fn.get()));
+  return remover.VisitExpr(std::move(expr));
 }
 
 
TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused);
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 3b17f7bd1d..e85e3c4d51 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -934,7 +934,7 @@ class PatternRewriter : ExprMutator {
       params.insert(p.get());
     }
     PatternRewriter rewriter(pat, rewriter_func, params);
-    return RemoveAllUnused(Downcast<Function>(rewriter.VisitExpr(f)));
+    return Downcast<Function>(RemoveAllUnused(rewriter.VisitExpr(f)));
   }
 
   void VisitBinding_(const VarBindingNode* binding) final {
diff --git a/src/relax/transform/dead_code_elimination.cc 
b/src/relax/transform/dead_code_elimination.cc
index 494665ec71..6d9f25296a 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -137,7 +137,7 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, 
Array<runtime::String> ent
     IRModule updates;
     for (const auto& [gvar, base_func] : mod->functions) {
       if (auto opt = base_func.as<Function>()) {
-        auto new_func = RemoveAllUnused(opt.value());
+        auto new_func = Downcast<Function>(RemoveAllUnused(opt.value()));
         if (!new_func.same_as(base_func)) {
           updates->Add(gvar, new_func);
         }
diff --git a/src/relax/transform/fold_constant.cc 
b/src/relax/transform/fold_constant.cc
index 8a78c98144..d6da79c484 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -34,7 +34,7 @@ class ConstantFolder : public ExprMutator {
  public:
   static Function Fold(Function func, IRModule ctx_module) {
     ConstantFolder folder(std::move(ctx_module));
-    func = RemoveAllUnused(Downcast<Function>(folder(func)));
+    func = Downcast<Function>(RemoveAllUnused(folder(func)));
     return func;
   }
 
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index e09a6d6232..8560db6a04 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -632,7 +632,7 @@ class GradientMutator : private ExprMutator {
     new_func = CallTIRWithGradEliminator::Transform(new_func);
 
     if (remove_all_unused) {
-      new_func = RemoveAllUnused(new_func);
+      new_func = Downcast<Function>(RemoveAllUnused(new_func));
     }
 
     // Step 5.3 mark the transformed function as public
diff --git a/tests/python/relax/test_analysis.py 
b/tests/python/relax/test_analysis.py
index d5545a0a56..dfee026206 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -88,6 +88,13 @@ def test_chained_remove_all_unused():
 
 
 def test_binding_block_remove_all_unused():
+    """Remove unused dataflow bindings
+
+    Removal of unused bindings may not remove side effects.  Since
+    bindings within a dataflow block are guaranteed not to have side
+    effects, they may be removed if unused.
+    """
+
     @tvm.script.ir_module
     class IdentityUnused:
         @R.function
@@ -117,24 +124,49 @@ def test_binding_block_remove_all_unused():
     tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])
 
 
-def test_binding_block_remove_all_unused_without_dataflow():
-    @tvm.script.ir_module
-    class IdentityUnused:
-        @R.function
-        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
-            lv0 = x
-            unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), 
dtype="float32"))
-            unused1 = R.call_dps_packed(
-                "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")
-            )
-            z = R.call_packed("vm.builtin.copy", lv0, 
sinfo_args=(R.Tensor((32, 32), "float32")))
-            return z
+def test_binding_block_remove_unused_pure_without_dataflow():
+    """Remove unused dataflow bindings
 
-    optimized = remove_all_unused(IdentityUnused["main"])
+    Removal of unused bindings may not remove side effects.  Unused
+    bindings whose value is a pure operation
+    (e.g. `R.call_dps_packed`) may be removed, even if outside of a
+    dataflow block.
+    """
 
-    GroundTruth = IdentityUnused
+    @R.function(private=True)
+    def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+        lv0 = x
+        unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), 
dtype="float32"))
+        unused1 = R.call_dps_packed("my_dps_func", (unused0,), R.Tensor((32, 
32), dtype="float32"))
+        return x
 
-    tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])
+    @R.function(private=True)
+    def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+        return x
+
+    after = remove_all_unused(before)
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_binding_block_keep_impure_without_dataflow():
+    """Remove unused dataflow bindings
+
+    Removal of unused bindings may not remove side effects.  Unused
+    bindings whose value is an impure operation (e.g. `R.call_packed`)
+    may not be removed, as outside of a dataflow block they may
+    contain side effects.
+    """
+
+    @R.function(private=True)
+    def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+        lv0 = x
+        y = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 
32), "float32")))
+        return y
+
+    expected = before
+
+    after = remove_all_unused(before)
+    tvm.ir.assert_structural_equal(expected, after)
 
 
 def test_binding_block_remove_all_unused_func_without_dataflow():
@@ -226,6 +258,70 @@ def 
test_edge_binding_block_fake_unused_remove_all_unused2():
     tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"])
 
 
+def test_remove_all_unused_from_dataflow_block():
+    """Like test_chained_remove_all_unused, but on a SeqExpr"""
+
+    @R.function
+    def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+        with R.dataflow():
+            lv0 = x
+            unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), 
dtype="float32"))
+            unused1 = R.call_dps_packed(
+                "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")
+            )
+            R.output(lv0)
+        return lv0
+
+    @R.function
+    def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+        with R.dataflow():
+            lv0 = x
+            R.output(lv0)
+        return lv0
+
+    after = remove_all_unused(before.body)
+    tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)
+
+
+def test_remove_all_unused_from_binding_block():
+    """Like test_chained_remove_all_unused, but on a SeqExpr"""
+
+    @R.function
+    def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+        lv0 = x
+        unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), 
dtype="float32"))
+        unused1 = R.call_dps_packed("my_dps_func", (unused0,), R.Tensor((32, 
32), dtype="float32"))
+        return lv0
+
+    @R.function
+    def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+        lv0 = x
+        return lv0
+
+    after = remove_all_unused(before.body)
+    tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)
+
+
+def test_retain_impure_calls_unused_in_binding_block():
+    """An impure call may have side effects, and must be kept"""
+
+    @R.function
+    def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+        lv0 = x
+        unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 
32), dtype="float32"))
+        unused1 = R.call_dps_packed("my_unused_call", (lv0,), R.Tensor((32, 
32), dtype="float32"))
+        return lv0
+
+    @R.function
+    def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+        lv0 = x
+        unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 
32), dtype="float32"))
+        return lv0
+
+    after = remove_all_unused(before.body)
+    tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)
+
+
 def test_name_to_binding_var_shadowing():
     @R.function
     def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
diff --git a/tests/python/relax/test_transform_fold_constant.py 
b/tests/python/relax/test_transform_fold_constant.py
index c2a3bd5092..a4dffba114 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -73,7 +73,6 @@ def test_one_fold_addone():
 
         @R.function
         def expected(c1: R.Tensor((16, 16), "float32")):
-            lv0 = c1
             return c1
 
     c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
@@ -104,7 +103,6 @@ def test_one_fold_transpose():
 
         @R.function
         def expected(c1: R.Tensor((3, 2), "float32")):
-            lv0 = c1
             return c1
 
     c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3)
@@ -135,8 +133,6 @@ def test_two_hop_addone():
 
         @R.function
         def expected(c1: R.Tensor((2, 2), "float32"), c2: R.Tensor((2, 2), 
"float32")):
-            lv0 = c1
-            lv1 = c2
             return c2
 
     c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2)
@@ -218,7 +214,7 @@ def test_fold_mixed_case():
             lv2 = relax.call_tir(cls.sub, (c0, lv1), R.Tensor((16, 16), 
dtype="float32"))
             # this line can not be folded because x's shape is unknown
             lv3 = relax.call_tir(cls.sub, (lv2, x), R.Tensor((16, 16), 
dtype="float32"))
-            return lv3
+            return (lv0, lv3)
 
         @R.function
         def expected(
@@ -226,19 +222,15 @@ def test_fold_mixed_case():
             c1: R.Tensor((16, 16), "float32"),
             c2: R.Tensor((16, 16), "float32"),
             x: R.Tensor("float32", ndim=2),
-        ) -> R.Tensor:
+        ):
             n, m = T.int64(), T.int64()
             cls = Module
             x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
             # this line cannot be folded because n is unknown
             lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), 
dtype="float32"))
-            # this line can be folded
-            lv1 = c1
-            # this line can be folded because all inputs are const
-            lv2 = c2
             # this line can not be folded because x's shape is unknown
             lv3 = relax.call_tir(cls.sub, (c2, x), R.Tensor((16, 16), 
dtype="float32"))
-            return lv3
+            return (lv0, lv3)
 
     c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
     c1_np = c0_np + 1
@@ -268,7 +260,6 @@ def test_int32_fold():
 
         @R.function
         def expected(c1: R.Tensor((16, 16), "int32")):
-            lv0 = c1
             return c1
 
     c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16)
diff --git a/tests/python/relax/test_tuning_api.py 
b/tests/python/relax/test_tuning_api.py
index 5c2f165dc3..082c9ce16a 100644
--- a/tests/python/relax/test_tuning_api.py
+++ b/tests/python/relax/test_tuning_api.py
@@ -64,7 +64,6 @@ class TestModule:
     # Expected IRModule after transformation.
     @R.function
     def expected(c1: R.Tensor((16, 16), "int32")):
-        lv0 = c1
         return c1
 
 

Reply via email to