This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 61fbf4262d [Unity] Fix FuseTIR when the same buffer is read multiple 
times with different access pattern (#14603)
61fbf4262d is described below

commit 61fbf4262d1729a97174fb118f05852fd8afc291
Author: masahi <[email protected]>
AuthorDate: Wed Apr 12 23:18:01 2023 +0900

    [Unity] Fix FuseTIR when the same buffer is read multiple times with 
different access pattern (#14603)
    
    When the same buffer is read multiple times with different access patterns 
in a single expression, the check below fails
    
    But this case should be allowed, for example in the following subgraph 
inp_0 is used twice in different read regions. See the test case for details. 
This subgraph arises if we run ConvertLayout on SD UNet from web-stable 
diffusion.
---
 src/relax/transform/fuse_tir.cc               |  43 +++------
 tests/python/relax/test_transform_fuse_tir.py | 120 ++++++++++++++++++++++++++
 2 files changed, 133 insertions(+), 30 deletions(-)

diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 183d395b60..432ddca0a7 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -117,15 +117,13 @@ class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& 
n, const PrimExpr& othe
 /*!
  * \brief Substitute a given source buffer with a given target buffer in 
statements or expressions.
  */
-class FuseTIRBufferSubstitor : private StmtExprMutator {
+class FuseTIRBufferSubstitutor : private StmtExprMutator {
  public:
-  explicit FuseTIRBufferSubstitor(const Map<Buffer, Buffer>& buffer_map,
-                                  const Map<Var, Var>& var_map) {
+  explicit FuseTIRBufferSubstitutor(const Map<Buffer, Buffer>& buffer_map,
+                                    const Map<Var, Var>& var_map) {
     buffer_remap_ = buffer_map;
     var_remap_ = var_map;
-    for (const auto& kv : buffer_map) {
-      const Buffer& src = kv.first;
-      const Buffer& tgt = kv.second;
+    for (const auto& [src, tgt] : buffer_map) {
       var_remap_.Set(src->data, tgt->data);
     }
   }
@@ -246,8 +244,6 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
   Map<tir::Buffer, tir::Buffer> buffer_remap_;
   /*! \brief Mapping from src tir var to tgt var. */
   Map<tir::Var, tir::Var> var_remap_;
-  /*! \brief The structural equality checker */
-  StructuralEqual structural_equal_;
 
   Array<tir::BufferRegion> UnionAccessRegion(const Array<BufferRegion>& 
regions) const {
     // For now we only allow Buffer access the same elements.
@@ -262,8 +258,6 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
       if (it == buffer_region_set.end()) {
         ret.push_back(region);
         buffer_region_set[region->buffer.get()] = region->region;
-      } else {
-        ICHECK(structural_equal_(region->region, it->second));
       }
     }
 
@@ -351,10 +345,8 @@ class FusedTIRConstructor : public ExprVisitor {
         // It's a symbolic shape var, no need to alloc Buffers.
         continue;
       }
-      auto ret = CreateParamsAndBuffers(GetStructInfo(relax_param),  //
-                                        relax_param->name_hint());
-      const Array<tir::Var>& params = ret.first;
-      const Array<tir::Buffer>& buffers = ret.second;
+      auto [params, buffers] = 
CreateParamsAndBuffers(GetStructInfo(relax_param),  //
+                                                      
relax_param->name_hint());
       ICHECK_EQ(params.size(), buffers.size());
       for (size_t i = 0; i < params.size(); ++i) {
         func_info_.buffer_map.Set(params[i], buffers[i]);
@@ -384,10 +376,8 @@ class FusedTIRConstructor : public ExprVisitor {
     // Step 4. Append symbolic vars
     const relax::Var& last_relax_param = func->params.back();
     if (GetStructInfo(last_relax_param)->IsInstance<ShapeStructInfoNode>()) {
-      auto ret =
+      auto [params, buffers] =
           CreateParamsAndBuffers(GetStructInfo(last_relax_param), 
last_relax_param->name_hint());
-      const Array<tir::Var>& params = ret.first;
-      const Array<tir::Buffer>& buffers = ret.second;
       ICHECK(buffers.empty());
       for (size_t i = 0; i < params.size(); ++i) {
         func_info_.params.push_back(params[i]);
@@ -682,9 +672,7 @@ class FusedTIRConstructor : public ExprVisitor {
              "list.";
       if (index == -1) index = 0;
       for (size_t i = 0; i < tuple->fields.size(); ++i) {
-        auto ret = CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
-        const Array<tir::Var>& ret_params = ret.first;
-        const Array<tir::Buffer>& ret_buffers = ret.second;
+        auto [ret_params, ret_buffers] = 
CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
         ICHECK_EQ(ret_params.size(), ret_buffers.size());
         // Adding tuple field results to the end of params and buffers.
         params.insert(params.end(), ret_params.begin(), ret_params.end());
@@ -714,19 +702,18 @@ class FusedTIRConstructor : public ExprVisitor {
   tir::PrimFunc ConstructFunc() {
     Map<String, ObjectRef> attr_map;
     attr_map.Set("tir.noalias", tir::const_true());
-    tir::FuseTIRBufferSubstitor substitor(func_info_.buffer_subst_map,
-                                          func_info_.symbolic_var_remap);
+    tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, 
func_info_.symbolic_var_remap);
     ICHECK(func_info_.global_name != "fused");
     // Remove output buffers from func_info_.alloc_buffers
     Array<tir::Buffer> alloc_buffers;
     for (const tir::Buffer& buf : func_info_.alloc_buffers) {
       if (func_info_.output_buffers.count(buf.get()) == 0) {
-        alloc_buffers.push_back(substitor.SubstituteAllocatedBuffer(buf));
+        alloc_buffers.push_back(subst.SubstituteAllocatedBuffer(buf));
       }
     }
     tir::Stmt body = 
tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies));
 
-    body = substitor.Substitute(body);
+    body = subst.Substitute(body);
     body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, 
alloc_buffers);
     body = tir::BlockRealize({}, Bool(true), Downcast<tir::Block>(body));
     tir::PrimFunc func(func_info_.params, body, VoidType(), 
func_info_.buffer_map,
@@ -804,9 +791,7 @@ class TIRFuseMutator : public ExprMutator {
     // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty 
block builder.
     TIRFuseMutator mutator(mod);
     // Step 1. Fuse all primitive relax functions, store the result in 
`fused_tir_funcs_`
-    for (const auto& kv : mod->functions) {
-      const GlobalVar& gv = kv.first;
-      const BaseFunc& func = kv.second;
+    for (const auto& [gv, func] : mod->functions) {
       // Only fuse primitive relax functions
       if (func->IsInstance<relax::FunctionNode>() && 
func->HasNonzeroAttr(attr::kPrimitive)) {
         tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv);
@@ -816,9 +801,7 @@ class TIRFuseMutator : public ExprMutator {
 
     // Step 2. Update all non-primitive relax functions and add it, with the 
dependent function,
     // into the new IRModule
-    for (const auto& kv : mod->functions) {
-      const GlobalVar& gv = kv.first;
-      const BaseFunc& func = kv.second;
+    for (const auto& [gv, func] : mod->functions) {
       if (func->IsInstance<relax::FunctionNode>() && 
!func->HasNonzeroAttr(attr::kPrimitive)) {
         relax::Function update_func = 
Downcast<Function>(mutator.VisitExpr(func));
         mutator.builder_->AddFunction(update_func, gv->name_hint);
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
index 8b856d3cc5..c7aa7984be 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -883,5 +883,125 @@ def test_symbolic_var_in_call_tir_args():
     _check(Before, Expected)
 
 
+def test_same_buffer_multiple_read():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def concatenate(
+            rxplaceholder: T.Buffer((T.int64(1), T.int64(4), T.int64(64), 
T.int64(64)), "float32"),
+            rxplaceholder_1: T.Buffer(
+                (T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"
+            ),
+            T_concat: T.Buffer((T.int64(2), T.int64(4), T.int64(64), 
T.int64(64)), "float32"),
+        ):
+            T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), 
T.int64(64), T.int64(64)):
+                with T.block("T_concat"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(
+                        rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, 
v_ax3],
+                        rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
+                    )
+                    T.writes(T_concat[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
+                        T.int64(1) <= v_ax0,
+                        rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, 
v_ax3],
+                        rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
+                    )
+
+        @T.prim_func
+        def transpose2(
+            rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(64), 
T.int64(64)), "float32"),
+            T_transpose: T.Buffer((T.int64(2), T.int64(64), T.int64(64), 
T.int64(4)), "float32"),
+        ):
+            T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), 
T.int64(64), T.int64(4)):
+                with T.block("T_transpose"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2])
+                    T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[
+                        v_ax0, v_ax3, v_ax1, v_ax2
+                    ]
+
+        @R.function
+        def fused_concatenate_transpose2(
+            inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
+        ) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.concatenate,
+                    (inp_0, inp_0),
+                    out_sinfo=R.Tensor((2, 4, 64, 64), dtype="float32"),
+                )
+                gv = R.call_tir(
+                    cls.transpose2, (lv,), out_sinfo=R.Tensor((2, 64, 64, 4), 
dtype="float32")
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
+        ) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
+            R.func_attr({"num_input": 3})
+            cls = Module
+            with R.dataflow():
+                lv = cls.fused_concatenate_transpose2(inp_0)
+                R.output(lv)
+            return lv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def fused_concatenate_transpose2(
+            inp_0: T.Buffer((T.int64(1), T.int64(4), T.int64(64), 
T.int64(64)), "float32"),
+            T_transpose_handle_intermediate: T.Buffer(
+                (T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32"
+            ),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            T_concat_handle_intermediate = T.alloc_buffer(
+                (T.int64(2), T.int64(4), T.int64(64), T.int64(64))
+            )
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), 
T.int64(64), T.int64(64)):
+                with T.block("T_concat"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(inp_0[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3])
+                    T.writes(T_concat_handle_intermediate[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                    T_concat_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = 
T.if_then_else(
+                        T.int64(1) <= v_ax0,
+                        inp_0[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3],
+                        inp_0[v_ax0, v_ax1, v_ax2, v_ax3],
+                    )
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), 
T.int64(64), T.int64(4)):
+                with T.block("T_transpose"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_concat_handle_intermediate[v_ax0, v_ax3, v_ax1, 
v_ax2])
+                    T.writes(T_transpose_handle_intermediate[v_ax0, v_ax1, 
v_ax2, v_ax3])
+                    T_transpose_handle_intermediate[
+                        v_ax0, v_ax1, v_ax2, v_ax3
+                    ] = T_concat_handle_intermediate[v_ax0, v_ax3, v_ax1, 
v_ax2]
+
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
+        ) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
+            R.func_attr({"num_input": 3})
+            cls = Expected
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.fused_concatenate_transpose2,
+                    (inp_0,),
+                    out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32"),
+                )
+                R.output(lv)
+            return lv
+
+    _check(Module, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to