This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9b5a7a457f [IR] Provide well-formed intermediate in
ApplyPassToFunction (#16843)
9b5a7a457f is described below
commit 9b5a7a457fc967bc38155abc1a71431603c76009
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Apr 5 13:21:52 2024 -0500
[IR] Provide well-formed intermediate in ApplyPassToFunction (#16843)
Prior to this commit, `ApplyPassToFunction` removed functions from the
`IRModule` to hide them from the inner `ir.transform.Pass`. The
dangling `GlobalVar` references to those functions meant that the
intermediate `IRModule` was ill-formed This commit updates the
`ApplyPassToFunction` utility to instead replace the functions with
`ExternFunc` nodes. This still prevents the inner `ir.transform.Pass`
from having visibility into functions that should not be mutated, but
provides a well-formed `IRModule`.
---
src/ir/apply_pass_to_function.cc | 136 +++++++++++++++++++++
src/ir/transform.cc | 32 +----
.../relax/test_transform_dead_code_elimination.py | 4 -
3 files changed, 137 insertions(+), 35 deletions(-)
diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc
new file mode 100644
index 0000000000..7f7bc7e90a
--- /dev/null
+++ b/src/ir/apply_pass_to_function.cc
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/ir/apply_pass_to_function.cc
+ * \brief Utility transformation that applies an inner pass to a subset of an
IRModule
+ */
+#include <tvm/ir/transform.h>
+#include <tvm/relax/expr.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/function.h>
+
+#include <unordered_set>
+
+#include "../runtime/regex.h"
+
+namespace tvm {
+namespace transform {
+
+namespace {
+BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key,
ObjectRef attr_value) {
+ if (auto tir = func.as<tir::PrimFunc>()) {
+ return WithAttr(tir.value(), attr_key, attr_value);
+ } else if (auto relax = func.as<relax::Function>()) {
+ return WithAttr(relax.value(), attr_key, attr_value);
+ } else {
+ return func;
+ }
+}
+
+BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) {
+ if (auto tir = func.as<tir::PrimFunc>()) {
+ return WithoutAttr(tir.value(), attr_key);
+ } else if (auto relax = func.as<relax::Function>()) {
+ return WithoutAttr(relax.value(), attr_key);
+ } else {
+ return func;
+ }
+}
+} // namespace
+
+Pass ApplyPassToFunction(Pass pass, String func_name_regex,
+ bool error_if_no_function_matches_regex) {
+ auto pass_name =
+ static_cast<const std::stringstream&>(std::stringstream() <<
"ApplyPassTo" << func_name_regex)
+ .str();
+
+ auto pass_func = [pass, func_name_regex, error_if_no_function_matches_regex](
+ IRModule mod, PassContext) -> IRModule {
+ bool at_least_one_function_matched_regex = false;
+ std::unordered_set<String> keep_original_version;
+ std::unordered_set<String> internal_functions;
+ IRModule subset;
+
+ for (auto [gvar, func] : mod->functions) {
+ std::string name = gvar->name_hint;
+ if (tvm::runtime::regex_match(name, func_name_regex)) {
+ at_least_one_function_matched_regex = true;
+ if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
+ // Function may be mutated, but is an internal function. Mark
+ // it as externally-exposed, so that any call-tracing internal
+ // transforms do not remove this function, in case it its
+ // callers are not being mutated.
+
+ internal_functions.insert(gvar->name_hint);
+ func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol,
gvar->name_hint);
+ }
+ } else {
+ // Function may not be mutated. Replace it with a
+ // `relax::ExternFunc` to prevent references to it from
+ // dangling.
+ keep_original_version.insert(gvar->name_hint);
+ func = relax::ExternFunc("dummy_" + name);
+ func->struct_info_ = gvar->struct_info_;
+ func->checked_type_ = gvar->checked_type_;
+ }
+
+ subset->Add(gvar, func);
+ }
+
+ if (error_if_no_function_matches_regex) {
+ CHECK(at_least_one_function_matched_regex)
+ << "No function matched regex '" << func_name_regex << "', out of
functions " << [&]() {
+ Array<String> function_names;
+ for (const auto& [gvar, func] : mod->functions) {
+ function_names.push_back(gvar->name_hint);
+ }
+ return function_names;
+ }();
+ }
+
+ IRModule new_subset = pass(subset);
+ if (new_subset.same_as(subset)) {
+ return mod;
+ }
+
+ auto write_ptr = mod.CopyOnWrite();
+ for (auto [gvar, func] : new_subset->functions) {
+ if (!keep_original_version.count(gvar->name_hint)) {
+ if (auto it = write_ptr->global_var_map_.find(gvar->name_hint);
+ it != write_ptr->global_var_map_.end()) {
+ write_ptr->Remove((*it).second);
+ }
+ if (internal_functions.count(gvar->name_hint)) {
+ func = BaseFuncWithoutAttr(func, tvm::attr::kGlobalSymbol);
+ }
+ write_ptr->Add(gvar, func);
+ }
+ }
+
+ return mod;
+ };
+
+ return CreateModulePass(pass_func, 0, pass_name, {});
+}
+
+TVM_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction);
+
+} // namespace transform
+} // namespace tvm
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index 3eb64fec84..dc67822411 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -25,6 +25,7 @@
#include <tvm/ir/transform.h>
#include <tvm/node/repr_printer.h>
#include <tvm/node/structural_hash.h>
+#include <tvm/relax/expr.h>
#include <tvm/relax/tuning_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
@@ -532,37 +533,6 @@ Pass CreateModulePass(const
runtime::TypedPackedFunc<IRModule(IRModule, PassCont
return ModulePass(pass_func, pass_info);
}
-Pass ApplyPassToFunction(Pass pass, String func_name_regex,
- bool error_if_no_function_matches_regex) {
- auto pass_name =
- static_cast<const std::stringstream&>(std::stringstream() <<
"ApplyPassTo" << func_name_regex)
- .str();
-
- auto pass_func = [pass, func_name_regex](IRModule mod, PassContext) ->
IRModule {
- IRModule subset;
-
- for (const auto& [gvar, func] : mod->functions) {
- std::string name = gvar->name_hint;
- if (tvm::runtime::regex_match(name, func_name_regex)) {
- subset->Add(gvar, func);
- }
- }
-
- if (subset->functions.size()) {
- IRModule new_subset = pass(subset);
- if (!new_subset.same_as(subset)) {
- mod.CopyOnWrite()->Update(new_subset);
- }
- }
-
- return mod;
- };
-
- return CreateModulePass(pass_func, 0, pass_name, {});
-}
-
-TVM_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction);
-
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("transform.PassInfo")
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py
b/tests/python/relax/test_transform_dead_code_elimination.py
index 2dae252cad..0cb0d46247 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -509,8 +509,6 @@ def test_extern_func():
verify(before, before)
[email protected]_well_formed_check_before_transform
[email protected]_well_formed_check_after_transform
def test_compatibility_with_apply_pass_to_function():
"""DeadCodeElimination can be used with ApplyPassToFunction
@@ -590,8 +588,6 @@ def test_compatibility_with_apply_pass_to_function():
tvm.ir.assert_structural_equal(Expected, After)
[email protected]_well_formed_check_before_transform
[email protected]_well_formed_check_after_transform
def test_well_formed_output_with_restricted_scope():
"""DeadCodeElimination can be used with ApplyPassToFunction