This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 873ab3d7a9 [Unity][Pass] Reuse prior infra to implement more complete 
DCE (#14334)
873ab3d7a9 is described below

commit 873ab3d7a9cffce7046046be67cdb2c4ae0b6aba
Author: Jiawei Liu <[email protected]>
AuthorDate: Sun Mar 19 21:28:54 2023 -0500

    [Unity][Pass] Reuse prior infra to implement more complete DCE (#14334)
    
    As a follow-up to https://github.com/apache/tvm/pull/14262, I just noticed 
that I previously implemented a function called `RemoveAllUnused` 
(https://github.com/apache/tvm/pull/14043) that can do function-wise DCE and 
should be more complete than https://github.com/apache/tvm/pull/14262 as 
`RemoveAllUnused` can also remove dead dataflow blocks (added two test-cases to 
show that).
    
    For now:
    - `tvm.relax.transform.DeadCodeElimination` is the pass for running DCE 
over an IRModule.
    - `tvm.relax.analysis.remove_all_unused` is a function for running DCE over 
a function.
    
    `tvm.relax.transform.DeadCodeElimination` is implemented based on 
`tvm.relax.analysis.remove_all_unused`. I did not sync the function name for 
the two, as I saw there are other uses for `RemoveAllUnused` so just want to be 
conservative first.
---
 include/tvm/relax/transform.h                      |   9 +-
 python/tvm/relax/analysis/analysis.py              |   8 +-
 python/tvm/relax/transform/transform.py            |  14 +--
 src/relax/transform/dead_code_elimination.cc       |  79 +++-----------
 .../relax/test_transform_dead_code_elimination.py  | 116 +++++++++++++++++++++
 5 files changed, 146 insertions(+), 80 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index eba7de1b0c..5a21f76b0b 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -387,14 +387,11 @@ TVM_DLL Pass ConvertLayout(Map<String, Array<String>> 
desired_layouts);
 
 /*!
  * \brief Dead code elimination.
+ * \sa RemoveAllUnused
  * Currently it removes:
  *   1. Unused local VarBindings in a DataflowBlock.
- *      The used var set is set to empty at the beginning of each 
DataflowBlock.
- *      We reverse scan the DataflowBlock, if a VarBinding
- *        - bindings to a dataflowvar, or
- *        - is used in the used var set
- *      We keep it and add its var to the used var set. Otherwise, we remove 
it.
- *   2. Unused Relax functions in the module.
+ *   2. Unused DataflowBlocks in a function.
+ *   3. Unused Relax functions in the module.
  *      We detect the call chain from the entry function, and remove all 
unused functions.
  * \return The Pass.
  */
diff --git a/python/tvm/relax/analysis/analysis.py 
b/python/tvm/relax/analysis/analysis.py
index ae64c08eb1..2a2c3d87b8 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -316,13 +316,19 @@ def name_to_binding(func: Function) -> Dict[str, 
List[Binding]]:
 
 
 def remove_all_unused(func: Function) -> Function:
-    """Remove all unused variables from the function.
+    """It removes:
+    1. Unused local VarBindings in a DataflowBlock.
+    2. Unused DataflowBlocks in a function.
 
     Parameters
     ----------
     func : Function
         The input function to be analyzed.
 
+    Notes
+    -----
+    For IRModule-wise DCE, use 
py:func:`tvm.relax.transform.DeadCodeElimination`.
+
     Returns
     -------
     Function
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index ebfd7a6765..b710196347 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -658,15 +658,11 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]]) 
-> tvm.ir.transform.Pas
 
 
 def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> 
tvm.ir.transform.Pass:
-    """Remove dead code in the program.
+    """Remove dead code in the IRModule.
        Currently it removes:
        1. Unused local VarBindings in a DataflowBlock.
-          The used var set is set to empty at the beginning of each 
DataflowBlock.
-          We reverse scan the DataflowBlock, if a VarBinding
-            - bindings to a dataflowvar, or
-            - is used in the used var set
-          We keep it and add its var to the used var set. Otherwise, we remove 
it.
-        2. Unused Relax functions in the module.
+       2. Unused DataflowBlocks in a function.
+       3. Unused Relax functions in the module.
           We detect the call chain from the entry function, and remove all 
unused functions.
 
     Parameters
@@ -674,6 +670,10 @@ def DeadCodeElimination(entry_functions: 
Optional[List[str]] = None) -> tvm.ir.t
     entry_functions: Optional[List[str]]
         The set of entry functions to start from.
 
+    Notes
+    -----
+    For function-wise DCE, use py:func:`tvm.relax.analysis.remove_all_unused`.
+
     Returns
     -------
     ret : tvm.transform.Pass
diff --git a/src/relax/transform/dead_code_elimination.cc 
b/src/relax/transform/dead_code_elimination.cc
index 3008db4b61..fe36eb28ef 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -21,17 +21,16 @@
 /*!
  * \file tvm/relax/transform/dead_code_elimination.cc
  * \brief Dead code elimination pass.
+ * \sa tvm/relax/ir/binding_rewrite.cc
+ *
  * Currently it removes:
  *   1. Unused local VarBindings in a DataflowBlock.
- *      The used var set is set to empty at the beginning of each 
DataflowBlock.
- *      We reverse scan the DataflowBlock, if a VarBinding
- *        - bindings to a dataflowvar, or
- *        - is used in the used var set
- *      We keep it and add its var to the used var set. Otherwise, we remove 
it.
- *   2. Unused Relax functions in the module.
+ *   2. Unused DataflowBlocks in a function.
+ *   3. Unused Relax functions in the module.
  *      We detect the call chain from the entry function, and remove all 
unused functions.
  */
 
+#include <tvm/relax/analysis.h>
 #include <tvm/relax/expr.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/transform.h>
@@ -104,70 +103,18 @@ IRModule RemoveUnusedFunctions(IRModule mod_, 
Array<runtime::String> entry_funcs
   return mod_;
 }
 
-class DeadCodeEliminator : public ExprMutator {
- private:
-  Expr VisitExpr_(const VarNode* op) final {
-    ICHECK(!used_vars_.empty());
-    used_vars_.back().insert(GetRef<Var>(op));
-    return GetRef<Expr>(op);
-  }
-
-  Expr VisitExpr_(const DataflowVarNode* op) final {
-    ICHECK(!used_vars_.empty());
-    used_vars_.back().insert(GetRef<Var>(op));
-    return GetRef<Expr>(op);
-  }
-
-  void VisitBinding_(const VarBindingNode* binding) { 
this->VisitExpr(binding->value); }
-
-  void VisitBinding_(const MatchCastNode* binding) {
-    this->VisitExpr(binding->value);
-    this->VisitAndCheckStructInfoFieldUnchanged(binding->struct_info);
-  }
-
-  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
-    // reverse scan the data flow block to find the used vars
-    used_vars_.push_back({});
-
-    std::vector<Binding> new_bindings;
-    for (auto rit = block->bindings.rbegin(); rit != block->bindings.rend(); 
rit++) {
-      const Var& var = (*rit)->var;
-      // only keep the used bindings or non dataflow var bindings
-      if (used_vars_.back().count(var) || !var->IsInstance<DataflowVarNode>()) 
{
-        new_bindings.push_back(*rit);
-        // collect the used vars
-        this->VisitBinding((*rit));
-      }
-    }
-
-    used_vars_.pop_back();
-    // reverse the bindings
-    std::reverse(new_bindings.begin(), new_bindings.end());
-    if (new_bindings.size() == block->bindings.size()) {
-      return GetRef<BindingBlock>(block);
-    } else {
-      auto n = make_object<DataflowBlockNode>(*block);
-      n->bindings = std::move(new_bindings);
-      return BindingBlock(n);
-    }
-  }
-
-  BindingBlock VisitBindingBlock_(const BindingBlockNode* block) final {
-    return GetRef<BindingBlock>(block);
-  }
-
-  std::vector<std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>> 
used_vars_{{}};
-};
-
 IRModule DeadCodeElimination(const IRModule& mod, Array<runtime::String> 
entry_functions) {
-  DeadCodeEliminator eliminator;
-  for (const auto& gv : mod->GetGlobalVars()) {
-    auto func = mod->Lookup(gv);
+  // S1: remove unused functions to reduce the number of functions to be 
analyzed.
+  IRModule tmp_mod = RemoveUnusedFunctions(mod, entry_functions);
+  // S2: remove unused variables in each function.
+  for (const auto& gv : tmp_mod->GetGlobalVars()) {
+    auto func = tmp_mod->Lookup(gv);
     if (func->IsInstance<FunctionNode>()) {
-      mod->Update(gv, Downcast<Function>(eliminator.VisitExpr(func)));
+      tmp_mod->Update(gv, RemoveAllUnused(Downcast<Function>(func)));
     }
   }
-  return RemoveUnusedFunctions(mod, entry_functions);
+  // S3: remove unused functions again as some callers may be removed in S2.
+  return RemoveUnusedFunctions(tmp_mod, entry_functions);
 }
 
 namespace transform {
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py 
b/tests/python/relax/test_transform_dead_code_elimination.py
index 51c05b5e7c..9c6e0e0567 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -332,5 +332,121 @@ def test_multiple_unused_funcs():
     assert not check_if_func_exists(new_mod, "unused_func2")
 
 
+def test_unused_dfb():
+    # test if an unused dataflow block can be removed.
+    @tvm.script.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+            # block 0
+            with R.dataflow():
+                lv0: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(
+                    x, axes=[0, 2, 3, 1]
+                )
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv0, lv1, data_layout="NHWC", kernel_layout="OHWI", 
out_layout="NHWC"
+                )
+                lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(lv2)
+            gv3 = R.astype(lv2, dtype="float16")
+            # dead block
+            with R.dataflow():
+                lv4: R.Tensor((2, 4, 26, 26), dtype="float16") = 
R.permute_dims(
+                    gv3, axes=[0, 3, 1, 2]
+                )
+                R.output(lv4)
+            return gv3
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+            # block 0
+            with R.dataflow():
+                lv0: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(
+                    x, axes=[0, 2, 3, 1]
+                )
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv0, lv1, data_layout="NHWC", kernel_layout="OHWI", 
out_layout="NHWC"
+                )
+                R.output(lv2)
+            gv3 = R.astype(lv2, dtype="float16")
+            return gv3
+
+    verify(Input, Expected)
+
+
+def test_unused_dfb2():
+    # test if an unused dataflow block can be removed.
+    @tvm.script.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+            # dead block
+            with R.dataflow():
+                lv0: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(
+                    x, axes=[0, 2, 3, 1]
+                )
+                R.output(lv0)
+
+            gv_x = R.astype(x, dtype="float16")
+            gv_w = R.astype(x, dtype="float16")
+
+            with R.dataflow():
+                lv1: R.Tensor((2, 28, 28, 3), dtype="float16") = 
R.permute_dims(
+                    gv_x, axes=[0, 2, 3, 1]
+                )
+                lv2: R.Tensor((4, 3, 3, 3), dtype="float16") = R.permute_dims(
+                    gv_w, axes=[0, 2, 3, 1]
+                )
+                lv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.nn.conv2d(
+                    lv1, lv2, data_layout="NHWC", kernel_layout="OHWI", 
out_layout="NHWC"
+                )
+                # dead instruction -> usee lv1 also dead.
+                lv4: R.Tensor((2, 3, 28, 28), dtype="float32") = 
R.permute_dims(
+                    lv0, axes=[0, 3, 1, 2]
+                )
+                R.output(lv3)
+            return lv3
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+            gv_x = R.astype(x, dtype="float16")
+            gv_w = R.astype(x, dtype="float16")
+
+            with R.dataflow():
+                lv1: R.Tensor((2, 28, 28, 3), dtype="float16") = 
R.permute_dims(
+                    gv_x, axes=[0, 2, 3, 1]
+                )
+                lv2: R.Tensor((4, 3, 3, 3), dtype="float16") = R.permute_dims(
+                    gv_w, axes=[0, 2, 3, 1]
+                )
+                lv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.nn.conv2d(
+                    lv1, lv2, data_layout="NHWC", kernel_layout="OHWI", 
out_layout="NHWC"
+                )
+                R.output(lv3)
+            return lv3
+
+    verify(Input, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to