YXY-0922 commented on code in PR #17161:
URL: https://github.com/apache/tvm/pull/17161#discussion_r1680496799


##########
src/tir/schedule/transform.cc:
##########
@@ -340,7 +340,9 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& 
sch, const tir::Block
     }
     auto consumers = sch->GetConsumers(block_rv);
     for (const auto& consumer : consumers) {
-      sch->ComputeInline(consumer);
+      auto sref = sch->GetSRef(consumer);
+      if (!tir::IsOutputBlock(sch->state(), sref, 
tir::GetScopeRoot(sch->state(), sref, true)))

Review Comment:
   Sure, I encountered this bug while using the meta_schedule to tune a conv2d 
operator. Here is the TIR example:
   
   ```python
   import tvm
   from tvm import te, topi, tir
   from tvm.ir.module import IRModule
   from tvm.script import tir as T
   from tvm.tir.schedule.transform import tile_with_tensor_intrin
   from tvm.tir.tensor_intrin.cuda import 
WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
   
   
   @tvm.script.ir_module
   class conv2d_Module:
       @T.prim_func
       def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3, 
7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
           conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
           pad_temp_reindex = T.alloc_buffer((200704, 147), "float16")
           B_reindex = T.alloc_buffer((64, 147), "float16")
           for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
               with T.block("pad_temp"):
                   v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
                   T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
                   T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                   pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2 
and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3], 
T.float16(0))
           for ax0, ax1 in T.grid(200704, 147):
               with T.block("pad_temp_reindex_reindex"):
                   v0, v1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 
2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
                   T.writes(pad_temp_reindex[v0, v1])
                   pad_temp_reindex[v0, v1] = pad_temp[v0 // 12544, v1 // 49, 
v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7]
           for ax0, ax1 in T.grid(64, 147):
               with T.block("B_reindex_reindex"):
                   v0, v1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
                   T.writes(B_reindex[v0, v1])
                   B_reindex[v0, v1] = B[v0, v1 // 49, v1 % 49 // 7, v1 % 7]
           for ax0, ax1, ax2 in T.grid(200704, 64, 147):
               with T.block("conv2d_nchw"):
                   v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2])
                   T.reads(pad_temp_reindex[v0, v2], B_reindex[v1, v2])
                   T.writes(conv2d_nchw_reindex[v0, v1])
                   T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
                   with T.init():
                       conv2d_nchw_reindex[v0, v1] = T.float16(0)
                   conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] + 
pad_temp_reindex[v0, v2] * B_reindex[v1, v2]
           for ax0, ax1 in T.grid(200704, 64):
               with T.block("conv2d_nchw_reindex"):
                   v0, v1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(conv2d_nchw_reindex[v0, v1])
                   T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 
% 112])
                   conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] = 
conv2d_nchw_reindex[v0, v1]
   
   sch = tvm.tir.Schedule(conv2d_Module)
   
   intrin =  WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
   block = sch.get_block("conv2d_nchw")
   
   tiled_loop = tile_with_tensor_intrin(sch, block, intrin, True)
   
   print(sch.mod)
   ```
   
   And the output is :
   ```python
   # from tvm.script import ir as I
   # from tvm.script import tir as T
   
   @I.ir_module
   class Module:
       @T.prim_func
       def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3, 
7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
           conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
           pad_temp_reindex_pad = T.alloc_buffer((200704, 160), "float16")
           B_reindex_pad = T.alloc_buffer((64, 160), "float16")
           for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
               with T.block("pad_temp"):
                   v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
                   T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
                   T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                   pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2 
and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3], 
T.float16(0))
           for i0, i1 in T.grid(200704, 160):
               with T.block("pad_temp_reindex_pad"):
                   v0, v1 = T.axis.remap("SS", [i0, i1])
                   T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 
2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
                   T.writes(pad_temp_reindex_pad[v0, v1])
                   pad_temp_reindex_pad[v0, v1] = T.if_then_else(v1 < 147, 
pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 
* 2 + v1 % 7], T.float16(0))
           for i0, i1 in T.grid(64, 160):
               with T.block("B_reindex_pad"):
                   v0, v1 = T.axis.remap("SS", [i0, i1])
                   T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
                   T.writes(B_reindex_pad[v0, v1])
                   B_reindex_pad[v0, v1] = T.if_then_else(v1 < 147, B[v0, v1 // 
49, v1 % 49 // 7, v1 % 7], T.float16(0))
           for ax0_0, ax1_0, ax2_0, ax0_1, ax1_1, ax2_1 in T.grid(12544, 4, 10, 
16, 16, 16):
               with T.block("conv2d_nchw"):
                   v0 = T.axis.spatial(200704, ax0_0 * 16 + ax0_1)
                   v1 = T.axis.spatial(64, ax1_0 * 16 + ax1_1)
                   v2 = T.axis.reduce(160, ax2_0 * 16 + ax2_1)
                   T.reads(pad_temp_reindex_pad[v0, v2], B_reindex_pad[v1, v2])
                   T.writes(conv2d_nchw_reindex[v0, v1])
                   T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
                   with T.init():
                       conv2d_nchw_reindex[v0, v1] = T.float16(0)
                   conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] + 
pad_temp_reindex_pad[v0, v2] * B_reindex_pad[v1, v2]
           for ax0, ax1 in T.grid(200704, 64):
               with T.block("conv2d_nchw_reindex"):
                   v0, v1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(conv2d_nchw_reindex[v0, v1])
                   T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 
% 112])
                   conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] = 
conv2d_nchw_reindex[v0, v1]
   ```
   The product of the three reduction axes is 147, hence padding is required.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to