This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c68a00a37d [Unity][FuseTIR] Flatten and add tuple fields to parameters
/ arguments only when they are used (#15113)
c68a00a37d is described below
commit c68a00a37d5d483b54858fcd79f87472e0b1c147
Author: masahi <[email protected]>
AuthorDate: Sat Jun 17 16:41:24 2023 +0900
[Unity][FuseTIR] Flatten and add tuple fields to parameters / arguments
only when they are used (#15113)
* wip
* wip
* works
* fix
* fixed
* comment
* more comment
* adding test
---
src/relax/transform/fuse_tir.cc | 86 +++++++++++++++++--
tests/python/relax/test_transform_fuse_tir.py | 119 ++++++++++++++++++++++++++
2 files changed, 199 insertions(+), 6 deletions(-)
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 601a4cff23..0f9168d062 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -22,6 +22,9 @@
#include <tvm/relax/transform.h>
#include <tvm/tir/stmt_functor.h>
+#include <unordered_map>
+#include <unordered_set>
+
#include "../../relay/analysis/graph_partitioner.h"
#include "../../support/arena.h"
#include "../../tir/ir/functor_common.h"
@@ -340,13 +343,51 @@ class FusedTIRConstructor : public ExprVisitor {
void VisitExpr_(const FunctionNode* func) final {
// Step 1. Create buffers for function params
+
+ // Record which fields in a tuple passed as a parameter are actually
accessed by the function.
+ std::unordered_set<const Object*> tuple_param;
+ for (auto param : func->params) {
+ if (GetStructInfo(param)->IsInstance<TupleStructInfoNode>()) {
+ tuple_param.insert(param.get());
+ }
+ }
+
+ PostOrderVisit(func->body, [=, &tuple_param](Expr e) {
+ if (auto tup_get = e.as<TupleGetItemNode>();
+ tup_get && tuple_param.count(tup_get->tuple.get())) {
+
func_info_.used_tuple_field_indices[tup_get->tuple.get()].insert(tup_get->index);
+ }
+ });
+
for (const Var& relax_param : func->params) {
- if (GetStructInfo(relax_param)->IsInstance<ShapeStructInfoNode>()) {
+ auto sinfo = GetStructInfo(relax_param);
+ if (sinfo->IsInstance<ShapeStructInfoNode>()) {
// It's a symbolic shape var, no need to alloc Buffers.
continue;
}
- auto [params, buffers] =
CreateParamsAndBuffers(GetStructInfo(relax_param), //
-
relax_param->name_hint());
+
+ auto [params, buffers] = [=]() {
+ if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
+ // Add only those tuple fields which are actually used by the
function body into the
+ // function parameters.
+ int index = 0;
+ Array<tir::Var> params;
+ Array<tir::Buffer> buffers;
+ for (auto i :
func_info_.used_tuple_field_indices[relax_param.get()]) {
+ auto [ret_params, ret_buffers] =
+ CreateParamsAndBuffers(tuple->fields[i],
relax_param->name_hint(), index);
+ ICHECK_EQ(ret_params.size(), ret_buffers.size());
+ // Adding tuple field results to the end of params and buffers.
+ params.insert(params.end(), ret_params.begin(), ret_params.end());
+ buffers.insert(buffers.end(), ret_buffers.begin(),
ret_buffers.end());
+ index += ret_params.size();
+ }
+ return std::make_pair(params, buffers);
+ } else {
+ return CreateParamsAndBuffers(sinfo, relax_param->name_hint());
+ }
+ }();
+
ICHECK_EQ(params.size(), buffers.size());
for (size_t i = 0; i < params.size(); ++i) {
func_info_.buffer_map.Set(params[i], buffers[i]);
@@ -468,7 +509,12 @@ class FusedTIRConstructor : public ExprVisitor {
int end_buf_idx = 0;
const TupleType& tuple_type =
Downcast<TupleType>(tuple_get_item->tuple->checked_type());
for (int i = 0; i < tuple_get_item->index; ++i) {
- begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
+ auto it =
func_info_.used_tuple_field_indices.find(tuple_get_item->tuple.get());
+ // If this tuple is not passed as a parameter, or if the field at the
index i is actually
+ // used, the corresponding buffer needs to be taken into account by
this function.
+ if (it == func_info_.used_tuple_field_indices.end() ||
it->second.count(i)) {
+ begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
+ }
}
end_buf_idx = begin_buf_idx +
GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]);
func_info_.expr2buffers.Set(
@@ -769,6 +815,8 @@ class FusedTIRConstructor : public ExprVisitor {
std::string global_name = "fused";
/*! \brief The map from symbolic var to its corresponding var in the fused
function */
tir::SymbolicMatcher symbolic_var_matcher =
tir::SymbolicMatcher(&symbolic_var_remap);
+ /*! \brief Record indices of tuple fields that are actually accessed. */
+ std::unordered_map<const Object*, std::unordered_set<size_t>>
used_tuple_field_indices;
};
/*! \brief The IRModule */
@@ -781,6 +829,19 @@ class FusedTIRConstructor : public ExprVisitor {
tir::PrimFunc fused_tir_;
};
+std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const
Var& tuple_var) {
+ // Need to be ordered
+ std::vector<size_t> indices;
+ PostOrderVisit(func->body, [&indices, tuple_var](Expr e) {
+ if (auto tup_get = e.as<TupleGetItemNode>(); tup_get &&
tup_get->tuple.same_as(tuple_var)) {
+ if (std::find(indices.begin(), indices.end(), tup_get->index) ==
indices.end()) {
+ indices.push_back(tup_get->index);
+ }
+ }
+ });
+ return indices;
+}
+
/*!
* \brief The helper class to fuse TIR functions and build a new module which
calls the fused TIR.
*/
@@ -840,6 +901,7 @@ class TIRFuseMutator : public ExprMutator {
if (call->op->IsInstance<GlobalVarNode>()) {
// Case 1. It is a relax cross function call
GlobalVar old_gv = Downcast<GlobalVar>(call->op);
+ auto relax_func = Downcast<Function>(mod_->Lookup(old_gv));
auto it = fused_tir_funcs_.find(old_gv);
if (it != fused_tir_funcs_.end()) {
const tir::PrimFunc& fused_tir = (*it).second;
@@ -848,8 +910,20 @@ class TIRFuseMutator : public ExprMutator {
// Step a. Flatten all args since call_tir does not support Tuple
value.
Array<Expr> arg_list;
Array<PrimExpr> tir_vars;
- for (const Expr& arg : call->args) {
- Array<Expr> flattened = FlattenArg(arg);
+ for (size_t i = 0; i < call->args.size(); ++i) {
+ auto arg = call->args[i];
+ Array<Expr> flattened;
+ if
(GetStructInfo(relax_func->params[i])->IsInstance<TupleStructInfoNode>()) {
+ // Add only those tuple fields which are actually used by the
function body
+ auto tup_get_indices = GetTupleAccessedIndices(relax_func.get(),
relax_func->params[i]);
+ for (size_t tup_get_ind : tup_get_indices) {
+ auto flattened_inner =
FlattenArg(builder_->Emit(TupleGetItem(arg, tup_get_ind)));
+ flattened.insert(flattened.end(), flattened_inner.begin(),
flattened_inner.end());
+ }
+ } else {
+ flattened.push_back(arg);
+ }
+
for (const Expr& e : flattened) {
StructInfo sinfo = GetStructInfo(e);
if (sinfo->IsInstance<TensorStructInfoNode>()) {
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index af770e0fc6..00dc714654 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -1079,5 +1079,124 @@ def test_tir_expression_in_shape():
_check(Module, Expected)
+def test_tuple_input_unused_field():
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def reshape(
+ A: T.Buffer((T.int64(4), T.int64(8), T.int64(2048)), "float32"),
+ T_reshape: T.Buffer((T.int64(4), T.int64(8), T.int64(32),
T.int64(64)), "float32"),
+ ):
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8),
T.int64(32), T.int64(64)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(
+ A[
+ (
+ ((v_ax2 * T.int64(64) + v_ax3) //
T.int64(2048) + v_ax1)
+ // T.int64(8)
+ + v_ax0
+ )
+ % T.int64(4),
+ ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) +
v_ax1) % T.int64(8),
+ (v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
+ ]
+ )
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[
+ (
+ ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) +
v_ax1) // T.int64(8)
+ + v_ax0
+ )
+ % T.int64(4),
+ ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) +
v_ax1) % T.int64(8),
+ (v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
+ ]
+
+ @R.function
+ def fused_reshape(
+ lv: R.Tuple(
+ R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8,
2048), dtype="float32")
+ )
+ ) -> R.Tensor((4, 8, 32, 64), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Module
+ with R.dataflow():
+ lv1: R.Tensor((4, 8, 2048), dtype="float32") = lv[0]
+ gv = R.call_tir(
+ cls.reshape, (lv1,), out_sinfo=R.Tensor((4, 8, 32, 64),
dtype="float32")
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ tup: R.Tuple(
+ R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8,
2048), dtype="float32")
+ )
+ ) -> R.Tensor((4, 8, 32, 64), dtype="float32"):
+ cls = Module
+ with R.dataflow():
+ lv_1: R.Tensor((4, 8, 32, 64), dtype="float32") =
cls.fused_reshape(tup)
+ R.output(lv_1)
+ return lv_1
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def fused_reshape(
+ lv_0: T.Buffer((T.int64(4), T.int64(8), T.int64(2048)), "float32"),
+ T_reshape_handle_intermediate: T.Buffer(
+ (T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32"
+ ),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8),
T.int64(32), T.int64(64)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(
+ lv_0[
+ (
+ ((v_ax2 * T.int64(64) + v_ax3) //
T.int64(2048) + v_ax1)
+ // T.int64(8)
+ + v_ax0
+ )
+ % T.int64(4),
+ ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) +
v_ax1) % T.int64(8),
+ (v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
+ ]
+ )
+ T.writes(T_reshape_handle_intermediate[v_ax0, v_ax1,
v_ax2, v_ax3])
+ T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]
= lv_0[
+ (
+ ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) +
v_ax1) // T.int64(8)
+ + v_ax0
+ )
+ % T.int64(4),
+ ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) +
v_ax1) % T.int64(8),
+ (v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
+ ]
+
+ @R.function
+ def main(
+ tup: R.Tuple(
+ R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8,
2048), dtype="float32")
+ )
+ ) -> R.Tensor((4, 8, 32, 64), dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ lv: R.Tensor((4, 8, 2048), dtype="float32") = tup[0]
+ lv_1 = R.call_tir(
+ cls.fused_reshape, (lv,), out_sinfo=R.Tensor((4, 8, 32,
64), dtype="float32")
+ )
+ R.output(lv_1)
+ return lv_1
+
+ _check(Module, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()