nverke commented on code in PR #13301:
URL: https://github.com/apache/tvm/pull/13301#discussion_r1018349185


##########
src/tir/schedule/analysis/reducer.cc:
##########
@@ -572,9 +572,25 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block) 
{
     if (!store) {
       return true;
     }
-    ICHECK(buffer_written.count(store->buffer.get()))
-        << "ValueError: The buffer \"" << store->buffer
-        << "\" is written in the block but is not in the block's signature";
+    const auto* body_block = block->body.as<BlockRealizeNode>();

Review Comment:
   Hmm I am not sure if I understand. Are you talking about a situation like 
this? 
   ```
   @T.prim_func
   def nested_reduction_loop_with_match_buffers(
       in0: T.Buffer[(4, 4, 4), "int8"],
       in1: T.Buffer[(4, 4, 4), "int8"],
       out: T.Buffer[(4, 4, 4), "int8"],
   ) -> None:
       # body
       # with T.block("root")
       for y in T.serial(4):
           with T.block("C"):
               T.reads(in0[y, 0:4, 0:4], in1[y, 0:4, 0:4])
               T.writes(out[y, 0:4, 0:4])
               for x in T.serial(4):
                   with T.block("C"):
                       T.reads(in0[y, x, 0:4], in1[y, x, 0:4])
                       T.writes(out[y, x, 0:4])
                       A = T.match_buffer(in0[y, x, 0:4], [4], dtype="int8", 
offset_factor=1)
                       B = T.match_buffer(in1[y, x, 0:4], [4], dtype="int8", 
offset_factor=1)
                       C = T.match_buffer(out[y, x, 0:4], [4], dtype="int8", 
offset_factor=1)
                       A_i8x4: T.int8x4 = A[0:4]
                       A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
                       B_i8x4: T.int8x4 = B[0:4]
                       B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
                       C[0:4] = T.reinterpret(A_i32 + B_i32, dtype="int8x4")
   ```
   My understanding is that this check is on the level of the store statement 
and already has collected all of the write regions for the parent loops so just 
adding the regions that the match buffers check should be enough. 
   
   Alternatively are you referring to something like this? 
   ```
   @T.prim_func
   def nested_reduction_loop_with_match_buffers(
       in0: T.Buffer[(4, 4, 4), "int8"],
       in1: T.Buffer[(4, 4, 4), "int8"],
       out: T.Buffer[(4, 4, 4), "int8"],
   ) -> None:
       # body
       # with T.block("root")
       for y in T.serial(4):
           with T.block("C"):
               T.reads(in0[y, 0:4, 0:4], in1[y, 0:4, 0:4])
               T.writes(out[y, 0:4, 0:4])
               A = T.match_buffer(in0[y, 0:4, 0:4], [4, 4], dtype="int8", 
offset_factor=1)
               B = T.match_buffer(in1[y, 0:4, 0:4], [4, 4], dtype="int8", 
offset_factor=1)
               C = T.match_buffer(out[y, 0:4, 0:4], [4, 4], dtype="int8", 
offset_factor=1)
               for x in T.serial(4):
                   with T.block("C"):
                       T.reads(in0[y, x, 0:4], in1[y, x, 0:4])
                       T.writes(out[y, x, 0:4])
                       A_i8x4: T.int8x4 = A[x, 0:4]
                       A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
                       B_i8x4: T.int8x4 = B[x, 0:4]
                       B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
                       C[x, 0:4] = T.reinterpret(A_i32 + B_i32, dtype="int8x4")
   ```
   Here I believe we are still able to pickup the match buffers from the body 
block. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to