This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new fc324d0f2c [Unity][Transform] Implement RemoveUnusedParameters (#16116)
fc324d0f2c is described below

commit fc324d0f2cff35f171c702d30be6a10280bafddb
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Dec 1 15:47:31 2023 -0600

    [Unity][Transform] Implement RemoveUnusedParameters (#16116)
    
    * [Unity] Implement RemoveUnusedParameters transform
    
    Currently, the `FuseOps` and `FuseTIR` passes have a large amount of
    added complexity to identify and handle partial use of tuple
    arguments.  The handling partial use of tuples could be significantly
    simpler if performed in multiple steps.
    
    1. Perform `FuseOps`.  Any tuple variables that are used by the fused
       function are passed as-is.
    
    2. Expand any parameters that are passed as a tuple.  Any unused
       tensors that were included in a partially-used tuple will be
       converted to unused parameters.
    
    3. Remove any unused parameters.  Any unused tensors that were
       included in a partially-used tuple will be removed in this
       step.
    
    4. Perform `FuseTIR`.  No checking for tuple arguments, either partial
       or full, is required at this step.
    
    This PR implements `relax.transform.RemoveUnusedParameters`, which is
    step (3) in this process.
    
    * Update based on review comments
---
 include/tvm/relax/transform.h                      |   6 +
 python/tvm/relax/transform/__init__.py             |   1 +
 python/tvm/relax/transform/transform.py            |  10 +
 src/relax/transform/remove_unused_parameters.cc    | 260 +++++++++++++++++++++
 .../test_transform_remove_unused_parameters.py     | 101 ++++++++
 5 files changed, 378 insertions(+)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 8c6417d19d..f743bb53d0 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -281,6 +281,12 @@ TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t 
index);
  */
 TVM_DLL Pass ExpandTupleArguments();
 
+/*! \brief Remove unused parameters to internal functions
+ *
+ * \return The Pass
+ */
+TVM_DLL Pass RemoveUnusedParameters();
+
 /*! \brief Remove unused outputs from internal functions
  *
  * \return The Pass
diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index b6887160c8..c3f037da5f 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -56,6 +56,7 @@ from .transform import (
     PatternCheckContext,
     RealizeVDevice,
     RemovePurityChecking,
+    RemoveUnusedParameters,
     RemoveUnusedOutputs,
     RewriteCUDAGraph,
     RewriteDataflowReshape,
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 1beb535f0b..0af89b7d9e 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -568,6 +568,16 @@ def ExpandTupleArguments() -> tvm.ir.transform.Pass:
     return _ffi_api.ExpandTupleArguments()  # type: ignore
 
 
+def RemoveUnusedParameters() -> tvm.ir.transform.Pass:
+    """Remove unused arguments to internal functions
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.RemoveUnusedParameters()  # type: ignore
+
+
 def RemoveUnusedOutputs() -> tvm.ir.transform.Pass:
     """Remove unused outputs from internal functions
 
diff --git a/src/relax/transform/remove_unused_parameters.cc 
b/src/relax/transform/remove_unused_parameters.cc
new file mode 100644
index 0000000000..d053d56f32
--- /dev/null
+++ b/src/relax/transform/remove_unused_parameters.cc
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/utils.h>
+
+#include <algorithm>
+#include <optional>
+#include <tuple>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+namespace {
+
+template <typename T>
+using PSet = std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>;
+
+template <typename T, typename U>
+using PMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
+
+/* \brief Describes the modifications to be made for a function */
+struct CalleeAnalysis {
+  /* \brief The updated private function */
+  Function func;
+
+  /* \brief A function that updates the callsite arguments
+   *
+   * \param The arguments used to call the original function
+   *
+   * \return The arguments to be used for the modified function
+   */
+  std::function<Array<Expr>(Array<Expr>)> arg_updater;
+};
+
+std::optional<CalleeAnalysis> AnalyzeCallee(Function func) {
+  bool is_exposed = 
func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+  if (is_exposed) return std::nullopt;
+
+  auto free_relax_vars = [&]() -> PSet<Var> {
+    auto array_free_vars = FreeVars(func->body);
+    return {array_free_vars.begin(), array_free_vars.end()};
+  }();
+
+  std::vector<bool> parameter_mask;
+  parameter_mask.reserve(func->params.size());
+
+  Array<Var> params;
+  for (const auto& param : func->params) {
+    bool is_used = free_relax_vars.count(param);
+    parameter_mask.push_back(is_used);
+    if (is_used) {
+      params.push_back(param);
+    }
+  }
+
+  if (func->params.size() == params.size()) {
+    // Early bail-out for the common case where the function uses all
+    // of its parameters.
+    return std::nullopt;
+  }
+
+  // Even if a parameter is unused, it may provide definitions for
+  // symbolic variables.  We still want to remove the relax variable
+  // to reduce computational steps in the parent, but we need to
+  // provide the symbolic variables the other steps.
+  auto defined_tir_params = [&]() -> PSet<tir::Var> {
+    auto param_sinfo =
+        TupleStructInfo(params.Map([](const auto& var) { return 
GetStructInfo(var); }));
+    auto arr = DefinableTIRVarsInStructInfo(param_sinfo);
+    return {arr.begin(), arr.end()};
+  }();
+
+  // Use an array to define the order of the symbolic variables
+  Array<tir::Var> free_tir_vars;
+  for (const auto& tir_var : FreeSymbolicVars(func->body)) {
+    if (!defined_tir_params.count(tir_var)) {
+      free_tir_vars.push_back(tir_var);
+    }
+  }
+
+  for (const auto& tir_var : free_tir_vars) {
+    Var relax_var("param_" + tir_var->name_hint, PrimStructInfo(tir_var));
+    params.push_back(relax_var);
+  }
+
+  FuncStructInfo new_sinfo(params.Map([](const auto& var) { return 
GetStructInfo(var); }),
+                           func->ret_struct_info,
+                           
Downcast<FuncStructInfo>(func->struct_info_)->purity);
+
+  auto arg_updater = [parameter_mask, old_relax_params = func->params,
+                      free_tir_vars](Array<Expr> old_args) -> Array<Expr> {
+    ICHECK_EQ(old_args.size(), parameter_mask.size())
+        << "Call provides " << old_args.size() << ", but the callee accepts "
+        << parameter_mask.size() << " parameters";
+
+    Array<Expr> new_args;
+    for (size_t i = 0; i < old_args.size(); i++) {
+      if (parameter_mask.at(i)) {
+        new_args.push_back(old_args[i]);
+      }
+    }
+
+    if (free_tir_vars.size()) {
+      Map<Var, Expr> old_binding;
+      for (size_t i = 0; i < old_relax_params.size(); i++) {
+        old_binding.Set(old_relax_params[i], old_args[i]);
+      }
+      arith::Analyzer analyzer;
+      auto tir_binding = InferSymbolicVarMap(old_binding, &analyzer);
+
+      for (const auto& tir_var : free_tir_vars) {
+        new_args.push_back(PrimValue(tir_binding.at(tir_var)));
+      }
+    }
+
+    return new_args;
+  };
+
+  auto write_ptr = func.CopyOnWrite();
+  write_ptr->params = params;
+  write_ptr->struct_info_ = new_sinfo;
+
+  return CalleeAnalysis{func, arg_updater};
+}
+
+class CallSiteMutator : public ExprMutator {
+ public:
+  explicit CallSiteMutator(PMap<GlobalVar, std::function<Call(Call)>> 
callsite_updaters)
+      : callsite_updaters_(callsite_updaters) {}
+
+  using ExprMutator::VisitExpr_;
+
+  Expr VisitExpr_(const FunctionNode* op) override {
+    auto node = ExprMutator::VisitExpr_(op);
+
+    // If a function was modified, that means it called into a private
+    // function that now takes a reduced number of arguments.  Some
+    // bindings in the calling scope, previously used to define those
+    // unused arguments, may be able to be removed as a result.
+    if (node.get() != op) {
+      node = RemoveAllUnused(node);
+    }
+    return node;
+  }
+
+  Expr VisitExpr_(const CallNode* op) override {
+    auto node = Downcast<Call>(ExprMutator::VisitExpr_(op));
+
+    if (auto gvar = node->op.as<GlobalVar>()) {
+      if (auto it = callsite_updaters_.find(gvar.value()); it != 
callsite_updaters_.end()) {
+        node = it->second(std::move(node));
+      }
+    }
+
+    return node;
+  }
+
+  PMap<GlobalVar, std::function<Call(Call)>> callsite_updaters_;
+};
+
+}  // namespace
+
+namespace transform {
+
+Pass RemoveUnusedParameters() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule mod, PassContext pc) -> IRModule {
+    PMap<GlobalVar, std::function<Call(Call)>> callsite_updaters;
+
+    {
+      IRModule new_callees;
+
+      for (const auto& [gvar, base_func] : mod->functions) {
+        if (auto func = base_func.as<Function>()) {
+          if (auto callee_res = AnalyzeCallee(func.value())) {
+            auto new_func = callee_res->func;
+            GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_);
+            new_gvar->struct_info_ = new_func->struct_info_;
+            new_callees->Add(new_gvar, new_func);
+
+            callsite_updaters[gvar] = [old_gvar = gvar, new_gvar,
+                                       arg_updater = 
callee_res->arg_updater](Call call) -> Call {
+              ICHECK(call->op.same_as(old_gvar)) << "InternalError: "
+                                                 << "Updater should be applied 
to " << old_gvar
+                                                 << ", but was applied to " << 
call->op;
+              auto write_ptr = call.CopyOnWrite();
+              write_ptr->op = new_gvar;
+              write_ptr->args = arg_updater(call->args);
+              return call;
+            };
+          }
+        }
+      }
+
+      if (callsite_updaters.empty()) {
+        return mod;
+      }
+      auto write_ptr = mod.CopyOnWrite();
+
+      // Remove any private subroutines that have unused parameters,
+      // then add the updated versions.  The new private functions
+      // have the same name, but require a new GlobalVar to hold the
+      // updated StructInfo.  As a result, calling `Update()` without
+      // first calling `Remove()` introduce a duplicate name and
+      // produce an error.
+      for (const auto& it : callsite_updaters) {
+        write_ptr->Remove(it.first);
+      }
+      write_ptr->Update(new_callees);
+    }
+
+    CallSiteMutator mutator(std::move(callsite_updaters));
+
+    IRModule caller_updates;
+
+    for (const auto& [gvar, base_func] : mod->functions) {
+      if (auto func = base_func.as<Function>()) {
+        auto mutated = Downcast<Function>(mutator.VisitExpr(func.value()));
+        if (!mutated.same_as(base_func)) {
+          caller_updates->Add(gvar, mutated);
+        }
+      }
+    }
+
+    if (caller_updates->functions.size()) {
+      mod.CopyOnWrite()->Update(caller_updates);
+    }
+    return mod;
+  };
+  return CreateModulePass(pass_func, 0, "RemoveUnusedParameters", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedParameters")
+    .set_body_typed(RemoveUnusedParameters);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_remove_unused_parameters.py 
b/tests/python/relax/test_transform_remove_unused_parameters.py
new file mode 100644
index 0000000000..82c8d0bd1d
--- /dev/null
+++ b/tests/python/relax/test_transform_remove_unused_parameters.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import ir as I, relax as R, tir as T
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.relax.transform.RemoveUnusedParameters()
+
+
+class TestSimple(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor, B: R.Tensor):
+            return Before.func(A, B)
+
+        @R.function(private=True)
+        def func(A: R.Tensor, B: R.Tensor) -> R.Tensor:
+            return A
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor, B: R.Tensor):
+            return Expected.func(A)
+
+        @R.function(private=True)
+        def func(A: R.Tensor) -> R.Tensor:
+            return A
+
+
+class TestSymbolicVariables(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], 
"float32"):
+            return Before.func(A)
+
+        @R.function(private=True)
+        def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], 
"float32"):
+            m = T.int64()
+            n = T.int64()
+            return R.zeros(R.shape([m, n]), dtype="float32")
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], 
"float32"):
+            m = T.int64()
+            n = T.int64()
+            out: R.Tensor([m, n], "float32") = Expected.func(R.prim_value(n), 
R.prim_value(m))
+            return out
+
+        @R.function(private=True)
+        def func(
+            param_n: R.Prim(value="n"), param_m: R.Prim(value="m")
+        ) -> R.Tensor(["m", "n"], "float32"):
+            m = T.int64()
+            n = T.int64()
+            return R.zeros(R.shape([m, n]), dtype="float32")
+
+
+class TestNoExtraSymbolicVariables(BaseCompare):
+    """Don't add symbolic variables if they can be inferred."""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], 
"float32"):
+            return Before.func(A)
+
+        @R.function(private=True)
+        def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], 
"float32"):
+            m = T.int64()
+            n = T.int64()
+            zeros = R.zeros(R.shape([m, n]), dtype="float32")
+            out = R.add(A, zeros)
+            return out
+
+    Expected = Before
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to