yzh119 commented on PR #14161:
URL: https://github.com/apache/tvm/pull/14161#issuecomment-1450189921
Some additional context (credit to discussion with @andy-yang-1 ):
In some cases, `reindex_cache_read/write` has the same effect as first
applying `cache_read` and then applying `transform_layout`, however, that's not
always the case especially when the buffer region to cache read is non-trivial.
For example, suppose we want to cache read buffer `A` in block `B` in the
example above, with indices `vi // 4, vj // 4, vi % 4, vj % 4`, we can use
`reindex_cache_read("B", 0, "shared", lambda i, j: (i // 4, j // 4, i % 4, j %
4))` which would transform the program to:
```python
@T.prim_func
def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128),
"float32")):
# with T.block("root"):
B_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi + 1, vj + 1])
T.writes(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
B_shared[vi // 4, vj // 4, vi % 4, vj % 4] = A[vi + 1, vj + 1] *
T.float32(2)
for i, j in T.grid(128, 128):
with T.block("B_shared"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
T.writes(B[vi, vj])
B[vi, vj] = B_shared[vi // 4, vj // 4, vi % 4, vj % 4]
```
while applying the `cache_read` + `transform_layout` would trigger the
following error:
```python
>>> sch.cache_read("B", 0, "shared")
>>> sch.transform_layout("B", ("read", 0), lambda i, j: (i // 4, j //4, i %
4, j % 4))
ScheduleError: An error occurred in the schedule primitive
'transform_layout'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(a: T.handle, c: T.handle):
A = T.match_buffer(a, (129, 129))
B = T.match_buffer(c, (128, 128))
with T.block("root"):
T.reads()
T.writes()
A_shared = T.alloc_buffer((129, 129), scope="shared")
for ax0 in range(129):
for ax1 in range(129):
with T.block("A_shared"):
v0 = T.axis.spatial(129, ax0)
v1 = T.axis.spatial(129, ax1)
T.reads(A[v0, v1])
T.writes(A_shared[v0, v1])
A_shared[v0, v1] = A[v0, v1]
for i in range(128):
for j in range(128):
with T.block("B"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
T.reads(A_shared[vi + 1, vj + 1])
T.writes(B[vi, vj])
B[vi, vj] = A_shared[vi + 1, vj + 1] * T.float32(2)
Error message: The transformation T.index_map(lambda i, j: (i // 4, j // 4,
i % 4, j % 4)) applied on buffer A_shared of shape [129, 129] would result in
shape [33, 33, 4, 4]. However, this would introduce padding wherever axis0 ==
32 and 1 <= axis2 or axis1 == 32 and 1 <= axis3 is true.
```
The reason is because by default `cache_read`/`cache_write` would allocate
buffers that covers `[vi + 1, vj + 1], vi \in [0, 128), vj \in [0, 129)` which
is `(129, 129)` and do not meet the necessity of `transform_layout`.
Another frequent use case is sparse access, suppose we want to cache read
the following function:
```
@T.prim_func
def func(a: T.handle, b: T.handle, F: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
F = T.match_buffer(f, (128,), "int32")
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[F[vi], vj] * 2.0
```
while `reindex_cache_read("B", 0, "shared", lambda i, j: (i // 4, j // 4, i
% 4, j % 4))` can get you a correct transformation:
```python
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128),
"float32"), F: T.Buffer((128,), "int32")):
# with T.block("root"):
B_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[F[vi], vj], F[vi])
T.writes(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
B_shared[vi // 4, vj // 4, vi % 4, vj % 4] = A[F[vi], vj] *
T.float32(2)
for i, j in T.grid(128, 128):
with T.block("B_shared"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
T.writes(B[vi, vj])
B[vi, vj] = B_shared[vi // 4, vj // 4, vi % 4, vj % 4]
```
`cache_read` + `transform_layout` would generate a wrong transformation:
```python
>>> sch.cache_read("B", 0, "shared")
>>> sch.transform_layout("B", ("read", 0), lambda i, j: (i // 4, j //4, i %
4, j % 4))
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128),
"float32"), F: T.Buffer((128,), "int32")):
# with T.block("root"):
A_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
for ax0, ax1 in T.grid(128, 128):
with T.block("A_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v0, v1])
T.writes(A_shared[v0 // 4, v1 // 4, v0 % 4, v1 % 4])
A_shared[v0 // 4, v1 // 4, v0 % 4, v1 % 4] = A[v0, v1]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_shared[F[vi] // 4, vj // 4, F[vi] % 4, vj % 4], F[vi])
T.writes(B[vi, vj])
B[vi, vj] = A_shared[F[vi] // 4, vj // 4, F[vi] % 4, vj % 4] *
T.float32(2)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]