This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new c68a00a37d [Unity][FuseTIR] Flatten and add tuple fields to parameters 
/ arguments only when they are used (#15113)
c68a00a37d is described below

commit c68a00a37d5d483b54858fcd79f87472e0b1c147
Author: masahi <[email protected]>
AuthorDate: Sat Jun 17 16:41:24 2023 +0900

    [Unity][FuseTIR] Flatten and add tuple fields to parameters / arguments 
only when they are used (#15113)
    
    * wip
    
    * wip
    
    * works
    
    * fix
    
    * fixed
    
    * comment
    
    * more comment
    
    * adding test
---
 src/relax/transform/fuse_tir.cc               |  86 +++++++++++++++++--
 tests/python/relax/test_transform_fuse_tir.py | 119 ++++++++++++++++++++++++++
 2 files changed, 199 insertions(+), 6 deletions(-)

diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 601a4cff23..0f9168d062 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -22,6 +22,9 @@
 #include <tvm/relax/transform.h>
 #include <tvm/tir/stmt_functor.h>
 
+#include <unordered_map>
+#include <unordered_set>
+
 #include "../../relay/analysis/graph_partitioner.h"
 #include "../../support/arena.h"
 #include "../../tir/ir/functor_common.h"
@@ -340,13 +343,51 @@ class FusedTIRConstructor : public ExprVisitor {
 
   void VisitExpr_(const FunctionNode* func) final {
     // Step 1. Create buffers for function params
+
+    // Record which fields in a tuple passed as a parameter are actually 
accessed by the function.
+    std::unordered_set<const Object*> tuple_param;
+    for (auto param : func->params) {
+      if (GetStructInfo(param)->IsInstance<TupleStructInfoNode>()) {
+        tuple_param.insert(param.get());
+      }
+    }
+
+    PostOrderVisit(func->body, [=, &tuple_param](Expr e) {
+      if (auto tup_get = e.as<TupleGetItemNode>();
+          tup_get && tuple_param.count(tup_get->tuple.get())) {
+        
func_info_.used_tuple_field_indices[tup_get->tuple.get()].insert(tup_get->index);
+      }
+    });
+
     for (const Var& relax_param : func->params) {
-      if (GetStructInfo(relax_param)->IsInstance<ShapeStructInfoNode>()) {
+      auto sinfo = GetStructInfo(relax_param);
+      if (sinfo->IsInstance<ShapeStructInfoNode>()) {
         // It's a symbolic shape var, no need to alloc Buffers.
         continue;
       }
-      auto [params, buffers] = 
CreateParamsAndBuffers(GetStructInfo(relax_param),  //
-                                                      
relax_param->name_hint());
+
+      auto [params, buffers] = [=]() {
+        if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
+          // Add only those tuple fields which are actually used by the 
function body into the
+          // function parameters.
+          int index = 0;
+          Array<tir::Var> params;
+          Array<tir::Buffer> buffers;
+          for (auto i : 
func_info_.used_tuple_field_indices[relax_param.get()]) {
+            auto [ret_params, ret_buffers] =
+                CreateParamsAndBuffers(tuple->fields[i], 
relax_param->name_hint(), index);
+            ICHECK_EQ(ret_params.size(), ret_buffers.size());
+            // Adding tuple field results to the end of params and buffers.
+            params.insert(params.end(), ret_params.begin(), ret_params.end());
+            buffers.insert(buffers.end(), ret_buffers.begin(), 
ret_buffers.end());
+            index += ret_params.size();
+          }
+          return std::make_pair(params, buffers);
+        } else {
+          return CreateParamsAndBuffers(sinfo, relax_param->name_hint());
+        }
+      }();
+
       ICHECK_EQ(params.size(), buffers.size());
       for (size_t i = 0; i < params.size(); ++i) {
         func_info_.buffer_map.Set(params[i], buffers[i]);
@@ -468,7 +509,12 @@ class FusedTIRConstructor : public ExprVisitor {
       int end_buf_idx = 0;
       const TupleType& tuple_type = 
Downcast<TupleType>(tuple_get_item->tuple->checked_type());
       for (int i = 0; i < tuple_get_item->index; ++i) {
-        begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
+        auto it = 
func_info_.used_tuple_field_indices.find(tuple_get_item->tuple.get());
+        // If this tuple is not passed as a parameter, or if the field at the 
index i is actually
+        // used, the corresponding buffer needs to be taken into account by 
this function.
+        if (it == func_info_.used_tuple_field_indices.end() || 
it->second.count(i)) {
+          begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
+        }
       }
       end_buf_idx = begin_buf_idx + 
GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]);
       func_info_.expr2buffers.Set(
@@ -769,6 +815,8 @@ class FusedTIRConstructor : public ExprVisitor {
     std::string global_name = "fused";
     /*! \brief The map from symbolic var to its corresponding var in the fused 
function */
     tir::SymbolicMatcher symbolic_var_matcher = 
tir::SymbolicMatcher(&symbolic_var_remap);
+    /*! \brief Record indices of tuple fields that are actually accessed. */
+    std::unordered_map<const Object*, std::unordered_set<size_t>> 
used_tuple_field_indices;
   };
 
   /*! \brief The IRModule */
@@ -781,6 +829,19 @@ class FusedTIRConstructor : public ExprVisitor {
   tir::PrimFunc fused_tir_;
 };
 
+std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const 
Var& tuple_var) {
+  // Need to be ordered
+  std::vector<size_t> indices;
+  PostOrderVisit(func->body, [&indices, tuple_var](Expr e) {
+    if (auto tup_get = e.as<TupleGetItemNode>(); tup_get && 
tup_get->tuple.same_as(tuple_var)) {
+      if (std::find(indices.begin(), indices.end(), tup_get->index) == 
indices.end()) {
+        indices.push_back(tup_get->index);
+      }
+    }
+  });
+  return indices;
+}
+
 /*!
  * \brief The helper class to fuse TIR functions and build a new module which 
calls the fused TIR.
  */
@@ -840,6 +901,7 @@ class TIRFuseMutator : public ExprMutator {
     if (call->op->IsInstance<GlobalVarNode>()) {
       // Case 1. It is a relax cross function call
       GlobalVar old_gv = Downcast<GlobalVar>(call->op);
+      auto relax_func = Downcast<Function>(mod_->Lookup(old_gv));
       auto it = fused_tir_funcs_.find(old_gv);
       if (it != fused_tir_funcs_.end()) {
         const tir::PrimFunc& fused_tir = (*it).second;
@@ -848,8 +910,20 @@ class TIRFuseMutator : public ExprMutator {
         // Step a. Flatten all args since call_tir does not support Tuple 
value.
         Array<Expr> arg_list;
         Array<PrimExpr> tir_vars;
-        for (const Expr& arg : call->args) {
-          Array<Expr> flattened = FlattenArg(arg);
+        for (size_t i = 0; i < call->args.size(); ++i) {
+          auto arg = call->args[i];
+          Array<Expr> flattened;
+          if 
(GetStructInfo(relax_func->params[i])->IsInstance<TupleStructInfoNode>()) {
+            // Add only those tuple fields which are actually used by the 
function body
+            auto tup_get_indices = GetTupleAccessedIndices(relax_func.get(), 
relax_func->params[i]);
+            for (size_t tup_get_ind : tup_get_indices) {
+              auto flattened_inner = 
FlattenArg(builder_->Emit(TupleGetItem(arg, tup_get_ind)));
+              flattened.insert(flattened.end(), flattened_inner.begin(), 
flattened_inner.end());
+            }
+          } else {
+            flattened.push_back(arg);
+          }
+
           for (const Expr& e : flattened) {
             StructInfo sinfo = GetStructInfo(e);
             if (sinfo->IsInstance<TensorStructInfoNode>()) {
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
index af770e0fc6..00dc714654 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -1079,5 +1079,124 @@ def test_tir_expression_in_shape():
     _check(Module, Expected)
 
 
+def test_tuple_input_unused_field():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def reshape(
+            A: T.Buffer((T.int64(4), T.int64(8), T.int64(2048)), "float32"),
+            T_reshape: T.Buffer((T.int64(4), T.int64(8), T.int64(32), 
T.int64(64)), "float32"),
+        ):
+            T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), 
T.int64(32), T.int64(64)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(
+                        A[
+                            (
+                                ((v_ax2 * T.int64(64) + v_ax3) // 
T.int64(2048) + v_ax1)
+                                // T.int64(8)
+                                + v_ax0
+                            )
+                            % T.int64(4),
+                            ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + 
v_ax1) % T.int64(8),
+                            (v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
+                        ]
+                    )
+                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[
+                        (
+                            ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + 
v_ax1) // T.int64(8)
+                            + v_ax0
+                        )
+                        % T.int64(4),
+                        ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + 
v_ax1) % T.int64(8),
+                        (v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
+                    ]
+
+        @R.function
+        def fused_reshape(
+            lv: R.Tuple(
+                R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 
2048), dtype="float32")
+            )
+        ) -> R.Tensor((4, 8, 32, 64), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Module
+            with R.dataflow():
+                lv1: R.Tensor((4, 8, 2048), dtype="float32") = lv[0]
+                gv = R.call_tir(
+                    cls.reshape, (lv1,), out_sinfo=R.Tensor((4, 8, 32, 64), 
dtype="float32")
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            tup: R.Tuple(
+                R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 
2048), dtype="float32")
+            )
+        ) -> R.Tensor((4, 8, 32, 64), dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                lv_1: R.Tensor((4, 8, 32, 64), dtype="float32") = 
cls.fused_reshape(tup)
+                R.output(lv_1)
+            return lv_1
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def fused_reshape(
+            lv_0: T.Buffer((T.int64(4), T.int64(8), T.int64(2048)), "float32"),
+            T_reshape_handle_intermediate: T.Buffer(
+                (T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32"
+            ),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), 
T.int64(32), T.int64(64)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(
+                        lv_0[
+                            (
+                                ((v_ax2 * T.int64(64) + v_ax3) // 
T.int64(2048) + v_ax1)
+                                // T.int64(8)
+                                + v_ax0
+                            )
+                            % T.int64(4),
+                            ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + 
v_ax1) % T.int64(8),
+                            (v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
+                        ]
+                    )
+                    T.writes(T_reshape_handle_intermediate[v_ax0, v_ax1, 
v_ax2, v_ax3])
+                    T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] 
= lv_0[
+                        (
+                            ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + 
v_ax1) // T.int64(8)
+                            + v_ax0
+                        )
+                        % T.int64(4),
+                        ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + 
v_ax1) % T.int64(8),
+                        (v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
+                    ]
+
+        @R.function
+        def main(
+            tup: R.Tuple(
+                R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 
2048), dtype="float32")
+            )
+        ) -> R.Tensor((4, 8, 32, 64), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                lv: R.Tensor((4, 8, 2048), dtype="float32") = tup[0]
+                lv_1 = R.call_tir(
+                    cls.fused_reshape, (lv,), out_sinfo=R.Tensor((4, 8, 32, 
64), dtype="float32")
+                )
+                R.output(lv_1)
+            return lv_1
+
+    _check(Module, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to