This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b4fbac785b [Unity] Fix `DataflowReshapeRewrite` when input has 
multiple buffers from tuple  (#14669)
b4fbac785b is described below

commit b4fbac785b79f4657a5c3521ab7c0b0106f942e7
Author: masahi <[email protected]>
AuthorDate: Fri Apr 21 09:57:07 2023 +0900

    [Unity] Fix `DataflowReshapeRewrite` when input has multiple buffers from 
tuple  (#14669)
    
    * use split instead of slice in CombineParallelMatmul
    
    * add test
    
    * wip
    
    * Fix DataflowReshapeRewrite when input has multiple buffers from tuple
    
    * Revert "use split instead of slice in CombineParallelMatmul"
    
    This reverts commit 901fee93112d9e05f703beeab6b8151621ca1373.
---
 src/relax/transform/rewrite_dataflow_reshape.cc    |  30 ++++-
 .../test_transform_rewrite_dataflow_reshape.py     | 127 ++++++++++++++++++++-
 2 files changed, 152 insertions(+), 5 deletions(-)

diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc 
b/src/relax/transform/rewrite_dataflow_reshape.cc
index e5d654fba3..daef7f7409 100644
--- a/src/relax/transform/rewrite_dataflow_reshape.cc
+++ b/src/relax/transform/rewrite_dataflow_reshape.cc
@@ -23,12 +23,25 @@
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/transform.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
 
 #include "../op/tensor/manipulate.h"
 
 namespace tvm {
 namespace relax {
 
+std::vector<size_t> GetUsedArgsIndices(const tir::PrimFunc& fn, size_t 
num_args) {
+  std::vector<size_t> indices;
+  for (size_t i = 0; i < num_args; ++i) {
+    auto buffer_var = fn->buffer_map[fn->params[i]]->data;
+    if (tir::UsesVar(fn->body, [=](const tir::VarNode* var) { return var == 
buffer_var.get(); })) {
+      indices.push_back(i);
+    }
+  }
+  return indices;
+}
+
 class DataflowReshapeRewriter : public ExprMutator {
  public:
   explicit DataflowReshapeRewriter(const IRModule& mod) : mod_(mod) {}
@@ -63,11 +76,22 @@ class DataflowReshapeRewriter : public ExprMutator {
     // We bring the calls of reshape PrimFunc back to calls of high-level
     // relax.reshape op, which will be lowered to calls of the ExternFunc
     // vm.builtin.reshape in the VMBuiltinLower pass.
-    Array<Expr> args = Downcast<Tuple>(call->args[1])->fields;
-    ICHECK_EQ(args.size(), 1);
+
+    auto prim_fn = 
Downcast<tir::PrimFunc>(mod_->Lookup(Downcast<GlobalVar>(call->args[0])));
+    auto arg_tuple = Downcast<Tuple>(call->args[1])->fields;
+    auto used_arg_indices = GetUsedArgsIndices(prim_fn, arg_tuple.size());
+
+    // The number of inputs to call_tir(reshape, (...)) might not be one, 
since FuseOps
+    // can generate a fused TupleGetItem + reshape function whose input is a 
tuple. FuseTIR
+    // then flattens the tuple input so that the fused TIR reshape function 
ends up having
+    // multiple input buffers. But only one of them should be accessed and 
reshaped.
+    ICHECK_EQ(used_arg_indices.size(), 1);
+
+    auto arg = arg_tuple[used_arg_indices[0]];
+
     TensorStructInfo res_sinfo = 
Downcast<TensorStructInfo>(call->struct_info_);
     ICHECK(res_sinfo->shape.defined());
-    return reshape(args[0], res_sinfo->shape.value());
+    return reshape(arg, res_sinfo->shape.value());
   }
 
   bool IsCallingTIRReshape(const CallNode* call) {
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py 
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index ecf2a96064..db737f82c0 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -257,6 +257,129 @@ def test_reshape_non_dataflow():
     tvm.ir.assert_structural_equal(mod, Module)
 
 
+def test_tuple_get_reshape():
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def fused_reshape5(
+            lv2_0: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), 
"float16"),
+            lv2_1: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), 
"float16"),
+            lv2_2: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), 
"float16"),
+            T_reshape_handle_intermediate: T.Buffer(
+                (T.int64(2), T.int64(4096), T.int64(8), T.int64(40)), "float16"
+            ),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4096), 
T.int64(8), T.int64(40)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(
+                        lv2_0[
+                            (
+                                ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) 
+ v_ax1)
+                                // T.int64(4096)
+                                + v_ax0
+                            )
+                            % T.int64(2),
+                            ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + 
v_ax1) % T.int64(4096),
+                            (v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
+                        ]
+                    )
+                    T.writes(T_reshape_handle_intermediate[v_ax0, v_ax1, 
v_ax2, v_ax3])
+                    T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] 
= lv2_0[
+                        (
+                            ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + 
v_ax1) // T.int64(4096)
+                            + v_ax0
+                        )
+                        % T.int64(2),
+                        ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + 
v_ax1) % T.int64(4096),
+                        (v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
+                    ]
+
+        @R.function
+        def main(
+            lv41_1: R.Tuple(
+                R.Tensor((2, 4096, 320), dtype="float16"),
+                R.Tensor((2, 4096, 320), dtype="float16"),
+                R.Tensor((2, 4096, 320), dtype="float16"),
+            )
+        ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"):
+            cls = Module
+            with R.dataflow():
+                lv: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[0]
+                lv1: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[1]
+                lv2: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[2]
+                lv645 = R.call_tir(
+                    cls.fused_reshape5,
+                    (lv, lv1, lv2),
+                    out_sinfo=R.Tensor((2, 4096, 8, 40), dtype="float16"),
+                )
+                out: R.Tensor((2, 4096, 8, 40), dtype="float16") = 
R.add(lv645, lv645)
+                R.output(out)
+            return out
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def fused_reshape5(
+            lv2_0: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), 
"float16"),
+            lv2_1: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), 
"float16"),
+            lv2_2: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), 
"float16"),
+            T_reshape_handle_intermediate: T.Buffer(
+                (T.int64(2), T.int64(4096), T.int64(8), T.int64(40)), "float16"
+            ),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4096), 
T.int64(8), T.int64(40)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(
+                        lv2_0[
+                            (
+                                ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) 
+ v_ax1)
+                                // T.int64(4096)
+                                + v_ax0
+                            )
+                            % T.int64(2),
+                            ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + 
v_ax1) % T.int64(4096),
+                            (v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
+                        ]
+                    )
+                    T.writes(T_reshape_handle_intermediate[v_ax0, v_ax1, 
v_ax2, v_ax3])
+                    T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] 
= lv2_0[
+                        (
+                            ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + 
v_ax1) // T.int64(4096)
+                            + v_ax0
+                        )
+                        % T.int64(2),
+                        ((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + 
v_ax1) % T.int64(4096),
+                        (v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
+                    ]
+
+        @R.function
+        def main(
+            lv41_1: R.Tuple(
+                R.Tensor((2, 4096, 320), dtype="float16"),
+                R.Tensor((2, 4096, 320), dtype="float16"),
+                R.Tensor((2, 4096, 320), dtype="float16"),
+            )
+        ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"):
+            with R.dataflow():
+                lv: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[0]
+                lv1: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[1]
+                lv2: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[2]
+                lv645: R.Tensor((2, 4096, 8, 40), dtype="float16") = R.reshape(
+                    lv, R.shape([2, 4096, 8, 40])
+                )
+                out: R.Tensor((2, 4096, 8, 40), dtype="float16") = 
R.add(lv645, lv645)
+                R.output(out)
+            return out
+
+    rewritten = relax.transform.RewriteDataflowReshape()(Module)
+    tvm.ir.assert_structural_equal(rewritten, Expected)
+
+
 if __name__ == "__main__":
-    test_reshape_pattern_detect()
-    # tvm.testing.main()
+    tvm.testing.main()

Reply via email to