This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3918e14389 [REFACTOR][IR] Inline ApplyPassToFunction into relax 
decompose_ops, delete the util (#19612)
3918e14389 is described below

commit 3918e143893147beaf673cfc9cf269fe652b7a8f
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue May 26 15:30:20 2026 -0400

    [REFACTOR][IR] Inline ApplyPassToFunction into relax decompose_ops, delete 
the util (#19612)
    
    ## Summary
    
    `ApplyPassToFunction` is a general-purpose wrapper that runs a pass on
    only the functions in an IRModule whose name matches a regex. Its sole
    in-tree production callers are `DecomposeOpsForInference` /
    `DecomposeOpsForTraining` in `src/relax/transform/decompose_ops.cc`, and
    both callers always supply a literal function name (never a regex
    pattern). Inlining the logic as a file-local helper simplifies the
    module-level context and removes an abstraction that exists only to
    support one use case.
    
    - Inline the helper as `ApplyDecomposeToFunction` (exact-name match, not
    regex) in `src/relax/transform/decompose_ops.cc`
    - Delete `src/ir/apply_pass_to_function.cc`, its `transform.h`
    declaration, and the Python wrapper in `python/tvm/ir/transform.py`
    - Remove two DCE tests
    (`test_compatibility_with_apply_pass_to_function`,
    `test_well_formed_output_with_restricted_scope`) that tested the
    utility's plumbing rather than DCE behavior
---
 include/tvm/ir/transform.h                         |  25 ----
 python/tvm/ir/transform.py                         |  43 ------
 src/ir/apply_pass_to_function.cc                   | 139 -------------------
 src/relax/transform/decompose_ops.cc               |  88 +++++++++++-
 .../relax/test_transform_dead_code_elimination.py  | 149 ---------------------
 5 files changed, 86 insertions(+), 358 deletions(-)

diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index 6d4f5c333c..436987ae78 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -529,31 +529,6 @@ TVM_DLL Pass 
CreateModulePass(std::function<IRModule(IRModule, PassContext)> pas
                               int opt_level, ffi::String name, 
ffi::Array<ffi::String> required,
                               bool traceable = false);
 
-/*
- * \brief Utility to apply a pass to specific functions in an IRModule
- *
- * TVM uses IRModule to IRModule transformations at all stages of
- * lowering.  These transformations may be useful when hand-writing an
- * optimized model, or to perform optimizations on specific kernels
- * within an IRModule.  This utility allows a pass to be applied to a
- * specified function, without altering other functions in the module.
- *
- * \param pass The IRModule to IRModule pass to be applied.
- *
- * \param func_name_regex A regex used to select the functions to be
- * updated.  The pass will be applied to all functions whose name
- * matches the regex.
- *
- * \param error_if_no_function_matches_regex Specifies the behavior if
- *     an IRModule does not contain any function matching the provided
- *     regex.  If true, an error will be raised.  If false (default),
- *     the IRModule will be returned unmodified.
- *
- * \return The modified IRModule to IRModule pass.
- */
-TVM_DLL Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex,
-                                 bool error_if_no_function_matches_regex = 
false);
-
 /*!
  * \brief A special trace pass that prints the header and IR to LOG(INFO).
  * \param header The header to be attached to the output.
diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py
index 3e22a2b908..0f0ad89e62 100644
--- a/python/tvm/ir/transform.py
+++ b/python/tvm/ir/transform.py
@@ -365,46 +365,3 @@ def PrintIR(header=""):
     The pass
     """
     return _ffi_transform_api.PrintIR(header)
-
-
-def ApplyPassToFunction(
-    transform: Pass,
-    func_name_regex: str,
-    error_if_no_function_matches_regex: bool = False,
-) -> Pass:
-    """Utility to apply a pass to specific functions in an IRModule
-
-    TVM uses IRModule to IRModule transformations at all stages of
-    lowering.  These transformations may be useful when hand-writing an
-    optimized model, or to perform optimizations on specific kernels
-    within an IRModule.  This utility allows a pass to be applied to a
-    specified function, without altering other functions in the module.
-
-    Parameters
-    ----------
-    transform: Pass
-
-        The IRModule to IRModule pass to be applied.
-
-    func_name_regex: str
-
-        A regex used to select the functions to be updated.  The pass
-        will be applied to all functions whose name matches the regex.
-
-    error_if_no_function_matches_regex: bool
-
-        Specifies the behavior if an IRModule does not contain any
-        function matching the provided regex.  If true, an error will
-        be raised.  If false (default), the IRModule will be returned
-        unmodified.
-
-    Returns
-    -------
-    new_transform: Pass
-
-        The modified IRModule to IRModule pass.
-
-    """
-    return _ffi_transform_api.ApplyPassToFunction(
-        transform, func_name_regex, error_if_no_function_matches_regex
-    )
diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc
deleted file mode 100644
index 1524ea9fc2..0000000000
--- a/src/ir/apply_pass_to_function.cc
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file src/ir/apply_pass_to_function.cc
- * \brief Utility transformation that applies an inner pass to a subset of an 
IRModule
- */
-#include <tvm/ffi/function.h>
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ir/transform.h>
-#include <tvm/relax/expr.h>
-#include <tvm/tirx/function.h>
-
-#include <unordered_set>
-
-#include "../runtime/regex.h"
-
-namespace tvm {
-namespace transform {
-
-namespace {
-BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, Any 
attr_value) {
-  if (auto tirx = func.as<tirx::PrimFunc>()) {
-    return WithAttr(tirx.value(), attr_key, attr_value);
-  } else if (auto relax = func.as<relax::Function>()) {
-    return WithAttr(relax.value(), attr_key, attr_value);
-  } else {
-    return func;
-  }
-}
-
-BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) {
-  if (auto tirx = func.as<tirx::PrimFunc>()) {
-    return WithoutAttr(tirx.value(), attr_key);
-  } else if (auto relax = func.as<relax::Function>()) {
-    return WithoutAttr(relax.value(), attr_key);
-  } else {
-    return func;
-  }
-}
-}  // namespace
-
-Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex,
-                         bool error_if_no_function_matches_regex) {
-  auto pass_name =
-      static_cast<const std::stringstream&>(std::stringstream() << 
"ApplyPassTo" << func_name_regex)
-          .str();
-
-  auto pass_func = [pass, func_name_regex, error_if_no_function_matches_regex](
-                       IRModule mod, PassContext) -> IRModule {
-    bool at_least_one_function_matched_regex = false;
-    std::unordered_set<ffi::String> keep_original_version;
-    std::unordered_set<ffi::String> internal_functions;
-    IRModule subset;
-
-    for (auto [gvar, func] : mod->functions) {
-      std::string name = gvar->name_hint;
-      if (tvm::runtime::regex_match(name, func_name_regex)) {
-        at_least_one_function_matched_regex = true;
-        if (!func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value()) 
{
-          // Function may be mutated, but is an internal function.  Mark
-          // it as externally-exposed, so that any call-tracing internal
-          // transforms do not remove this function, in case it its
-          // callers are not being mutated.
-
-          internal_functions.insert(gvar->name_hint);
-          func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol, 
gvar->name_hint);
-        }
-      } else {
-        // Function may not be mutated.  Replace it with a
-        // `relax::ExternFunc` to prevent references to it from
-        // dangling.
-        keep_original_version.insert(gvar->name_hint);
-        func = relax::ExternFunc("dummy_" + name);
-        func->struct_info_ = gvar->struct_info_;
-      }
-
-      subset->Add(gvar, func);
-    }
-
-    if (error_if_no_function_matches_regex) {
-      TVM_FFI_ICHECK(at_least_one_function_matched_regex)
-          << "No function matched regex '" << func_name_regex << "', out of 
functions " << [&]() {
-               ffi::Array<ffi::String> function_names;
-               for (const auto& [gvar, func] : mod->functions) {
-                 function_names.push_back(gvar->name_hint);
-               }
-               return function_names;
-             }();
-    }
-
-    IRModule new_subset = pass(subset);
-    if (new_subset.same_as(subset)) {
-      return mod;
-    }
-
-    auto write_ptr = mod.CopyOnWrite();
-    for (auto [gvar, func] : new_subset->functions) {
-      if (!keep_original_version.count(gvar->name_hint)) {
-        if (auto it = write_ptr->global_var_map_.find(gvar->name_hint);
-            it != write_ptr->global_var_map_.end()) {
-          write_ptr->Remove((*it).second);
-        }
-        if (internal_functions.count(gvar->name_hint)) {
-          func = BaseFuncWithoutAttr(func, tvm::attr::kGlobalSymbol);
-        }
-        write_ptr->Add(gvar, func);
-      }
-    }
-
-    return mod;
-  };
-
-  return CreateModulePass(pass_func, 0, pass_name, {});
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
-  namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("transform.ApplyPassToFunction", ApplyPassToFunction);
-}
-
-}  // namespace transform
-}  // namespace tvm
diff --git a/src/relax/transform/decompose_ops.cc 
b/src/relax/transform/decompose_ops.cc
index c53d9b0f3a..1c5e658493 100644
--- a/src/relax/transform/decompose_ops.cc
+++ b/src/relax/transform/decompose_ops.cc
@@ -25,6 +25,9 @@
 #include <tvm/relax/attrs/nn.h>
 #include <tvm/relax/struct_info.h>
 #include <tvm/relax/transform.h>
+#include <tvm/tirx/function.h>
+
+#include <unordered_set>
 
 #include "utils.h"
 
@@ -212,6 +215,87 @@ class OpDecomposer : public ExprMutator {
 
 namespace transform {
 
+namespace {
+
+/*! \brief Helper: add or remove an attribute on a BaseFunc */
+BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, Any 
attr_value) {
+  if (auto tirx = func.as<tirx::PrimFunc>()) {
+    return WithAttr(tirx.value(), attr_key, attr_value);
+  } else if (auto relax_fn = func.as<relax::Function>()) {
+    return WithAttr(relax_fn.value(), attr_key, attr_value);
+  } else {
+    return func;
+  }
+}
+
+BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) {
+  if (auto tirx = func.as<tirx::PrimFunc>()) {
+    return WithoutAttr(tirx.value(), attr_key);
+  } else if (auto relax_fn = func.as<relax::Function>()) {
+    return WithoutAttr(relax_fn.value(), attr_key);
+  } else {
+    return func;
+  }
+}
+
+/*!
+ * \brief Apply a pass to a single named function within an IRModule.
+ *
+ * Replaces all other functions with dummy ExternFunc stubs so that the
+ * pass does not see them, then restores the original module.  Uses
+ * exact name match (not a regex) because all in-tree callers supply a
+ * literal function name.
+ */
+Pass ApplyDecomposeToFunction(Pass pass, ffi::String func_name) {
+  auto pass_func = [pass, func_name](IRModule mod, PassContext) -> IRModule {
+    std::unordered_set<ffi::String> keep_original_version;
+    std::unordered_set<ffi::String> internal_functions;
+    IRModule subset;
+
+    for (auto [gvar, func] : mod->functions) {
+      if (gvar->name_hint == func_name) {
+        if (!func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value()) 
{
+          // Mark internal functions as externally-exposed so that
+          // call-tracing transforms inside the pass do not remove them.
+          internal_functions.insert(gvar->name_hint);
+          func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol, 
gvar->name_hint);
+        }
+      } else {
+        // Replace non-target functions with stubs to keep references intact.
+        keep_original_version.insert(gvar->name_hint);
+        func = relax::ExternFunc("dummy_" + std::string(gvar->name_hint));
+        func->struct_info_ = gvar->struct_info_;
+      }
+      subset->Add(gvar, func);
+    }
+
+    IRModule new_subset = pass(subset);
+    if (new_subset.same_as(subset)) {
+      return mod;
+    }
+
+    auto write_ptr = mod.CopyOnWrite();
+    for (auto [gvar, func] : new_subset->functions) {
+      if (!keep_original_version.count(gvar->name_hint)) {
+        if (auto it = write_ptr->global_var_map_.find(gvar->name_hint);
+            it != write_ptr->global_var_map_.end()) {
+          write_ptr->Remove((*it).second);
+        }
+        if (internal_functions.count(gvar->name_hint)) {
+          func = BaseFuncWithoutAttr(func, tvm::attr::kGlobalSymbol);
+        }
+        write_ptr->Add(gvar, func);
+      }
+    }
+    return mod;
+  };
+
+  std::string pass_name = "ApplyDecomposeTo" + std::string(func_name);
+  return CreateModulePass(pass_func, 0, pass_name, {});
+}
+
+}  // namespace
+
 Pass MutateOpsForTraining() {
   auto pass_func = [](Function func, IRModule, PassContext) -> Function {
     TrainingOperatorMutator mutator;
@@ -236,7 +320,7 @@ Pass DecomposeOps() {
 
 Pass DecomposeOpsForInference(ffi::Optional<ffi::String> func_name) {
   if (func_name) {
-    return ApplyPassToFunction(DecomposeOps(), func_name.value());
+    return ApplyDecomposeToFunction(DecomposeOps(), func_name.value());
   } else {
     return DecomposeOps();
   }
@@ -246,7 +330,7 @@ Pass DecomposeOpsForTraining(ffi::Optional<ffi::String> 
func_name) {
   auto module_pass = tvm::transform::Sequential({MutateOpsForTraining(), 
DecomposeOps()},
                                                 "DecomposeOpsForTraining");
   if (func_name) {
-    return ApplyPassToFunction(module_pass, func_name.value());
+    return ApplyDecomposeToFunction(module_pass, func_name.value());
   } else {
     return module_pass;
   }
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py 
b/tests/python/relax/test_transform_dead_code_elimination.py
index 82eeba354f..87366137b1 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -572,155 +572,6 @@ def test_extern_func():
     verify(before, before)
 
 
-def test_compatibility_with_apply_pass_to_function():
-    """DeadCodeElimination can be used with ApplyPassToFunction
-
-    The `ApplyPassToFunction` utility calls another transform, where
-    only the specified functions are exposed to the internal
-    transform.  This intermediate does not contain `cls.subroutine`,
-    and so the intermediate is ill-formed.
-
-    In general, IRModule transformations may assume that their inputs
-    are well-formed.  In specific cases, IRModule transformations may
-    accept IRModules that are ill-formed.  The `DeadCodeElimination`
-    transform allows IRModule arguments that are ill-formed due to
-    a dangling GlobalVar.
-
-    After `DeadCodeElimination` completes, the resulting function is
-    inserted in the original IRModule, providing a well-formed output
-    from `ApplyPassToFunction`.
-
-    """
-
-    @I.ir_module(s_tir=True)
-    class Before:
-        @R.function
-        def to_be_transformed(A: R.Tensor):
-            cls = Before
-
-            B = R.add(A, A)
-            C = cls.subroutine(B)
-            D = R.multiply(C, C)
-            return C
-
-        @R.function
-        def to_be_ignored(A: R.Tensor):
-            cls = Before
-
-            B = R.add(A, A)
-            C = cls.subroutine(B)
-            D = R.multiply(C, C)
-            return C
-
-        @R.function(private=True)
-        def subroutine(arg: R.Tensor) -> R.Tensor:
-            return R.add(arg, arg)
-
-    @I.ir_module(s_tir=True)
-    class Expected:
-        @R.function
-        def to_be_transformed(A: R.Tensor):
-            cls = Expected
-
-            B = R.add(A, A)
-            C = cls.subroutine(B)
-            return C
-
-        @R.function
-        def to_be_ignored(A: R.Tensor):
-            cls = Expected
-
-            B = R.add(A, A)
-            C = cls.subroutine(B)
-            D = R.multiply(C, C)
-            return C
-
-        @R.function(private=True)
-        def subroutine(arg: R.Tensor) -> R.Tensor:
-            return R.add(arg, arg)
-
-    # The well-formed check in conftest.py must be disabled, to avoid
-    # triggering on the ill-formed intermediate, so this unit test
-    # checks it explicitly.
-    assert tvm.relax.analysis.well_formed(Before)
-    After = tvm.ir.transform.ApplyPassToFunction(
-        tvm.relax.transform.DeadCodeElimination(),
-        "to_be_transformed",
-    )(Before)
-    assert tvm.relax.analysis.well_formed(After)
-    tvm.ir.assert_structural_equal(Expected, After)
-
-
-def test_well_formed_output_with_restricted_scope():
-    """DeadCodeElimination can be used with ApplyPassToFunction
-
-    If the call graph cannot be completely traced, private functions
-    should not be removed.
-
-    See `test_compatibility_with_apply_pass_to_function` for full
-    description of `DeadCodeElimination` and `ApplyPassToFunction`.
-
-    """
-
-    @I.ir_module(s_tir=True)
-    class Before:
-        @R.function
-        def main(A: R.Tensor):
-            cls = Before
-
-            B = R.add(A, A)
-            C = cls.subroutine(B)
-            D = R.multiply(C, C)
-            return C
-
-        @R.function(private=True)
-        def subroutine(A: R.Tensor) -> R.Tensor:
-            cls = Before
-
-            B = R.add(A, A)
-            C = cls.subsubroutine(B)
-            D = R.multiply(C, C)
-            return C
-
-        @R.function(private=True)
-        def subsubroutine(A: R.Tensor) -> R.Tensor:
-            B = R.add(A, A)
-            C = R.multiply(B, B)
-            return B
-
-    @I.ir_module(s_tir=True)
-    class Expected:
-        @R.function
-        def main(A: R.Tensor):
-            cls = Expected
-
-            B = R.add(A, A)
-            C = cls.subroutine(B)
-            return C
-
-        @R.function(private=True)
-        def subroutine(A: R.Tensor) -> R.Tensor:
-            cls = Expected
-
-            B = R.add(A, A)
-            C = cls.subsubroutine(B)
-            D = R.multiply(C, C)
-            return C
-
-        @R.function(private=True)
-        def subsubroutine(A: R.Tensor) -> R.Tensor:
-            B = R.add(A, A)
-            return B
-
-    assert tvm.relax.analysis.well_formed(Before)
-    After = tvm.ir.transform.ApplyPassToFunction(
-        tvm.relax.transform.DeadCodeElimination(),
-        "main|subsubroutine",
-    )(Before)
-    assert tvm.relax.analysis.well_formed(After)
-    tvm.ir.assert_structural_equal(Expected, After)
-
-
 def test_recursively_defined_lambda():
     """DCE may be applied to recursively-defined functions
 

Reply via email to