nverke commented on code in PR #13301:
URL: https://github.com/apache/tvm/pull/13301#discussion_r1018349185
##########
src/tir/schedule/analysis/reducer.cc:
##########
@@ -572,9 +572,25 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block)
{
if (!store) {
return true;
}
- ICHECK(buffer_written.count(store->buffer.get()))
- << "ValueError: The buffer \"" << store->buffer
- << "\" is written in the block but is not in the block's signature";
+ const auto* body_block = block->body.as<BlockRealizeNode>();
Review Comment:
Hmm I am not sure if I understand. Are you talking about a situation like
this?
```
@T.prim_func
def nested_reduction_loop_with_match_buffers(
in0: T.Buffer[(4, 4, 4), "int8"],
in1: T.Buffer[(4, 4, 4), "int8"],
out: T.Buffer[(4, 4, 4), "int8"],
) -> None:
# body
# with T.block("root")
for y in T.serial(4):
with T.block("C"):
T.reads(in0[y, 0:4, 0:4], in1[y, 0:4, 0:4])
T.writes(out[y, 0:4, 0:4])
for x in T.serial(4):
with T.block("C"):
T.reads(in0[y, x, 0:4], in1[y, x, 0:4])
T.writes(out[y, x, 0:4])
A = T.match_buffer(in0[y, x, 0:4], [4], dtype="int8",
offset_factor=1)
B = T.match_buffer(in1[y, x, 0:4], [4], dtype="int8",
offset_factor=1)
C = T.match_buffer(out[y, x, 0:4], [4], dtype="int8",
offset_factor=1)
A_i8x4: T.int8x4 = A[0:4]
A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
B_i8x4: T.int8x4 = B[0:4]
B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
C[0:4] = T.reinterpret(A_i32 + B_i32, dtype="int8x4")
```
My understanding is that this check is on the level of the store statement
and already has collected all of the write regions for the parent loops so just
adding the regions that the match buffers check should be enough.
Alternatively are you referring to something like this?
```
@T.prim_func
def nested_reduction_loop_with_match_buffers(
in0: T.Buffer[(4, 4, 4), "int8"],
in1: T.Buffer[(4, 4, 4), "int8"],
out: T.Buffer[(4, 4, 4), "int8"],
) -> None:
# body
# with T.block("root")
for y in T.serial(4):
with T.block("C"):
T.reads(in0[y, 0:4, 0:4], in1[y, 0:4, 0:4])
T.writes(out[y, 0:4, 0:4])
A = T.match_buffer(in0[y, 0:4, 0:4], [4, 4], dtype="int8",
offset_factor=1)
B = T.match_buffer(in1[y, 0:4, 0:4], [4, 4], dtype="int8",
offset_factor=1)
C = T.match_buffer(out[y, 0:4, 0:4], [4, 4], dtype="int8",
offset_factor=1)
for x in T.serial(4):
with T.block("C"):
T.reads(in0[y, x, 0:4], in1[y, x, 0:4])
T.writes(out[y, x, 0:4])
A_i8x4: T.int8x4 = A[x, 0:4]
A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
B_i8x4: T.int8x4 = B[x, 0:4]
B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
C[x, 0:4] = T.reinterpret(A_i32 + B_i32, dtype="int8x4")
```
Here I believe we are still able to pickup the match buffers from the body
block.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]