Lunderberg commented on code in PR #15986:
URL: https://github.com/apache/tvm/pull/15986#discussion_r1373212442
##########
tests/python/unittest/test_tir_transform_lower_opaque_block.py:
##########
@@ -250,6 +250,34 @@ def transformed_strided_buffer_func(
C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2)
[email protected]_func
+def compacted_symbolic_strided_buffer_func(a: T.handle) -> None:
+ n = T.int64()
+ A = T.match_buffer(a, (1, n, 10240), "float32")
+ for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
+ with T.block(""):
+ T.reads(A[0, i * 128 + j * 32:i * 128 + j * 32 + 96, k * 64:k * 64
+ 64])
+ A_pad_shared_dyn = T.alloc_buffer((1, T.min((n + 63) // 64 * 64,
96), 64), "float32", strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1),
scope="shared.dyn")
+ for ax0, ax1 in T.grid(96, 64):
+ with T.block("A_pad_shared.dyn"):
+ T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64)
+ T.reads(A[0, i * 128 + j * 32 + ax0, k * 64 + ax1])
+ T.writes(A_pad_shared_dyn[0, ax0, ax1])
+ A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else(i * 128 + j
* 32 + ax0 < n, A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], T.float16(0))
+
+
[email protected]_func
+def transformed_symbolic_strided_buffer_func(a: T.handle):
+ n = T.int64()
+ A = T.match_buffer(a, (1, n, 10240))
+ for i, j, k in T.grid(((n + T.int64(63)) // T.int64(64) * T.int64(4) +
T.int64(7)) // T.int64(8), 2, 160):
Review Comment:
Does the test case depend on `T.int64` datatypes? If not, this would be
much more readable by using `T.int32`. Because it is the default integer type
in TVMScript, it wouldn't require the explicit type conversions. (e.g. `(n +
63)` instead of `(n + T.int64(63))`.
##########
tests/python/unittest/test_tir_transform_lower_opaque_block.py:
##########
@@ -250,6 +250,34 @@ def transformed_strided_buffer_func(
C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2)
[email protected]_func
+def compacted_symbolic_strided_buffer_func(a: T.handle) -> None:
+ n = T.int64()
+ A = T.match_buffer(a, (1, n, 10240), "float32")
+ for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
Review Comment:
Unrelated, the presence of this expression is kind of odd to me. Assuming
this example came from a TIR printout, I would have expected `((n + 63) // 64 *
4 + 7) // 8` to be simplified to the equivalent `(n + 127) // 128`. The fact
that it didn't simplify may indicate that I should take a look at the
`CanonicalSimplifier`.
##########
tests/python/unittest/test_tir_transform_lower_opaque_block.py:
##########
@@ -250,6 +250,34 @@ def transformed_strided_buffer_func(
C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2)
[email protected]_func
+def compacted_symbolic_strided_buffer_func(a: T.handle) -> None:
+ n = T.int64()
+ A = T.match_buffer(a, (1, n, 10240), "float32")
+ for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
+ with T.block(""):
+ T.reads(A[0, i * 128 + j * 32:i * 128 + j * 32 + 96, k * 64:k * 64
+ 64])
+ A_pad_shared_dyn = T.alloc_buffer((1, T.min((n + 63) // 64 * 64,
96), 64), "float32", strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1),
scope="shared.dyn")
+ for ax0, ax1 in T.grid(96, 64):
+ with T.block("A_pad_shared.dyn"):
+ T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64)
+ T.reads(A[0, i * 128 + j * 32 + ax0, k * 64 + ax1])
+ T.writes(A_pad_shared_dyn[0, ax0, ax1])
+ A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else(i * 128 + j
* 32 + ax0 < n, A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], T.float16(0))
+
+
[email protected]_func
+def transformed_symbolic_strided_buffer_func(a: T.handle):
+ n = T.int64()
+ A = T.match_buffer(a, (1, n, 10240))
+ for i, j, k in T.grid(((n + T.int64(63)) // T.int64(64) * T.int64(4) +
T.int64(7)) // T.int64(8), 2, 160):
+ A_pad_shared_dyn = T.allocate([1, T.min((n + T.int64(63)) //
T.int64(64) * T.int64(64), T.int64(96)), 72], "float32", "shared.dyn")
+ A_pad_shared_dyn_1 = T.decl_buffer((1, T.min((n + T.int64(63)) //
T.int64(64) * T.int64(64), T.int64(96)), 64), data=A_pad_shared_dyn,
strides=(T.int64(72) * T.min((n + T.int64(63)) // T.int64(64) * T.int64(64),
T.int64(96)), 72, 1), scope="shared.dyn")
Review Comment:
The expression `T.min((n + T.int64(63)) // T.int64(64) * T.int64(64)` occurs
frequently, and makes it difficult to read. Can this be pulled out into
`padded_size = T.meta_var(T.min((n + T.int64(63)) // T.int64(64) *
T.int64(64))`? The generated TIR will still contain the full expression, but
the test case can be easier to read.
##########
tests/python/unittest/test_tir_transform_lower_opaque_block.py:
##########
@@ -250,6 +250,34 @@ def transformed_strided_buffer_func(
C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2)
[email protected]_func
+def compacted_symbolic_strided_buffer_func(a: T.handle) -> None:
+ n = T.int64()
+ A = T.match_buffer(a, (1, n, 10240), "float32")
+ for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
+ with T.block(""):
+ T.reads(A[0, i * 128 + j * 32:i * 128 + j * 32 + 96, k * 64:k * 64
+ 64])
+ A_pad_shared_dyn = T.alloc_buffer((1, T.min((n + 63) // 64 * 64,
96), 64), "float32", strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1),
scope="shared.dyn")
+ for ax0, ax1 in T.grid(96, 64):
+ with T.block("A_pad_shared.dyn"):
+ T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64)
Review Comment:
This looks like the same `T.reads` and `T.writes` annotations as would be
automatically inferred from the block's body. Unless the test depends on a
specific override to use non-default read/write annotations, it should be
removed for readability.
##########
src/tir/transforms/inject_ptx_async_copy.cc:
##########
@@ -79,7 +80,7 @@ class PTXAsyncCopyInjector : public StmtMutator {
if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
- Array<PrimExpr> args = {store->buffer->data, tir::Mul(dst_offset,
PrimExpr(index_factor)),
+ Array<PrimExpr> args = {store->buffer->data, mul(dst_offset,
PrimExpr(index_factor)),
Review Comment:
Since your description mentions this as a separate bug, can it either be
split out into a separate PR, or (since it is a relatively small change), have
a test case added for it?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]