qsqqsqqsq-intellif opened a new pull request, #17423:
URL: https://github.com/apache/tvm/pull/17423
### Overview
This PR introduces a new TIR schedule primitive **annotate_buffer_access**
that allows explicit annotation of buffer access regions for both reads and
writes.
### Motivation
TVM currently does not support inferring the numerical range of
floating-point calculations. As a result, buffer access regions involving
floating-point calculations cannot be accurately inferred and default to the
full extent of the buffer. This new primitive addresses this limitation by
allowing manual specification of access regions.
### Usage scenarios
This primitive is particularly useful for operations where the default
buffer region inference may not capture the precise access patterns, such as in
resize operations. It overrides the automatically inferred region for the
specified buffer.
### Example
before:
```python
@T.prim_func
def before(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1,
1, 16, 16), "float16")):
for i0, i1, i2, i3 in T.grid(1, 1, 16, 16):
with T.block("resize"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
T.reads(x[v_i0, v_i1, 0:32, 0:32])
T.writes(resize[v_i0, v_i1, v_i2, v_i3])
resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16",
T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32",
T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) -
T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0),
T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5))
* T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31),
0)]))
```
Perform annotate_buffer_access:
```python
sch.annotate_buffer_access(block, 0, "read",
gen_new_ranges=lambda v_i0, v_i1, v_i2, v_i3: [
v_i0,
v_i1,
(v_i2 * 2 - 3, v_i2 * 2 + 3),
(v_i3 * 2 - 3, v_i3 * 2 + 3),
],
)
```
after:
```python
@T.prim_func
def after(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1,
1, 16, 16), "float16")):
for i0, i1, i2, i3 in T.grid(1, 1, 16, 16):
with T.block("resize"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
T.reads(x[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 + 3, v_i3 * 2 -
3:v_i3 * 2 + 3])
T.writes(resize[v_i0, v_i1, v_i2, v_i3])
T.block_attr({"explicit_read_region": [0]})
resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16",
T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32",
T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) -
T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0),
T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5))
* T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31),
0)]))
```
The primitive adds an annotation(`T.block_attr({"explicit_read_region":
[0]})`) to the block, indicating that an explicit region has been provided for
the buffer at the given index. This annotation is used in the
CompactBufferAllocation pass to respect the manually specified region instead
of relying on automatic inference.
### Note
Caution should be exercised when using this function, as incorrect
annotations may lead to incorrect code generation or runtime errors. It's
crucial to ensure that the specified region covers all actual reads or writes
performed by the block for the given buffer.
cc @Hzfengsy @junrushao
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]