This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 545e0977e3 [Relax] Allow DeadCodeElimination within
ApplyPassToFunction (#16801)
545e0977e3 is described below
commit 545e0977e327aad3f60a110961bdef733d08dbb1
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Apr 3 10:28:00 2024 -0500
[Relax] Allow DeadCodeElimination within ApplyPassToFunction (#16801)
The `tvm.ir.transform.ApplyPassToFunction` allows a transform to be
applied selectively to some portions of a `IRModule`, without applying
to the entire `IRModule`. For example, to apply an optimization
pass (e.g. `relax.transform.ExpandMatmulOfSum`) or an
interface-altering pass (e.g. `relax.transform.BundleModelParams`) to
specific functions. It does so by generating an intermediate
`IRModule` containing only the functions specified, applying the
transform to that intermediate, then merging the results.
When using `ApplyPassToFunction` to apply `DeadCodeElimination`, or a
pipeline containing `DeadCodeElimination`, this intermediate
`IRModule` may contain calls to `GlobalVar` instances that are not
within the intermediate `IRModule`. Prior to this commit, this
resulted in an error being thrown when collecting the call graph.
This commit updates `DeadCodeElimination` to instead handle incomplete
call-graph collection.
---
src/relax/transform/dead_code_elimination.cc | 37 ++++-
tests/python/relax/conftest.py | 22 ++-
.../relax/test_transform_dead_code_elimination.py | 155 +++++++++++++++++++++
3 files changed, 202 insertions(+), 12 deletions(-)
diff --git a/src/relax/transform/dead_code_elimination.cc
b/src/relax/transform/dead_code_elimination.cc
index 73f66d2ef3..28c7d74ef8 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -50,12 +50,22 @@ class CallTracer : public ExprVisitor {
explicit CallTracer(IRModule mod) : mod_{mod}, called_funcs_{}, visiting_{}
{}
void VisitExpr_(const GlobalVarNode* op) final {
- called_funcs_.insert(GetRef<GlobalVar>(op));
- auto func = mod_->Lookup(op->name_hint);
- if (const auto* function_node = func.as<FunctionNode>()) {
- VisitExpr(GetRef<Function>(function_node));
+ auto gvar = GetRef<GlobalVar>(op);
+ called_funcs_.insert(gvar);
+ if (auto func = mod_->functions.Get(gvar)) {
+ if (const auto* function_node = func.as<FunctionNode>()) {
+ VisitExpr(GetRef<Function>(function_node));
+ }
+ // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls
therein.
+ } else {
+ // The GlobalVar is not contained in the IRModule. While the
+ // input IRModule is ill-formed, this specific case is allowed
+ // for use with `relax.transform.ApplyPassToFunction`. If this
+ // occurs, DCE should not remove any internal functions from the
+ // IRModule, as their removal is only valid if we have a
+ // complete call graph.
+ all_callees_found_ = false;
}
- // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls
therein.
}
void VisitExpr_(const CallNode* call_node) final {
ExprVisitor::VisitExpr_(call_node); }
@@ -77,11 +87,24 @@ class CallTracer : public ExprVisitor {
VisitExpr(main_func);
}
- bool check_if_called(GlobalVar gv) { return called_funcs_.count(gv) > 0; }
+ /* \brief Check if a function is unreachable
+ *
+ * \param gvar The function to be checked
+ *
+ * \return True if the function can be proven to be unreachable,
+ * either directly or indirectly, from an external caller.
+ * Otherwise, false.
+ */
+ bool CheckIfProvablyUnreachable(const GlobalVar& gvar) const {
+ return all_callees_found_ && !called_funcs_.count(gvar);
+ }
private:
IRModule mod_;
+ /* \brief Whether all callees could be located within the IRModule */
+ bool all_callees_found_{true};
+
// Record the names of all encountered functions.
std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> called_funcs_;
@@ -101,7 +124,7 @@ IRModule RemoveUnusedFunctions(
// The tracer contains all user-provided entry functions, all
// externally-callable functions, and anything that is directly or
// indirectly accessible from an entry function.
- if (!tracer.check_if_called(kv.first)) {
+ if (tracer.CheckIfProvablyUnreachable(kv.first)) {
to_remove.push_back(kv.first);
}
}
diff --git a/tests/python/relax/conftest.py b/tests/python/relax/conftest.py
index 1e12a95e52..bb5a04ef76 100644
--- a/tests/python/relax/conftest.py
+++ b/tests/python/relax/conftest.py
@@ -37,7 +37,14 @@ def pytest_configure(config):
"markers",
(
"skip_well_formed_check_before_transform: "
- "Only check for well-formed IRModule after a transform"
+ "Suppress the default well-formed check before a IRModule
transform"
+ ),
+ )
+ config.addinivalue_line(
+ "markers",
+ (
+ "skip_well_formed_check_after_transform: "
+ "Suppress the default well-formed check after a IRModule transform"
),
)
@@ -54,15 +61,20 @@ def pytest_configure(config):
# `@pytest.mark.skip_well_formed_check_before_transform`
@pytest.fixture(autouse=True)
def apply_instrument_well_formed(unit_test_marks):
-
validate_before_transform = "skip_well_formed_check_before_transform" not
in unit_test_marks
+ validate_after_transform = "skip_well_formed_check_after_transform" not in
unit_test_marks
- instrument =
WellFormedInstrument(validate_before_transform=validate_before_transform)
current = tvm.transform.PassContext.current()
+ instruments = list(current.instruments)
+
+ if validate_before_transform or validate_after_transform:
+ instruments.append(
+
WellFormedInstrument(validate_before_transform=validate_before_transform)
+ )
override = tvm.transform.PassContext(
- # Append the new instrument
- instruments=[*current.instruments, instrument],
+ # With the new WellFormedInstrument appended
+ instruments=instruments,
# Forward all other parameters
opt_level=current.opt_level,
required_pass=current.required_pass,
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py
b/tests/python/relax/test_transform_dead_code_elimination.py
index c0a2d47b19..2dae252cad 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+import pytest
+
import tvm
import tvm.testing
from tvm.relax.transform import DeadCodeElimination
@@ -507,5 +509,158 @@ def test_extern_func():
verify(before, before)
[email protected]_well_formed_check_before_transform
[email protected]_well_formed_check_after_transform
+def test_compatibility_with_apply_pass_to_function():
+ """DeadCodeElimination can be used with ApplyPassToFunction
+
+ The `ApplyPassToFunction` utility calls another transform, where
+ only the specified functions are exposed to the internal
+ transform. This intermediate does not contain `cls.subroutine`,
+ and so the intermediate is ill-formed.
+
+ In general, IRModule transformations may assume that their inputs
+ are well-formed. In specific cases, IRModule transformations may
+ accept IRModules that are ill-formed. The `DeadCodeElimination`
+ transform allows IRModule arguments that are ill-formed due to
+ a dangling GlobalVar.
+
+ After `DeadCodeElimination` completes, the resulting function is
+ inserted in the original IRModule, providing a well-formed output
+ from `ApplyPassToFunction`.
+
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def to_be_transformed(A: R.Tensor):
+ cls = Before
+
+ B = R.add(A, A)
+ C = cls.subroutine(B)
+ D = R.multiply(C, C)
+ return C
+
+ @R.function
+ def to_be_ignored(A: R.Tensor):
+ cls = Before
+
+ B = R.add(A, A)
+ C = cls.subroutine(B)
+ D = R.multiply(C, C)
+ return C
+
+ @R.function(private=True)
+ def subroutine(arg: R.Tensor) -> R.Tensor:
+ return R.add(arg, arg)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def to_be_transformed(A: R.Tensor):
+ cls = Expected
+
+ B = R.add(A, A)
+ C = cls.subroutine(B)
+ return C
+
+ @R.function
+ def to_be_ignored(A: R.Tensor):
+ cls = Expected
+
+ B = R.add(A, A)
+ C = cls.subroutine(B)
+ D = R.multiply(C, C)
+ return C
+
+ @R.function(private=True)
+ def subroutine(arg: R.Tensor) -> R.Tensor:
+ return R.add(arg, arg)
+
+ # The well-formed check in conftest.py must be disabled, to avoid
+ # triggering on the ill-formed intermediate, so this unit test
+ # checks it explicitly.
+ assert tvm.relax.analysis.well_formed(Before)
+ After = tvm.ir.transform.ApplyPassToFunction(
+ tvm.relax.transform.DeadCodeElimination(),
+ "to_be_transformed",
+ )(Before)
+ assert tvm.relax.analysis.well_formed(After)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
[email protected]_well_formed_check_before_transform
[email protected]_well_formed_check_after_transform
+def test_well_formed_output_with_restricted_scope():
+ """DeadCodeElimination can be used with ApplyPassToFunction
+
+ If the call graph cannot be completely traced, private functions
+ should not be removed.
+
+ See `test_compatibility_with_apply_pass_to_function` for full
+ description of `DeadCodeElimination` and `ApplyPassToFunction`.
+
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(A: R.Tensor):
+ cls = Before
+
+ B = R.add(A, A)
+ C = cls.subroutine(B)
+ D = R.multiply(C, C)
+ return C
+
+ @R.function(private=True)
+ def subroutine(A: R.Tensor) -> R.Tensor:
+ cls = Before
+
+ B = R.add(A, A)
+ C = cls.subsubroutine(B)
+ D = R.multiply(C, C)
+ return C
+
+ @R.function(private=True)
+ def subsubroutine(A: R.Tensor) -> R.Tensor:
+ B = R.add(A, A)
+ C = R.multiply(B, B)
+ return B
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(A: R.Tensor):
+ cls = Expected
+
+ B = R.add(A, A)
+ C = cls.subroutine(B)
+ return C
+
+ @R.function(private=True)
+ def subroutine(A: R.Tensor) -> R.Tensor:
+ cls = Expected
+
+ B = R.add(A, A)
+ C = cls.subsubroutine(B)
+ D = R.multiply(C, C)
+ return C
+
+ @R.function(private=True)
+ def subsubroutine(A: R.Tensor) -> R.Tensor:
+ B = R.add(A, A)
+ return B
+
+ assert tvm.relax.analysis.well_formed(Before)
+ After = tvm.ir.transform.ApplyPassToFunction(
+ tvm.relax.transform.DeadCodeElimination(),
+ "main|subsubroutine",
+ )(Before)
+ assert tvm.relax.analysis.well_formed(After)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
if __name__ == "__main__":
tvm.testing.main()