yzh119 opened a new pull request, #14161:
URL: https://github.com/apache/tvm/pull/14161
# Motivation
Currently, we have schedule primitives `cache_read`/`cache_write`, which
allocate cache buffers and create cache stages copying data from the buffer
being accessed to the cache buffer. However, `cache_read`/`cache_write` do only
support customized indices. For the following block:
```python
@T.prim_func
def func(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(c, (129, 129))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi + 1, vj + 1] * 2.0
```
after `cache_read("B", 0, "share")`, we get:
```python
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((129, 129),
"float32")):
# with T.block("root"):
A_shared = T.alloc_buffer((128, 128), scope="shared")
for ax0, ax1 in T.grid(128, 128):
with T.block("A_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v0, v1])
T.writes(A_shared[v0, v1])
A_shared[v0, v1] = A[v0, v1]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_shared[vi + 1, vj + 1])
T.writes(B[vi, vj])
B[vi, vj] = A_shared[vi + 1, vj + 1] * T.float32(2)
```
where we access `A_shared` using the same indices(`vi + 1, vj + 1`) as
original block, which is not flexible especially we want to do some layout
transformation while copying data from original buffer to cache buffer (in MMA
tensorization, and in flashattention)
This PR propose a new interface that enables us to customize the indices to
access the cache buffer, which is expressive enough to describe transposing and
blocking.
# Proposed API
Below is the proposed interface of `reindex_cache_read`
(`reindex_cache_write` has similar interface):
```python
def reindex_cache_read(
self,
block: Union[BlockRV, str],
read_buffer_index: int,
storage_scope: str,
index_map: Union[IndexMap, Callable],
consumer_blocks: Optional[List[Union[BlockRV, str]]] = None,
) -> BlockRV:
...
```
Where `block`, `read_buffer_index`, `storage_scope` and `consumer_blocks`
have a similar meaning as in `cache_read`, there is another argument
`index_map` specifies what indices to use to access the cache buffer, in the
form of a index map that maps current block itervars to target indices. Suppose
the block has itervars `vi, vj` and the user wants to access the cache buffer
with customized indices `[vi // 16, vj // 16, vi % 16, vj % 16]`, user should
set the argument `index_map` to `lambda vi, vj: (vi // 16, vj // 16, vi % 16,
vj % 16)`.
# Example
By applying `reindex_cache_read("B", 0, lambda i, j: (j, i))` to `func`, we
get:
```python
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((129, 129),
"float32")):
# with T.block("root"):
A_shared = T.alloc_buffer((128, 128), scope="shared")
for i, j in T.grid(128, 128):
with T.block("A_shared"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi + 1, vj + 1])
T.writes(A_shared[vj, vi])
A_shared[vj, vi] = A[vi + 1, vj + 1]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_shared[vj, vi])
T.writes(B[vi, vj])
B[vi, vj] = A_shared[vj, vi] * T.float32(2)
```
# Notes
Unlike `cache_read`/`cache_write` which allows `cache_read` a rectangle
region, we only allows `reindex_cache_read` a single point, but it's enough to
cover most use cases.
The cache stage block follows the original order of loops and block itervars
in the block. If a block itervar does not appear in the buffer access region,
it and its corresponding loop variables will be omitted. User can then use
`transform_block_layout` primitive to reorder the block itervars and
surrounding loops of the cache read/write block.
# Relations to Existing Schedule Primitives
- Relation with `reindex`
- `reindex` only supports the special case of
`reindex_cache_read`/`reindex_cache_write`, `index_map` is the identity map.
- Relation with `transform_layout`
- `transform_layout` also accepts an index map as input, but is only
applicable to input buffers (`reindex_cache_read/write` can apply to allocated
intermediate buffers produced by other blocks), and does not have a
`storage_scope` field.
- Relation with `cache_read/wite`
- `cache_read`/`cache_write` do not support customized indices.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]