abhikran-quic opened a new issue, #14859:
URL: https://github.com/apache/tvm/issues/14859

   When `compute_at` is used before `transform_layout`, zero initialization of 
input ddr block happens after copying data from input buffer which leads to 
incorrect output. This happens because `EpiloguePlan` is selected instead of 
`ReplacementPlan` in `transform_layout`.
   
   ### Expected behavior
   
   Input should be initialized to 0 before being used
   
   ### Actual behavior
   
   Input is initialized to 0 at wrong place.
   
   ### Environment
   
   OS: Ubuntu 18.04
   TVM: 0.13
   
   ### Steps to reproduce
   
   Example test case
   
   ```
   import tvm
   from tvm import te
   from tvm.script import tir as T, ir as I
   
   @T.prim_func
   def matmul(A: T.Buffer((128, 136), "float32"), B: T.Buffer((136, 140), 
"float32"), C: T.Buffer((128, 140), "float32")):
       T.func_attr({"global_symbol": "main", "tir.noalias": True})
       # with T.block("root"):
       for h, w, r in T.grid(128, 140, 136):
           with T.block("C"):
               v_h, v_w, v_r = T.axis.remap("SSR", [h, w, r])
               T.reads(A[v_h, v_r], B[v_r, v_w])
               T.writes(C[v_h, v_w])
               with T.init():
                   C[v_h, v_w] = T.float32(0)
               C[v_h, v_w] = C[v_h, v_w] + A[v_h, v_r] * B[v_r, v_w]
   func = matmul
   mod = tvm.IRModule.from_expr(func)
   sch = tvm.tir.Schedule(mod)
   def transpose_fn(h, w):
       return [h//32, w//32, h%32, w%32]
   block = sch.get_block("C")
   read_block = sch.cache_read(block, 0, "global.ddr")
   sch.compute_at(read_block, sch.get_loops(block)[1])
   sch.transform_layout(block, ("read",0), transpose_fn, pad_value=0.0)
   print(tvm.lower(sch.mod))
   ```
   
   The produces output IR
   
   ```
   @I.ir_module
   class Module:
       @T.prim_func
       def main(A: T.Buffer((128, 136), "float32"), B: T.Buffer((136, 140), 
"float32"), C: T.Buffer((128, 140), "float32")):
           T.func_attr({"global_symbol": "main", "tir.noalias": True})
           A_global_ddr = T.allocate([20480], "float32", "global.ddr")
           A_global_ddr_1 = T.Buffer((20480,), data=A_global_ddr, 
scope="global.ddr")
           for h, w in T.grid(128, 140):
               for ax0 in range(136):
                   A_1 = T.Buffer((17408,), data=A.data)
                   A_global_ddr_1[h // 32 * 5120 + ax0 // 32 * 1024 + h % 32 * 
32 + ax0 % 32] = A_1[h * 136 + ax0]
               for r in range(136):
                   cse_var_1: T.int32 = h * 140 + w
                   C_1 = T.Buffer((17920,), data=C.data)
                   if r == 0:
                       C_1[cse_var_1] = T.float32(0)
                   B_1 = T.Buffer((19040,), data=B.data)
                   C_1[cse_var_1] = C_1[cse_var_1] + A_global_ddr_1[h // 32 * 
5120 + r // 32 * 1024 + h % 32 * 32 + r % 32] * B_1[r * 140 + w]
           for axis0, axis1, axis2, axis3 in T.grid(4, 5, 32, 32):
               if axis1 == 4 and 8 <= axis3:
                   A_global_ddr_1[axis0 * 5120 + axis1 * 1024 + axis2 * 32 + 
axis3] = T.float32(0)
   ```
   
   Here, the zero init of `A_global_ddr_1` is incorrect as it will yield 
incorrect output if `A_global_ddr_1` is used down in the pipeline. 
   
   If we swap `compute_at` and `transform_layout` schedules in the test case, 
then the output is correct because `ReplacementPlan` gets selected in 
`transform_layout`.
   
   ### Triage
   
   Please refer to the list of label tags 
[here](https://github.com/apache/tvm/wiki/Issue-Triage-Labels) to find the 
relevant tags and add them below in a bullet format (example below).
   
   * needs-triage
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to