This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e73d26cb96 [Unity][Transform] Simple Dead Code Elimination (#14262)
e73d26cb96 is described below
commit e73d26cb96ea0875913ba334192232bcfcf5990a
Author: Bohan Hou <[email protected]>
AuthorDate: Sun Mar 19 11:33:18 2023 -0400
[Unity][Transform] Simple Dead Code Elimination (#14262)
This PR adds a new pass DeadCodeElimination, which currently removes unused
local vars in df blocks.
---
include/tvm/relax/transform.h | 23 +++-
python/tvm/relax/transform/transform.py | 58 ++++++++
...ve_unused_funcs.cc => dead_code_elimination.cc} | 101 +++++++++++---
src/relax/transform/merge_composite_functions.cc | 2 +-
src/relax/transform/run_codegen.cc | 2 +-
src/relax/transform/utils.h | 13 +-
tests/python/relax/test_transform_alter_op_impl.py | 2 +-
....py => test_transform_dead_code_elimination.py} | 147 +++++++++++++++++++--
8 files changed, 307 insertions(+), 41 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index ead8b0c31e..369290e661 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -346,13 +346,6 @@ TVM_DLL Pass MergeCompositeFunctions();
*/
TVM_DLL Pass FuseTIR();
-/*!
- * \brief Remove unused global relax functions in an IRModule.
- * \param entry_functions list of entry functions
- * \return The Pass.
- */
-TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
-
/*!
* \brief Run codegen.
* \param target_options pairs of target name and compilation options
@@ -392,6 +385,22 @@ TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>&
op_impl_map,
*/
TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);
+
+/*!
+ * \brief Dead code elimination.
+ * Currently it removes:
+ * 1. Unused local VarBindings in a DataflowBlock.
+ * The used var set is set to empty at the beginning of each
DataflowBlock.
+ * We reverse scan the DataflowBlock, if a VarBinding
+ * - bindings to a dataflowvar, or
+ * - is used in the used var set
+ * We keep it and add its var to the used var set. Otherwise, we remove
it.
+ * 2. Unused Relax functions in the module.
+ * We detect the call chain from the entry function, and remove all
unused functions.
+ * \return The Pass.
+ */
+TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
+
} // namespace transform
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index c10d0130c1..72768bf676 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -657,6 +657,64 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]])
-> tvm.ir.transform.Pas
return _ffi_api.ConvertLayout(desired_layouts) # type: ignore
+def AlterOpImpl(
+ op_impl_map: Dict[str, PrimFunc],
+ op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
+):
+ """Replace all PrimFunc's which have matching 'operator_name' attribute,
with replacement
+ PrimFunc that could possibly have different layouts on i/o buffers. The
layout
+ transformations on i/o buffers is present in the op_buffer_transforms map.
Inserts the layout
+ transformations in the call sites of PrimFuncs being replaced to transform
i/o
+ tensors into expected layout by new PrimFunc.
+
+ Parameters
+ ----------
+ op_impl_map: Dict[str, PrimFunc]
+ op_kind to PrimFunc map
+ op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]
+ op_kind to layout transformation map for each of the buffers
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ for operator_name, transform_list in op_buffer_transforms.items():
+ l = []
+ for transform in transform_list:
+ if isinstance(transform, Callable):
+ transform = IndexMap.from_func(transform)
+ l.append(transform)
+ op_buffer_transforms[operator_name] = l
+
+ return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms) # type:
ignore
+
+
+def DeadCodeElimination(entry_functions: Optional[List[str]] = None) ->
tvm.ir.transform.Pass:
+ """Remove dead code in the program.
+ Currently it removes:
+ 1. Unused local VarBindings in a DataflowBlock.
+ The used var set is set to empty at the beginning of each
DataflowBlock.
+ We reverse scan the DataflowBlock, if a VarBinding
+ - bindings to a dataflowvar, or
+ - is used in the used var set
+ We keep it and add its var to the used var set. Otherwise, we remove
it.
+ 2. Unused Relax functions in the module.
+ We detect the call chain from the entry function, and remove all
unused functions.
+
+ Parameters
+ ----------
+ entry_functions: Optional[List[str]]
+ The set of entry functions to start from.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass.
+ """
+ if entry_functions is None:
+ entry_functions = ["main"]
+ return _ffi_api.DeadCodeElimination(entry_functions) # type: ignore
+
+
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""
diff --git a/src/relax/transform/remove_unused_funcs.cc
b/src/relax/transform/dead_code_elimination.cc
similarity index 50%
rename from src/relax/transform/remove_unused_funcs.cc
rename to src/relax/transform/dead_code_elimination.cc
index 5572da1338..3008db4b61 100644
--- a/src/relax/transform/remove_unused_funcs.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -19,16 +19,23 @@
*/
/*!
- * \file tvm/relax/transform/remove_unused_funcs.cc
- * \brief Remove unused global relax functions in a IRModule.
+ * \file tvm/relax/transform/dead_code_elimination.cc
+ * \brief Dead code elimination pass.
+ * Currently it removes:
+ * 1. Unused local VarBindings in a DataflowBlock.
+ * The used var set is set to empty at the beginning of each
DataflowBlock.
+ * We reverse scan the DataflowBlock, if a VarBinding
+ * - bindings to a dataflowvar, or
+ * - is used in the used var set
+ * We keep it and add its var to the used var set. Otherwise, we remove
it.
+ * 2. Unused Relax functions in the module.
+ * We detect the call chain from the entry function, and remove all
unused functions.
*/
+#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
-#include <unordered_set>
-#include <vector>
-
#include "utils.h"
namespace tvm {
@@ -81,14 +88,6 @@ class CallTracer : ExprVisitor {
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visiting_;
};
-/*!
- * \brief Remove functions that are not used.
- *
- * \param mod_ IRModule.
- * \param entry_funcs The set of functions that can be entry function.
- *
- * \return The module with dead functions removed.
- */
IRModule RemoveUnusedFunctions(IRModule mod_, Array<runtime::String>
entry_funcs) {
auto tracer = CallTracer(mod_);
for (auto entry : entry_funcs) {
@@ -105,16 +104,82 @@ IRModule RemoveUnusedFunctions(IRModule mod_,
Array<runtime::String> entry_funcs
return mod_;
}
-} // namespace relax
+class DeadCodeEliminator : public ExprMutator {
+ private:
+ Expr VisitExpr_(const VarNode* op) final {
+ ICHECK(!used_vars_.empty());
+ used_vars_.back().insert(GetRef<Var>(op));
+ return GetRef<Expr>(op);
+ }
+
+ Expr VisitExpr_(const DataflowVarNode* op) final {
+ ICHECK(!used_vars_.empty());
+ used_vars_.back().insert(GetRef<Var>(op));
+ return GetRef<Expr>(op);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) {
this->VisitExpr(binding->value); }
+
+ void VisitBinding_(const MatchCastNode* binding) {
+ this->VisitExpr(binding->value);
+ this->VisitAndCheckStructInfoFieldUnchanged(binding->struct_info);
+ }
+
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+ // reverse scan the data flow block to find the used vars
+ used_vars_.push_back({});
+
+ std::vector<Binding> new_bindings;
+ for (auto rit = block->bindings.rbegin(); rit != block->bindings.rend();
rit++) {
+ const Var& var = (*rit)->var;
+ // only keep the used bindings or non dataflow var bindings
+ if (used_vars_.back().count(var) || !var->IsInstance<DataflowVarNode>())
{
+ new_bindings.push_back(*rit);
+ // collect the used vars
+ this->VisitBinding((*rit));
+ }
+ }
+
+ used_vars_.pop_back();
+ // reverse the bindings
+ std::reverse(new_bindings.begin(), new_bindings.end());
+ if (new_bindings.size() == block->bindings.size()) {
+ return GetRef<BindingBlock>(block);
+ } else {
+ auto n = make_object<DataflowBlockNode>(*block);
+ n->bindings = std::move(new_bindings);
+ return BindingBlock(n);
+ }
+ }
+
+ BindingBlock VisitBindingBlock_(const BindingBlockNode* block) final {
+ return GetRef<BindingBlock>(block);
+ }
+
+ std::vector<std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>>
used_vars_{{}};
+};
+
+IRModule DeadCodeElimination(const IRModule& mod, Array<runtime::String>
entry_functions) {
+ DeadCodeEliminator eliminator;
+ for (const auto& gv : mod->GetGlobalVars()) {
+ auto func = mod->Lookup(gv);
+ if (func->IsInstance<FunctionNode>()) {
+ mod->Update(gv, Downcast<Function>(eliminator.VisitExpr(func)));
+ }
+ }
+ return RemoveUnusedFunctions(mod, entry_functions);
+}
namespace transform {
-Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
+
+Pass DeadCodeElimination(Array<runtime::String> entry_functions) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule m, PassContext pc) { return relax::RemoveUnusedFunctions(m,
entry_functions); };
- return CreateModulePass(pass_func, 0, "RemoveUnusedFunctions", {});
+ [=](IRModule m, PassContext pc) { return relax::DeadCodeElimination(m,
entry_functions); };
+ return CreateModulePass(pass_func, 1, "DeadCodeElimination", {});
}
-TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions);
+TVM_REGISTER_GLOBAL("relax.transform.DeadCodeElimination").set_body_typed(DeadCodeElimination);
} // namespace transform
+} // namespace relax
} // namespace tvm
diff --git a/src/relax/transform/merge_composite_functions.cc
b/src/relax/transform/merge_composite_functions.cc
index 609dd173f2..f444d5c4f6 100644
--- a/src/relax/transform/merge_composite_functions.cc
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -336,7 +336,7 @@ IRModule MergeCompositeFunctions(IRModule mod) {
new_mod->Update(gvar, func);
}
// TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better
way to handle this.
- return RemoveUnusedFunctions(new_mod, {"main"});
+ return DeadCodeElimination(new_mod, {"main"});
}
namespace transform {
diff --git a/src/relax/transform/run_codegen.cc
b/src/relax/transform/run_codegen.cc
index b5a4d7536f..5c6b985202 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -65,7 +65,7 @@ class CodeGenRunner : ExprMutator {
}
// TODO(@tvm-team): Implicit pass dependency. Revisit when we have a
better way to handle this.
- return RemoveUnusedFunctions(out_mod, entry_functions);
+ return DeadCodeElimination(out_mod, entry_functions);
}
using ExprMutator::VisitExpr_;
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 003519cffc..d51fe53101 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -100,12 +100,21 @@ class MemoizedExprTranslator : public
::tvm::relax::ExprFunctor<OutputType(const
};
/*!
- * \brief Remove unused global relax functions in an IRModule.
+ * \brief Dead code elimination
+ * Currently it removes:
+ * 1. Unused local VarBindings in a DataflowBlock.
+ * The used var set is set to empty at the beginning of each
DataflowBlock.
+ * We reverse scan the DataflowBlock, if a VarBinding
+ * - bindings to a dataflowvar, or
+ * - is used in the used var set
+ * We keep it and add its var to the used var set. Otherwise, we remove
it.
+ * 2. Unused Relax functions in the module.
+ * We detect the call chain from the entry function, and remove all
unused functions.
* \param mod The target module
* \param entry_functions list of entry functions
* \return The updated module.
*/
-TVM_DLL IRModule RemoveUnusedFunctions(IRModule mod, Array<runtime::String>
entry_funcs);
+TVM_DLL IRModule DeadCodeElimination(const IRModule& mod,
Array<runtime::String> entry_funcs);
/*!
* \brief Get the external symbol of the Relax function name.
diff --git a/tests/python/relax/test_transform_alter_op_impl.py
b/tests/python/relax/test_transform_alter_op_impl.py
index e8fa29a074..77e2d4e359 100644
--- a/tests/python/relax/test_transform_alter_op_impl.py
+++ b/tests/python/relax/test_transform_alter_op_impl.py
@@ -28,7 +28,7 @@ def _check(before, expected, operator_name,
replacement_primfunc, layout_changes
after = relax.transform.AlterOpImpl(
{operator_name: replacement_primfunc}, {operator_name: layout_changes}
)(before)
- after = relax.transform.RemoveUnusedFunctions()(after)
+ after = relax.transform.DeadCodeElimination()(after)
tvm.ir.assert_structural_equal(after, expected)
diff --git a/tests/python/relax/test_transform_remove_unused_funcs.py
b/tests/python/relax/test_transform_dead_code_elimination.py
similarity index 56%
rename from tests/python/relax/test_transform_remove_unused_funcs.py
rename to tests/python/relax/test_transform_dead_code_elimination.py
index d44936c35b..51c05b5e7c 100644
--- a/tests/python/relax/test_transform_remove_unused_funcs.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -15,13 +15,138 @@
# specific language governing permissions and limitations
# under the License.
-import pytest
import tvm
-import tvm.script
import tvm.testing
-from tvm import relax
-from tvm.script import relax as R
-from tvm.script import tir as T
+from tvm.relax.transform import DeadCodeElimination
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+
+def verify(input, expected):
+ tvm.ir.assert_structural_equal(DeadCodeElimination()(input), expected)
+
+
+def test_simple():
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ bias: R.Tensor((26, 26), dtype="float32"),
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ gv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ gv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ gv,
+ gv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv21: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ gv2, axes=[0, 3, 1, 2]
+ )
+ gv22: R.Tensor((2, 4, 26, 26), dtype="float32") = R.add(gv21,
bias)
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ bias: R.Tensor((26, 26), dtype="float32"),
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+ with R.dataflow():
+ gv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ gv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ gv,
+ gv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_2block():
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ bias: R.Tensor((26, 26), dtype="float32"),
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ # block 0
+ with R.dataflow():
+ gv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ gv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ gv,
+ gv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv21: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ gv2, axes=[0, 3, 1, 2]
+ )
+ gv22: R.Tensor((2, 4, 26, 26), dtype="float32") = R.add(gv21,
bias)
+ R.output(gv2)
+ gv3 = R.astype(gv2, dtype="float16")
+ return gv3
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ bias: R.Tensor((26, 26), dtype="float32"),
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ with R.dataflow():
+ gv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ gv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ gv,
+ gv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ R.output(gv2)
+ gv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.astype(gv2,
dtype="float16")
+ return gv3
+
+ verify(Input, Expected)
def check_if_func_exists(mod, func_name):
@@ -57,7 +182,7 @@ def test_unused_relax_func():
mod = InputModule
assert mod
- new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+ new_mod = DeadCodeElimination()(mod)
assert check_if_func_exists(new_mod, "main")
assert check_if_func_exists(new_mod, "tir_add")
assert not check_if_func_exists(new_mod, "unused_func")
@@ -93,7 +218,7 @@ def test_unused_relax_func_custom_entry_func():
assert mod
# Test entry function other than "main".
- new_mod =
relax.transform.RemoveUnusedFunctions(entry_functions=["foo"])(mod)
+ new_mod = DeadCodeElimination(entry_functions=["foo"])(mod)
assert check_if_func_exists(new_mod, "foo")
assert check_if_func_exists(new_mod, "tir_add")
assert not check_if_func_exists(new_mod, "unused_func")
@@ -128,7 +253,7 @@ def test_unused_relax_func_symbolic_shape():
mod = InputModule
assert mod
- new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+ new_mod = DeadCodeElimination()(mod)
assert check_if_func_exists(new_mod, "main")
assert check_if_func_exists(new_mod, "tir_add")
assert not check_if_func_exists(new_mod, "unused_func")
@@ -163,7 +288,7 @@ def test_unused_prim_func():
mod = InputModule
assert mod
- new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+ new_mod = DeadCodeElimination()(mod)
assert check_if_func_exists(new_mod, "main")
assert check_if_func_exists(new_mod, "relax_add")
# RemoveUnusedFunction pass won't remove the function with global symbol
for the external linkage.
@@ -200,7 +325,7 @@ def test_multiple_unused_funcs():
mod = InputModule
assert mod
- new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+ new_mod = DeadCodeElimination()(mod)
assert check_if_func_exists(new_mod, "main")
# RemoveUnusedFunction pass won't remove the function with global symbol
for the external linkage.
assert check_if_func_exists(new_mod, "unused_func1")
@@ -208,4 +333,4 @@ def test_multiple_unused_funcs():
if __name__ == "__main__":
- pytest.main([__file__])
+ tvm.testing.main()