yzh119 commented on PR #14161:
URL: https://github.com/apache/tvm/pull/14161#issuecomment-1450189921

   Some additional context (credit to discussion with @andy-yang-1 ):
   
   In some cases, `reindex_cache_read/write` has the same effect as first 
applying `cache_read` and then applying `transform_layout`, however, that's not 
always the case especially when the buffer region to cache read is non-trivial.
   
   For example, suppose we want to cache read buffer `A` in block `B` in the 
example above, with indices `vi // 4, vj // 4, vi % 4, vj % 4`, we can use 
`reindex_cache_read("B", 0, "shared", lambda i, j: (i // 4, j // 4, i % 4, j % 
4))` which would transform the program to:
   ```python
   @T.prim_func
   def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128), 
"float32")):
       # with T.block("root"):
       B_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
       for i, j in T.grid(128, 128):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(A[vi + 1, vj + 1])
               T.writes(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
               B_shared[vi // 4, vj // 4, vi % 4, vj % 4] = A[vi + 1, vj + 1] * 
T.float32(2)
       for i, j in T.grid(128, 128):
           with T.block("B_shared"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
               T.writes(B[vi, vj])
               B[vi, vj] = B_shared[vi // 4, vj // 4, vi % 4, vj % 4]
   ```
   
   while applying the `cache_read` + `transform_layout` would trigger the 
following error:
   ```python
   >>> sch.cache_read("B", 0, "shared")
   >>> sch.transform_layout("B", ("read", 0), lambda i, j: (i // 4, j //4, i % 
4, j % 4))
   ScheduleError: An error occurred in the schedule primitive 
'transform_layout'.
   The IR with diagnostic is:
   # from tvm.script import ir as I
   # from tvm.script import tir as T
   
   @I.ir_module
   class Module:
       @T.prim_func
       def main(a: T.handle, c: T.handle):
           A = T.match_buffer(a, (129, 129))
           B = T.match_buffer(c, (128, 128))
           with T.block("root"):
               T.reads()
               T.writes()
               A_shared = T.alloc_buffer((129, 129), scope="shared")
               for ax0 in range(129):
                   for ax1 in range(129):
                       with T.block("A_shared"):
                           v0 = T.axis.spatial(129, ax0)
                           v1 = T.axis.spatial(129, ax1)
                           T.reads(A[v0, v1])
                           T.writes(A_shared[v0, v1])
                           A_shared[v0, v1] = A[v0, v1]
               for i in range(128):
                   for j in range(128):
                       with T.block("B"):
                           vi = T.axis.spatial(128, i)
                           vj = T.axis.spatial(128, j)
                           T.reads(A_shared[vi + 1, vj + 1])
                           T.writes(B[vi, vj])
                           B[vi, vj] = A_shared[vi + 1, vj + 1] * T.float32(2)
   Error message: The transformation T.index_map(lambda i, j: (i // 4, j // 4, 
i % 4, j % 4)) applied on buffer A_shared of shape [129, 129] would result in 
shape [33, 33, 4, 4].  However, this would introduce padding wherever axis0 == 
32 and 1 <= axis2 or axis1 == 32 and 1 <= axis3 is true.
   ```
   The reason is because by default `cache_read`/`cache_write` would allocate 
buffers that covers `[vi + 1, vj + 1], vi \in [0, 128), vj \in [0, 129)` which 
is `(129, 129)` and do not meet the necessity of `transform_layout`.
   
   Another frequent use case is sparse access, suppose we want to cache read 
the following function:
   ```
   @T.prim_func
   def func(a: T.handle, b: T.handle, F: T.handle) -> None:
       A = T.match_buffer(a, (128, 128))
       F = T.match_buffer(f, (128,), "int32")
       B = T.match_buffer(b, (128, 128))
       for i, j in T.grid(128, 128):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               B[vi, vj] = A[F[vi], vj] * 2.0
   ```
   while `reindex_cache_read("B", 0, "shared", lambda i, j: (i // 4, j // 4, i 
% 4, j % 4))` can get you a correct transformation:
   ```python
   @T.prim_func
   def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), 
"float32"), F: T.Buffer((128,), "int32")):
       # with T.block("root"):
       B_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
       for i, j in T.grid(128, 128):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(A[F[vi], vj], F[vi])
               T.writes(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
               B_shared[vi // 4, vj // 4, vi % 4, vj % 4] = A[F[vi], vj] * 
T.float32(2)
       for i, j in T.grid(128, 128):
           with T.block("B_shared"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
               T.writes(B[vi, vj])
               B[vi, vj] = B_shared[vi // 4, vj // 4, vi % 4, vj % 4]
   ```
   
   `cache_read` + `transform_layout` would generate a wrong transformation:
   ```python
   >>> sch.cache_read("B", 0, "shared")
   >>> sch.transform_layout("B", ("read", 0), lambda i, j: (i // 4, j //4, i % 
4, j % 4))
   @T.prim_func
   def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), 
"float32"), F: T.Buffer((128,), "int32")):
       # with T.block("root"):
       A_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
       for ax0, ax1 in T.grid(128, 128):
           with T.block("A_shared"):
               v0, v1 = T.axis.remap("SS", [ax0, ax1])
               T.reads(A[v0, v1])
               T.writes(A_shared[v0 // 4, v1 // 4, v0 % 4, v1 % 4])
               A_shared[v0 // 4, v1 // 4, v0 % 4, v1 % 4] = A[v0, v1]
       for i, j in T.grid(128, 128):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(A_shared[F[vi] // 4, vj // 4, F[vi] % 4, vj % 4], F[vi])
               T.writes(B[vi, vj])
               B[vi, vj] = A_shared[F[vi] // 4, vj // 4, F[vi] % 4, vj % 4] * 
T.float32(2)
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to