DzAvril commented on pull request #10652:
URL: https://github.com/apache/tvm/pull/10652#issuecomment-1076272361


   After fixing the two bugs above, double buffer works in the final Cuda code, 
but it causes precision dropping.
   Quote description in [[PASS] InjectDoubleBuffer 
#405](https://github.com/apache/tvm/pull/405). Double buffer changes source 
code to target code shown below.
   Source
   ```C++
   for (i, 0, 100) {
     allocate B[float32 * 4]
     for (i, 0, 4) {
       B[i] = A[((i*4) + i)]
     }
     for (i, 0, 4) {
       A[i] = (B[i] + 1.000000f)
     }
   }
   ```
   Target 
   ```C++
   allocate B[float32 * 2 * 4]
   for (i, 0, 4) {
     B[i] = A[i]
   }
   for (i, 0, 99) { 
     // prefetch next iteration
     for (i, 0, 4) {
       B[((((i + 1) % 2)*4) + i)] = A[(((i*4) + i) + 4)]
     }
     for (i, 0, 4) {
       A[i] = (B[(((i % 2)*4) + i)] + 1.000000f)
     }
   }
   for (i, 0, 4) {
     A[i] = (B[(i + 4)] + 1.000000f)
   }
   ```
   In the target code, the size of B is doubled. In the second for loop, first 
read data into the last half part of B and then process the first half part of 
B. So computation can hide the latency of reading global memory.
   As described in the previous comment, double buffer in tensor core conv2d 
template brings a call node `tir.tvm_access_ptr`, this function reads data from 
doubled buffer `im2col_reshape.shared` and pass data to another function for 
processing. Part of lowered tir shown as below (PS. the code block below is not 
generated by the test script attached above, just for easier explanation):
   ```C++
   for (k.outer.outer.outer: int32, 0, 2) {
       if ((k.outer.outer.outer + 1) < 3) {
       attr [im2col_reshape.shared] "double_buffer_write" = 1;
       for (ax0.ax1.outer.fused.outer.outer.outer_1: int32, 0, 4) {
           attr [IterVar(threadIdx.z, (nullptr), "ThreadIndex", "threadIdx.z")] 
"thread_extent" = 1;
           attr [IterVar(threadIdx.y, (nullptr), "ThreadIndex", "threadIdx.y")] 
"thread_extent" = 2;
           attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] 
"thread_extent" = 32;
           im2col_reshape.shared[(broadcast((floormod((k.outer.outer.outer + 
1), 2)*2560), 8) + ramp(((((ax0.ax1.outer.fused.outer.outer.outer_1*640) + 
(threadIdx.y*320)) + (floordiv(threadIdx.x, 4)*40)) + (floormod(threadIdx.x, 
4)*8)), 1, 8))] = (int8x8*)placeholder_7[ramp((((((((blockIdx.x*12288) + 
(ax0.outer.outer*6144)) + (ax0.ax1.outer.fused.outer.outer.outer_1*1536)) + 
(threadIdx.y*768)) + (floordiv(threadIdx.x, 4)*96)) + ((k.outer.outer.outer + 
1)*32)) + (floormod(threadIdx.x, 4)*8)), 1, 8)]
       }
       }
       if ((k.outer.outer.outer + 1) < 3) {
       attr [placeholder_reshape.shared] "double_buffer_write" = 1;
       attr [IterVar(threadIdx.z, (nullptr), "ThreadIndex", "threadIdx.z")] 
"thread_extent" = 1;
       attr [IterVar(threadIdx.y, (nullptr), "ThreadIndex", "threadIdx.y")] 
"thread_extent" = 2;
       attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] 
"thread_extent" = 32;
       if (((threadIdx.y*8) + floordiv(threadIdx.x, 4)) < 8) {
           if (((threadIdx.y*32) + threadIdx.x) < 32) {
           if (threadIdx.y < 1) {
               
placeholder_reshape.shared[(broadcast((floormod((k.outer.outer.outer + 1), 
2)*320), 8) + ramp((((threadIdx.y*320) + (floordiv(threadIdx.x, 4)*40)) + 
(floormod(threadIdx.x, 4)*8)), 1, 8))] = 
(int8x8*)placeholder_8[ramp((((((threadIdx.y*768) + (blockIdx.y*768)) + 
(floordiv(threadIdx.x, 4)*96)) + ((k.outer.outer.outer + 1)*32)) + 
(floormod(threadIdx.x, 4)*8)), 1, 8)]
           }
           }
       }
       }
       for (k.outer.inner: int32, 0, 2) {
       allocate(im2col_reshape.shared.wmma.matrix_a: Pointer(wmma.matrix_a 
int8), int8, [64, 16]), storage_scope = wmma.matrix_a {
           for (ax0.outer: int32, 0, 2) {
           @tir.tvm_load_matrix_sync(im2col_reshape.shared.wmma.matrix_a, 32, 
8, 16, ax0.outer, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), 
im2col_reshape.shared, ((ax0.outer*1280) + (k.outer.inner*16)), 1280, 1, 
dtype=handle), 40, "row_major", dtype=handle)
           }
           allocate(placeholder_reshape.shared.wmma.matrix_b: 
Pointer(wmma.matrix_b int8), int8, [8, 16]), storage_scope = wmma.matrix_b {
           @tir.tvm_load_matrix_sync(placeholder_reshape.shared.wmma.matrix_b, 
32, 8, 16, 0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), 
placeholder_reshape.shared, (k.outer.inner*16), 320, 1, dtype=handle), 40, 
"col_major", dtype=handle)
           for (i.c.outer: int32, 0, 2) {
               @tir.tvm_mma_sync(implicit_gemm_conv.wmma.accumulator, 
i.c.outer, im2col_reshape.shared.wmma.matrix_a, i.c.outer, 
placeholder_reshape.shared.wmma.matrix_b, 0, 
implicit_gemm_conv.wmma.accumulator, i.c.outer, dtype=handle)
           }
           }
       }
       }
   }
   ```
   In the first iterate in the loop `for (k.outer.outer.outer: int32, 0, 2) `, 
we load data from global memory to the last half part of 
`im2col_reshape.shared`, and we process first half part data of 
`im2col_reshape.shared`. In the second iterate we load the first half part of 
`im2col_reshape.shared`, but we still process the first half part data of 
`im2col_reshape.shared`.
   > I guess the author expects double buffer just in load node or store node, 
so double buffer in call node is not in his/her expectation.
   
   As I guessed in the previous comment, the author didn't expect double buffer 
as a parameter of a call node. So the solution is processing double buffer in 
call node.
   ```C++
   // inject_double_buffer:DoubleBufferInjector
     PrimExpr VisitExpr_(const CallNode* op) final {
       if (op->op.same_as(builtin::tvm_access_ptr())) {
         const VarNode* buf = op->args[1].as<VarNode>();
         auto it = dbuffer_info_.find(buf);
         if (it != dbuffer_info_.end()) {
           const StorageEntry& e = it->second;
           ICHECK(e.stride.defined());
           ICHECK(e.switch_read_var.defined());
           Array<PrimExpr> args;
           // dtype
           args.push_back(op->args[0]);
           // data
           args.push_back(op->args[1]);
           // offset
           args.push_back(e.switch_read_var * e.stride + op->args[2]);
           // extent
           args.push_back(op->args[3]);
           // rw_mask
           args.push_back(op->args[4]);
           return Call(op->dtype, op->op, args);
         } else {
           return GetRef<PrimExpr>(op);
         }
       } else {
         return StmtExprMutator::VisitExpr_(op);
       }
     }
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to