This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new b4fbac785b [Unity] Fix `DataflowReshapeRewrite` when input has
multiple buffers from tuple (#14669)
b4fbac785b is described below
commit b4fbac785b79f4657a5c3521ab7c0b0106f942e7
Author: masahi <[email protected]>
AuthorDate: Fri Apr 21 09:57:07 2023 +0900
[Unity] Fix `DataflowReshapeRewrite` when input has multiple buffers from
tuple (#14669)
* use split instead of slice in CombineParallelMatmul
* add test
* wip
* Fix DataflowReshapeRewrite when input has multiple buffers from tuple
* Revert "use split instead of slice in CombineParallelMatmul"
This reverts commit 901fee93112d9e05f703beeab6b8151621ca1373.
---
src/relax/transform/rewrite_dataflow_reshape.cc | 30 ++++-
.../test_transform_rewrite_dataflow_reshape.py | 127 ++++++++++++++++++++-
2 files changed, 152 insertions(+), 5 deletions(-)
diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc
b/src/relax/transform/rewrite_dataflow_reshape.cc
index e5d654fba3..daef7f7409 100644
--- a/src/relax/transform/rewrite_dataflow_reshape.cc
+++ b/src/relax/transform/rewrite_dataflow_reshape.cc
@@ -23,12 +23,25 @@
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
#include "../op/tensor/manipulate.h"
namespace tvm {
namespace relax {
+std::vector<size_t> GetUsedArgsIndices(const tir::PrimFunc& fn, size_t
num_args) {
+ std::vector<size_t> indices;
+ for (size_t i = 0; i < num_args; ++i) {
+ auto buffer_var = fn->buffer_map[fn->params[i]]->data;
+ if (tir::UsesVar(fn->body, [=](const tir::VarNode* var) { return var ==
buffer_var.get(); })) {
+ indices.push_back(i);
+ }
+ }
+ return indices;
+}
+
class DataflowReshapeRewriter : public ExprMutator {
public:
explicit DataflowReshapeRewriter(const IRModule& mod) : mod_(mod) {}
@@ -63,11 +76,22 @@ class DataflowReshapeRewriter : public ExprMutator {
// We bring the calls of reshape PrimFunc back to calls of high-level
// relax.reshape op, which will be lowered to calls of the ExternFunc
// vm.builtin.reshape in the VMBuiltinLower pass.
- Array<Expr> args = Downcast<Tuple>(call->args[1])->fields;
- ICHECK_EQ(args.size(), 1);
+
+ auto prim_fn =
Downcast<tir::PrimFunc>(mod_->Lookup(Downcast<GlobalVar>(call->args[0])));
+ auto arg_tuple = Downcast<Tuple>(call->args[1])->fields;
+ auto used_arg_indices = GetUsedArgsIndices(prim_fn, arg_tuple.size());
+
+ // The number of inputs to call_tir(reshape, (...)) might not be one,
since FuseOps
+ // can generate a fused TupleGetItem + reshape function whose input is a
tuple. FuseTIR
+ // then flattens the tuple input so that the fused TIR reshape function
ends up having
+ // multiple input buffers. But only one of them should be accessed and
reshaped.
+ ICHECK_EQ(used_arg_indices.size(), 1);
+
+ auto arg = arg_tuple[used_arg_indices[0]];
+
TensorStructInfo res_sinfo =
Downcast<TensorStructInfo>(call->struct_info_);
ICHECK(res_sinfo->shape.defined());
- return reshape(args[0], res_sinfo->shape.value());
+ return reshape(arg, res_sinfo->shape.value());
}
bool IsCallingTIRReshape(const CallNode* call) {
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index ecf2a96064..db737f82c0 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -257,6 +257,129 @@ def test_reshape_non_dataflow():
tvm.ir.assert_structural_equal(mod, Module)
+def test_tuple_get_reshape():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def fused_reshape5(
+ lv2_0: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)),
"float16"),
+ lv2_1: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)),
"float16"),
+ lv2_2: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)),
"float16"),
+ T_reshape_handle_intermediate: T.Buffer(
+ (T.int64(2), T.int64(4096), T.int64(8), T.int64(40)), "float16"
+ ),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4096),
T.int64(8), T.int64(40)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(
+ lv2_0[
+ (
+ ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320)
+ v_ax1)
+ // T.int64(4096)
+ + v_ax0
+ )
+ % T.int64(2),
+ ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) +
v_ax1) % T.int64(4096),
+ (v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
+ ]
+ )
+ T.writes(T_reshape_handle_intermediate[v_ax0, v_ax1,
v_ax2, v_ax3])
+ T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]
= lv2_0[
+ (
+ ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) +
v_ax1) // T.int64(4096)
+ + v_ax0
+ )
+ % T.int64(2),
+ ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) +
v_ax1) % T.int64(4096),
+ (v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
+ ]
+
+ @R.function
+ def main(
+ lv41_1: R.Tuple(
+ R.Tensor((2, 4096, 320), dtype="float16"),
+ R.Tensor((2, 4096, 320), dtype="float16"),
+ R.Tensor((2, 4096, 320), dtype="float16"),
+ )
+ ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"):
+ cls = Module
+ with R.dataflow():
+ lv: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[0]
+ lv1: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[1]
+ lv2: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[2]
+ lv645 = R.call_tir(
+ cls.fused_reshape5,
+ (lv, lv1, lv2),
+ out_sinfo=R.Tensor((2, 4096, 8, 40), dtype="float16"),
+ )
+ out: R.Tensor((2, 4096, 8, 40), dtype="float16") =
R.add(lv645, lv645)
+ R.output(out)
+ return out
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def fused_reshape5(
+ lv2_0: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)),
"float16"),
+ lv2_1: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)),
"float16"),
+ lv2_2: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)),
"float16"),
+ T_reshape_handle_intermediate: T.Buffer(
+ (T.int64(2), T.int64(4096), T.int64(8), T.int64(40)), "float16"
+ ),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4096),
T.int64(8), T.int64(40)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(
+ lv2_0[
+ (
+ ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320)
+ v_ax1)
+ // T.int64(4096)
+ + v_ax0
+ )
+ % T.int64(2),
+ ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) +
v_ax1) % T.int64(4096),
+ (v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
+ ]
+ )
+ T.writes(T_reshape_handle_intermediate[v_ax0, v_ax1,
v_ax2, v_ax3])
+ T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]
= lv2_0[
+ (
+ ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) +
v_ax1) // T.int64(4096)
+ + v_ax0
+ )
+ % T.int64(2),
+ ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) +
v_ax1) % T.int64(4096),
+ (v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
+ ]
+
+ @R.function
+ def main(
+ lv41_1: R.Tuple(
+ R.Tensor((2, 4096, 320), dtype="float16"),
+ R.Tensor((2, 4096, 320), dtype="float16"),
+ R.Tensor((2, 4096, 320), dtype="float16"),
+ )
+ ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"):
+ with R.dataflow():
+ lv: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[0]
+ lv1: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[1]
+ lv2: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[2]
+ lv645: R.Tensor((2, 4096, 8, 40), dtype="float16") = R.reshape(
+ lv, R.shape([2, 4096, 8, 40])
+ )
+ out: R.Tensor((2, 4096, 8, 40), dtype="float16") =
R.add(lv645, lv645)
+ R.output(out)
+ return out
+
+ rewritten = relax.transform.RewriteDataflowReshape()(Module)
+ tvm.ir.assert_structural_equal(rewritten, Expected)
+
+
if __name__ == "__main__":
- test_reshape_pattern_detect()
- # tvm.testing.main()
+ tvm.testing.main()