Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163620046
Writing out some of my thoughts, to see if there's a way to express the
constraints while only using existing TIR features. The main goals would be as
follows.
1. Allow simplification of expressions based on the values present in the
padding.
2. Allow local simplifications to take advantage of non-local constraints,
without requiring a full end-to-end analysis.
3. Specify the non-local constraints in some deducible manner that doesn't
impose a runtime performance penalty.
Next, working through various options for how the constraints could be
stored. In the examples below, sketching out how these would apply to the
element-wise operation which starts as below.
```python
@T.prim_func
def func(A: T.Buffer[(14), "int32"], B: T.Buffer[14, "int32"]):
for i in T.serial(14):
B[i] = 2 * A[i]
```
1. Apply layout transforms on local caches. Here, the full lifetime of a
buffer is known. All TIR optimization are done prior to hoisting the cache and
layout transformation into the graph level.
- For read caches, pad value is whatever gets conditionally written to
the padding while generating it. In example below, `AC` could be recognized as
being padded.
```python
@T.prim_func
def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
AC = T.alloc_buffer([4, 4], "int32")
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
AC[io, ii] = A[4 * io + ii]
else:
AC[io, ii] = 0
for i in T.serial(14):
B[i] = 2 * AC[i // 4, i % 4]
```
- For write caches, pad value is whatever is in the padding after the
last write to the cache. In example below, `BC` could be recognized as being
padded.
```python
@T.prim_func
def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
BC = T.alloc_buffer([4, 4], "int32")
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
BC[io, ii] = 2 * A[4*io + ii]
else:
BC[io, ii] = 0
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[i] = BC[io, ii]
```
- Downside, either of the `else` statements could be eliminated as a
no-op, since they don't contribute to the output `B` value. After that
elimination, there wouldn't be any way to reconstruct the pad value.
2. When hoisting an allocation+transformation, write the pad value to the
buffer at the start of function from which it was hoisted. This way, the pad
value can still be used in local reasoning.
- No change needed in producers, since they would already write the pad
value to the buffer.
- For consumers, would be represented as writing `pad_value` into the
padding at the start of the function.
```python
@T.prim_func
def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
if 4 * io + ii >= 14:
AC[io, ii] = 0
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
```
- Downside, repeated unnecessary effort at the beginning of each
consumer. Avoiding it with this representation would require knowing that the
producer had written `pad_value` already, which is exactly the information
we're trying to avoid.
3. When hoisting an allocation+transformation, write the pad value to the
buffer at the start of function from which it was hoisted, and write
`T.undef()` at the end. This way, the pad value can still be used in local
reasoning, and no-op removal can remove the repeated writing when lowering.
- No change needed in producers, since they would already write the pad
value to the buffer.
- For consumers, would be like option 2, but with an additional write of
`T.undef()` at the end of the function. When lowering, the write of
`T.undef()` would allow the first write to be removed as a no-op because it is
overwritten. The `T.undef()` can then be removed as described in the RFC.
```python
@T.prim_func
def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
if 4 * io + ii >= 14:
AC[io, ii] = 0
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
for io, ii in T.grid(4, 4):
if 4 * io + ii >= 14:
AC[io, ii] = T.undef()
```
- Downside, no way to distinguish between "can assume the pad value is
zero" and "can overwrite the pad value at will". The writing of `T.undef()`
would allow any writes to the padding to be inserted as a no-op.
- Downside, wouldn't actually simplify out in cases where the pad value
is used. The first in a pair of repeated writes to the same location can only
be removed if there are no reads between the writes. After using the pad value
to eliminate `if 4 * io + ii < 14` from the compute, the dummy loop that writes
the padding could no longer be removed.
4. Use `AssertStmt` in a loop to declare known information about the buffers.
- No change needed in producers, since the pad value is already written
out.
- For consumers, would have an initial loop that asserts the pad value is
correct.
```python
@T.prim_func
def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
if 4 * io + ii >= 14:
assert AC[io, ii] == 0, "padding"
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
```
- Downside, assert statements have target-dependent handling. In
`CodeGenLLVM` and `CodeGenSPIRV`, they are treated as no-ops. In `CodeGenCPU`
and `CodeGenC`, they generate asserts. In `CodeGenCUDA`, they aren't handled
at all and would error out.
Could work around this with a lowering pass, but identifying these
conditions would require having a special string in the message, and packing
structured data into strings makes me wary.
5. Use `AssertStmt` with implicitly-defined variables to declare known
information about the buffers.
```python
@T.prim_func
def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
a = T.var("int32")
b = T.var("int32")
assert (
AC[a, b] == 0 or (4 * a + b < 14) or (a < 0) or (a >= 4) or (b <
0) or (b >= 4)
), "padding"
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
```
- Can apply to clamped texture memory, since the variables in the
assertion isn't restricted to the bounds.
- Would need to recognize specific pattern of `BufferLoad` being used to
define variables used in constraint.
- The implicitly-defined variables can be written in current TIR, but
variables would ensure that this isn't something that ever makes it into
generated code at runtime.
- Downside, implicitly-defined variables are something of a red flag.
6. Store constraints in the function attributes, either as a dictionary or
as a structured object.
```python
@T.prim_func
def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
T.func_attr(
"buffer_constraints",
[
{
"buffer": AC,
"predicate": lambda io, ii: 4 * io + ii < 14,
"pad_value": lambda io, ii: 0,
},
],
)
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
```
- Downside, requires transformations that change a buffer to be aware
that other structures will also need to be replaced.
- Downside, requires simplifications to either be passed the entire
`PrimFunc`, or to be explicitly passed the `"buffer_constraints"` list.
- Downside, would break expectations of `IRMutatorWithAnalyzer`. The
current entry point of any `Stmt` or `Expr` would need to have additional
information of the `"buffer_constraints"`.
7. Store constraints in the `Buffer` object, either as a dictionary or as a
structured object.
```python
@T.prim_func
def func(ac: T.handle, B: T.Buffer[14, "int32"]):
AC = T.match_buffer(
shape=(4, 4),
dtype="int32",
constraints=[T.BufferConstraints(predicate=lambda io, ii: 4 * io
+ ii < 14, pad_value=0)],
)
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
```
- Downside, introduces additional data structure in TIR.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]