slyubomirsky commented on code in PR #16116:
URL: https://github.com/apache/tvm/pull/16116#discussion_r1411508481


##########
src/relax/transform/remove_unused_parameters.cc:
##########
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/utils.h>
+
+#include <algorithm>
+#include <optional>
+#include <tuple>
+
+namespace tvm {
+namespace relax {
+
+namespace {
+
+template <typename T>
+using PSet = std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>;
+
+template <typename T, typename U>
+using PMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
+
+struct CalleeAnalysis {
+  // The updated function
+  Function func;
+
+  // A mutator that updates calls at the call site.
+  std::function<Array<Expr>(Array<Expr>)> arg_updater;
+};
+
+std::optional<CalleeAnalysis> AnalyzeCallee(Function func) {
+  bool is_exposed = 
func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+  if (is_exposed) return std::nullopt;
+
+  auto free_relax_vars = [&]() -> PSet<Var> {
+    auto array_free_vars = FreeVars(func->body);
+    return {array_free_vars.begin(), array_free_vars.end()};
+  }();
+
+  std::vector<bool> parameter_mask;
+  parameter_mask.reserve(func->params.size());
+
+  Array<Var> params;
+  for (const auto& param : func->params) {
+    bool is_used = free_relax_vars.count(param);
+    parameter_mask.push_back(is_used);
+    if (is_used) {
+      params.push_back(param);
+    }
+  }
+
+  if (func->params.size() == params.size()) {
+    // Early bail-out for the common case where the function uses all
+    // of its parameters.
+    return std::nullopt;
+  }
+
+  // Even if a parameter is unused, it may provide definitions for
+  // symbolic variables.  We still want to remove the relax variable
+  // to reduce computational steps in the parent, but we need to
+  // provide the symbolic variables the other steps.
+  auto defined_tir_params = [&]() -> PSet<tir::Var> {
+    auto param_sinfo =
+        TupleStructInfo(params.Map([](const auto& var) { return 
GetStructInfo(var); }));
+    auto arr = DefinableTIRVarsInStructInfo(param_sinfo);
+    return {arr.begin(), arr.end()};
+  }();
+
+  // Use an array to define the order of the symbolic variables
+  Array<tir::Var> free_tir_vars;
+  for (const auto& tir_var : FreeSymbolicVars(func->body)) {
+    if (!defined_tir_params.count(tir_var)) {
+      free_tir_vars.push_back(tir_var);
+    }
+  }
+
+  for (const auto& tir_var : free_tir_vars) {
+    Var relax_var("param_" + tir_var->name_hint, PrimStructInfo(tir_var));
+    params.push_back(relax_var);
+  }
+
+  FuncStructInfo new_sinfo(params.Map([](const auto& var) { return 
GetStructInfo(var); }),
+                           func->ret_struct_info,
+                           
Downcast<FuncStructInfo>(func->struct_info_)->purity);
+
+  auto arg_updater = [parameter_mask, old_relax_params = func->params,
+                      free_tir_vars](Array<Expr> old_args) -> Array<Expr> {
+    ICHECK_EQ(old_args.size(), parameter_mask.size())
+        << "Call provides " << old_args.size() << ", but the callee accepts "
+        << parameter_mask.size() << " parameters";
+
+    Array<Expr> new_args;
+    for (size_t i = 0; i < old_args.size(); i++) {
+      if (parameter_mask.at(i)) {
+        new_args.push_back(old_args[i]);
+      }
+    }
+
+    if (free_tir_vars.size()) {
+      Map<Var, Expr> old_binding;
+      for (size_t i = 0; i < old_relax_params.size(); i++) {
+        old_binding.Set(old_relax_params[i], old_args[i]);
+      }
+      arith::Analyzer analyzer;
+      auto tir_binding = InferSymbolicVarMap(old_binding, &analyzer);
+
+      for (const auto& tir_var : free_tir_vars) {
+        new_args.push_back(PrimValue(tir_binding.at(tir_var)));
+      }
+    }
+
+    return new_args;
+  };
+
+  auto write_ptr = func.CopyOnWrite();
+  write_ptr->params = params;
+  write_ptr->struct_info_ = new_sinfo;
+
+  return CalleeAnalysis{func, arg_updater};
+}
+
+class CallSiteMutator : public ExprMutator {
+ public:
+  explicit CallSiteMutator(PMap<GlobalVar, std::function<Call(Call)>> 
callsite_updaters)
+      : callsite_updaters_(callsite_updaters) {}
+
+  using ExprMutator::VisitExpr_;
+
+  Expr VisitExpr_(const CallNode* op) override {
+    auto node = Downcast<Call>(ExprMutator::VisitExpr_(op));
+
+    if (auto gvar = node->op.as<GlobalVar>()) {
+      if (auto it = callsite_updaters_.find(gvar.value()); it != 
callsite_updaters_.end()) {
+        node = it->second(std::move(node));
+      }
+    }
+
+    return node;
+  }
+
+  PMap<GlobalVar, std::function<Call(Call)>> callsite_updaters_;
+};
+
+}  // namespace
+
+namespace transform {
+
+Pass RemoveUnusedParameters() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule mod, PassContext pc) -> IRModule {
+    PMap<GlobalVar, std::function<Call(Call)>> callsite_updaters;
+
+    {
+      IRModule new_callees;
+
+      for (const auto& [gvar, base_func] : mod->functions) {
+        if (auto func = base_func.as<Function>()) {
+          if (auto callee_res = AnalyzeCallee(func.value())) {
+            auto new_func = callee_res->func;
+            GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_);
+            new_gvar->struct_info_ = new_func->struct_info_;
+            new_callees->Add(new_gvar, new_func);
+
+            callsite_updaters[gvar] = [old_gvar = gvar, new_gvar,
+                                       arg_updater = 
callee_res->arg_updater](Call call) -> Call {
+              ICHECK(call->op.same_as(old_gvar)) << "InternalError: "
+                                                 << "Updater should be applied 
to " << old_gvar
+                                                 << ", but was applied to " << 
call->op;
+              auto write_ptr = call.CopyOnWrite();
+              write_ptr->op = new_gvar;
+              write_ptr->args = arg_updater(call->args);
+              return call;
+            };
+          }
+        }
+      }
+
+      if (callsite_updaters.empty()) {
+        return mod;
+      }
+      auto write_ptr = mod.CopyOnWrite();
+      for (const auto& it : callsite_updaters) {
+        write_ptr->Remove(it.first);
+      }
+      write_ptr->Update(new_callees);

Review Comment:
   Might be worth putting in a comment that this is to transfer over the 
functions that will be remaining unchanged. I had to think for a second to 
realize what this was doing.



##########
src/relax/transform/remove_unused_parameters.cc:
##########
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/utils.h>
+
+#include <algorithm>
+#include <optional>
+#include <tuple>
+
+namespace tvm {
+namespace relax {
+
+namespace {
+
+template <typename T>
+using PSet = std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>;
+
+template <typename T, typename U>
+using PMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
+
+struct CalleeAnalysis {
+  // The updated function
+  Function func;
+
+  // A mutator that updates calls at the call site.
+  std::function<Array<Expr>(Array<Expr>)> arg_updater;
+};
+
+std::optional<CalleeAnalysis> AnalyzeCallee(Function func) {
+  bool is_exposed = 
func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+  if (is_exposed) return std::nullopt;
+
+  auto free_relax_vars = [&]() -> PSet<Var> {
+    auto array_free_vars = FreeVars(func->body);
+    return {array_free_vars.begin(), array_free_vars.end()};
+  }();
+
+  std::vector<bool> parameter_mask;
+  parameter_mask.reserve(func->params.size());
+
+  Array<Var> params;
+  for (const auto& param : func->params) {
+    bool is_used = free_relax_vars.count(param);
+    parameter_mask.push_back(is_used);
+    if (is_used) {
+      params.push_back(param);
+    }
+  }
+
+  if (func->params.size() == params.size()) {
+    // Early bail-out for the common case where the function uses all
+    // of its parameters.
+    return std::nullopt;
+  }
+
+  // Even if a parameter is unused, it may provide definitions for
+  // symbolic variables.  We still want to remove the relax variable
+  // to reduce computational steps in the parent, but we need to
+  // provide the symbolic variables the other steps.
+  auto defined_tir_params = [&]() -> PSet<tir::Var> {
+    auto param_sinfo =
+        TupleStructInfo(params.Map([](const auto& var) { return 
GetStructInfo(var); }));
+    auto arr = DefinableTIRVarsInStructInfo(param_sinfo);
+    return {arr.begin(), arr.end()};
+  }();
+
+  // Use an array to define the order of the symbolic variables
+  Array<tir::Var> free_tir_vars;
+  for (const auto& tir_var : FreeSymbolicVars(func->body)) {
+    if (!defined_tir_params.count(tir_var)) {
+      free_tir_vars.push_back(tir_var);
+    }
+  }
+
+  for (const auto& tir_var : free_tir_vars) {
+    Var relax_var("param_" + tir_var->name_hint, PrimStructInfo(tir_var));
+    params.push_back(relax_var);
+  }
+
+  FuncStructInfo new_sinfo(params.Map([](const auto& var) { return 
GetStructInfo(var); }),
+                           func->ret_struct_info,
+                           
Downcast<FuncStructInfo>(func->struct_info_)->purity);
+
+  auto arg_updater = [parameter_mask, old_relax_params = func->params,

Review Comment:
   >old_relax_params = func->params
   
   I did not know about this syntax, very cool



##########
src/relax/transform/remove_unused_parameters.cc:
##########
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/utils.h>
+
+#include <algorithm>
+#include <optional>
+#include <tuple>
+
+namespace tvm {
+namespace relax {
+
+namespace {
+
+template <typename T>
+using PSet = std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>;
+
+template <typename T, typename U>
+using PMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
+
+struct CalleeAnalysis {
+  // The updated function
+  Function func;
+
+  // A mutator that updates calls at the call site.
+  std::function<Array<Expr>(Array<Expr>)> arg_updater;

Review Comment:
   The comment should probably explain how the function is meant to be used. 
The use of the word "mutator" can be confused for ExprMutator, so it might be 
better to be a little more specific.



##########
src/relax/transform/remove_unused_parameters.cc:
##########
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/utils.h>
+
+#include <algorithm>
+#include <optional>
+#include <tuple>
+
+namespace tvm {
+namespace relax {
+
+namespace {
+
+template <typename T>
+using PSet = std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>;

Review Comment:
   What does the P stand for?



##########
tests/python/relax/test_transform_remove_unused_parameters.py:
##########
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import ir as I, relax as R, tir as T
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.relax.transform.RemoveUnusedParameters()
+
+
+class TestSimple(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor, B: R.Tensor):
+            return Before.func(A, B)
+
+        @R.function(private=True)
+        def func(A: R.Tensor, B: R.Tensor) -> R.Tensor:
+            return A
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor, B: R.Tensor):
+            return Expected.func(A)
+
+        @R.function(private=True)
+        def func(A: R.Tensor) -> R.Tensor:
+            return A
+
+
+class TestSymbolicVariables(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], 
"float32"):
+            return Before.func(A)
+
+        @R.function(private=True)
+        def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], 
"float32"):
+            m = T.int64()
+            n = T.int64()
+            return R.zeros(R.shape([m, n]), dtype="float32")
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], 
"float32"):
+            m = T.int64()
+            n = T.int64()
+            out: R.Tensor([m, n], "float32") = Expected.func(R.prim_value(n), 
R.prim_value(m))
+            return out
+
+        @R.function(private=True)
+        def func(
+            param_n: R.Prim(value="n"), param_m: R.Prim(value="m")
+        ) -> R.Tensor(["m", "n"], "float32"):
+            m = T.int64()
+            n = T.int64()
+            return R.zeros(R.shape([m, n]), dtype="float32")
+
+
+class TestNoExtraSymbolicVariables(BaseCompare):
+    """Don't add symbolic variables if they can be inferred."""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], 
"float32"):
+            return Before.func(A)
+
+        @R.function(private=True)
+        def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], 
"float32"):
+            m = T.int64()
+            n = T.int64()
+            zeros = R.zeros(R.shape([m, n]), dtype="float32")
+            out = R.add(A, zeros)
+            return out
+
+    Expected = Before
+

Review Comment:
   One test case that might have interesting results would be with an inner 
function. I think (haven't thought it through all the way) that you might have 
to separately invoke the transformation on inner functions. Granted, the issue 
is moot if lambda lifting is used earlier.



##########
src/relax/transform/remove_unused_parameters.cc:
##########
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/utils.h>
+
+#include <algorithm>
+#include <optional>
+#include <tuple>
+
+namespace tvm {
+namespace relax {
+
+namespace {
+
+template <typename T>
+using PSet = std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>;
+
+template <typename T, typename U>
+using PMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
+
+struct CalleeAnalysis {
+  // The updated function
+  Function func;
+
+  // A mutator that updates calls at the call site.
+  std::function<Array<Expr>(Array<Expr>)> arg_updater;
+};
+
+std::optional<CalleeAnalysis> AnalyzeCallee(Function func) {
+  bool is_exposed = 
func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+  if (is_exposed) return std::nullopt;
+
+  auto free_relax_vars = [&]() -> PSet<Var> {
+    auto array_free_vars = FreeVars(func->body);
+    return {array_free_vars.begin(), array_free_vars.end()};
+  }();
+
+  std::vector<bool> parameter_mask;
+  parameter_mask.reserve(func->params.size());
+
+  Array<Var> params;
+  for (const auto& param : func->params) {
+    bool is_used = free_relax_vars.count(param);
+    parameter_mask.push_back(is_used);
+    if (is_used) {
+      params.push_back(param);
+    }
+  }
+
+  if (func->params.size() == params.size()) {
+    // Early bail-out for the common case where the function uses all
+    // of its parameters.
+    return std::nullopt;
+  }
+
+  // Even if a parameter is unused, it may provide definitions for
+  // symbolic variables.  We still want to remove the relax variable
+  // to reduce computational steps in the parent, but we need to
+  // provide the symbolic variables the other steps.
+  auto defined_tir_params = [&]() -> PSet<tir::Var> {
+    auto param_sinfo =
+        TupleStructInfo(params.Map([](const auto& var) { return 
GetStructInfo(var); }));
+    auto arr = DefinableTIRVarsInStructInfo(param_sinfo);
+    return {arr.begin(), arr.end()};
+  }();
+
+  // Use an array to define the order of the symbolic variables
+  Array<tir::Var> free_tir_vars;
+  for (const auto& tir_var : FreeSymbolicVars(func->body)) {
+    if (!defined_tir_params.count(tir_var)) {
+      free_tir_vars.push_back(tir_var);
+    }
+  }
+
+  for (const auto& tir_var : free_tir_vars) {
+    Var relax_var("param_" + tir_var->name_hint, PrimStructInfo(tir_var));
+    params.push_back(relax_var);
+  }
+
+  FuncStructInfo new_sinfo(params.Map([](const auto& var) { return 
GetStructInfo(var); }),
+                           func->ret_struct_info,
+                           
Downcast<FuncStructInfo>(func->struct_info_)->purity);
+
+  auto arg_updater = [parameter_mask, old_relax_params = func->params,
+                      free_tir_vars](Array<Expr> old_args) -> Array<Expr> {
+    ICHECK_EQ(old_args.size(), parameter_mask.size())
+        << "Call provides " << old_args.size() << ", but the callee accepts "
+        << parameter_mask.size() << " parameters";
+
+    Array<Expr> new_args;
+    for (size_t i = 0; i < old_args.size(); i++) {
+      if (parameter_mask.at(i)) {
+        new_args.push_back(old_args[i]);
+      }
+    }
+
+    if (free_tir_vars.size()) {
+      Map<Var, Expr> old_binding;
+      for (size_t i = 0; i < old_relax_params.size(); i++) {
+        old_binding.Set(old_relax_params[i], old_args[i]);
+      }
+      arith::Analyzer analyzer;
+      auto tir_binding = InferSymbolicVarMap(old_binding, &analyzer);
+
+      for (const auto& tir_var : free_tir_vars) {
+        new_args.push_back(PrimValue(tir_binding.at(tir_var)));
+      }
+    }
+
+    return new_args;
+  };
+
+  auto write_ptr = func.CopyOnWrite();
+  write_ptr->params = params;
+  write_ptr->struct_info_ = new_sinfo;
+
+  return CalleeAnalysis{func, arg_updater};
+}
+
+class CallSiteMutator : public ExprMutator {
+ public:
+  explicit CallSiteMutator(PMap<GlobalVar, std::function<Call(Call)>> 
callsite_updaters)
+      : callsite_updaters_(callsite_updaters) {}
+
+  using ExprMutator::VisitExpr_;
+
+  Expr VisitExpr_(const CallNode* op) override {
+    auto node = Downcast<Call>(ExprMutator::VisitExpr_(op));
+
+    if (auto gvar = node->op.as<GlobalVar>()) {
+      if (auto it = callsite_updaters_.find(gvar.value()); it != 
callsite_updaters_.end()) {
+        node = it->second(std::move(node));
+      }
+    }
+
+    return node;
+  }
+
+  PMap<GlobalVar, std::function<Call(Call)>> callsite_updaters_;
+};
+
+}  // namespace
+
+namespace transform {
+
+Pass RemoveUnusedParameters() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule mod, PassContext pc) -> IRModule {
+    PMap<GlobalVar, std::function<Call(Call)>> callsite_updaters;
+
+    {
+      IRModule new_callees;
+
+      for (const auto& [gvar, base_func] : mod->functions) {
+        if (auto func = base_func.as<Function>()) {
+          if (auto callee_res = AnalyzeCallee(func.value())) {
+            auto new_func = callee_res->func;
+            GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_);
+            new_gvar->struct_info_ = new_func->struct_info_;
+            new_callees->Add(new_gvar, new_func);
+
+            callsite_updaters[gvar] = [old_gvar = gvar, new_gvar,
+                                       arg_updater = 
callee_res->arg_updater](Call call) -> Call {
+              ICHECK(call->op.same_as(old_gvar)) << "InternalError: "
+                                                 << "Updater should be applied 
to " << old_gvar
+                                                 << ", but was applied to " << 
call->op;
+              auto write_ptr = call.CopyOnWrite();
+              write_ptr->op = new_gvar;
+              write_ptr->args = arg_updater(call->args);
+              return call;
+            };
+          }
+        }
+      }
+
+      if (callsite_updaters.empty()) {
+        return mod;
+      }
+      auto write_ptr = mod.CopyOnWrite();
+      for (const auto& it : callsite_updaters) {
+        write_ptr->Remove(it.first);
+      }
+      write_ptr->Update(new_callees);
+    }
+
+    CallSiteMutator mutator(std::move(callsite_updaters));
+
+    IRModule caller_updates;
+
+    for (const auto& [gvar, base_func] : mod->functions) {
+      if (auto func = base_func.as<Function>()) {
+        auto mutated = Downcast<Function>(mutator.VisitExpr(func.value()));
+        if (!mutated.same_as(base_func)) {
+          caller_updates->Add(gvar, mutated);
+        }
+      }
+    }
+
+    if (caller_updates->functions.size()) {
+      mod.CopyOnWrite()->Update(caller_updates);
+    }
+    return mod;
+  };
+  auto inner_pass = CreateModulePass(pass_func, 0, 
"RemoveUnusedParametersInner", {});
+  return tvm::transform::Sequential(
+      {
+          inner_pass,
+          CanonicalizeBindings(),
+          DeadCodeElimination({}),

Review Comment:
   Are there specific reasons to include the other passes or is it just to 
clean things up?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to