This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 545e0977e3 [Relax] Allow DeadCodeElimination within 
ApplyPassToFunction (#16801)
545e0977e3 is described below

commit 545e0977e327aad3f60a110961bdef733d08dbb1
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Apr 3 10:28:00 2024 -0500

    [Relax] Allow DeadCodeElimination within ApplyPassToFunction (#16801)
    
    The `tvm.ir.transform.ApplyPassToFunction` allows a transform to be
    applied selectively to some portions of a `IRModule`, without applying
    to the entire `IRModule`.  For example, to apply an optimization
    pass (e.g. `relax.transform.ExpandMatmulOfSum`) or an
    interface-altering pass (e.g. `relax.transform.BundleModelParams`) to
    specific functions.  It does so by generating an intermediate
    `IRModule` containing only the functions specified, applying the
    transform to that intermediate, then merging the results.
    
    When using `ApplyPassToFunction` to apply `DeadCodeElimination`, or a
    pipeline containing `DeadCodeElimination`, this intermediate
    `IRModule` may contain calls to `GlobalVar` instances that are not
    within the intermediate `IRModule`.  Prior to this commit, this
    resulted in an error being thrown when collecting the call graph.
    This commit updates `DeadCodeElimination` to instead handle incomplete
    call-graph collection.
---
 src/relax/transform/dead_code_elimination.cc       |  37 ++++-
 tests/python/relax/conftest.py                     |  22 ++-
 .../relax/test_transform_dead_code_elimination.py  | 155 +++++++++++++++++++++
 3 files changed, 202 insertions(+), 12 deletions(-)

diff --git a/src/relax/transform/dead_code_elimination.cc 
b/src/relax/transform/dead_code_elimination.cc
index 73f66d2ef3..28c7d74ef8 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -50,12 +50,22 @@ class CallTracer : public ExprVisitor {
   explicit CallTracer(IRModule mod) : mod_{mod}, called_funcs_{}, visiting_{} 
{}
 
   void VisitExpr_(const GlobalVarNode* op) final {
-    called_funcs_.insert(GetRef<GlobalVar>(op));
-    auto func = mod_->Lookup(op->name_hint);
-    if (const auto* function_node = func.as<FunctionNode>()) {
-      VisitExpr(GetRef<Function>(function_node));
+    auto gvar = GetRef<GlobalVar>(op);
+    called_funcs_.insert(gvar);
+    if (auto func = mod_->functions.Get(gvar)) {
+      if (const auto* function_node = func.as<FunctionNode>()) {
+        VisitExpr(GetRef<Function>(function_node));
+      }
+      // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls 
therein.
+    } else {
+      // The GlobalVar is not contained in the IRModule.  While the
+      // input IRModule is ill-formed, this specific case is allowed
+      // for use with `relax.transform.ApplyPassToFunction`.  If this
+      // occurs, DCE should not remove any internal functions from the
+      // IRModule, as their removal is only valid if we have a
+      // complete call graph.
+      all_callees_found_ = false;
     }
-    // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls 
therein.
   }
 
   void VisitExpr_(const CallNode* call_node) final { 
ExprVisitor::VisitExpr_(call_node); }
@@ -77,11 +87,24 @@ class CallTracer : public ExprVisitor {
     VisitExpr(main_func);
   }
 
-  bool check_if_called(GlobalVar gv) { return called_funcs_.count(gv) > 0; }
+  /* \brief Check if a function is unreachable
+   *
+   * \param gvar The function to be checked
+   *
+   * \return True if the function can be proven to be unreachable,
+   * either directly or indirectly, from an external caller.
+   * Otherwise, false.
+   */
+  bool CheckIfProvablyUnreachable(const GlobalVar& gvar) const {
+    return all_callees_found_ && !called_funcs_.count(gvar);
+  }
 
  private:
   IRModule mod_;
 
+  /* \brief Whether all callees could be located within the IRModule */
+  bool all_callees_found_{true};
+
   // Record the names of all encountered functions.
   std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> called_funcs_;
 
@@ -101,7 +124,7 @@ IRModule RemoveUnusedFunctions(
     // The tracer contains all user-provided entry functions, all
     // externally-callable functions, and anything that is directly or
     // indirectly accessible from an entry function.
-    if (!tracer.check_if_called(kv.first)) {
+    if (tracer.CheckIfProvablyUnreachable(kv.first)) {
       to_remove.push_back(kv.first);
     }
   }
diff --git a/tests/python/relax/conftest.py b/tests/python/relax/conftest.py
index 1e12a95e52..bb5a04ef76 100644
--- a/tests/python/relax/conftest.py
+++ b/tests/python/relax/conftest.py
@@ -37,7 +37,14 @@ def pytest_configure(config):
         "markers",
         (
             "skip_well_formed_check_before_transform: "
-            "Only check for well-formed IRModule after a transform"
+            "Suppress the default well-formed check before a IRModule 
transform"
+        ),
+    )
+    config.addinivalue_line(
+        "markers",
+        (
+            "skip_well_formed_check_after_transform: "
+            "Suppress the default well-formed check after a IRModule transform"
         ),
     )
 
@@ -54,15 +61,20 @@ def pytest_configure(config):
 # `@pytest.mark.skip_well_formed_check_before_transform`
 @pytest.fixture(autouse=True)
 def apply_instrument_well_formed(unit_test_marks):
-
     validate_before_transform = "skip_well_formed_check_before_transform" not 
in unit_test_marks
+    validate_after_transform = "skip_well_formed_check_after_transform" not in 
unit_test_marks
 
-    instrument = 
WellFormedInstrument(validate_before_transform=validate_before_transform)
     current = tvm.transform.PassContext.current()
+    instruments = list(current.instruments)
+
+    if validate_before_transform or validate_after_transform:
+        instruments.append(
+            
WellFormedInstrument(validate_before_transform=validate_before_transform)
+        )
 
     override = tvm.transform.PassContext(
-        # Append the new instrument
-        instruments=[*current.instruments, instrument],
+        # With the new WellFormedInstrument appended
+        instruments=instruments,
         # Forward all other parameters
         opt_level=current.opt_level,
         required_pass=current.required_pass,
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py 
b/tests/python/relax/test_transform_dead_code_elimination.py
index c0a2d47b19..2dae252cad 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import pytest
+
 import tvm
 import tvm.testing
 from tvm.relax.transform import DeadCodeElimination
@@ -507,5 +509,158 @@ def test_extern_func():
     verify(before, before)
 
 
[email protected]_well_formed_check_before_transform
[email protected]_well_formed_check_after_transform
+def test_compatibility_with_apply_pass_to_function():
+    """DeadCodeElimination can be used with ApplyPassToFunction
+
+    The `ApplyPassToFunction` utility calls another transform, where
+    only the specified functions are exposed to the internal
+    transform.  This intermediate does not contain `cls.subroutine`,
+    and so the intermediate is ill-formed.
+
+    In general, IRModule transformations may assume that their inputs
+    are well-formed.  In specific cases, IRModule transformations may
+    accept IRModules that are ill-formed.  The `DeadCodeElimination`
+    transform allows IRModule arguments that are ill-formed due to
+    a dangling GlobalVar.
+
+    After `DeadCodeElimination` completes, the resulting function is
+    inserted in the original IRModule, providing a well-formed output
+    from `ApplyPassToFunction`.
+
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def to_be_transformed(A: R.Tensor):
+            cls = Before
+
+            B = R.add(A, A)
+            C = cls.subroutine(B)
+            D = R.multiply(C, C)
+            return C
+
+        @R.function
+        def to_be_ignored(A: R.Tensor):
+            cls = Before
+
+            B = R.add(A, A)
+            C = cls.subroutine(B)
+            D = R.multiply(C, C)
+            return C
+
+        @R.function(private=True)
+        def subroutine(arg: R.Tensor) -> R.Tensor:
+            return R.add(arg, arg)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def to_be_transformed(A: R.Tensor):
+            cls = Expected
+
+            B = R.add(A, A)
+            C = cls.subroutine(B)
+            return C
+
+        @R.function
+        def to_be_ignored(A: R.Tensor):
+            cls = Expected
+
+            B = R.add(A, A)
+            C = cls.subroutine(B)
+            D = R.multiply(C, C)
+            return C
+
+        @R.function(private=True)
+        def subroutine(arg: R.Tensor) -> R.Tensor:
+            return R.add(arg, arg)
+
+    # The well-formed check in conftest.py must be disabled, to avoid
+    # triggering on the ill-formed intermediate, so this unit test
+    # checks it explicitly.
+    assert tvm.relax.analysis.well_formed(Before)
+    After = tvm.ir.transform.ApplyPassToFunction(
+        tvm.relax.transform.DeadCodeElimination(),
+        "to_be_transformed",
+    )(Before)
+    assert tvm.relax.analysis.well_formed(After)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
[email protected]_well_formed_check_before_transform
[email protected]_well_formed_check_after_transform
+def test_well_formed_output_with_restricted_scope():
+    """DeadCodeElimination can be used with ApplyPassToFunction
+
+    If the call graph cannot be completely traced, private functions
+    should not be removed.
+
+    See `test_compatibility_with_apply_pass_to_function` for full
+    description of `DeadCodeElimination` and `ApplyPassToFunction`.
+
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor):
+            cls = Before
+
+            B = R.add(A, A)
+            C = cls.subroutine(B)
+            D = R.multiply(C, C)
+            return C
+
+        @R.function(private=True)
+        def subroutine(A: R.Tensor) -> R.Tensor:
+            cls = Before
+
+            B = R.add(A, A)
+            C = cls.subsubroutine(B)
+            D = R.multiply(C, C)
+            return C
+
+        @R.function(private=True)
+        def subsubroutine(A: R.Tensor) -> R.Tensor:
+            B = R.add(A, A)
+            C = R.multiply(B, B)
+            return B
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor):
+            cls = Expected
+
+            B = R.add(A, A)
+            C = cls.subroutine(B)
+            return C
+
+        @R.function(private=True)
+        def subroutine(A: R.Tensor) -> R.Tensor:
+            cls = Expected
+
+            B = R.add(A, A)
+            C = cls.subsubroutine(B)
+            D = R.multiply(C, C)
+            return C
+
+        @R.function(private=True)
+        def subsubroutine(A: R.Tensor) -> R.Tensor:
+            B = R.add(A, A)
+            return B
+
+    assert tvm.relax.analysis.well_formed(Before)
+    After = tvm.ir.transform.ApplyPassToFunction(
+        tvm.relax.transform.DeadCodeElimination(),
+        "main|subsubroutine",
+    )(Before)
+    assert tvm.relax.analysis.well_formed(After)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to