Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1165713753
> Talking about “constraints”, it is also useful to talk about categories of
them, roughly we can divide them into three categories.
I like this breakdown, and agree. In this categorization, what I've been
calling "constraints" would be "assumptions". Double-checking in `builtin.h`,
it looks like we don't currently have a TIR equivalent of `__builtin_assume`.
For usage of assumptions, I think the key would be to insert an assumption
whenever the information that could otherwise prove it is hoisted out of the
PrimFunc. That would provide non-local information that could be used by the
PrimFunc to allow local simplifications.
> transformation of PrimFunc do not change the PrimFunc interface: this is
really important so we can transform a PrimFunc without worrying about how the
graph interacts with it(as the interface remains the same, we can lift out the
blocks earlier)
I don't think we can make this strong of a statement, as it would also
forbid fusing operators together or hoisting a stage out of a PrimFunc. In
both cases, the signature of the resulting PrimFunc may be different than it
was before. This shows up in the example, as the interface of `grow` is
different from the transformed `grow_packed`.
As a slightly less general statement, I would say that transformations of a
PrimFunc *in isolation* may not change the PrimFunc's interface. So an
optimization search to improve the performance of a single subgraph may not
change the layout of its own arguments, nor may it change assumptions of what
is present in the padding, as those would change its interface. However, a
graph-level transform would be allowed to fuse subgraphs, to hoist stages out
of a PrimFunc, to alter the layout of a PrimFunc's input, or to alter the
assumptions provided about the inputs. In general, a PrimFunc's interface
could only be changed when calls into the PrimFunc are also modified to remain
compatible.
Is there a better term than "scheduling primitive" to describe layout
transformations that impact input/output buffers? I think the difference is
between context-independent transformations that may be performed on a PrimFunc
without changing, as opposed to context-dependent transformations that may only
be performed as part of a graph-level transformation.
> Each function needs to have its own TIR analysis of how it flows things
back, for example, in the case of `addone`, we can safely flow PadMapping back,
changing `addone` to `addone_packed` by analyzing the TIR. If the addone is
elemwise exp however, we need to insert a select operator(because `exp(0)=1` )
the message to input becomes `PadMapping(constraint, pad_value=undef)`.
Would this handle cases where there are multiple different options for how
an operator could be implemented? Otherwise, I'm not sure how this would
handle cases where multiple different sets of layouts/constraints could be
inferred from different TIR-level schedules of the same operator. As examples,
the drop-down has 6 different implementations of `addone`, each of which would
allow different hoistable pad/crop operations.
<details>
<summary>Click to expand</summary>
<br>
```python
# Implementation 1, no preproc/postproc are present.
#
# No hoistable layout transformations. Could be fused with a layout
# transformation, but doesn't otherwise provide any constraints.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for i in T.serial(14):
with T.block("compute"):
B[i] = A[i] + 1
# Implementation 2, pad input/output, but never access the padding of
# either input or output.
#
# In back-propagation of constraints, the T.undef() that is cropped
# from BC could be narrowed to a known value provided from the
# successor. However, AC's padding is never written to, so could
# propagate T.undef() back to preceding function.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
if 4 * io + ii < 14:
AC[io, ii] = A[4 * io + ii]
for i in T.serial(14):
with T.block("compute"):
BC[i // 4, i % 4] = AC[i // 4, i % 4] + 1
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", T.undef()])
B[i] = BC[i // 4, i % 4]
# Implementation 3, pad input with known value, but never access
# padding of output.
#
# In back-propagation of constraints, the T.undef() that is cropped
# from BC could be narrowed to a known value provided from the
# successor. AC's padding is written to, so this would propagate
# `PadMapping(predicate, pad_value=0)` to the previous operator.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
for i in T.serial(14):
with T.block("compute"):
BC[i // 4, i % 4] = AC[i // 4, i % 4] + 1
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", T.undef()])
B[i] = BC[i // 4, i % 4]
# Implementation 4, pad input with arbitrary value, provide no
# guarantees in output.
#
# In back-propagation of constraints, the T.undef() that is cropped
# from BC could be narrowed to a known value provided from the
# successor. AC's padding is written to, so this would propagate
# `PadMapping(predicate, pad_value=BC_pad_value - 1)` to the
# previous operator.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii],
T.undef())
for io, ii in T.grid(4, 4):
with T.block("compute"):
BC[io, ii] = AC[io, ii] + 1
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", T.undef()])
B[i] = BC[i // 4, i % 4]
# Implementation 5, pad input with known value, analysis of TIR
# successfully propagates pad value through to provide assumption when
# cropping.
#
# In back-propagation of constraints, the output assumption is fixed.
# Unless the operator following addone has included the constraint 1
# as the required value in its padding, the crop/pad pair wouldn't be
# able to be removed. AC's padding is written to, and would propagate
# `PadMapping(predicate, pad_value=0)` to the previous operator.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
for io, ii in T.grid(4, 4):
with T.block("compute"):
BC[io, ii] = AC[io, ii] + 1
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", 1])
B[i] = BC[i // 4, i % 4]
# Implementation 6, pad input with known value, analysis of TIR can't
# successfully propagate pad value through to the output.
#
# In back-propagation of constraints, the output assumption is fixed.
# Since we don't provide an assumption of what will be returned, the
# graph-level pair of `crop(T.undef())` followed by `pad(x)` could
# only be canceled out if `x` is `T.undef()`. AC's padding is written
# to, and would propagate `PadMapping(predicate, pad_value=0)` to
# the previous operator.
@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
with T.block():
T.block_attr("preproc", "pad")
AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
for io, ii in T.grid(4, 4):
with T.block("compute"):
BC[io, ii] = AC[io, ii] + 1
for i in T.serial(14):
with T.block():
T.block_attr("postproc", ["crop", T.undef()])
B[i] = BC[i // 4, i % 4]
```
</details>
I think the main change is that the temporary stages with annotation will
need to allow multiple possibilities, rather than a single definitive layout.
These options could then be searched at the graph-level to decide on the
appropriate layout. After that is decided, the tempoerary stage could be
selected and the transformations hoisted.
> But extra amount of care is needed when we attempt to move
`crop_with_pad_assume`, as it really depends on the value property of its input.
Completely agreed. I think this is true at both the TIR and graph levels,
that allowing assumptions means ensuring that the assumption isn't changed
after it is used for simplifications. The advantage of writing the assumptions
at the graph level is that specific pairs of functions (such as
`crop_with_pad_assume(pad_value)` followed by `pad_with_value(pad_value)`) can
be identified as no-ops, without needing a full proof of it.
I think the main rules that would need to be followed when handling
assumptions would be the following three.
1. An assumption may be inserted wherever it can be statically proven, or
asserted by a user about user-supplied input.
2. An assumption may be removed only if it can be statically proven.
Assertions from a user about user-supplied input may never be removed, as they
may have already been used to perform irreversible simplifications.
3. Static provers must reset all assumptions about a variable when
`T.undef()` is assigned to it, even though these assignments are removed during
lowering.
The restriction against changing a PrimFunc's interface fall out directly
from rule #1. Since an assumption that restrict values of an input cannot be
proven, these assumptions may not be modified.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]