yzh119 opened a new pull request, #14161:
URL: https://github.com/apache/tvm/pull/14161

   # Motivation
   Currently, we have schedule primitives `cache_read`/`cache_write`, which 
allocate cache buffers and create cache stages copying data from the buffer 
being accessed to the cache buffer. However, `cache_read`/`cache_write` do only 
support customized indices. For the following block:
   ```python
   @T.prim_func
   def func(a: T.handle, c: T.handle) -> None:
       A = T.match_buffer(a, (128, 128))
       B = T.match_buffer(c, (129, 129))
       for i, j in T.grid(128, 128):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               B[vi, vj] = A[vi + 1, vj + 1] * 2.0
   ```
   after `cache_read("B", 0, "share")`, we get:
   ```python
   @T.prim_func
   def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((129, 129), 
"float32")):
       # with T.block("root"):
       A_shared = T.alloc_buffer((128, 128), scope="shared")
       for ax0, ax1 in T.grid(128, 128):
           with T.block("A_shared"):
               v0, v1 = T.axis.remap("SS", [ax0, ax1])
               T.reads(A[v0, v1])
               T.writes(A_shared[v0, v1])
               A_shared[v0, v1] = A[v0, v1]
       for i, j in T.grid(128, 128):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(A_shared[vi + 1, vj + 1])
               T.writes(B[vi, vj])
               B[vi, vj] = A_shared[vi + 1, vj + 1] * T.float32(2)
   ```
   where we access `A_shared` using the same indices(`vi + 1, vj + 1`) as 
original block, which is not flexible especially we want to do some layout 
transformation while copying data from original buffer to cache buffer (in MMA 
tensorization, and in flashattention)
   
   This PR propose a new interface that enables us to customize the indices to 
access the cache buffer, which is expressive enough to describe transposing and 
blocking.
   
   # Proposed API
   
   Below is the proposed interface of `reindex_cache_read` 
(`reindex_cache_write` has similar interface):
   ```python
   def reindex_cache_read(
       self,
       block: Union[BlockRV, str],
       read_buffer_index: int,
       storage_scope: str,
       index_map: Union[IndexMap, Callable],
       consumer_blocks: Optional[List[Union[BlockRV, str]]] = None,
   ) -> BlockRV:
       ...
   ```
   Where `block`, `read_buffer_index`, `storage_scope` and `consumer_blocks` 
have a similar meaning as in `cache_read`, there is another argument 
`index_map` specifies what indices to use to access the cache buffer, in the 
form of a index map that maps current block itervars to target indices. Suppose 
the block has itervars `vi, vj` and the user wants to access the cache buffer 
with customized indices `[vi // 16, vj // 16, vi % 16, vj % 16]`, user should 
set the argument `index_map` to `lambda vi, vj: (vi // 16, vj // 16, vi % 16, 
vj % 16)`.
   
   # Example
   By applying `reindex_cache_read("B", 0, lambda i, j: (j, i))` to `func`, we 
get:
   ```python
   @T.prim_func
   def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((129, 129), 
"float32")):
       # with T.block("root"):
       A_shared = T.alloc_buffer((128, 128), scope="shared")
       for i, j in T.grid(128, 128):
           with T.block("A_shared"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(A[vi + 1, vj + 1])
               T.writes(A_shared[vj, vi])
               A_shared[vj, vi] = A[vi + 1, vj + 1]
       for i, j in T.grid(128, 128):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(A_shared[vj, vi])
               T.writes(B[vi, vj])
               B[vi, vj] = A_shared[vj, vi] * T.float32(2)
   ```
   
   # Notes
   Unlike `cache_read`/`cache_write` which allows `cache_read` a rectangle 
region, we only allows `reindex_cache_read` a single point, but it's enough to 
cover most use cases.
   
   The cache stage block follows the original order of loops and block itervars 
in the block. If a block itervar does not appear in the buffer access region, 
it and its corresponding loop variables will be omitted. User can then use 
`transform_block_layout` primitive to reorder the block itervars and 
surrounding loops of the cache read/write block.
   
   # Relations to Existing Schedule Primitives
   
   - Relation with `reindex`
       - `reindex` only supports the special case of 
`reindex_cache_read`/`reindex_cache_write`, `index_map` is the identity map.
   - Relation with `transform_layout`
       - `transform_layout` also accepts an index map as input, but is only 
applicable to input buffers (`reindex_cache_read/write` can apply to allocated 
intermediate buffers produced by other blocks), and does not have a 
`storage_scope` field.
   - Relation with `cache_read/wite`
       - `cache_read`/`cache_write` do not support customized indices.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to