wrongtest-intellif commented on code in PR #15961:
URL: https://github.com/apache/tvm/pull/15961#discussion_r1368057844
##########
src/tir/transforms/inject_ptx_async_copy.cc:
##########
@@ -113,9 +116,11 @@ class PTXAsyncCopyInjector : public StmtMutator {
return PrimExpr();
}();
if (src_offset.defined() && dst_offset.defined()) {
Review Comment:
Could you add one unittest for related circumstance? thank you!
##########
src/tir/transforms/ir_utils.cc:
##########
@@ -417,7 +417,8 @@ Array<PrimExpr> GetBufferAllocationShape(const Buffer&
buffer) {
if (buffer->strides.size()) {
ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
- ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i])));
Review Comment:
Could you add one unittest for the change? Eg, check the TIR before and
after `LowerOpaqueBlock`:
The whole script is as below, of course, we can only keep part of it
including `lv55_reindex_pad_shared_dyn = T.alloc_buffer((T.int64(1), T.min((n +
T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), T.int64(64)),
"float16", strides=(T.int64(72) * T.min((n + T.int64(63)) // T.int64(64) *
T.int64(64), T.int64(96)), T.int64(72), T.int64(1)), scope="shared.dyn")`
```python
@T.prim_func(private=True)
def fused_NT_matmul4_add1(p_lv55: T.handle, lv11: T.Buffer((T.int64(2560),
T.int64(10240)), "float16"), lv12: T.Buffer((T.int64(2560),), "float16"),
p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
lv55 = T.match_buffer(p_lv55, (T.int64(1), n, T.int64(10240)),
"float16")
var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n,
T.int64(2560)), "float16")
# with T.block("root"):
for blockIdx_z in T.thread_binding(T.int64(1), thread="blockIdx.z"):
for blockIdx_x in T.thread_binding(((n + T.int64(63)) //
T.int64(64) * T.int64(4) + T.int64(7)) // T.int64(8), thread="blockIdx.x"):
for blockIdx_y in T.thread_binding(T.int64(20),
thread="blockIdx.y"):
for threadIdx_x in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
for threadIdx_y in T.thread_binding(T.int64(4),
thread="threadIdx.y"):
with T.block(""):
T.reads(lv55[T.int64(0), blockIdx_x *
T.int64(128):blockIdx_x * T.int64(128) + T.int64(128),
T.int64(0):T.int64(10240)], lv11[blockIdx_y * T.int64(128):blockIdx_y *
T.int64(128) + T.int64(128), T.int64(0):T.int64(10240)], lv12[blockIdx_y *
T.int64(128) + threadIdx_y // T.int64(2) * T.int64(64):blockIdx_y *
T.int64(128) + threadIdx_y // T.int64(2) * T.int64(64) + T.int64(64)])
T.writes(var_T_add_intermediate[T.int64(0),
blockIdx_x * T.int64(128) + threadIdx_y % T.int64(2) * T.int64(64):blockIdx_x *
T.int64(128) + threadIdx_y % T.int64(2) * T.int64(64) + T.int64(64), blockIdx_y
* T.int64(128) + threadIdx_y // T.int64(2) * T.int64(64):blockIdx_y *
T.int64(128) + threadIdx_y // T.int64(2) * T.int64(64) + T.int64(64)])
var_NT_matmul_intermediate_reindex_pad_shared_dyn = T.alloc_buffer((T.int64(1),
T.int64(128), T.int64(128)), "float16", scope="shared.dyn")
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator =
T.alloc_buffer((T.int64(1), T.int64(64), T.int64(64)), "float16",
scope="wmma.accumulator")
for ax1_0_3, ax2_0_3 in T.grid(T.int64(2),
T.int64(2)):
for ax1_0_4_init, ax2_0_4_init in
T.grid(T.int64(2), T.int64(2)):
with T.block("NT_matmul_o_init"):
T.where(blockIdx_x * T.int64(2)
+ threadIdx_y % T.int64(2) - (n + T.int64(63)) // T.int64(64) < T.int64(0))
T.reads()
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32) + ax1_0_4_init * T.int64(16):ax1_0_3 * T.int64(32) +
ax1_0_4_init * T.int64(16) + T.int64(16), ax2_0_3 * T.int64(32) + ax2_0_4_init
* T.int64(16):ax2_0_3 * T.int64(32) + ax2_0_4_init * T.int64(16) + T.int64(16)])
with T.block("NT_matmul_init_o"):
T.reads()
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32) + ax1_0_4_init * T.int64(16):ax1_0_3 * T.int64(32) +
ax1_0_4_init * T.int64(16) + T.int64(16), ax2_0_3 * T.int64(32) + ax2_0_4_init
* T.int64(16):ax2_0_3 * T.int64(32) + ax2_0_4_init * T.int64(16) + T.int64(16)])
T.tvm_fill_fragment(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator.data,
16, 16, 16, ax1_0_3 * T.int64(8) + ax1_0_4_init * T.int64(4) + ax2_0_3 *
T.int64(2) + ax2_0_4_init, T.float32(0))
for ax3_0_0 in range(T.int64(160)):
with T.block(""):
T.reads(lv55[T.int64(0),
blockIdx_x * T.int64(128) + ax1_0_3 * T.int64(32):blockIdx_x * T.int64(128) +
ax1_0_3 * T.int64(32) + T.int64(96), ax3_0_0 * T.int64(64):ax3_0_0 *
T.int64(64) + T.int64(64)], lv11[blockIdx_y * T.int64(128) + ax2_0_3 *
T.int64(32):blockIdx_y * T.int64(128) + ax2_0_3 * T.int64(32) + T.int64(96),
ax3_0_0 * T.int64(64):ax3_0_0 * T.int64(64) + T.int64(64)],
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32):ax1_0_3 * T.int64(32) + T.int64(32), ax2_0_3 *
T.int64(32):ax2_0_3 * T.int64(32) + T.int64(32)])
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32):ax1_0_3 * T.int64(32) + T.int64(32), ax2_0_3 *
T.int64(32):ax2_0_3 * T.int64(32) + T.int64(32)])
lv55_reindex_pad_shared_dyn =
T.alloc_buffer((T.int64(1), T.min((n + T.int64(63)) // T.int64(64) *
T.int64(64), T.int64(96)), T.int64(64)), "float16", strides=(T.int64(72) *
T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)),
T.int64(72), T.int64(1)), scope="shared.dyn")
lv11_reindex_shared_dyn =
T.alloc_buffer((T.int64(1), T.int64(96), T.int64(64)), "float16",
strides=(T.int64(6912), T.int64(72), T.int64(1)), scope="shared.dyn")
for ax0_ax1_fused_0 in
range(T.int64(12)):
for ax0_ax1_fused_3 in
T.vectorized(T.int64(4)):
with
T.block("lv55_reindex_pad_shared.dyn"):
T.where(blockIdx_x *
T.int64(16) + ax1_0_3 * T.int64(4) + ax0_ax1_fused_0 - (n + T.int64(63)) //
T.int64(64) * T.int64(8) < T.int64(0))
T.reads(lv55[T.int64(0), blockIdx_x * T.int64(128) + ax1_0_3 * T.int64(32) +
ax0_ax1_fused_0 * T.int64(8) + threadIdx_y * T.int64(2) + threadIdx_x //
T.int64(16), ax3_0_0 * T.int64(64) + threadIdx_x % T.int64(16) * T.int64(4) +
ax0_ax1_fused_3])
T.writes(lv55_reindex_pad_shared_dyn[T.int64(0), ax0_ax1_fused_0 * T.int64(8) +
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x %
T.int64(16) * T.int64(4) + ax0_ax1_fused_3])
T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
lv55_reindex_pad_shared_dyn[T.int64(0), ax0_ax1_fused_0 * T.int64(8) +
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x %
T.int64(16) * T.int64(4) + ax0_ax1_fused_3] = T.if_then_else(blockIdx_x *
T.int64(128) + ax1_0_3 * T.int64(32) + ax0_ax1_fused_0 * T.int64(8) +
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16) < n, lv55[T.int64(0),
blockIdx_x * T.int64(128) + ax1_0_3 * T.int64(32) + ax0_ax1_fused_0 *
T.int64(8) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), ax3_0_0 *
T.int64(64) + threadIdx_x % T.int64(16) * T.int64(4) + ax0_ax1_fused_3],
T.float16(0))
for ax0_ax1_fused_0 in
range(T.int64(12)):
for ax0_ax1_fused_3 in
T.vectorized(T.int64(4)):
with
T.block("lv11_reindex_shared.dyn"):
T.reads(lv11[blockIdx_y * T.int64(128) + ax2_0_3 * T.int64(32) +
ax0_ax1_fused_0 * T.int64(8) + threadIdx_y * T.int64(2) + threadIdx_x //
T.int64(16), ax3_0_0 * T.int64(64) + threadIdx_x % T.int64(16) * T.int64(4) +
ax0_ax1_fused_3])
T.writes(lv11_reindex_shared_dyn[T.int64(0), ax0_ax1_fused_0 * T.int64(8) +
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x %
T.int64(16) * T.int64(4) + ax0_ax1_fused_3])
T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
lv11_reindex_shared_dyn[T.int64(0), ax0_ax1_fused_0 * T.int64(8) + threadIdx_y
* T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) *
T.int64(4) + ax0_ax1_fused_3] = lv11[blockIdx_y * T.int64(128) + ax2_0_3 *
T.int64(32) + ax0_ax1_fused_0 * T.int64(8) + threadIdx_y * T.int64(2) +
threadIdx_x // T.int64(16), ax3_0_0 * T.int64(64) + threadIdx_x % T.int64(16) *
T.int64(4) + ax0_ax1_fused_3]
for ax3_0_1 in range(T.int64(4)):
with T.block(""):
T.reads(lv55_reindex_pad_shared_dyn[T.int64(0), threadIdx_y % T.int64(2) *
T.int64(64):threadIdx_y % T.int64(2) * T.int64(64) + T.int64(32), ax3_0_1 *
T.int64(16):ax3_0_1 * T.int64(16) + T.int64(16)],
lv11_reindex_shared_dyn[T.int64(0), threadIdx_y // T.int64(2) *
T.int64(64):threadIdx_y // T.int64(2) * T.int64(64) + T.int64(32), ax3_0_1 *
T.int64(16):ax3_0_1 * T.int64(16) + T.int64(16)],
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32):ax1_0_3 * T.int64(32) + T.int64(32), ax2_0_3 *
T.int64(32):ax2_0_3 * T.int64(32) + T.int64(32)])
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32):ax1_0_3 * T.int64(32) + T.int64(32), ax2_0_3 *
T.int64(32):ax2_0_3 * T.int64(32) + T.int64(32)])
lv55_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((T.int64(1),
T.int64(32), T.int64(16)), "float16", scope="wmma.matrix_a")
lv11_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((T.int64(1),
T.int64(32), T.int64(16)), "float16", scope="wmma.matrix_b")
for ax0_0 in
T.unroll(T.int64(2)):
for ax1_0 in
T.unroll(T.int64(1)):
for ax0_1, ax1_1
in T.grid(T.int64(16), T.int64(16)):
with
T.block("lv55_reindex_pad_shared.dyn_wmma.matrix_a"):
T.where(blockIdx_x * T.int64(2) + threadIdx_y % T.int64(2) - (n + T.int64(63))
// T.int64(64) < T.int64(0))
T.reads(lv55_reindex_pad_shared_dyn[T.int64(0), threadIdx_y % T.int64(2) *
T.int64(64) + ax0_0 * T.int64(16) + ax0_1, ax3_0_1 * T.int64(16) + ax1_1])
T.writes(lv55_reindex_pad_shared_dyn_wmma_matrix_a[T.int64(0), ax0_0 *
T.int64(16) + ax0_1, ax1_1])
lv55_reindex_pad_shared_dyn_wmma_matrix_a[T.int64(0), ax0_0 * T.int64(16) +
ax0_1, ax1_1] = lv55_reindex_pad_shared_dyn[T.int64(0), threadIdx_y %
T.int64(2) * T.int64(64) + ax0_0 * T.int64(16) + ax0_1, ax3_0_1 * T.int64(16) +
ax1_1]
for ax0, ax1 in
T.grid(T.int64(32), T.int64(16)):
with
T.block("lv11_reindex_shared.dyn_wmma.matrix_b"):
T.reads(lv11_reindex_shared_dyn[T.int64(0), threadIdx_y // T.int64(2) *
T.int64(64) + ax0, ax3_0_1 * T.int64(16) + ax1])
T.writes(lv11_reindex_shared_dyn_wmma_matrix_b[T.int64(0), ax0, ax1])
lv11_reindex_shared_dyn_wmma_matrix_b[T.int64(0), ax0, ax1] =
lv11_reindex_shared_dyn[T.int64(0), threadIdx_y // T.int64(2) * T.int64(64) +
ax0, ax3_0_1 * T.int64(16) + ax1]
for ax1_0_4, ax2_0_4 in
T.grid(T.int64(2), T.int64(2)):
with
T.block("NT_matmul_o_update"):
T.where(blockIdx_x * T.int64(2) + threadIdx_y % T.int64(2) - (n + T.int64(63))
// T.int64(64) < T.int64(0))
T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16):ax1_0_3 * T.int64(32) + ax1_0_4
* T.int64(16) + T.int64(16), ax2_0_3 * T.int64(32) + ax2_0_4 *
T.int64(16):ax2_0_3 * T.int64(32) + ax2_0_4 * T.int64(16) + T.int64(16)],
lv55_reindex_pad_shared_dyn_wmma_matrix_a[T.int64(0), ax1_0_4 *
T.int64(16):ax1_0_4 * T.int64(16) + T.int64(16), T.int64(0):T.int64(16)],
lv11_reindex_shared_dyn_wmma_matrix_b[T.int64(0), ax2_0_4 * T.int64(16):ax2_0_4
* T.int64(16) + T.int64(16), T.int64(0):T.int64(16)])
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16):ax1_0_3 * T.int64(32) + ax1_0_4
* T.int64(16) + T.int64(16), ax2_0_3 * T.int64(32) + ax2_0_4 *
T.int64(16):ax2_0_3 * T.int64(32) + ax2_0_4 * T.int64(16) + T.int64(16)])
for ax1_1,
ax2_1, ax3_1 in T.grid(T.int64(16), T.int64(16), T.int64(16)):
with
T.block("NT_matmul"):
T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16) + ax1_1, ax2_0_3 * T.int64(32) +
ax2_0_4 * T.int64(16) + ax2_1],
lv55_reindex_pad_shared_dyn_wmma_matrix_a[T.int64(0), ax1_0_4 * T.int64(16) +
ax1_1, ax3_1], lv11_reindex_shared_dyn_wmma_matrix_b[T.int64(0), ax2_0_4 *
T.int64(16) + ax2_1, ax3_1])
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16) + ax1_1, ax2_0_3 * T.int64(32) +
ax2_0_4 * T.int64(16) + ax2_1])
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16) + ax1_1, ax2_0_3 * T.int64(32) +
ax2_0_4 * T.int64(16) + ax2_1] =
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16) + ax1_1, ax2_0_3 * T.int64(32) +
ax2_0_4 * T.int64(16) + ax2_1] +
lv55_reindex_pad_shared_dyn_wmma_matrix_a[T.int64(0), ax1_0_4 * T.int64(16) +
ax1_1, ax3_1] * lv11_reindex_shared_dyn_wmma_matrix_b[T.int64(0), ax2_0_4 *
T.int64(16) + ax2_1, ax3_1]
for ax0_0, ax1_0, ax0_1, ax1_1 in
T.grid(T.int64(4), T.int64(4), T.int64(16), T.int64(16)):
with
T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn_wmma.accumulator"):
T.where(blockIdx_x * T.int64(2) +
threadIdx_y % T.int64(2) - (n + T.int64(63)) // T.int64(64) < T.int64(0))
T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax0_0 * T.int64(16) + ax0_1, ax1_0 * T.int64(16) + ax1_1])
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn[T.int64(0),
threadIdx_y % T.int64(2) * T.int64(64) + ax0_0 * T.int64(16) + ax0_1,
threadIdx_y // T.int64(2) * T.int64(64) + ax1_0 * T.int64(16) + ax1_1])
var_NT_matmul_intermediate_reindex_pad_shared_dyn[T.int64(0), threadIdx_y %
T.int64(2) * T.int64(64) + ax0_0 * T.int64(16) + ax0_1, threadIdx_y //
T.int64(2) * T.int64(64) + ax1_0 * T.int64(16) + ax1_1] =
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
ax0_0 * T.int64(16) + ax0_1, ax1_0 * T.int64(16) + ax1_1]
for ax0_ax1_fused_0 in range(T.int64(32)):
for ax0_ax1_fused_2 in
T.vectorized(T.int64(4)):
with
T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn"):
T.where(blockIdx_x * T.int64(2)
+ threadIdx_y % T.int64(2) - (n + T.int64(63)) // T.int64(64) < T.int64(0))
T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn[T.int64(0),
threadIdx_y % T.int64(2) * T.int64(64) + ax0_ax1_fused_0 * T.int64(2) +
threadIdx_x // T.int64(16), threadIdx_y // T.int64(2) * T.int64(64) +
threadIdx_x % T.int64(16) * T.int64(4) + ax0_ax1_fused_2], lv12[blockIdx_y *
T.int64(128) + threadIdx_y // T.int64(2) * T.int64(64) + threadIdx_x %
T.int64(16) * T.int64(4) + ax0_ax1_fused_2])
T.writes(var_T_add_intermediate[T.int64(0), blockIdx_x * T.int64(128) +
threadIdx_y % T.int64(2) * T.int64(64) + ax0_ax1_fused_0 * T.int64(2) +
threadIdx_x // T.int64(16), blockIdx_y * T.int64(128) + threadIdx_y //
T.int64(2) * T.int64(64) + threadIdx_x % T.int64(16) * T.int64(4) +
ax0_ax1_fused_2])
T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]})
if blockIdx_x * T.int64(128) +
threadIdx_y % T.int64(2) * T.int64(64) + ax0_ax1_fused_0 * T.int64(2) +
threadIdx_x // T.int64(16) < n:
var_T_add_intermediate[T.int64(0), blockIdx_x * T.int64(128) + threadIdx_y %
T.int64(2) * T.int64(64) + ax0_ax1_fused_0 * T.int64(2) + threadIdx_x //
T.int64(16), blockIdx_y * T.int64(128) + threadIdx_y // T.int64(2) *
T.int64(64) + threadIdx_x % T.int64(16) * T.int64(4) + ax0_ax1_fused_2] =
var_NT_matmul_intermediate_reindex_pad_shared_dyn[T.int64(0), threadIdx_y %
T.int64(2) * T.int64(64) + ax0_ax1_fused_0 * T.int64(2) + threadIdx_x //
T.int64(16), threadIdx_y // T.int64(2) * T.int64(64) + threadIdx_x %
T.int64(16) * T.int64(4) + ax0_ax1_fused_2] + lv12[blockIdx_y * T.int64(128) +
threadIdx_y // T.int64(2) * T.int64(64) + threadIdx_x % T.int64(16) *
T.int64(4) + ax0_ax1_fused_2]
```
##########
src/tir/transforms/inject_ptx_async_copy.cc:
##########
@@ -79,8 +80,10 @@ class PTXAsyncCopyInjector : public StmtMutator {
if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
- Array<PrimExpr> args = {store->buffer->data, tir::Mul(dst_offset,
PrimExpr(index_factor)),
Review Comment:
We can use helper function `TVM_DLL PrimExpr mul(PrimExpr a, PrimExpr b,
Span span = Span()` to do automatical integer type upcasting.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]