This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 346f5d4ab0 [Unity][Pass] Remove Unused Function (#14061)
346f5d4ab0 is described below

commit 346f5d4ab03927d2d850314a1c29104a128f36bd
Author: Sunghyun Park <[email protected]>
AuthorDate: Mon Feb 20 21:44:11 2023 -0800

    [Unity][Pass] Remove Unused Function (#14061)
    
    This PR implements a pass to clean up unused functions.
    
    Co-authored-by: masahi <[email protected]>
---
 python/tvm/ir/function.py                          |  26 ++-
 src/relax/transform/remove_unused_funcs.cc         | 120 ++++++++++++
 src/relax/transform/utils.h                        | 122 ++++++++++++
 .../relax/test_transform_remove_unused_funcs.py    | 211 +++++++++++++++++++++
 4 files changed, 475 insertions(+), 4 deletions(-)

diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py
index d02698edb5..b64553d31c 100644
--- a/python/tvm/ir/function.py
+++ b/python/tvm/ir/function.py
@@ -14,11 +14,13 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Function defintiions."""
+"""Function definitions."""
+from typing import Union, Dict
 from enum import IntEnum
 import tvm.runtime
-
+from tvm.runtime.object import Object
 from .expr import RelayExpr
+from .attrs import DictAttrs
 from . import _ffi_api
 
 
@@ -38,7 +40,7 @@ class BaseFunc(RelayExpr):
         """Return the attrs member of the function."""
         return _ffi_api.BaseFunc_Attrs(self)
 
-    def with_attr(self, attr_key_or_dict, attr_value=None):
+    def with_attr(self, attr_key_or_dict, attr_value=None) -> "BaseFunc":
         """Create a new copy of the function and update the attribute.
 
         Parameters
@@ -51,7 +53,7 @@ class BaseFunc(RelayExpr):
 
         Returns
         -------
-        func : Function
+        func : BaseFunc
             A new copy of the function
         """
         # make sure we first copy so that we can safely do copy on write
@@ -67,6 +69,22 @@ class BaseFunc(RelayExpr):
             res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value)
         )
 
+    def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> 
"BaseFunc":
+        """Copy the IRModule and add the given attribute map to it.
+        Parameters
+        ----------
+        attr_map: Union[DictAttrs, Dict[str, Object]]
+            The attribute map
+        Returns
+        -------
+        func : BaseFunc
+            A new copy of the function
+        """
+        if isinstance(attr_map, tvm.ir.DictAttrs):
+            attr_map = attr_map._dict()
+
+        return _ffi_api.BaseFuncWithAttrs(self, attr_map)
+
     def without_attr(self, attr_key: str) -> "BaseFunc":
         """Create a new copy of the function with an attribute without 
provided key.
 
diff --git a/src/relax/transform/remove_unused_funcs.cc 
b/src/relax/transform/remove_unused_funcs.cc
new file mode 100644
index 0000000000..5572da1338
--- /dev/null
+++ b/src/relax/transform/remove_unused_funcs.cc
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/remove_unused_funcs.cc
+ * \brief Remove unused global relax functions in a IRModule.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include <unordered_set>
+#include <vector>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+/**
+ * \brief Detects all the functions that can be possibly called by entry 
function.
+ */
+class CallTracer : ExprVisitor {
+ public:
+  explicit CallTracer(IRModule mod_) : mod_{mod_}, called_funcs_{}, 
visiting_{} {}
+
+  void VisitExpr_(const GlobalVarNode* op) final {
+    called_funcs_.insert(GetRef<GlobalVar>(op));
+    auto func = mod_->Lookup(op->name_hint);
+    if (const auto* function_node = func.as<FunctionNode>()) {
+      VisitExpr(GetRef<Function>(function_node));
+    }
+    // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls 
therein.
+  }
+
+  void VisitExpr_(const CallNode* call_node) final { 
ExprVisitor::VisitExpr_(call_node); }
+
+  void VisitExpr_(const FunctionNode* func_node) final {
+    auto func = GetRef<Function>(func_node);
+    if (visiting_.find(func) == visiting_.end()) {
+      visiting_.insert(func);
+      for (auto param : func_node->params) {
+        ExprVisitor::VisitExpr(param);
+      }
+      ExprVisitor::VisitExpr(func_node->body);
+    }
+  }
+
+  void Trace(std::string entry) {
+    called_funcs_.insert(mod_->GetGlobalVar(entry));
+    auto main_func = mod_->Lookup(entry);
+    VisitExpr(main_func);
+  }
+
+  bool check_if_called(GlobalVar gv) { return called_funcs_.count(gv) > 0; }
+
+ private:
+  IRModule mod_;
+
+  // Record the names of all encountered functions.
+  std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> called_funcs_;
+
+  // Record the expressions that are being visited.
+  std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visiting_;
+};
+
+/*!
+ * \brief Remove functions that are not used.
+ *
+ * \param mod_ IRModule.
+ * \param entry_funcs The set of functions that can be entry function.
+ *
+ * \return The module with dead functions removed.
+ */
+IRModule RemoveUnusedFunctions(IRModule mod_, Array<runtime::String> 
entry_funcs) {
+  auto tracer = CallTracer(mod_);
+  for (auto entry : entry_funcs) {
+    tracer.Trace(entry);
+  }
+  auto existing_functions = mod_->functions;
+  for (auto f : existing_functions) {
+    // If a function has an external linkage type, we do not remove it.
+    // Otherwise, we check the function and remove it if it is not used 
anywhere.
+    if (f.second->GetLinkageType() == LinkageType::kInternal && 
!tracer.check_if_called(f.first)) {
+      mod_->Remove(f.first);
+    }
+  }
+  return mod_;
+}
+
+}  // namespace relax
+
+namespace transform {
+Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule m, PassContext pc) { return relax::RemoveUnusedFunctions(m, 
entry_functions); };
+  return CreateModulePass(pass_func, 0, "RemoveUnusedFunctions", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions);
+
+}  // namespace transform
+}  // namespace tvm
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
new file mode 100644
index 0000000000..d94c1e3b3e
--- /dev/null
+++ b/src/relax/transform/utils.h
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/utils.h
+ * \brief Additional utility classes and functions for working with the Relax 
IR.
+ */
+#ifndef TVM_RELAX_TRANSFORM_UTILS_H_
+#define TVM_RELAX_TRANSFORM_UTILS_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+
+#include <string>
+#include <unordered_map>
+
+#include "../../relay/analysis/graph_partitioner.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief A simple wrapper around ExprFunctor for a single argument case.
+ *  The result of visit is memoized.
+ */
+template <typename OutputType>
+class MemoizedExprTranslator : public 
::tvm::relax::ExprFunctor<OutputType(const Expr&)> {
+  using BaseFunctor = ::tvm::relax::ExprFunctor<OutputType(const Expr&)>;
+
+ public:
+  /*! \brief virtual destructor */
+  virtual ~MemoizedExprTranslator() {}
+
+  /*!
+   * \brief The memoized call.
+   * \param n The expression node.
+   * \return The result of the call
+   */
+  virtual OutputType VisitExpr(const Expr& n) {
+    ICHECK(n.defined());
+    auto it = memo_.find(n);
+    if (it != memo_.end()) {
+      return it->second;
+    }
+    auto res = BaseFunctor::VisitExpr(n);
+    memo_[n] = res;
+    return res;
+  }
+
+  virtual OutputType VisitExpr_(const VarNode* vn) {
+    ICHECK(memo_.count(GetRef<Expr>(vn)));
+    return memo_[GetRef<Expr>(vn)];
+  }
+
+  virtual OutputType VisitBinding_(const VarBindingNode* binding) {
+    ICHECK_EQ(memo_.count(binding->var), 0);
+    auto v = VisitExpr(binding->value);
+    memo_[binding->var] = v;
+    return v;
+  }
+
+ protected:
+  /*! \brief Internal map used for memoization. */
+  std::unordered_map<Expr, OutputType, ObjectPtrHash, ObjectPtrEqual> memo_;
+};
+
+/*!
+ * \brief Remove unused global relax functions in an IRModule.
+ * \param mod The target module
+ * \param entry_functions list of entry functions
+ * \return The updated module.
+ */
+TVM_DLL IRModule RemoveUnusedFunctions(IRModule mod, Array<runtime::String> 
entry_funcs);
+
+/*!
+ * \brief Get the external symbol of the Relax function name.
+ *
+ * \param func The provided function.
+ * \return An external symbol.
+ */
+inline std::string GetExtSymbol(const Function& func) {
+  const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+  ICHECK(name_node.defined()) << "Fail to retrieve external symbol.";
+  return std::string(name_node.value());
+}
+
+/*!
+ * \brief Fuse ops or functions according to the given partition, and grouped 
them into a new
+ * function.
+ *
+ * \param mod The input module.
+ * \param partition A mapping from a subexpression to the containing group.
+ * \param lift_constants Whether or not to lift bound constants to parameters 
of the
+ * grouped function.
+ * \return A new module containing grouped functions.
+ */
+IRModule MakeGroupedFunctions(
+    IRModule mod,
+    const std::unordered_map<const Object*, relay::GraphPartitioner::Group*>& 
partition,
+    bool lift_constants = true);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_TRANSFORM_UTILS_H_
diff --git a/tests/python/relax/test_transform_remove_unused_funcs.py 
b/tests/python/relax/test_transform_remove_unused_funcs.py
new file mode 100644
index 0000000000..8a57b38508
--- /dev/null
+++ b/tests/python/relax/test_transform_remove_unused_funcs.py
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def check_if_func_exists(mod, func_name):
+    gvs = [gv.name_hint for gv in mod.get_global_vars()]
+    return func_name in gvs
+
+
+def test_unused_relax_func():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def tir_add(
+            x: T.Buffer[(16, 16), "float32"],
+            y: T.Buffer[(16, 16), "float32"],
+            z: T.Buffer[(16, 16), "float32"],
+        ) -> None:
+            for i, j in T.grid(16, 16):
+                with T.block("add"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+        @R.function
+        def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 
16), "float32")):
+            gv0 = R.add(x, w)
+            return gv0
+
+        @R.function
+        def main(
+            x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+        ) -> R.Tensor((16, 16), "float32"):
+            gv0 = R.call_tir(tir_add, (x, w), R.Tensor((16, 16), 
dtype="float32"))
+            return gv0
+
+    mod = InputModule
+    assert mod
+    new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+    assert check_if_func_exists(new_mod, "main")
+    assert check_if_func_exists(new_mod, "tir_add")
+    assert not check_if_func_exists(new_mod, "unused_func")
+
+
+def test_unused_relax_func_custom_entry_func():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def tir_add(
+            x: T.Buffer[(16, 16), "float32"],
+            y: T.Buffer[(16, 16), "float32"],
+            z: T.Buffer[(16, 16), "float32"],
+        ) -> None:
+            for i, j in T.grid(16, 16):
+                with T.block("add"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+        @R.function
+        def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 
16), "float32")):
+            gv0 = R.add(x, w)
+            return gv0
+
+        @R.function
+        def foo(
+            x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+        ) -> R.Tensor((16, 16), "float32"):
+            gv0 = R.call_tir(tir_add, (x, w), R.Tensor((16, 16), 
dtype="float32"))
+            return gv0
+
+    mod = InputModule
+    assert mod
+
+    # Test entry function other than "main".
+    new_mod = 
relax.transform.RemoveUnusedFunctions(entry_functions=["foo"])(mod)
+    assert check_if_func_exists(new_mod, "foo")
+    assert check_if_func_exists(new_mod, "tir_add")
+    assert not check_if_func_exists(new_mod, "unused_func")
+
+
+def test_unused_relax_func_symbolic_shape():
+    # Test with relax function w/ symbolic shape.
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def tir_add(
+            x: T.Buffer[(16, 16), "float32"],
+            y: T.Buffer[(16, 16), "float32"],
+            z: T.Buffer[(16, 16), "float32"],
+        ) -> None:
+            for i, j in T.grid(16, 16):
+                with T.block("add"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+        @R.function
+        def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", 
"k"), "float32")):
+            gv0 = R.add(x, w)
+            return gv0
+
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")):
+            m, k = T.var("int64"), T.var("int64")
+            gv0 = R.call_tir(tir_add, (x, w), R.Tensor((m + 1, k), 
dtype="float32"))
+            return gv0
+
+    mod = InputModule
+    assert mod
+
+    new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+    assert check_if_func_exists(new_mod, "main")
+    assert check_if_func_exists(new_mod, "tir_add")
+    assert not check_if_func_exists(new_mod, "unused_func")
+
+
+def test_unused_prim_func():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def unused_func(
+            x: T.Buffer[(16, 16), "float32"],
+            y: T.Buffer[(16, 16), "float32"],
+            z: T.Buffer[(16, 16), "float32"],
+        ) -> None:
+            T.func_attr({"global_symbol": "tir_unused"})
+            for i, j in T.grid(16, 16):
+                with T.block("add"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+        @R.function
+        def relax_add(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), 
"float32")):
+            gv0 = R.add(x, w)
+            return gv0
+
+        @R.function
+        def main(
+            x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+        ) -> R.Tensor((16, 16), "float32"):
+            gv0 = relax_add(x, w)
+            return gv0
+
+    mod = InputModule
+    assert mod
+    new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+    assert check_if_func_exists(new_mod, "main")
+    assert check_if_func_exists(new_mod, "relax_add")
+    # RemoveUnusedFunction pass won't remove the function with global symbol 
for the external linkage.
+    assert check_if_func_exists(new_mod, "unused_func")
+
+
+def test_multiple_unused_funcs():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def unused_func1(
+            x: T.Buffer[(16, 16), "float32"],
+            y: T.Buffer[(16, 16), "float32"],
+            z: T.Buffer[(16, 16), "float32"],
+        ) -> None:
+            T.func_attr({"global_symbol": "tir_unused"})
+            for i, j in T.grid(16, 16):
+                with T.block("add"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+        @R.function
+        def unused_func2(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 
16), "float32")):
+            gv0 = R.add(x, w)
+            return gv0
+
+        @R.function
+        def main(
+            x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+        ) -> R.Tensor((16, 16), "float32"):
+            gv0 = R.add(x, w)
+            return gv0
+
+    mod = InputModule
+    assert mod
+
+    new_mod = relax.transform.RemoveUnusedFunctions()(mod)
+    assert check_if_func_exists(new_mod, "main")
+    # RemoveUnusedFunction pass won't remove the function with global symbol 
for the external linkage.
+    assert check_if_func_exists(new_mod, "unused_func1")
+    assert not check_if_func_exists(new_mod, "unused_func2")
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])

Reply via email to