This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3918e14389 [REFACTOR][IR] Inline ApplyPassToFunction into relax
decompose_ops, delete the util (#19612)
3918e14389 is described below
commit 3918e143893147beaf673cfc9cf269fe652b7a8f
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue May 26 15:30:20 2026 -0400
[REFACTOR][IR] Inline ApplyPassToFunction into relax decompose_ops, delete
the util (#19612)
## Summary
`ApplyPassToFunction` is a general-purpose wrapper that runs a pass on
only the functions in an IRModule whose name matches a regex. Its sole
in-tree production callers are `DecomposeOpsForInference` /
`DecomposeOpsForTraining` in `src/relax/transform/decompose_ops.cc`, and
both callers always supply a literal function name (never a regex
pattern). Inlining the logic as a file-local helper simplifies the
module-level context and removes an abstraction that exists only to
support one use case.
- Inline the helper as `ApplyDecomposeToFunction` (exact-name match, not
regex) in `src/relax/transform/decompose_ops.cc`
- Delete `src/ir/apply_pass_to_function.cc`, its `transform.h`
declaration, and the Python wrapper in `python/tvm/ir/transform.py`
- Remove two DCE tests
(`test_compatibility_with_apply_pass_to_function`,
`test_well_formed_output_with_restricted_scope`) that tested the
utility's plumbing rather than DCE behavior
---
include/tvm/ir/transform.h | 25 ----
python/tvm/ir/transform.py | 43 ------
src/ir/apply_pass_to_function.cc | 139 -------------------
src/relax/transform/decompose_ops.cc | 88 +++++++++++-
.../relax/test_transform_dead_code_elimination.py | 149 ---------------------
5 files changed, 86 insertions(+), 358 deletions(-)
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index 6d4f5c333c..436987ae78 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -529,31 +529,6 @@ TVM_DLL Pass
CreateModulePass(std::function<IRModule(IRModule, PassContext)> pas
int opt_level, ffi::String name,
ffi::Array<ffi::String> required,
bool traceable = false);
-/*
- * \brief Utility to apply a pass to specific functions in an IRModule
- *
- * TVM uses IRModule to IRModule transformations at all stages of
- * lowering. These transformations may be useful when hand-writing an
- * optimized model, or to perform optimizations on specific kernels
- * within an IRModule. This utility allows a pass to be applied to a
- * specified function, without altering other functions in the module.
- *
- * \param pass The IRModule to IRModule pass to be applied.
- *
- * \param func_name_regex A regex used to select the functions to be
- * updated. The pass will be applied to all functions whose name
- * matches the regex.
- *
- * \param error_if_no_function_matches_regex Specifies the behavior if
- * an IRModule does not contain any function matching the provided
- * regex. If true, an error will be raised. If false (default),
- * the IRModule will be returned unmodified.
- *
- * \return The modified IRModule to IRModule pass.
- */
-TVM_DLL Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex,
- bool error_if_no_function_matches_regex =
false);
-
/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py
index 3e22a2b908..0f0ad89e62 100644
--- a/python/tvm/ir/transform.py
+++ b/python/tvm/ir/transform.py
@@ -365,46 +365,3 @@ def PrintIR(header=""):
The pass
"""
return _ffi_transform_api.PrintIR(header)
-
-
-def ApplyPassToFunction(
- transform: Pass,
- func_name_regex: str,
- error_if_no_function_matches_regex: bool = False,
-) -> Pass:
- """Utility to apply a pass to specific functions in an IRModule
-
- TVM uses IRModule to IRModule transformations at all stages of
- lowering. These transformations may be useful when hand-writing an
- optimized model, or to perform optimizations on specific kernels
- within an IRModule. This utility allows a pass to be applied to a
- specified function, without altering other functions in the module.
-
- Parameters
- ----------
- transform: Pass
-
- The IRModule to IRModule pass to be applied.
-
- func_name_regex: str
-
- A regex used to select the functions to be updated. The pass
- will be applied to all functions whose name matches the regex.
-
- error_if_no_function_matches_regex: bool
-
- Specifies the behavior if an IRModule does not contain any
- function matching the provided regex. If true, an error will
- be raised. If false (default), the IRModule will be returned
- unmodified.
-
- Returns
- -------
- new_transform: Pass
-
- The modified IRModule to IRModule pass.
-
- """
- return _ffi_transform_api.ApplyPassToFunction(
- transform, func_name_regex, error_if_no_function_matches_regex
- )
diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc
deleted file mode 100644
index 1524ea9fc2..0000000000
--- a/src/ir/apply_pass_to_function.cc
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file src/ir/apply_pass_to_function.cc
- * \brief Utility transformation that applies an inner pass to a subset of an
IRModule
- */
-#include <tvm/ffi/function.h>
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ir/transform.h>
-#include <tvm/relax/expr.h>
-#include <tvm/tirx/function.h>
-
-#include <unordered_set>
-
-#include "../runtime/regex.h"
-
-namespace tvm {
-namespace transform {
-
-namespace {
-BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, Any
attr_value) {
- if (auto tirx = func.as<tirx::PrimFunc>()) {
- return WithAttr(tirx.value(), attr_key, attr_value);
- } else if (auto relax = func.as<relax::Function>()) {
- return WithAttr(relax.value(), attr_key, attr_value);
- } else {
- return func;
- }
-}
-
-BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) {
- if (auto tirx = func.as<tirx::PrimFunc>()) {
- return WithoutAttr(tirx.value(), attr_key);
- } else if (auto relax = func.as<relax::Function>()) {
- return WithoutAttr(relax.value(), attr_key);
- } else {
- return func;
- }
-}
-} // namespace
-
-Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex,
- bool error_if_no_function_matches_regex) {
- auto pass_name =
- static_cast<const std::stringstream&>(std::stringstream() <<
"ApplyPassTo" << func_name_regex)
- .str();
-
- auto pass_func = [pass, func_name_regex, error_if_no_function_matches_regex](
- IRModule mod, PassContext) -> IRModule {
- bool at_least_one_function_matched_regex = false;
- std::unordered_set<ffi::String> keep_original_version;
- std::unordered_set<ffi::String> internal_functions;
- IRModule subset;
-
- for (auto [gvar, func] : mod->functions) {
- std::string name = gvar->name_hint;
- if (tvm::runtime::regex_match(name, func_name_regex)) {
- at_least_one_function_matched_regex = true;
- if (!func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value())
{
- // Function may be mutated, but is an internal function. Mark
- // it as externally-exposed, so that any call-tracing internal
- // transforms do not remove this function, in case it its
- // callers are not being mutated.
-
- internal_functions.insert(gvar->name_hint);
- func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol,
gvar->name_hint);
- }
- } else {
- // Function may not be mutated. Replace it with a
- // `relax::ExternFunc` to prevent references to it from
- // dangling.
- keep_original_version.insert(gvar->name_hint);
- func = relax::ExternFunc("dummy_" + name);
- func->struct_info_ = gvar->struct_info_;
- }
-
- subset->Add(gvar, func);
- }
-
- if (error_if_no_function_matches_regex) {
- TVM_FFI_ICHECK(at_least_one_function_matched_regex)
- << "No function matched regex '" << func_name_regex << "', out of
functions " << [&]() {
- ffi::Array<ffi::String> function_names;
- for (const auto& [gvar, func] : mod->functions) {
- function_names.push_back(gvar->name_hint);
- }
- return function_names;
- }();
- }
-
- IRModule new_subset = pass(subset);
- if (new_subset.same_as(subset)) {
- return mod;
- }
-
- auto write_ptr = mod.CopyOnWrite();
- for (auto [gvar, func] : new_subset->functions) {
- if (!keep_original_version.count(gvar->name_hint)) {
- if (auto it = write_ptr->global_var_map_.find(gvar->name_hint);
- it != write_ptr->global_var_map_.end()) {
- write_ptr->Remove((*it).second);
- }
- if (internal_functions.count(gvar->name_hint)) {
- func = BaseFuncWithoutAttr(func, tvm::attr::kGlobalSymbol);
- }
- write_ptr->Add(gvar, func);
- }
- }
-
- return mod;
- };
-
- return CreateModulePass(pass_func, 0, pass_name, {});
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("transform.ApplyPassToFunction", ApplyPassToFunction);
-}
-
-} // namespace transform
-} // namespace tvm
diff --git a/src/relax/transform/decompose_ops.cc
b/src/relax/transform/decompose_ops.cc
index c53d9b0f3a..1c5e658493 100644
--- a/src/relax/transform/decompose_ops.cc
+++ b/src/relax/transform/decompose_ops.cc
@@ -25,6 +25,9 @@
#include <tvm/relax/attrs/nn.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
+#include <tvm/tirx/function.h>
+
+#include <unordered_set>
#include "utils.h"
@@ -212,6 +215,87 @@ class OpDecomposer : public ExprMutator {
namespace transform {
+namespace {
+
+/*! \brief Helper: add or remove an attribute on a BaseFunc */
+BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, Any
attr_value) {
+ if (auto tirx = func.as<tirx::PrimFunc>()) {
+ return WithAttr(tirx.value(), attr_key, attr_value);
+ } else if (auto relax_fn = func.as<relax::Function>()) {
+ return WithAttr(relax_fn.value(), attr_key, attr_value);
+ } else {
+ return func;
+ }
+}
+
+BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) {
+ if (auto tirx = func.as<tirx::PrimFunc>()) {
+ return WithoutAttr(tirx.value(), attr_key);
+ } else if (auto relax_fn = func.as<relax::Function>()) {
+ return WithoutAttr(relax_fn.value(), attr_key);
+ } else {
+ return func;
+ }
+}
+
+/*!
+ * \brief Apply a pass to a single named function within an IRModule.
+ *
+ * Replaces all other functions with dummy ExternFunc stubs so that the
+ * pass does not see them, then restores the original module. Uses
+ * exact name match (not a regex) because all in-tree callers supply a
+ * literal function name.
+ */
+Pass ApplyDecomposeToFunction(Pass pass, ffi::String func_name) {
+ auto pass_func = [pass, func_name](IRModule mod, PassContext) -> IRModule {
+ std::unordered_set<ffi::String> keep_original_version;
+ std::unordered_set<ffi::String> internal_functions;
+ IRModule subset;
+
+ for (auto [gvar, func] : mod->functions) {
+ if (gvar->name_hint == func_name) {
+ if (!func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value())
{
+ // Mark internal functions as externally-exposed so that
+ // call-tracing transforms inside the pass do not remove them.
+ internal_functions.insert(gvar->name_hint);
+ func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol,
gvar->name_hint);
+ }
+ } else {
+ // Replace non-target functions with stubs to keep references intact.
+ keep_original_version.insert(gvar->name_hint);
+ func = relax::ExternFunc("dummy_" + std::string(gvar->name_hint));
+ func->struct_info_ = gvar->struct_info_;
+ }
+ subset->Add(gvar, func);
+ }
+
+ IRModule new_subset = pass(subset);
+ if (new_subset.same_as(subset)) {
+ return mod;
+ }
+
+ auto write_ptr = mod.CopyOnWrite();
+ for (auto [gvar, func] : new_subset->functions) {
+ if (!keep_original_version.count(gvar->name_hint)) {
+ if (auto it = write_ptr->global_var_map_.find(gvar->name_hint);
+ it != write_ptr->global_var_map_.end()) {
+ write_ptr->Remove((*it).second);
+ }
+ if (internal_functions.count(gvar->name_hint)) {
+ func = BaseFuncWithoutAttr(func, tvm::attr::kGlobalSymbol);
+ }
+ write_ptr->Add(gvar, func);
+ }
+ }
+ return mod;
+ };
+
+ std::string pass_name = "ApplyDecomposeTo" + std::string(func_name);
+ return CreateModulePass(pass_func, 0, pass_name, {});
+}
+
+} // namespace
+
Pass MutateOpsForTraining() {
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
TrainingOperatorMutator mutator;
@@ -236,7 +320,7 @@ Pass DecomposeOps() {
Pass DecomposeOpsForInference(ffi::Optional<ffi::String> func_name) {
if (func_name) {
- return ApplyPassToFunction(DecomposeOps(), func_name.value());
+ return ApplyDecomposeToFunction(DecomposeOps(), func_name.value());
} else {
return DecomposeOps();
}
@@ -246,7 +330,7 @@ Pass DecomposeOpsForTraining(ffi::Optional<ffi::String>
func_name) {
auto module_pass = tvm::transform::Sequential({MutateOpsForTraining(),
DecomposeOps()},
"DecomposeOpsForTraining");
if (func_name) {
- return ApplyPassToFunction(module_pass, func_name.value());
+ return ApplyDecomposeToFunction(module_pass, func_name.value());
} else {
return module_pass;
}
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py
b/tests/python/relax/test_transform_dead_code_elimination.py
index 82eeba354f..87366137b1 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -572,155 +572,6 @@ def test_extern_func():
verify(before, before)
-def test_compatibility_with_apply_pass_to_function():
- """DeadCodeElimination can be used with ApplyPassToFunction
-
- The `ApplyPassToFunction` utility calls another transform, where
- only the specified functions are exposed to the internal
- transform. This intermediate does not contain `cls.subroutine`,
- and so the intermediate is ill-formed.
-
- In general, IRModule transformations may assume that their inputs
- are well-formed. In specific cases, IRModule transformations may
- accept IRModules that are ill-formed. The `DeadCodeElimination`
- transform allows IRModule arguments that are ill-formed due to
- a dangling GlobalVar.
-
- After `DeadCodeElimination` completes, the resulting function is
- inserted in the original IRModule, providing a well-formed output
- from `ApplyPassToFunction`.
-
- """
-
- @I.ir_module(s_tir=True)
- class Before:
- @R.function
- def to_be_transformed(A: R.Tensor):
- cls = Before
-
- B = R.add(A, A)
- C = cls.subroutine(B)
- D = R.multiply(C, C)
- return C
-
- @R.function
- def to_be_ignored(A: R.Tensor):
- cls = Before
-
- B = R.add(A, A)
- C = cls.subroutine(B)
- D = R.multiply(C, C)
- return C
-
- @R.function(private=True)
- def subroutine(arg: R.Tensor) -> R.Tensor:
- return R.add(arg, arg)
-
- @I.ir_module(s_tir=True)
- class Expected:
- @R.function
- def to_be_transformed(A: R.Tensor):
- cls = Expected
-
- B = R.add(A, A)
- C = cls.subroutine(B)
- return C
-
- @R.function
- def to_be_ignored(A: R.Tensor):
- cls = Expected
-
- B = R.add(A, A)
- C = cls.subroutine(B)
- D = R.multiply(C, C)
- return C
-
- @R.function(private=True)
- def subroutine(arg: R.Tensor) -> R.Tensor:
- return R.add(arg, arg)
-
- # The well-formed check in conftest.py must be disabled, to avoid
- # triggering on the ill-formed intermediate, so this unit test
- # checks it explicitly.
- assert tvm.relax.analysis.well_formed(Before)
- After = tvm.ir.transform.ApplyPassToFunction(
- tvm.relax.transform.DeadCodeElimination(),
- "to_be_transformed",
- )(Before)
- assert tvm.relax.analysis.well_formed(After)
- tvm.ir.assert_structural_equal(Expected, After)
-
-
-def test_well_formed_output_with_restricted_scope():
- """DeadCodeElimination can be used with ApplyPassToFunction
-
- If the call graph cannot be completely traced, private functions
- should not be removed.
-
- See `test_compatibility_with_apply_pass_to_function` for full
- description of `DeadCodeElimination` and `ApplyPassToFunction`.
-
- """
-
- @I.ir_module(s_tir=True)
- class Before:
- @R.function
- def main(A: R.Tensor):
- cls = Before
-
- B = R.add(A, A)
- C = cls.subroutine(B)
- D = R.multiply(C, C)
- return C
-
- @R.function(private=True)
- def subroutine(A: R.Tensor) -> R.Tensor:
- cls = Before
-
- B = R.add(A, A)
- C = cls.subsubroutine(B)
- D = R.multiply(C, C)
- return C
-
- @R.function(private=True)
- def subsubroutine(A: R.Tensor) -> R.Tensor:
- B = R.add(A, A)
- C = R.multiply(B, B)
- return B
-
- @I.ir_module(s_tir=True)
- class Expected:
- @R.function
- def main(A: R.Tensor):
- cls = Expected
-
- B = R.add(A, A)
- C = cls.subroutine(B)
- return C
-
- @R.function(private=True)
- def subroutine(A: R.Tensor) -> R.Tensor:
- cls = Expected
-
- B = R.add(A, A)
- C = cls.subsubroutine(B)
- D = R.multiply(C, C)
- return C
-
- @R.function(private=True)
- def subsubroutine(A: R.Tensor) -> R.Tensor:
- B = R.add(A, A)
- return B
-
- assert tvm.relax.analysis.well_formed(Before)
- After = tvm.ir.transform.ApplyPassToFunction(
- tvm.relax.transform.DeadCodeElimination(),
- "main|subsubroutine",
- )(Before)
- assert tvm.relax.analysis.well_formed(After)
- tvm.ir.assert_structural_equal(Expected, After)
-
-
def test_recursively_defined_lambda():
"""DCE may be applied to recursively-defined functions