masahi commented on a change in pull request #8636:
URL: https://github.com/apache/tvm/pull/8636#discussion_r688790802



##########
File path: python/tvm/topi/gpu/conv2d_nhwc.py
##########
@@ -85,15 +87,17 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
     thread_yz = te.thread_axis((0, vthread_n), "vthread", name="vy")
 
     # Schedule for output
-    ni, hi, wi, fi = s[output].op.axis
-    bz = s[output].fuse(hi, wi)
+    ni, _, wi, fi = s[output].op.axis
+    bz = wi
+    fi, vec = s[output].split(fi, factor=vec_factor)
+    s[output].vectorize(vec)

Review comment:
       Is this supposed to vectorize the conv2d inner loop? Based on generated 
code, I think it only vectorize the last stage, which can be copying local to 
global me or fused activation computation. I wonder where 6-7x perf improvement 
comes from?
   
   Here is an example of generated code where `vec_factor` is fixed to 2.
   ```extern "C" __global__ void __launch_bounds__(16) 
default_function_kernel0(float* __restrict__ A, float* __restrict__ W, float* 
__restrict__ Conv2dOutput) {
     float Conv2dOutput_local[8];
     __shared__ float PaddedInput_shared[96];
     __shared__ float W_shared[256];
     float PaddedInput_shared_local[2];
     float W_shared_local[4];
     for (int yy = 0; yy < 8; ++yy) {
       for (int nn_c_init = 0; nn_c_init < 2; ++nn_c_init) {
         for (int ff_c_init = 0; ff_c_init < 4; ++ff_c_init) {
           Conv2dOutput_local[(((nn_c_init * 4) + ff_c_init))] = 0.000000e+00f;
         }
       }
       for (int rc_outer = 0; rc_outer < 8; ++rc_outer) {
         for (int ry = 0; ry < 5; ++ry) {
           for (int rx = 0; rx < 5; ++rx) {
             __syncthreads();
             for (int ax0_ax3_fused_outer_outer = 0; ax0_ax3_fused_outer_outer 
< 4; ++ax0_ax3_fused_outer_outer) {
               PaddedInput_shared[((((ax0_ax3_fused_outer_outer * 24) + 
(((int)threadIdx.y) * 4)) + ((int)threadIdx.x)))] = (((((2 <= ((yy * 2) + ry)) 
&& (((yy * 2) + ry) < 18)) && (2 <= ((((int)blockIdx.z) * 2) + rx))) && 
(((((int)blockIdx.z) * 2) + rx) < 18)) ? A[((((((((((ax0_ax3_fused_outer_outer 
* 32768) + (yy * 4096)) + (ry * 2048)) + (((int)blockIdx.z) * 256)) + (rx * 
128)) + (rc_outer * 16)) + (((int)threadIdx.y) * 4)) + ((int)threadIdx.x)) - 
4352))] : 0.000000e+00f);
             }
             for (int ax2_ax3_fused_outer_outer_outer = 0; 
ax2_ax3_fused_outer_outer_outer < 8; ++ax2_ax3_fused_outer_outer_outer) {
               ((float2*)(W_shared + ((((ax2_ax3_fused_outer_outer_outer * 32) 
+ (((int)threadIdx.y) * 8)) + (((int)threadIdx.x) * 2)))))[0] = ((float2*)(W + 
((((((((ry * 81920) + (rx * 16384)) + (rc_outer * 2048)) + 
(ax2_ax3_fused_outer_outer_outer * 256)) + ((((((int)threadIdx.y) * 8) + 
(((int)threadIdx.x) * 2)) >> 4) * 128)) + (((int)blockIdx.x) * 16)) + 
(((((int)threadIdx.y) * 8) + (((int)threadIdx.x) * 2)) & 15)))))[0];
             }
             __syncthreads();
             for (int rc_inner = 0; rc_inner < 16; ++rc_inner) {
               for (int ax0 = 0; ax0 < 2; ++ax0) {
                 if (((((int)threadIdx.y) * 2) + ax0) < 4) {
                   PaddedInput_shared_local[(ax0)] = 
PaddedInput_shared[((((((int)threadIdx.y) * 48) + (ax0 * 24)) + rc_inner))];
                 }
               }
               for (int ax3 = 0; ax3 < 4; ++ax3) {
                 W_shared_local[(ax3)] = W_shared[((((rc_inner * 16) + 
(((int)threadIdx.x) * 4)) + ax3))];
               }
               for (int nn_c = 0; nn_c < 2; ++nn_c) {
                 for (int ff_c = 0; ff_c < 4; ++ff_c) {
                   if (((((int)threadIdx.y) * 2) + nn_c) < 4) {
                     Conv2dOutput_local[(((nn_c * 4) + ff_c))] = 
(Conv2dOutput_local[(((nn_c * 4) + ff_c))] + (PaddedInput_shared_local[(nn_c)] 
* W_shared_local[(ff_c)]));
                   }
                 }
               }
             }
           }
         }
       }
       for (int nn_inner = 0; nn_inner < 2; ++nn_inner) {
         for (int ff_outer_inner = 0; ff_outer_inner < 2; ++ff_outer_inner) {
           if (((((int)threadIdx.y) * 2) + nn_inner) < 4) {
             if (((int)threadIdx.y) < 2) {
               ((float2*)(Conv2dOutput + ((((((((((int)threadIdx.y) * 16384) + 
(nn_inner * 8192)) + (yy * 1024)) + (((int)blockIdx.z) * 128)) + 
(((int)blockIdx.x) * 16)) + (((int)threadIdx.x) * 4)) + (ff_outer_inner * 
2)))))[0] = ((float2*)(Conv2dOutput_local + (((nn_inner * 4) + (ff_outer_inner 
* 2)))))[0];
             }
           }
         }
       }
     }
   }
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to