This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 346f5d4ab0 [Unity][Pass] Remove Unused Function (#14061)
346f5d4ab0 is described below
commit 346f5d4ab03927d2d850314a1c29104a128f36bd
Author: Sunghyun Park <[email protected]>
AuthorDate: Mon Feb 20 21:44:11 2023 -0800
[Unity][Pass] Remove Unused Function (#14061)
This PR implements a pass to clean up unused functions.
Co-authored-by: masahi <[email protected]>
---
python/tvm/ir/function.py | 26 ++-
src/relax/transform/remove_unused_funcs.cc | 120 ++++++++++++
src/relax/transform/utils.h | 122 ++++++++++++
.../relax/test_transform_remove_unused_funcs.py | 211 +++++++++++++++++++++
4 files changed, 475 insertions(+), 4 deletions(-)
diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py
index d02698edb5..b64553d31c 100644
--- a/python/tvm/ir/function.py
+++ b/python/tvm/ir/function.py
@@ -14,11 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Function defintiions."""
+"""Function definitions."""
+from typing import Union, Dict
from enum import IntEnum
import tvm.runtime
-
+from tvm.runtime.object import Object
from .expr import RelayExpr
+from .attrs import DictAttrs
from . import _ffi_api
@@ -38,7 +40,7 @@ class BaseFunc(RelayExpr):
"""Return the attrs member of the function."""
return _ffi_api.BaseFunc_Attrs(self)
- def with_attr(self, attr_key_or_dict, attr_value=None):
+ def with_attr(self, attr_key_or_dict, attr_value=None) -> "BaseFunc":
"""Create a new copy of the function and update the attribute.
Parameters
@@ -51,7 +53,7 @@ class BaseFunc(RelayExpr):
Returns
-------
- func : Function
+ func : BaseFunc
A new copy of the function
"""
# make sure we first copy so that we can safely do copy on write
@@ -67,6 +69,22 @@ class BaseFunc(RelayExpr):
res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value)
)
+ def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) ->
"BaseFunc":
+ """Copy the IRModule and add the given attribute map to it.
+ Parameters
+ ----------
+ attr_map: Union[DictAttrs, Dict[str, Object]]
+ The attribute map
+ Returns
+ -------
+ func : BaseFunc
+ A new copy of the function
+ """
+ if isinstance(attr_map, tvm.ir.DictAttrs):
+ attr_map = attr_map._dict()
+
+ return _ffi_api.BaseFuncWithAttrs(self, attr_map)
+
def without_attr(self, attr_key: str) -> "BaseFunc":
"""Create a new copy of the function with an attribute without
provided key.
diff --git a/src/relax/transform/remove_unused_funcs.cc
b/src/relax/transform/remove_unused_funcs.cc
new file mode 100644
index 0000000000..5572da1338
--- /dev/null
+++ b/src/relax/transform/remove_unused_funcs.cc
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/remove_unused_funcs.cc
+ * \brief Remove unused global relax functions in a IRModule.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include <unordered_set>
+#include <vector>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+/**
+ * \brief Detects all the functions that can be possibly called by entry
function.
+ */
+class CallTracer : ExprVisitor {
+ public:
+ explicit CallTracer(IRModule mod_) : mod_{mod_}, called_funcs_{},
visiting_{} {}
+
+ void VisitExpr_(const GlobalVarNode* op) final {
+ called_funcs_.insert(GetRef<GlobalVar>(op));
+ auto func = mod_->Lookup(op->name_hint);
+ if (const auto* function_node = func.as<FunctionNode>()) {
+ VisitExpr(GetRef<Function>(function_node));
+ }
+ // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls
therein.
+ }
+
+ void VisitExpr_(const CallNode* call_node) final {
ExprVisitor::VisitExpr_(call_node); }
+
+ void VisitExpr_(const FunctionNode* func_node) final {
+ auto func = GetRef<Function>(func_node);
+ if (visiting_.find(func) == visiting_.end()) {
+ visiting_.insert(func);
+ for (auto param : func_node->params) {
+ ExprVisitor::VisitExpr(param);
+ }
+ ExprVisitor::VisitExpr(func_node->body);
+ }
+ }
+
+ void Trace(std::string entry) {
+ called_funcs_.insert(mod_->GetGlobalVar(entry));
+ auto main_func = mod_->Lookup(entry);
+ VisitExpr(main_func);
+ }
+
+ bool check_if_called(GlobalVar gv) { return called_funcs_.count(gv) > 0; }
+
+ private:
+ IRModule mod_;
+
+ // Record the names of all encountered functions.
+ std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> called_funcs_;
+
+ // Record the expressions that are being visited.
+ std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visiting_;
+};
+
+/*!
+ * \brief Remove functions that are not used.
+ *
+ * \param mod_ IRModule.
+ * \param entry_funcs The set of functions that can be entry function.
+ *
+ * \return The module with dead functions removed.
+ */
+IRModule RemoveUnusedFunctions(IRModule mod_, Array<runtime::String>
entry_funcs) {
+ auto tracer = CallTracer(mod_);
+ for (auto entry : entry_funcs) {
+ tracer.Trace(entry);
+ }
+ auto existing_functions = mod_->functions;
+ for (auto f : existing_functions) {
+ // If a function has an external linkage type, we do not remove it.
+ // Otherwise, we check the function and remove it if it is not used
anywhere.
+ if (f.second->GetLinkageType() == LinkageType::kInternal &&
!tracer.check_if_called(f.first)) {
+ mod_->Remove(f.first);
+ }
+ }
+ return mod_;
+}
+
+} // namespace relax
+
+namespace transform {
+Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule m, PassContext pc) { return relax::RemoveUnusedFunctions(m,
entry_functions); };
+ return CreateModulePass(pass_func, 0, "RemoveUnusedFunctions", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions);
+
+} // namespace transform
+} // namespace tvm
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
new file mode 100644
index 0000000000..d94c1e3b3e
--- /dev/null
+++ b/src/relax/transform/utils.h
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/utils.h
+ * \brief Additional utility classes and functions for working with the Relax
IR.
+ */
+#ifndef TVM_RELAX_TRANSFORM_UTILS_H_
+#define TVM_RELAX_TRANSFORM_UTILS_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+
+#include <string>
+#include <unordered_map>
+
+#include "../../relay/analysis/graph_partitioner.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief A simple wrapper around ExprFunctor for a single argument case.
+ * The result of visit is memoized.
+ */
+template <typename OutputType>
+class MemoizedExprTranslator : public
::tvm::relax::ExprFunctor<OutputType(const Expr&)> {
+ using BaseFunctor = ::tvm::relax::ExprFunctor<OutputType(const Expr&)>;
+
+ public:
+ /*! \brief virtual destructor */
+ virtual ~MemoizedExprTranslator() {}
+
+ /*!
+ * \brief The memoized call.
+ * \param n The expression node.
+ * \return The result of the call
+ */
+ virtual OutputType VisitExpr(const Expr& n) {
+ ICHECK(n.defined());
+ auto it = memo_.find(n);
+ if (it != memo_.end()) {
+ return it->second;
+ }
+ auto res = BaseFunctor::VisitExpr(n);
+ memo_[n] = res;
+ return res;
+ }
+
+ virtual OutputType VisitExpr_(const VarNode* vn) {
+ ICHECK(memo_.count(GetRef<Expr>(vn)));
+ return memo_[GetRef<Expr>(vn)];
+ }
+
+ virtual OutputType VisitBinding_(const VarBindingNode* binding) {
+ ICHECK_EQ(memo_.count(binding->var), 0);
+ auto v = VisitExpr(binding->value);
+ memo_[binding->var] = v;
+ return v;
+ }
+
+ protected:
+ /*! \brief Internal map used for memoization. */
+ std::unordered_map<Expr, OutputType, ObjectPtrHash, ObjectPtrEqual> memo_;
+};
+
+/*!
+ * \brief Remove unused global relax functions in an IRModule.
+ * \param mod The target module
+ * \param entry_functions list of entry functions
+ * \return The updated module.
+ */
+TVM_DLL IRModule RemoveUnusedFunctions(IRModule mod, Array<runtime::String>
entry_funcs);
+
+/*!
+ * \brief Get the external symbol of the Relax function name.
+ *
+ * \param func The provided function.
+ * \return An external symbol.
+ */
+inline std::string GetExtSymbol(const Function& func) {
+ const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(name_node.defined()) << "Fail to retrieve external symbol.";
+ return std::string(name_node.value());
+}
+
+/*!
+ * \brief Fuse ops or functions according to the given partition, and grouped
them into a new
+ * function.
+ *
+ * \param mod The input module.
+ * \param partition A mapping from a subexpression to the containing group.
+ * \param lift_constants Whether or not to lift bound constants to parameters
of the
+ * grouped function.
+ * \return A new module containing grouped functions.
+ */
+IRModule MakeGroupedFunctions(
+ IRModule mod,
+ const std::unordered_map<const Object*, relay::GraphPartitioner::Group*>&
partition,
+ bool lift_constants = true);
+
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_TRANSFORM_UTILS_H_
diff --git a/tests/python/relax/test_transform_remove_unused_funcs.py
b/tests/python/relax/test_transform_remove_unused_funcs.py
new file mode 100644
index 0000000000..8a57b38508
--- /dev/null
+++ b/tests/python/relax/test_transform_remove_unused_funcs.py
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def check_if_func_exists(mod, func_name):
+ gvs = [gv.name_hint for gv in mod.get_global_vars()]
+ return func_name in gvs
+
+
+def test_unused_relax_func():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def tir_add(
+ x: T.Buffer[(16, 16), "float32"],
+ y: T.Buffer[(16, 16), "float32"],
+ z: T.Buffer[(16, 16), "float32"],
+ ) -> None:
+ for i, j in T.grid(16, 16):
+ with T.block("add"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+ @R.function
+ def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16,
16), "float32")):
+ gv0 = R.add(x, w)
+ return gv0
+
+ @R.function
+ def main(
+ x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+ ) -> R.Tensor((16, 16), "float32"):
+ gv0 = R.call_tir(tir_add, (x, w), R.Tensor((16, 16),
dtype="float32"))
+ return gv0
+
+ mod = InputModule
+ assert mod
+ new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+ assert check_if_func_exists(new_mod, "main")
+ assert check_if_func_exists(new_mod, "tir_add")
+ assert not check_if_func_exists(new_mod, "unused_func")
+
+
+def test_unused_relax_func_custom_entry_func():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def tir_add(
+ x: T.Buffer[(16, 16), "float32"],
+ y: T.Buffer[(16, 16), "float32"],
+ z: T.Buffer[(16, 16), "float32"],
+ ) -> None:
+ for i, j in T.grid(16, 16):
+ with T.block("add"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+ @R.function
+ def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16,
16), "float32")):
+ gv0 = R.add(x, w)
+ return gv0
+
+ @R.function
+ def foo(
+ x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+ ) -> R.Tensor((16, 16), "float32"):
+ gv0 = R.call_tir(tir_add, (x, w), R.Tensor((16, 16),
dtype="float32"))
+ return gv0
+
+ mod = InputModule
+ assert mod
+
+ # Test entry function other than "main".
+ new_mod =
relax.transform.RemoveUnusedFunctions(entry_functions=["foo"])(mod)
+ assert check_if_func_exists(new_mod, "foo")
+ assert check_if_func_exists(new_mod, "tir_add")
+ assert not check_if_func_exists(new_mod, "unused_func")
+
+
+def test_unused_relax_func_symbolic_shape():
+ # Test with relax function w/ symbolic shape.
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def tir_add(
+ x: T.Buffer[(16, 16), "float32"],
+ y: T.Buffer[(16, 16), "float32"],
+ z: T.Buffer[(16, 16), "float32"],
+ ) -> None:
+ for i, j in T.grid(16, 16):
+ with T.block("add"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+ @R.function
+ def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n",
"k"), "float32")):
+ gv0 = R.add(x, w)
+ return gv0
+
+ @R.function
+ def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"),
"float32")):
+ m, k = T.var("int64"), T.var("int64")
+ gv0 = R.call_tir(tir_add, (x, w), R.Tensor((m + 1, k),
dtype="float32"))
+ return gv0
+
+ mod = InputModule
+ assert mod
+
+ new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+ assert check_if_func_exists(new_mod, "main")
+ assert check_if_func_exists(new_mod, "tir_add")
+ assert not check_if_func_exists(new_mod, "unused_func")
+
+
+def test_unused_prim_func():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def unused_func(
+ x: T.Buffer[(16, 16), "float32"],
+ y: T.Buffer[(16, 16), "float32"],
+ z: T.Buffer[(16, 16), "float32"],
+ ) -> None:
+ T.func_attr({"global_symbol": "tir_unused"})
+ for i, j in T.grid(16, 16):
+ with T.block("add"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+ @R.function
+ def relax_add(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16),
"float32")):
+ gv0 = R.add(x, w)
+ return gv0
+
+ @R.function
+ def main(
+ x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+ ) -> R.Tensor((16, 16), "float32"):
+ gv0 = relax_add(x, w)
+ return gv0
+
+ mod = InputModule
+ assert mod
+ new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+ assert check_if_func_exists(new_mod, "main")
+ assert check_if_func_exists(new_mod, "relax_add")
+ # RemoveUnusedFunction pass won't remove the function with global symbol
for the external linkage.
+ assert check_if_func_exists(new_mod, "unused_func")
+
+
+def test_multiple_unused_funcs():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def unused_func1(
+ x: T.Buffer[(16, 16), "float32"],
+ y: T.Buffer[(16, 16), "float32"],
+ z: T.Buffer[(16, 16), "float32"],
+ ) -> None:
+ T.func_attr({"global_symbol": "tir_unused"})
+ for i, j in T.grid(16, 16):
+ with T.block("add"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+ @R.function
+ def unused_func2(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16,
16), "float32")):
+ gv0 = R.add(x, w)
+ return gv0
+
+ @R.function
+ def main(
+ x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+ ) -> R.Tensor((16, 16), "float32"):
+ gv0 = R.add(x, w)
+ return gv0
+
+ mod = InputModule
+ assert mod
+
+ new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+ assert check_if_func_exists(new_mod, "main")
+ # RemoveUnusedFunction pass won't remove the function with global symbol
for the external linkage.
+ assert check_if_func_exists(new_mod, "unused_func1")
+ assert not check_if_func_exists(new_mod, "unused_func2")
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])