DzAvril commented on pull request #10652: URL: https://github.com/apache/tvm/pull/10652#issuecomment-1076272361
After fixing the two bugs above, double buffer works in the final Cuda code, but it causes precision dropping. Quote description in [[PASS] InjectDoubleBuffer #405](https://github.com/apache/tvm/pull/405). Double buffer changes source code to target code shown below. Source ```C++ for (i, 0, 100) { allocate B[float32 * 4] for (i, 0, 4) { B[i] = A[((i*4) + i)] } for (i, 0, 4) { A[i] = (B[i] + 1.000000f) } } ``` Target ```C++ allocate B[float32 * 2 * 4] for (i, 0, 4) { B[i] = A[i] } for (i, 0, 99) { // prefetch next iteration for (i, 0, 4) { B[((((i + 1) % 2)*4) + i)] = A[(((i*4) + i) + 4)] } for (i, 0, 4) { A[i] = (B[(((i % 2)*4) + i)] + 1.000000f) } } for (i, 0, 4) { A[i] = (B[(i + 4)] + 1.000000f) } ``` In the target code, the size of B is doubled. In the second for loop, first read data into the last half part of B and then process the first half part of B. So computation can hide the latency of reading global memory. As described in the previous comment, double buffer in tensor core conv2d template brings a call node `tir.tvm_access_ptr`, this function reads data from doubled buffer `im2col_reshape.shared` and pass data to another function for processing. Part of lowered tir shown as below (PS. the code block below is not generated by the test script attached above, just for easier explanation): ```C++ for (k.outer.outer.outer: int32, 0, 2) { if ((k.outer.outer.outer + 1) < 3) { attr [im2col_reshape.shared] "double_buffer_write" = 1; for (ax0.ax1.outer.fused.outer.outer.outer_1: int32, 0, 4) { attr [IterVar(threadIdx.z, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 2; attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32; im2col_reshape.shared[(broadcast((floormod((k.outer.outer.outer + 1), 2)*2560), 8) + ramp(((((ax0.ax1.outer.fused.outer.outer.outer_1*640) + (threadIdx.y*320)) + (floordiv(threadIdx.x, 4)*40)) + (floormod(threadIdx.x, 4)*8)), 1, 8))] = (int8x8*)placeholder_7[ramp((((((((blockIdx.x*12288) + (ax0.outer.outer*6144)) + (ax0.ax1.outer.fused.outer.outer.outer_1*1536)) + (threadIdx.y*768)) + (floordiv(threadIdx.x, 4)*96)) + ((k.outer.outer.outer + 1)*32)) + (floormod(threadIdx.x, 4)*8)), 1, 8)] } } if ((k.outer.outer.outer + 1) < 3) { attr [placeholder_reshape.shared] "double_buffer_write" = 1; attr [IterVar(threadIdx.z, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 2; attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32; if (((threadIdx.y*8) + floordiv(threadIdx.x, 4)) < 8) { if (((threadIdx.y*32) + threadIdx.x) < 32) { if (threadIdx.y < 1) { placeholder_reshape.shared[(broadcast((floormod((k.outer.outer.outer + 1), 2)*320), 8) + ramp((((threadIdx.y*320) + (floordiv(threadIdx.x, 4)*40)) + (floormod(threadIdx.x, 4)*8)), 1, 8))] = (int8x8*)placeholder_8[ramp((((((threadIdx.y*768) + (blockIdx.y*768)) + (floordiv(threadIdx.x, 4)*96)) + ((k.outer.outer.outer + 1)*32)) + (floormod(threadIdx.x, 4)*8)), 1, 8)] } } } } for (k.outer.inner: int32, 0, 2) { allocate(im2col_reshape.shared.wmma.matrix_a: Pointer(wmma.matrix_a int8), int8, [64, 16]), storage_scope = wmma.matrix_a { for (ax0.outer: int32, 0, 2) { @tir.tvm_load_matrix_sync(im2col_reshape.shared.wmma.matrix_a, 32, 8, 16, ax0.outer, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), im2col_reshape.shared, ((ax0.outer*1280) + (k.outer.inner*16)), 1280, 1, dtype=handle), 40, "row_major", dtype=handle) } allocate(placeholder_reshape.shared.wmma.matrix_b: Pointer(wmma.matrix_b int8), int8, [8, 16]), storage_scope = wmma.matrix_b { @tir.tvm_load_matrix_sync(placeholder_reshape.shared.wmma.matrix_b, 32, 8, 16, 0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), placeholder_reshape.shared, (k.outer.inner*16), 320, 1, dtype=handle), 40, "col_major", dtype=handle) for (i.c.outer: int32, 0, 2) { @tir.tvm_mma_sync(implicit_gemm_conv.wmma.accumulator, i.c.outer, im2col_reshape.shared.wmma.matrix_a, i.c.outer, placeholder_reshape.shared.wmma.matrix_b, 0, implicit_gemm_conv.wmma.accumulator, i.c.outer, dtype=handle) } } } } } ``` In the first iterate in the loop `for (k.outer.outer.outer: int32, 0, 2) `, we load data from global memory to the last half part of `im2col_reshape.shared`, and we process first half part data of `im2col_reshape.shared`. In the second iterate we load the first half part of `im2col_reshape.shared`, but we still process the first half part data of `im2col_reshape.shared`. > I guess the author expects double buffer just in load node or store node, so double buffer in call node is not in his/her expectation. As I guessed in the previous comment, the author didn't expect double buffer as a parameter of a call node. So the solution is processing double buffer in call node. ```C++ // inject_double_buffer:DoubleBufferInjector PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { const VarNode* buf = op->args[1].as<VarNode>(); auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { const StorageEntry& e = it->second; ICHECK(e.stride.defined()); ICHECK(e.switch_read_var.defined()); Array<PrimExpr> args; // dtype args.push_back(op->args[0]); // data args.push_back(op->args[1]); // offset args.push_back(e.switch_read_var * e.stride + op->args[2]); // extent args.push_back(op->args[3]); // rw_mask args.push_back(op->args[4]); return Call(op->dtype, op->op, args); } else { return GetRef<PrimExpr>(op); } } else { return StmtExprMutator::VisitExpr_(op); } } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
