This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 873ab3d7a9 [Unity][Pass] Reuse prior infra to implement more complete
DCE (#14334)
873ab3d7a9 is described below
commit 873ab3d7a9cffce7046046be67cdb2c4ae0b6aba
Author: Jiawei Liu <[email protected]>
AuthorDate: Sun Mar 19 21:28:54 2023 -0500
[Unity][Pass] Reuse prior infra to implement more complete DCE (#14334)
As a follow-up to https://github.com/apache/tvm/pull/14262, I just noticed
that I previously implemented a function called `RemoveAllUnused`
(https://github.com/apache/tvm/pull/14043) that can do function-wise DCE and
should be more complete than https://github.com/apache/tvm/pull/14262 as
`RemoveAllUnused` can also remove dead dataflow blocks (added two test-cases to
show that).
For now:
- `tvm.relax.transform.DeadCodeElimination` is the pass for running DCE
over an IRModule.
- `tvm.relax.analysis.remove_all_unused` is a function for running DCE over
a function.
`tvm.relax.transform.DeadCodeElimination` is implemented based on
`tvm.relax.analysis.remove_all_unused`. I did not sync the function name for
the two, as I saw there are other uses for `RemoveAllUnused` so just want to be
conservative first.
---
include/tvm/relax/transform.h | 9 +-
python/tvm/relax/analysis/analysis.py | 8 +-
python/tvm/relax/transform/transform.py | 14 +--
src/relax/transform/dead_code_elimination.cc | 79 +++-----------
.../relax/test_transform_dead_code_elimination.py | 116 +++++++++++++++++++++
5 files changed, 146 insertions(+), 80 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index eba7de1b0c..5a21f76b0b 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -387,14 +387,11 @@ TVM_DLL Pass ConvertLayout(Map<String, Array<String>>
desired_layouts);
/*!
* \brief Dead code elimination.
+ * \sa RemoveAllUnused
* Currently it removes:
* 1. Unused local VarBindings in a DataflowBlock.
- * The used var set is set to empty at the beginning of each
DataflowBlock.
- * We reverse scan the DataflowBlock, if a VarBinding
- * - bindings to a dataflowvar, or
- * - is used in the used var set
- * We keep it and add its var to the used var set. Otherwise, we remove
it.
- * 2. Unused Relax functions in the module.
+ * 2. Unused DataflowBlocks in a function.
+ * 3. Unused Relax functions in the module.
* We detect the call chain from the entry function, and remove all
unused functions.
* \return The Pass.
*/
diff --git a/python/tvm/relax/analysis/analysis.py
b/python/tvm/relax/analysis/analysis.py
index ae64c08eb1..2a2c3d87b8 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -316,13 +316,19 @@ def name_to_binding(func: Function) -> Dict[str,
List[Binding]]:
def remove_all_unused(func: Function) -> Function:
- """Remove all unused variables from the function.
+ """It removes:
+ 1. Unused local VarBindings in a DataflowBlock.
+ 2. Unused DataflowBlocks in a function.
Parameters
----------
func : Function
The input function to be analyzed.
+ Notes
+ -----
+ For IRModule-wise DCE, use
py:func:`tvm.relax.transform.DeadCodeElimination`.
+
Returns
-------
Function
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index ebfd7a6765..b710196347 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -658,15 +658,11 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]])
-> tvm.ir.transform.Pas
def DeadCodeElimination(entry_functions: Optional[List[str]] = None) ->
tvm.ir.transform.Pass:
- """Remove dead code in the program.
+ """Remove dead code in the IRModule.
Currently it removes:
1. Unused local VarBindings in a DataflowBlock.
- The used var set is set to empty at the beginning of each
DataflowBlock.
- We reverse scan the DataflowBlock, if a VarBinding
- - bindings to a dataflowvar, or
- - is used in the used var set
- We keep it and add its var to the used var set. Otherwise, we remove
it.
- 2. Unused Relax functions in the module.
+ 2. Unused DataflowBlocks in a function.
+ 3. Unused Relax functions in the module.
We detect the call chain from the entry function, and remove all
unused functions.
Parameters
@@ -674,6 +670,10 @@ def DeadCodeElimination(entry_functions:
Optional[List[str]] = None) -> tvm.ir.t
entry_functions: Optional[List[str]]
The set of entry functions to start from.
+ Notes
+ -----
+ For function-wise DCE, use py:func:`tvm.relax.analysis.remove_all_unused`.
+
Returns
-------
ret : tvm.transform.Pass
diff --git a/src/relax/transform/dead_code_elimination.cc
b/src/relax/transform/dead_code_elimination.cc
index 3008db4b61..fe36eb28ef 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -21,17 +21,16 @@
/*!
* \file tvm/relax/transform/dead_code_elimination.cc
* \brief Dead code elimination pass.
+ * \sa tvm/relax/ir/binding_rewrite.cc
+ *
* Currently it removes:
* 1. Unused local VarBindings in a DataflowBlock.
- * The used var set is set to empty at the beginning of each
DataflowBlock.
- * We reverse scan the DataflowBlock, if a VarBinding
- * - bindings to a dataflowvar, or
- * - is used in the used var set
- * We keep it and add its var to the used var set. Otherwise, we remove
it.
- * 2. Unused Relax functions in the module.
+ * 2. Unused DataflowBlocks in a function.
+ * 3. Unused Relax functions in the module.
* We detect the call chain from the entry function, and remove all
unused functions.
*/
+#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
@@ -104,70 +103,18 @@ IRModule RemoveUnusedFunctions(IRModule mod_,
Array<runtime::String> entry_funcs
return mod_;
}
-class DeadCodeEliminator : public ExprMutator {
- private:
- Expr VisitExpr_(const VarNode* op) final {
- ICHECK(!used_vars_.empty());
- used_vars_.back().insert(GetRef<Var>(op));
- return GetRef<Expr>(op);
- }
-
- Expr VisitExpr_(const DataflowVarNode* op) final {
- ICHECK(!used_vars_.empty());
- used_vars_.back().insert(GetRef<Var>(op));
- return GetRef<Expr>(op);
- }
-
- void VisitBinding_(const VarBindingNode* binding) {
this->VisitExpr(binding->value); }
-
- void VisitBinding_(const MatchCastNode* binding) {
- this->VisitExpr(binding->value);
- this->VisitAndCheckStructInfoFieldUnchanged(binding->struct_info);
- }
-
- BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
- // reverse scan the data flow block to find the used vars
- used_vars_.push_back({});
-
- std::vector<Binding> new_bindings;
- for (auto rit = block->bindings.rbegin(); rit != block->bindings.rend();
rit++) {
- const Var& var = (*rit)->var;
- // only keep the used bindings or non dataflow var bindings
- if (used_vars_.back().count(var) || !var->IsInstance<DataflowVarNode>())
{
- new_bindings.push_back(*rit);
- // collect the used vars
- this->VisitBinding((*rit));
- }
- }
-
- used_vars_.pop_back();
- // reverse the bindings
- std::reverse(new_bindings.begin(), new_bindings.end());
- if (new_bindings.size() == block->bindings.size()) {
- return GetRef<BindingBlock>(block);
- } else {
- auto n = make_object<DataflowBlockNode>(*block);
- n->bindings = std::move(new_bindings);
- return BindingBlock(n);
- }
- }
-
- BindingBlock VisitBindingBlock_(const BindingBlockNode* block) final {
- return GetRef<BindingBlock>(block);
- }
-
- std::vector<std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>>
used_vars_{{}};
-};
-
IRModule DeadCodeElimination(const IRModule& mod, Array<runtime::String>
entry_functions) {
- DeadCodeEliminator eliminator;
- for (const auto& gv : mod->GetGlobalVars()) {
- auto func = mod->Lookup(gv);
+ // S1: remove unused functions to reduce the number of functions to be
analyzed.
+ IRModule tmp_mod = RemoveUnusedFunctions(mod, entry_functions);
+ // S2: remove unused variables in each function.
+ for (const auto& gv : tmp_mod->GetGlobalVars()) {
+ auto func = tmp_mod->Lookup(gv);
if (func->IsInstance<FunctionNode>()) {
- mod->Update(gv, Downcast<Function>(eliminator.VisitExpr(func)));
+ tmp_mod->Update(gv, RemoveAllUnused(Downcast<Function>(func)));
}
}
- return RemoveUnusedFunctions(mod, entry_functions);
+ // S3: remove unused functions again as some callers may be removed in S2.
+ return RemoveUnusedFunctions(tmp_mod, entry_functions);
}
namespace transform {
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py
b/tests/python/relax/test_transform_dead_code_elimination.py
index 51c05b5e7c..9c6e0e0567 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -332,5 +332,121 @@ def test_multiple_unused_funcs():
assert not check_if_func_exists(new_mod, "unused_func2")
+def test_unused_dfb():
+ # test if an unused dataflow block can be removed.
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ # block 0
+ with R.dataflow():
+ lv0: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(
+ x, axes=[0, 2, 3, 1]
+ )
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv0, lv1, data_layout="NHWC", kernel_layout="OHWI",
out_layout="NHWC"
+ )
+ lv3: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(lv2)
+ gv3 = R.astype(lv2, dtype="float16")
+ # dead block
+ with R.dataflow():
+ lv4: R.Tensor((2, 4, 26, 26), dtype="float16") =
R.permute_dims(
+ gv3, axes=[0, 3, 1, 2]
+ )
+ R.output(lv4)
+ return gv3
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ # block 0
+ with R.dataflow():
+ lv0: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(
+ x, axes=[0, 2, 3, 1]
+ )
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv0, lv1, data_layout="NHWC", kernel_layout="OHWI",
out_layout="NHWC"
+ )
+ R.output(lv2)
+ gv3 = R.astype(lv2, dtype="float16")
+ return gv3
+
+ verify(Input, Expected)
+
+
+def test_unused_dfb2():
+ # test if an unused dataflow block can be removed.
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ # dead block
+ with R.dataflow():
+ lv0: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(
+ x, axes=[0, 2, 3, 1]
+ )
+ R.output(lv0)
+
+ gv_x = R.astype(x, dtype="float16")
+ gv_w = R.astype(x, dtype="float16")
+
+ with R.dataflow():
+ lv1: R.Tensor((2, 28, 28, 3), dtype="float16") =
R.permute_dims(
+ gv_x, axes=[0, 2, 3, 1]
+ )
+ lv2: R.Tensor((4, 3, 3, 3), dtype="float16") = R.permute_dims(
+ gv_w, axes=[0, 2, 3, 1]
+ )
+ lv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.nn.conv2d(
+ lv1, lv2, data_layout="NHWC", kernel_layout="OHWI",
out_layout="NHWC"
+ )
+ # dead instruction -> usee lv1 also dead.
+ lv4: R.Tensor((2, 3, 28, 28), dtype="float32") =
R.permute_dims(
+ lv0, axes=[0, 3, 1, 2]
+ )
+ R.output(lv3)
+ return lv3
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ gv_x = R.astype(x, dtype="float16")
+ gv_w = R.astype(x, dtype="float16")
+
+ with R.dataflow():
+ lv1: R.Tensor((2, 28, 28, 3), dtype="float16") =
R.permute_dims(
+ gv_x, axes=[0, 2, 3, 1]
+ )
+ lv2: R.Tensor((4, 3, 3, 3), dtype="float16") = R.permute_dims(
+ gv_w, axes=[0, 2, 3, 1]
+ )
+ lv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.nn.conv2d(
+ lv1, lv2, data_layout="NHWC", kernel_layout="OHWI",
out_layout="NHWC"
+ )
+ R.output(lv3)
+ return lv3
+
+ verify(Input, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()