YXY-0922 commented on code in PR #17161:
URL: https://github.com/apache/tvm/pull/17161#discussion_r1680496799
##########
src/tir/schedule/transform.cc:
##########
@@ -340,7 +340,9 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule&
sch, const tir::Block
}
auto consumers = sch->GetConsumers(block_rv);
for (const auto& consumer : consumers) {
- sch->ComputeInline(consumer);
+ auto sref = sch->GetSRef(consumer);
+ if (!tir::IsOutputBlock(sch->state(), sref,
tir::GetScopeRoot(sch->state(), sref, true)))
Review Comment:
Sure, I encountered this bug while using the meta_schedule to tune a conv2d
operator. Here is the TIR example:
```python
import tvm
from tvm import te, topi, tir
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm.tir.schedule.transform import tile_with_tensor_intrin
from tvm.tir.tensor_intrin.cuda import
WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
@tvm.script.ir_module
class conv2d_Module:
@T.prim_func
def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3,
7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
pad_temp_reindex = T.alloc_buffer((200704, 147), "float16")
B_reindex = T.alloc_buffer((64, 147), "float16")
for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2
and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3],
T.float16(0))
for ax0, ax1 in T.grid(200704, 147):
with T.block("pad_temp_reindex_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 *
2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
T.writes(pad_temp_reindex[v0, v1])
pad_temp_reindex[v0, v1] = pad_temp[v0 // 12544, v1 // 49,
v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7]
for ax0, ax1 in T.grid(64, 147):
with T.block("B_reindex_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
T.writes(B_reindex[v0, v1])
B_reindex[v0, v1] = B[v0, v1 // 49, v1 % 49 // 7, v1 % 7]
for ax0, ax1, ax2 in T.grid(200704, 64, 147):
with T.block("conv2d_nchw"):
v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2])
T.reads(pad_temp_reindex[v0, v2], B_reindex[v1, v2])
T.writes(conv2d_nchw_reindex[v0, v1])
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
with T.init():
conv2d_nchw_reindex[v0, v1] = T.float16(0)
conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] +
pad_temp_reindex[v0, v2] * B_reindex[v1, v2]
for ax0, ax1 in T.grid(200704, 64):
with T.block("conv2d_nchw_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(conv2d_nchw_reindex[v0, v1])
T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0
% 112])
conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] =
conv2d_nchw_reindex[v0, v1]
sch = tvm.tir.Schedule(conv2d_Module)
intrin = WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
block = sch.get_block("conv2d_nchw")
tiled_loop = tile_with_tensor_intrin(sch, block, intrin, True)
print(sch.mod)
```
And the output is :
```python
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3,
7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
pad_temp_reindex_pad = T.alloc_buffer((200704, 160), "float16")
B_reindex_pad = T.alloc_buffer((64, 160), "float16")
for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2
and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3],
T.float16(0))
for i0, i1 in T.grid(200704, 160):
with T.block("pad_temp_reindex_pad"):
v0, v1 = T.axis.remap("SS", [i0, i1])
T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 *
2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
T.writes(pad_temp_reindex_pad[v0, v1])
pad_temp_reindex_pad[v0, v1] = T.if_then_else(v1 < 147,
pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112
* 2 + v1 % 7], T.float16(0))
for i0, i1 in T.grid(64, 160):
with T.block("B_reindex_pad"):
v0, v1 = T.axis.remap("SS", [i0, i1])
T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
T.writes(B_reindex_pad[v0, v1])
B_reindex_pad[v0, v1] = T.if_then_else(v1 < 147, B[v0, v1 //
49, v1 % 49 // 7, v1 % 7], T.float16(0))
for ax0_0, ax1_0, ax2_0, ax0_1, ax1_1, ax2_1 in T.grid(12544, 4, 10,
16, 16, 16):
with T.block("conv2d_nchw"):
v0 = T.axis.spatial(200704, ax0_0 * 16 + ax0_1)
v1 = T.axis.spatial(64, ax1_0 * 16 + ax1_1)
v2 = T.axis.reduce(160, ax2_0 * 16 + ax2_1)
T.reads(pad_temp_reindex_pad[v0, v2], B_reindex_pad[v1, v2])
T.writes(conv2d_nchw_reindex[v0, v1])
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
with T.init():
conv2d_nchw_reindex[v0, v1] = T.float16(0)
conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] +
pad_temp_reindex_pad[v0, v2] * B_reindex_pad[v1, v2]
for ax0, ax1 in T.grid(200704, 64):
with T.block("conv2d_nchw_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(conv2d_nchw_reindex[v0, v1])
T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0
% 112])
conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] =
conv2d_nchw_reindex[v0, v1]
```
The product of the three reduction axes is 147, hence padding is required.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]