zxybazh opened a new issue, #16614:
URL: https://github.com/apache/tvm/issues/16614
### Expected behavior
MetaSchedule Tuning Works for the given Conv2d workload
### Actual behavior
Triggers an error `ValueError: The block no longer exists in the IRModule`
during application of schedule rule Multi-level tiling with tensor intrin. I
notcied that `state->tensor_core_reindex_store` would point to a block that is
already merged into another block via ComputeInline during application of
`TileWithTensorIntrin`.
### Environment
Latest TVM Main
### Steps to reproduce
```
import tvm
from tvm.script import tir as T
from tvm import meta_schedule as ms
@T.prim_func(private=True)
def fused_conv2d_add1(reshape3: T.Buffer((T.int64(50), T.int64(8),
T.int64(72), T.int64(128)), "float16"), conv_in_weight: T.Buffer((T.int64(320),
T.int64(8), T.int64(3), T.int64(3)), "float16"), lv23: T.Buffer((T.int64(1),
T.int64(320), T.int64(1), T.int64(1)), "float16"), T_add_intermediate:
T.Buffer((T.int64(50), T.int64(320), T.int64(72), T.int64(128)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(50), T.int64(8), T.int64(74),
T.int64(130)), "float16")
conv2d_nchw_intermediate = T.alloc_buffer((T.int64(50), T.int64(320),
T.int64(72), T.int64(128)), "float16")
for i0, i1, i2, i3 in T.grid(T.int64(50), T.int64(8), T.int64(74),
T.int64(130)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(reshape3[v_i0, v_i1, v_i2 - T.int64(1), v_i3 -
T.int64(1)])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <=
v_i2 and v_i2 < T.int64(73) and T.int64(1) <= v_i3 and v_i3 < T.int64(129),
reshape3[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float16(0))
for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(50), T.int64(320),
T.int64(72), T.int64(128), T.int64(8), T.int64(3), T.int64(3)):
with T.block("conv2d_nchw"):
v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx =
T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx],
conv_in_weight[v_ff, v_rc, v_ry, v_rx])
T.writes(conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx])
with T.init():
conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] =
T.float16(0)
conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] =
conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy +
v_ry, v_xx + v_rx] * conv_in_weight[v_ff, v_rc, v_ry, v_rx]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(50), T.int64(320), T.int64(72),
T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
T.reads(conv2d_nchw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3],
lv23[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] =
conv2d_nchw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv23[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
target=tvm.target.Target("nvidia/nvidia-a10g")
func = fused_conv2d_add1
ms.tune_tir(func, target=target, max_trials_global=100, work_dir="./temp")
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]