This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e73d26cb96 [Unity][Transform] Simple Dead Code Elimination (#14262)
e73d26cb96 is described below

commit e73d26cb96ea0875913ba334192232bcfcf5990a
Author: Bohan Hou <[email protected]>
AuthorDate: Sun Mar 19 11:33:18 2023 -0400

    [Unity][Transform] Simple Dead Code Elimination (#14262)
    
    This PR adds a new pass DeadCodeElimination, which currently removes unused 
local vars in df blocks.
---
 include/tvm/relax/transform.h                      |  23 +++-
 python/tvm/relax/transform/transform.py            |  58 ++++++++
 ...ve_unused_funcs.cc => dead_code_elimination.cc} | 101 +++++++++++---
 src/relax/transform/merge_composite_functions.cc   |   2 +-
 src/relax/transform/run_codegen.cc                 |   2 +-
 src/relax/transform/utils.h                        |  13 +-
 tests/python/relax/test_transform_alter_op_impl.py |   2 +-
 ....py => test_transform_dead_code_elimination.py} | 147 +++++++++++++++++++--
 8 files changed, 307 insertions(+), 41 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index ead8b0c31e..369290e661 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -346,13 +346,6 @@ TVM_DLL Pass MergeCompositeFunctions();
  */
 TVM_DLL Pass FuseTIR();
 
-/*!
- * \brief Remove unused global relax functions in an IRModule.
- * \param entry_functions list of entry functions
- * \return The Pass.
- */
-TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
-
 /*!
  * \brief Run codegen.
  * \param target_options pairs of target name and compilation options
@@ -392,6 +385,22 @@ TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& 
op_impl_map,
  */
 TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);
 
+
+/*!
+ * \brief Dead code elimination.
+ * Currently it removes:
+ *   1. Unused local VarBindings in a DataflowBlock.
+ *      The used var set is set to empty at the beginning of each 
DataflowBlock.
+ *      We reverse scan the DataflowBlock, if a VarBinding
+ *        - bindings to a dataflowvar, or
+ *        - is used in the used var set
+ *      We keep it and add its var to the used var set. Otherwise, we remove 
it.
+ *   2. Unused Relax functions in the module.
+ *      We detect the call chain from the entry function, and remove all 
unused functions.
+ * \return The Pass.
+ */
+TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
+
 }  // namespace transform
 }  // namespace relax
 }  // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index c10d0130c1..72768bf676 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -657,6 +657,64 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]]) 
-> tvm.ir.transform.Pas
     return _ffi_api.ConvertLayout(desired_layouts)  # type: ignore
 
 
+def AlterOpImpl(
+    op_impl_map: Dict[str, PrimFunc],
+    op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
+):
+    """Replace all PrimFunc's which have matching 'operator_name' attribute, 
with replacement
+    PrimFunc that could possibly have different layouts on i/o buffers. The 
layout
+    transformations on i/o buffers is present in the op_buffer_transforms map. 
Inserts the layout
+    transformations in the call sites of PrimFuncs being replaced to transform 
i/o
+    tensors into expected layout by new PrimFunc.
+
+    Parameters
+    ----------
+    op_impl_map: Dict[str, PrimFunc]
+        op_kind to PrimFunc map
+    op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]
+        op_kind to layout transformation map for each of the buffers
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    for operator_name, transform_list in op_buffer_transforms.items():
+        l = []
+        for transform in transform_list:
+            if isinstance(transform, Callable):
+                transform = IndexMap.from_func(transform)
+            l.append(transform)
+        op_buffer_transforms[operator_name] = l
+
+    return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms)  # type: 
ignore
+
+
+def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> 
tvm.ir.transform.Pass:
+    """Remove dead code in the program.
+       Currently it removes:
+       1. Unused local VarBindings in a DataflowBlock.
+          The used var set is set to empty at the beginning of each 
DataflowBlock.
+          We reverse scan the DataflowBlock, if a VarBinding
+            - bindings to a dataflowvar, or
+            - is used in the used var set
+          We keep it and add its var to the used var set. Otherwise, we remove 
it.
+        2. Unused Relax functions in the module.
+          We detect the call chain from the entry function, and remove all 
unused functions.
+
+    Parameters
+    ----------
+    entry_functions: Optional[List[str]]
+        The set of entry functions to start from.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass.
+    """
+    if entry_functions is None:
+        entry_functions = ["main"]
+    return _ffi_api.DeadCodeElimination(entry_functions)  # type: ignore
+
+
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass."""
 
diff --git a/src/relax/transform/remove_unused_funcs.cc 
b/src/relax/transform/dead_code_elimination.cc
similarity index 50%
rename from src/relax/transform/remove_unused_funcs.cc
rename to src/relax/transform/dead_code_elimination.cc
index 5572da1338..3008db4b61 100644
--- a/src/relax/transform/remove_unused_funcs.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -19,16 +19,23 @@
  */
 
 /*!
- * \file tvm/relax/transform/remove_unused_funcs.cc
- * \brief Remove unused global relax functions in a IRModule.
+ * \file tvm/relax/transform/dead_code_elimination.cc
+ * \brief Dead code elimination pass.
+ * Currently it removes:
+ *   1. Unused local VarBindings in a DataflowBlock.
+ *      The used var set is set to empty at the beginning of each 
DataflowBlock.
+ *      We reverse scan the DataflowBlock, if a VarBinding
+ *        - bindings to a dataflowvar, or
+ *        - is used in the used var set
+ *      We keep it and add its var to the used var set. Otherwise, we remove 
it.
+ *   2. Unused Relax functions in the module.
+ *      We detect the call chain from the entry function, and remove all 
unused functions.
  */
 
+#include <tvm/relax/expr.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/transform.h>
 
-#include <unordered_set>
-#include <vector>
-
 #include "utils.h"
 
 namespace tvm {
@@ -81,14 +88,6 @@ class CallTracer : ExprVisitor {
   std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visiting_;
 };
 
-/*!
- * \brief Remove functions that are not used.
- *
- * \param mod_ IRModule.
- * \param entry_funcs The set of functions that can be entry function.
- *
- * \return The module with dead functions removed.
- */
 IRModule RemoveUnusedFunctions(IRModule mod_, Array<runtime::String> 
entry_funcs) {
   auto tracer = CallTracer(mod_);
   for (auto entry : entry_funcs) {
@@ -105,16 +104,82 @@ IRModule RemoveUnusedFunctions(IRModule mod_, 
Array<runtime::String> entry_funcs
   return mod_;
 }
 
-}  // namespace relax
+class DeadCodeEliminator : public ExprMutator {
+ private:
+  Expr VisitExpr_(const VarNode* op) final {
+    ICHECK(!used_vars_.empty());
+    used_vars_.back().insert(GetRef<Var>(op));
+    return GetRef<Expr>(op);
+  }
+
+  Expr VisitExpr_(const DataflowVarNode* op) final {
+    ICHECK(!used_vars_.empty());
+    used_vars_.back().insert(GetRef<Var>(op));
+    return GetRef<Expr>(op);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) { 
this->VisitExpr(binding->value); }
+
+  void VisitBinding_(const MatchCastNode* binding) {
+    this->VisitExpr(binding->value);
+    this->VisitAndCheckStructInfoFieldUnchanged(binding->struct_info);
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+    // reverse scan the data flow block to find the used vars
+    used_vars_.push_back({});
+
+    std::vector<Binding> new_bindings;
+    for (auto rit = block->bindings.rbegin(); rit != block->bindings.rend(); 
rit++) {
+      const Var& var = (*rit)->var;
+      // only keep the used bindings or non dataflow var bindings
+      if (used_vars_.back().count(var) || !var->IsInstance<DataflowVarNode>()) 
{
+        new_bindings.push_back(*rit);
+        // collect the used vars
+        this->VisitBinding((*rit));
+      }
+    }
+
+    used_vars_.pop_back();
+    // reverse the bindings
+    std::reverse(new_bindings.begin(), new_bindings.end());
+    if (new_bindings.size() == block->bindings.size()) {
+      return GetRef<BindingBlock>(block);
+    } else {
+      auto n = make_object<DataflowBlockNode>(*block);
+      n->bindings = std::move(new_bindings);
+      return BindingBlock(n);
+    }
+  }
+
+  BindingBlock VisitBindingBlock_(const BindingBlockNode* block) final {
+    return GetRef<BindingBlock>(block);
+  }
+
+  std::vector<std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>> 
used_vars_{{}};
+};
+
+IRModule DeadCodeElimination(const IRModule& mod, Array<runtime::String> 
entry_functions) {
+  DeadCodeEliminator eliminator;
+  for (const auto& gv : mod->GetGlobalVars()) {
+    auto func = mod->Lookup(gv);
+    if (func->IsInstance<FunctionNode>()) {
+      mod->Update(gv, Downcast<Function>(eliminator.VisitExpr(func)));
+    }
+  }
+  return RemoveUnusedFunctions(mod, entry_functions);
+}
 
 namespace transform {
-Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
+
+Pass DeadCodeElimination(Array<runtime::String> entry_functions) {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
-      [=](IRModule m, PassContext pc) { return relax::RemoveUnusedFunctions(m, 
entry_functions); };
-  return CreateModulePass(pass_func, 0, "RemoveUnusedFunctions", {});
+      [=](IRModule m, PassContext pc) { return relax::DeadCodeElimination(m, 
entry_functions); };
+  return CreateModulePass(pass_func, 1, "DeadCodeElimination", {});
 }
 
-TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions);
+TVM_REGISTER_GLOBAL("relax.transform.DeadCodeElimination").set_body_typed(DeadCodeElimination);
 
 }  // namespace transform
+}  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/transform/merge_composite_functions.cc 
b/src/relax/transform/merge_composite_functions.cc
index 609dd173f2..f444d5c4f6 100644
--- a/src/relax/transform/merge_composite_functions.cc
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -336,7 +336,7 @@ IRModule MergeCompositeFunctions(IRModule mod) {
     new_mod->Update(gvar, func);
   }
   // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better 
way to handle this.
-  return RemoveUnusedFunctions(new_mod, {"main"});
+  return DeadCodeElimination(new_mod, {"main"});
 }
 
 namespace transform {
diff --git a/src/relax/transform/run_codegen.cc 
b/src/relax/transform/run_codegen.cc
index b5a4d7536f..5c6b985202 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -65,7 +65,7 @@ class CodeGenRunner : ExprMutator {
     }
 
     // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a 
better way to handle this.
-    return RemoveUnusedFunctions(out_mod, entry_functions);
+    return DeadCodeElimination(out_mod, entry_functions);
   }
 
   using ExprMutator::VisitExpr_;
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 003519cffc..d51fe53101 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -100,12 +100,21 @@ class MemoizedExprTranslator : public 
::tvm::relax::ExprFunctor<OutputType(const
 };
 
 /*!
- * \brief Remove unused global relax functions in an IRModule.
+ * \brief Dead code elimination
+ * Currently it removes:
+ *   1. Unused local VarBindings in a DataflowBlock.
+ *      The used var set is set to empty at the beginning of each 
DataflowBlock.
+ *      We reverse scan the DataflowBlock, if a VarBinding
+ *        - bindings to a dataflowvar, or
+ *        - is used in the used var set
+ *      We keep it and add its var to the used var set. Otherwise, we remove 
it.
+ *   2. Unused Relax functions in the module.
+ *      We detect the call chain from the entry function, and remove all 
unused functions.
  * \param mod The target module
  * \param entry_functions list of entry functions
  * \return The updated module.
  */
-TVM_DLL IRModule RemoveUnusedFunctions(IRModule mod, Array<runtime::String> 
entry_funcs);
+TVM_DLL IRModule DeadCodeElimination(const IRModule& mod, 
Array<runtime::String> entry_funcs);
 
 /*!
  * \brief Get the external symbol of the Relax function name.
diff --git a/tests/python/relax/test_transform_alter_op_impl.py 
b/tests/python/relax/test_transform_alter_op_impl.py
index e8fa29a074..77e2d4e359 100644
--- a/tests/python/relax/test_transform_alter_op_impl.py
+++ b/tests/python/relax/test_transform_alter_op_impl.py
@@ -28,7 +28,7 @@ def _check(before, expected, operator_name, 
replacement_primfunc, layout_changes
     after = relax.transform.AlterOpImpl(
         {operator_name: replacement_primfunc}, {operator_name: layout_changes}
     )(before)
-    after = relax.transform.RemoveUnusedFunctions()(after)
+    after = relax.transform.DeadCodeElimination()(after)
     tvm.ir.assert_structural_equal(after, expected)
 
 
diff --git a/tests/python/relax/test_transform_remove_unused_funcs.py 
b/tests/python/relax/test_transform_dead_code_elimination.py
similarity index 56%
rename from tests/python/relax/test_transform_remove_unused_funcs.py
rename to tests/python/relax/test_transform_dead_code_elimination.py
index d44936c35b..51c05b5e7c 100644
--- a/tests/python/relax/test_transform_remove_unused_funcs.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -15,13 +15,138 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import pytest
 import tvm
-import tvm.script
 import tvm.testing
-from tvm import relax
-from tvm.script import relax as R
-from tvm.script import tir as T
+from tvm.relax.transform import DeadCodeElimination
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+
+def verify(input, expected):
+    tvm.ir.assert_structural_equal(DeadCodeElimination()(input), expected)
+
+
+def test_simple():
+    @tvm.script.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            bias: R.Tensor((26, 26), dtype="float32"),
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                gv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                gv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    gv,
+                    gv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv21: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    gv2, axes=[0, 3, 1, 2]
+                )
+                gv22: R.Tensor((2, 4, 26, 26), dtype="float32") = R.add(gv21, 
bias)
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            bias: R.Tensor((26, 26), dtype="float32"),
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                gv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    gv,
+                    gv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_2block():
+    @tvm.script.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            bias: R.Tensor((26, 26), dtype="float32"),
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+            # block 0
+            with R.dataflow():
+                gv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                gv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    gv,
+                    gv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv21: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    gv2, axes=[0, 3, 1, 2]
+                )
+                gv22: R.Tensor((2, 4, 26, 26), dtype="float32") = R.add(gv21, 
bias)
+                R.output(gv2)
+            gv3 = R.astype(gv2, dtype="float16")
+            return gv3
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            bias: R.Tensor((26, 26), dtype="float32"),
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+            with R.dataflow():
+                gv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                gv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    gv,
+                    gv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                R.output(gv2)
+            gv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.astype(gv2, 
dtype="float16")
+            return gv3
+
+    verify(Input, Expected)
 
 
 def check_if_func_exists(mod, func_name):
@@ -57,7 +182,7 @@ def test_unused_relax_func():
 
     mod = InputModule
     assert mod
-    new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+    new_mod = DeadCodeElimination()(mod)
     assert check_if_func_exists(new_mod, "main")
     assert check_if_func_exists(new_mod, "tir_add")
     assert not check_if_func_exists(new_mod, "unused_func")
@@ -93,7 +218,7 @@ def test_unused_relax_func_custom_entry_func():
     assert mod
 
     # Test entry function other than "main".
-    new_mod = 
relax.transform.RemoveUnusedFunctions(entry_functions=["foo"])(mod)
+    new_mod = DeadCodeElimination(entry_functions=["foo"])(mod)
     assert check_if_func_exists(new_mod, "foo")
     assert check_if_func_exists(new_mod, "tir_add")
     assert not check_if_func_exists(new_mod, "unused_func")
@@ -128,7 +253,7 @@ def test_unused_relax_func_symbolic_shape():
     mod = InputModule
     assert mod
 
-    new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+    new_mod = DeadCodeElimination()(mod)
     assert check_if_func_exists(new_mod, "main")
     assert check_if_func_exists(new_mod, "tir_add")
     assert not check_if_func_exists(new_mod, "unused_func")
@@ -163,7 +288,7 @@ def test_unused_prim_func():
 
     mod = InputModule
     assert mod
-    new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+    new_mod = DeadCodeElimination()(mod)
     assert check_if_func_exists(new_mod, "main")
     assert check_if_func_exists(new_mod, "relax_add")
     # RemoveUnusedFunction pass won't remove the function with global symbol 
for the external linkage.
@@ -200,7 +325,7 @@ def test_multiple_unused_funcs():
     mod = InputModule
     assert mod
 
-    new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+    new_mod = DeadCodeElimination()(mod)
     assert check_if_func_exists(new_mod, "main")
     # RemoveUnusedFunction pass won't remove the function with global symbol 
for the external linkage.
     assert check_if_func_exists(new_mod, "unused_func1")
@@ -208,4 +333,4 @@ def test_multiple_unused_funcs():
 
 
 if __name__ == "__main__":
-    pytest.main([__file__])
+    tvm.testing.main()

Reply via email to