abhikran-quic opened a new issue, #14859:
URL: https://github.com/apache/tvm/issues/14859
When `compute_at` is used before `transform_layout`, zero initialization of
input ddr block happens after copying data from input buffer which leads to
incorrect output. This happens because `EpiloguePlan` is selected instead of
`ReplacementPlan` in `transform_layout`.
### Expected behavior
Input should be initialized to 0 before being used
### Actual behavior
Input is initialized to 0 at wrong place.
### Environment
OS: Ubuntu 18.04
TVM: 0.13
### Steps to reproduce
Example test case
```
import tvm
from tvm import te
from tvm.script import tir as T, ir as I
@T.prim_func
def matmul(A: T.Buffer((128, 136), "float32"), B: T.Buffer((136, 140),
"float32"), C: T.Buffer((128, 140), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
for h, w, r in T.grid(128, 140, 136):
with T.block("C"):
v_h, v_w, v_r = T.axis.remap("SSR", [h, w, r])
T.reads(A[v_h, v_r], B[v_r, v_w])
T.writes(C[v_h, v_w])
with T.init():
C[v_h, v_w] = T.float32(0)
C[v_h, v_w] = C[v_h, v_w] + A[v_h, v_r] * B[v_r, v_w]
func = matmul
mod = tvm.IRModule.from_expr(func)
sch = tvm.tir.Schedule(mod)
def transpose_fn(h, w):
return [h//32, w//32, h%32, w%32]
block = sch.get_block("C")
read_block = sch.cache_read(block, 0, "global.ddr")
sch.compute_at(read_block, sch.get_loops(block)[1])
sch.transform_layout(block, ("read",0), transpose_fn, pad_value=0.0)
print(tvm.lower(sch.mod))
```
The produces output IR
```
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 136), "float32"), B: T.Buffer((136, 140),
"float32"), C: T.Buffer((128, 140), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A_global_ddr = T.allocate([20480], "float32", "global.ddr")
A_global_ddr_1 = T.Buffer((20480,), data=A_global_ddr,
scope="global.ddr")
for h, w in T.grid(128, 140):
for ax0 in range(136):
A_1 = T.Buffer((17408,), data=A.data)
A_global_ddr_1[h // 32 * 5120 + ax0 // 32 * 1024 + h % 32 *
32 + ax0 % 32] = A_1[h * 136 + ax0]
for r in range(136):
cse_var_1: T.int32 = h * 140 + w
C_1 = T.Buffer((17920,), data=C.data)
if r == 0:
C_1[cse_var_1] = T.float32(0)
B_1 = T.Buffer((19040,), data=B.data)
C_1[cse_var_1] = C_1[cse_var_1] + A_global_ddr_1[h // 32 *
5120 + r // 32 * 1024 + h % 32 * 32 + r % 32] * B_1[r * 140 + w]
for axis0, axis1, axis2, axis3 in T.grid(4, 5, 32, 32):
if axis1 == 4 and 8 <= axis3:
A_global_ddr_1[axis0 * 5120 + axis1 * 1024 + axis2 * 32 +
axis3] = T.float32(0)
```
Here, the zero init of `A_global_ddr_1` is incorrect as it will yield
incorrect output if `A_global_ddr_1` is used down in the pipeline.
If we swap `compute_at` and `transform_layout` schedules in the test case,
then the output is correct because `ReplacementPlan` gets selected in
`transform_layout`.
### Triage
Please refer to the list of label tags
[here](https://github.com/apache/tvm/wiki/Issue-Triage-Labels) to find the
relevant tags and add them below in a bullet format (example below).
* needs-triage
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]