This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new fe9d2fe57d [Unity][Transform] Implement ExpandTupleArguments (#16115)
fe9d2fe57d is described below

commit fe9d2fe57d45ceaef26f04c45e3d908f311516a7
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Dec 1 08:21:04 2023 -0600

    [Unity][Transform] Implement ExpandTupleArguments (#16115)
    
    [Unity] Implement ExpandTupleArguments transform
    
    Currently, the `FuseOps` and `FuseTIR` passes have a large amount of
    added complexity to identify and handle partial use of tuple
    arguments.  The handling partial use of tuples could be significantly
    simpler if performed in multiple steps.
    
    1. Perform `FuseOps`.  Any tuple variables that are used by the fused
       function are passed as-is.
    
    2. Expand any parameters that are passed as a tuple.  Any unused
       tensors that were included in a partially-used tuple will be
       converted to unused parameters.
    
    3. Remove any unused parameters.  Any unused tensors that were
       included in a partially-used tuple will be removed in this
       step.
    
    4. Perform `FuseTIR`.  No checking for tuple arguments, either partial
       or full, is required at this step.
    
    This PR implements `relax.transform.ExpandTupleArguments`, which is
    step (2) in this process.
---
 include/tvm/relax/transform.h                      |   6 +
 python/tvm/relax/transform/__init__.py             |   1 +
 python/tvm/relax/transform/transform.py            |  10 ++
 src/relax/transform/expand_tuple_arguments.cc      | 187 +++++++++++++++++++++
 .../relax/test_transform_expand_tuple_args.py      |  79 +++++++++
 5 files changed, 283 insertions(+)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index b043765a69..8c6417d19d 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -275,6 +275,12 @@ TVM_DLL Pass LiftTransformParams();
  */
 TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index);
 
+/*! \brief Expand tuple arguments to internal functions
+ *
+ * \return The Pass
+ */
+TVM_DLL Pass ExpandTupleArguments();
+
 /*! \brief Remove unused outputs from internal functions
  *
  * \return The Pass
diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index 0ce0ebba11..b6887160c8 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -33,6 +33,7 @@ from .transform import (
     DecomposeOpsForInference,
     DecomposeOpsForTraining,
     EliminateCommonSubexpr,
+    ExpandTupleArguments,
     FewShotTuning,
     FoldConstant,
     FunctionPass,
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 428f8c24ef..1beb535f0b 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -558,6 +558,16 @@ def FoldConstant() -> tvm.ir.transform.Pass:
     return _ffi_api.FoldConstant()  # type: ignore
 
 
+def ExpandTupleArguments() -> tvm.ir.transform.Pass:
+    """Expand tuple arguments to internal functions
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.ExpandTupleArguments()  # type: ignore
+
+
 def RemoveUnusedOutputs() -> tvm.ir.transform.Pass:
     """Remove unused outputs from internal functions
 
diff --git a/src/relax/transform/expand_tuple_arguments.cc 
b/src/relax/transform/expand_tuple_arguments.cc
new file mode 100644
index 0000000000..c61832bbab
--- /dev/null
+++ b/src/relax/transform/expand_tuple_arguments.cc
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include <algorithm>
+#include <tuple>
+
+namespace tvm {
+namespace relax {
+
+namespace {
+
+template <typename T, typename U>
+using PMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
+
+Optional<Function> ExpandParams(Function func) {
+  bool is_exposed = 
func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+  if (is_exposed) return NullOpt;
+
+  bool has_tuple_param = std::any_of(
+      func->params.begin(), func->params.end(),
+      [](const Var& param) -> bool { return 
param->struct_info_.as<TupleStructInfoNode>(); });
+
+  if (!has_tuple_param) return NullOpt;
+
+  Array<Var> params;
+  Array<Binding> bindings;
+
+  std::function<void(const Var&)> expand_param = [&](const Var& param) {
+    if (auto sinfo = param->struct_info_.as<TupleStructInfoNode>()) {
+      Array<Expr> internal_tuple;
+      for (size_t i = 0; i < sinfo->fields.size(); i++) {
+        auto name = static_cast<const std::stringstream&>(std::stringstream()
+                                                          << 
param->name_hint() << "_" << i)
+                        .str();
+        Var new_param(name, sinfo->fields[i]);
+        internal_tuple.push_back(new_param);
+        expand_param(new_param);
+      }
+      bindings.push_back(VarBinding(param, Tuple(internal_tuple)));
+    } else {
+      params.push_back(param);
+    }
+  };
+
+  for (const auto& param : func->params) {
+    expand_param(param);
+  }
+
+  FuncStructInfo new_sinfo(params.Map([](const auto& var) { return 
GetStructInfo(var); }),
+                           func->ret_struct_info,
+                           
Downcast<FuncStructInfo>(func->struct_info_)->purity);
+
+  auto write_ptr = func.CopyOnWrite();
+  write_ptr->params = params;
+  write_ptr->body = SeqExpr({BindingBlock(bindings)}, func->body);
+  write_ptr->struct_info_ = new_sinfo;
+
+  return func;
+}
+
+class TupleExpander : public ExprMutator {
+ public:
+  explicit TupleExpander(PMap<GlobalVar, GlobalVar> callees) : 
replacements_(callees) {}
+
+  using ExprMutator::VisitExpr_;
+
+  Expr VisitExpr_(const CallNode* op) override {
+    auto node = Downcast<Call>(ExprMutator::VisitExpr_(op));
+
+    if (auto gvar = node->op.as<GlobalVar>()) {
+      if (auto it = replacements_.find(gvar.value()); it != 
replacements_.end()) {
+        Array<Expr> new_args;
+
+        std::function<void(const Expr&)> expand_arg = [&](const Expr& arg) {
+          if (auto sinfo = arg->struct_info_.as<TupleStructInfoNode>()) {
+            for (size_t i = 0; i < sinfo->fields.size(); i++) {
+              expand_arg(TupleGetItem(arg, i));
+            }
+          } else {
+            new_args.push_back(arg);
+          }
+        };
+
+        for (const auto& arg : node->args) {
+          expand_arg(arg);
+        }
+
+        auto write_ptr = node.CopyOnWrite();
+        write_ptr->op = it->second;
+        write_ptr->args = new_args;
+      }
+    }
+
+    return node;
+  }
+
+  PMap<GlobalVar, GlobalVar> replacements_;
+};
+
+}  // namespace
+
+namespace transform {
+
+Pass ExpandTupleArguments() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule mod, PassContext pc) -> IRModule {
+    PMap<GlobalVar, GlobalVar> gvar_replacements;
+
+    {
+      PMap<GlobalVar, Function> new_callees;
+
+      for (const auto& [gvar, base_func] : mod->functions) {
+        if (auto func = base_func.as<Function>()) {
+          if (auto opt = ExpandParams(func.value())) {
+            auto new_func = opt.value();
+            GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_);
+            new_gvar->struct_info_ = new_func->struct_info_;
+            gvar_replacements[gvar] = new_gvar;
+            new_callees[new_gvar] = new_func;
+          }
+        }
+      }
+
+      if (gvar_replacements.empty()) {
+        return mod;
+      }
+      auto write_ptr = mod.CopyOnWrite();
+      for (auto [old_gvar, new_gvar] : gvar_replacements) {
+        write_ptr->Remove(old_gvar);
+        write_ptr->Add(new_gvar, new_callees.at(new_gvar));
+      }
+    }
+
+    TupleExpander mutator(std::move(gvar_replacements));
+
+    IRModule caller_updates;
+
+    for (const auto& [gvar, base_func] : mod->functions) {
+      if (auto func = base_func.as<Function>()) {
+        auto mutated = Downcast<Function>(mutator.VisitExpr(func.value()));
+        if (!mutated.same_as(base_func)) {
+          caller_updates->Add(gvar, mutated);
+        }
+      }
+    }
+
+    if (caller_updates->functions.size()) {
+      mod.CopyOnWrite()->Update(caller_updates);
+    }
+    return mod;
+  };
+  auto inner_pass = CreateModulePass(pass_func, 0, 
"ExpandTupleArgumentsInner", {});
+
+  return tvm::transform::Sequential(
+      {
+          inner_pass,
+          CanonicalizeBindings(),
+          DeadCodeElimination({}),
+      },
+      "ExpandTupleArguments");
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ExpandTupleArguments").set_body_typed(ExpandTupleArguments);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_expand_tuple_args.py 
b/tests/python/relax/test_transform_expand_tuple_args.py
new file mode 100644
index 0000000000..a90db1d84d
--- /dev/null
+++ b/tests/python/relax/test_transform_expand_tuple_args.py
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import ir as I, relax as R, tir as T
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.relax.transform.ExpandTupleArguments()
+
+
+class TestSimple(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor, B: R.Tensor):
+            return Before.func((A, B))
+
+        @R.function(private=True)
+        def func(args: R.Tuple([R.Tensor, R.Tensor])) -> R.Tensor:
+            return args[0]
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor, B: R.Tensor):
+            return Expected.func(A, B)
+
+        @R.function(private=True)
+        def func(A: R.Tensor, B: R.Tensor) -> R.Tensor:
+            return A
+
+
+class TestNested(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor, B: R.Tensor, C: R.Tensor, D: R.Tensor) -> 
R.Tensor:
+            return Before.func(((A, B), (C, D)))
+
+        @R.function(private=True)
+        def func(
+            args: R.Tuple(
+                [
+                    R.Tuple([R.Tensor, R.Tensor]),
+                    R.Tuple([R.Tensor, R.Tensor]),
+                ]
+            )
+        ) -> R.Tensor:
+            return args[0][1]
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor, B: R.Tensor, C: R.Tensor, D: R.Tensor) -> 
R.Tensor:
+            return Expected.func(A, B, C, D)
+
+        @R.function(private=True)
+        def func(A: R.Tensor, B: R.Tensor, C: R.Tensor, D: R.Tensor) -> 
R.Tensor:
+            return B
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to