wrongtest opened a new pull request #9527:
URL: https://github.com/apache/tvm/pull/9527


   Hi there~ This PR is an enforcement for `compute_at` and 
`reverse_compute_at` primitives. Binding block into loops may create some 
non-trivial iter bounds. Complex iter bound is neither human-kind friendly nor 
compatible with backend passes targeting at bounds and conditions (eg, loop 
partition). So the PR try to distinguish some of complex bounds and use block 
predicates to make the ir structure simpler.
   
   A working example is as below, we want to create spatial tiles and read each 
tiled data from cache, thus the schedule operation is `compute_at` cache_read 
block into tiled loops.
   ```python
   @T.prim_func
   def tiled_pooling_read_cache(a: T.handle, b: T.handle) -> None:
       X = T.match_buffer(a, [224, 224], dtype="float32")
       Y = T.match_buffer(b, [224, 224], dtype="float32")
       cache = T.alloc_buffer([224, 224], dtype="float32")
       for hh, ww in T.grid(224, 224):
           with T.block("cache"):
               h, w = T.axis.remap("SS", [hh, ww])
               T.reads([X[h, w]])
               T.writes([cache[h, w]])
               cache[h, w] = X[h, w]
       for hh_0, ww_0, hh_1, ww_1, khh, kww in T.grid(28, 28, 8, 8, 3, 3):
           with T.block("compute"):
               h = T.axis.spatial(224, hh_0 * 8 + hh_1)
               w = T.axis.spatial(224, ww_0 * 8 + ww_1)
               kh, kw = T.axis.remap("RR", [khh, kww])
               T.reads([Y[h, w], cache[h + kh - 1, w + kw - 1]])
               T.writes([Y[h, w]])
               with T.init():
                   Y[h, w] = 0.0
               Y[h, w] = T.max(Y[h, w], T.if_then_else(
                   T.likely(1 <= h + kh, dtype="bool") and \
                   T.likely(h + kh < 225, dtype="bool") and \
                   T.likely(1 <= w + kw, dtype="bool") and \
                   T.likely(w + kw < 225, dtype="bool"),
                   cache[h + kh - 1, w + kw - 1], 0.0, dtype="float32"))
   ```
   Main stream code will produce
   ```python
   @T.prim_func
   def func(a: T.handle, b: T.handle) -> None:
       X = T.match_buffer(a, [224, 224], dtype="float32")
       Y = T.match_buffer(b, [224, 224], dtype="float32")
       # body
       # with T.block("root")
       cache = T.alloc_buffer([224, 224], dtype="float32")
       for hh_0, ww_0 in T.grid(28, 28):
           for ax0 in T.serial(0, T.min(hh_0 * 8 + 8, 223) + 1 - T.max(hh_0 * 8 
- 1, 0)):
               for ax1 in T.serial(0, T.min(ww_0 * 8 + 8, 223) + 1 - T.max(ww_0 
* 8 - 1, 0)):
                   with T.block("cache"):
                       h = T.axis.spatial(224, T.max(hh_0 * 8 - 1, 0) + ax0)
                       w = T.axis.spatial(224, T.max(ww_0 * 8 - 1, 0) + ax1)
                       T.reads([X[h, w]])
                       T.writes([cache[h, w]])
                       cache[h, w] = X[h, w]
           for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
               with T.block("compute"):
                   ...
   ```
   The PR will produce
   ```python
   def tiled_pooling_read_cache_after_compute_at(a: T.handle, b: T.handle) -> 
None:
       X = T.match_buffer(a, [224, 224], dtype="float32")
       Y = T.match_buffer(b, [224, 224], dtype="float32")
       cache = T.alloc_buffer([224, 224], dtype="float32")
       for hh_0, ww_0 in T.grid(28, 28):
           for ax0, ax1 in T.grid(10, 10):
               with T.block("cache"):
                   h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0)
                   w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1)
                   T.where(1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 
<= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225)
                   T.reads([X[h, w]])
                   T.writes([cache[h, w]])
                   cache[h, w] = X[h, w]
           for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
               with T.block("compute"):
                   ...
   ```
   
   The modification is to delay the intersection of intset deduced from 
required uses and intset enforced by buffer shape / original iter bound. 
Instead of direct intset intersection (can create much complex expr of 
min/max), A `BlockVarDomainInfo` class is added to maintain above two intsets 
named as `dom` and `bound`. Finally the implementation can choose with some 
heuristic:
   1. use (`dom` ^ `bound`) as iter domain if it is simple enough
   2. use `dom` as iter domain and add block predicate for `bound`
   
   The PR also add minimal support to analyze floordiv/floormod in 
provide-required region mapping.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to