wrongtest commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1152928725
Thanks for the all great discussions! It is so excited that we will have a
more powerful ability to handle all things like paddings and imperfect tiles.
Since our team rely on the code path of s-tir, we are extremely interested
in the story on s-tir. I would be very appreciated if we have some details on
s-tir padding. I would like to use a [127, 127, 127] matmul to depict my
questions :)
```python
@T.prim_func
def matmul(A: T.Buffer[(127, 127), "float32"], B: T.Buffer[(127, 127),
"float32"], C: T.Buffer[(127, 127), "float32"]):
for i, j, k in T.grid(127, 127, 127):
with T.block("compute"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] += A[vi, vk] * B[vk, vj]
```
In current s-tir state, we can construct padded loop and buffer using
existing primitives by "split and then fuse" trick:
```python
s = tvm.tir.Schedule(matmul)
blk = s.get_block("compute")
i, j, k = s.get_loops(blk)
s.fuse(*s.split(i, factors=[4, 32]))
s.fuse(*s.split(j, factors=[4, 32]))
s.fuse(*s.split(k, factors=[4, 32]))
s.transform_layout(blk, "A", lambda i,k: ((i // 32) * 32 + i % 32, (k // 32)
* 32 + k % 32))
s.transform_layout(blk, "B", lambda k,j: ((k // 32) * 32 + k % 32, (j // 32)
* 32 + j % 32))
s.transform_layout(blk, "C", lambda i,j: ((i // 32) * 32 + i % 32, (j // 32)
* 32 + j % 32))
```
We will get (if simplified)
```python
@T.prim_func
def func(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128),
"float32"], C: T.Buffer[(128, 128), "float32"]):
for i_0_i_1_fused, j_0_j_1_fused, k_0_k_1_fused in T.grid(128, 128, 128):
with T.block("compute"):
vi = T.axis.spatial(127, i_0_i_1_fused)
vj = T.axis.spatial(127, j_0_j_1_fused)
vk = T.axis.reduce(127, k_0_k_1_fused)
T.where(i_0_i_1_fused < 127 and j_0_j_1_fused < 127 and
k_0_k_1_fused < 127)
T.reads(A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
```
Then the only thing left is the condition for padding:
`T.where(i_0_i_1_fused < 127 and j_0_j_1_fused < 127 and k_0_k_1_fused < 127)`.
I believe we now get to the point on current RFC about over-computation and
branch tradeoff. And below are some my questions ~
1. What happened when change to `s.transform_layout(..., pad_value=0)`? (if
we want over-computations)
- (possible behavior 1) Insert padding filling code as a producer block
of `compute`.
- since the effect is immediate, maybe we do not need
`BufferConstraint` annotations afterwards?
- (possible behavior 2) Annotate buffers and let lowering passes to
handle.
- we may require `BufferConstraint` to direct lowering passes,
- (possible behavior 3) Pass `BufferConstraint` upwards into graph level
- thus assume the param buffer match the constraint, do not write edge
values.
2. For (1.2)(1.3), it seems encode the `BufferConstraint` into the buffer
object is not the only choice.
- For s-tir, fix me, at least for common cases the constraint could be
treat to be local wrt the transformed block. What if we encode the constraint
just into the block, as its memory access properties.
We found previously, block memory annotations `T.reads`, `T.writes`
(`BufferRegion`) have some limitations that they loss conditional access
informations. Maybe we can also combine `BufferConstraint` with `BufferRegion`?
- For graph level annotations, IIUC, it uses "Tensor" typed value
instead of "Buffer" conceptually. Maybe we still need another construction
instead of `Buffer` with `BufferConstraint` field?
We could also consider instantiate graph level transformation
explicitly. This is our solution currently:
https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807/4.
- Nevertheless, if finally we decide extent the buffer node structure,
hope we can have an explicit lifetime for the `BufferConstraint` in the TIR
lowering. Thus storage related passes afterwards do not bother, especially for
customized passes developed by vendors.
3. For the reduce axis padding, mentioned in
https://github.com/apache/tvm-rfcs/pull/77#discussion_r894899301
- In TIR level, since the schedule primitive should preserve the
semantic correctness, how we prove the `k` dimension padding should only be
zero? Especially when we do not know it is a "matmul" op generally. I think it
is important if we want to use padded `transform_layout` in auto-schedule
fashion applications.
cc @Lunderberg @tqchen @vinx13 @Hzfengsy
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]