yjk21 opened a new issue #4540: TVM CUDA schedule: Incorrect shared memory size loading a padded tile URL: https://github.com/apache/incubator-tvm/issues/4540 First of all, TVM is amazing work! It's very enjoyable working with it. I am trying to get the following to work: - I have a 512x512 input - I pad it with 3 pixels in each direction, i.e. the padded input is 518x518 - I compute a stencil on this The computation: ```python import tvm in_size = 512 pad = 3 pad_size = in_size + 2 * pad A = tvm.placeholder((in_size, in_size), name='A') # padding from the conv example Apad = tvm.compute( (pad_size, pad_size), lambda yy, xx: tvm.if_then_else( tvm.all(yy >= pad, yy - pad < in_size, xx >= pad, xx - pad < in_size), A[yy-pad, xx-pad], tvm.const(0., "float32")), name='Apad') Out = tvm.compute((pad_size, pad_size), lambda j,i: tvm.if_then_else( tvm.all(i>0, j>0, i < pad_size - 1, j < pad_size - 1), (Apad[j-1,i-1] + Apad[j-1,i]+ Apad[j-1,i+1] +Apad[j,i-1] + Apad[j,i]+ Apad[j,i+1] +Apad[j+1,i-1] + Apad[j+1,i] +Apad[j+1,i+1]) , tvm.const(0.,"float32")), name ='Out') ``` I am trying the following schedule, with the goal to compute in a TXxTY tile, that loads a 10x10 tile into shared memory: ```python s = tvm.create_schedule(Out.op) TX = 8 TY = 8 block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis((0, TX), "threadIdx.x") thread_y = tvm.thread_axis((0, TY), "threadIdx.y") j,i = Out.op.axis ii, io = s[Out].split(i, nparts=TX) ji, jo = s[Out].split(j, nparts=TY) s[Out].reorder(jo, io, ji, ii) ijo = s[Out].fuse(jo, io) Ashm = s.cache_read(Apad, 'shared', [Out]) s[Ashm].compute_at(s[Out], ii) s[Apad].compute_inline() s[Out].bind(ijo, block_x) s[Out].bind(ii, thread_x) s[Out].bind(ji, thread_y) ja,ia= s[Ashm].op.axis iao,iai = s[Ashm].split(ia, nparts=65) jao,jai = s[Ashm].split(ja, nparts=65) s[Ashm].reorder(jao, iao, jai, iai) ijao = s[Ashm].fuse(jao, iao) s[Ashm].bind(jai, thread_y) s[Ashm].bind(iai, thread_x) s[Ashm].bind(ijao, block_x) ``` This is lowered into the following code, which allocates 209764 floats of shared memory: ``` // attr [Out] storage_scope = "global" allocate Out[float32 * 268324] produce Out { // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 4225 // attr [Apad.shared] storage_scope = "shared" allocate Apad.shared[float32 * 209764] // attr [iter_var(threadIdx.y, range(min=0, ext=8), threadIdx.y)] thread_extent = 8 // attr [iter_var(threadIdx.x, range(min=0, ext=8), threadIdx.x)] thread_extent = 8 produce Apad.shared { if (likely((((floordiv(blockIdx.x, 65)*8) + threadIdx.y) < 458))) { if (likely((((floormod(blockIdx.x, 65)*8) + threadIdx.x) < 458))) { if (likely((1 <= ((floordiv(blockIdx.x, 65)*9) + threadIdx.y)))) { if (likely((((floordiv(blockIdx.x, 65)*9) + threadIdx.y) < 519))) { if (likely((1 <= ((floormod(blockIdx.x, 65)*9) + threadIdx.x)))) { if (likely((((floormod(blockIdx.x, 65)*9) + threadIdx.x) < 519))) { Apad.shared[((((floordiv(blockIdx.x, 65)*3664) + (threadIdx.y*458)) + (floormod(blockIdx.x, 65)*8)) + threadIdx.x)] = tvm_if_then_else(((((4 <= ((floordiv(blockIdx.x, 65)*9) + threadIdx.y)) && (((floordiv(blockIdx.x, 65)*9) + threadIdx.y) < 516)) && (4 <= ((floormod(blockIdx.x, 65)*9) + threadIdx.x))) && (((floormod(blockIdx.x, 65)*9) + threadIdx.x) < 516)), A[(((((floordiv(blockIdx.x, 65)*4608) + (threadIdx.y*512)) + (floormod(blockIdx.x, 65)*9)) + threadIdx.x) - 2052)], 0f) } } } } } } } if (likely((((threadIdx.y*65) + floordiv(blockIdx.x, 65)) < 518))) { if (likely((((threadIdx.x*65) + floormod(blockIdx.x, 65)) < 518))) { if (likely((((threadIdx.y*65) + floordiv(blockIdx.x, 65)) < 518))) { if (likely((((threadIdx.x*65) + floormod(blockIdx.x, 65)) < 518))) { Out[((((threadIdx.y*33670) + (floordiv(blockIdx.x, 65)*518)) + (threadIdx.x*65)) + floormod(blockIdx.x, 65))] = tvm_if_then_else(((((0 < ((threadIdx.x*65) + floormod(blockIdx.x, 65))) && (0 < ((threadIdx.y*65) + floordiv(blockIdx.x, 65)))) && (((threadIdx.x*65) + floormod(blockIdx.x, 65)) < 517)) && (((threadIdx.y*65) + floordiv(blockIdx.x, 65)) < 517)), ((((((((Apad.shared[((threadIdx.y*29770) + (threadIdx.x*65))] + Apad.shared[(((threadIdx.y*29770) + (threadIdx.x*65)) + 1)]) + Apad.shared[(((threadIdx.y*29770) + (threadIdx.x*65)) + 2)]) + Apad.shared[(((threadIdx.y*29770) + (threadIdx.x*65)) + 458)]) + Apad.shared[(((threadIdx.y*29770) + (threadIdx.x*65)) + 459)]) + Apad.shared[(((threadIdx.y*29770) + (threadIdx.x*65)) + 460)]) + Apad.shared[(((threadIdx.y*29770) + (threadIdx.x*65)) + 916)]) + Apad.shared[(((threadIdx.y*29770) + (threadIdx.x*65)) + 917)]) + Apad.shared[(((threadIdx.y*29770) + (threadIdx.x*65)) + 918)]), 0f) } } } } } ``` Of course, it is very likely that my code is wrong. Could you please have a look if this is an issue? Thanks Young-Jun
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
