zhreshold commented on a change in pull request #14139: Performance improvement 
in Normalize GPU Kernel
URL: https://github.com/apache/incubator-mxnet/pull/14139#discussion_r256202475
 
 

 ##########
 File path: src/operator/image/image_random.cu
 ##########
 @@ -111,6 +110,232 @@ void ToTensorImplCUDA(mshadow::Stream<gpu> *s,
         MSHADOW_CUDA_POST_KERNEL_CHECK(ToTensorCudaKernel);
 }
 
+// Normalize Kernel for 3D input
+template<typename xpu, typename DType>
+__global__ void NormalizeCudaKernel(const Tensor<xpu, 3, DType> input,
+                                    const Tensor<xpu, 3, DType> output,
+                                    const int req,
+                                    const int N,
+                                    const int H,
+                                    const int W,
+                                    const int C,
+                                    const float mean_d0,
+                                    const float mean_d1,
+                                    const float mean_d2,
+                                    const float std_d0,
+                                    const float std_d1,
+                                    const float std_d2) {
+    // We process one image per thread block.
+    // In 3D case, we have only 1 block i.e., blockIdx.x
+    // We do not use it.
+
+    float mean = mean_d0;
+    float std = std_d0;
+    for (int c = 0; c < C; ++c) {
+        switch (c) {
+            case 0 : mean = mean_d0;
+                     std = std_d0;
+                     break;
+            case 1 : mean = mean_d1;
+                     std = std_d1;
+                     break;
+            case 2 : mean = mean_d2;
+                     std = std_d2;
+                     break;
+        }
+        for (int h = threadIdx.y; h < H; h += blockDim.y) {
+            for (int w = threadIdx.x; w < W; w += blockDim.x) {
+                KERNEL_ASSIGN(output[c][h][w], req,
+                              (input[c][h][w] - mean) / std);
+            }
+        }
+    }
+}
+
+// Normalize Kernel for 4D input
+template<typename xpu, typename DType>
+__global__ void NormalizeCudaKernel(const Tensor<xpu, 4, DType> input,
+                                    const Tensor<xpu, 4, DType> output,
+                                    const int req,
+                                    const int N,
+                                    const int H,
+                                    const int W,
+                                    const int C,
+                                    const float mean_d0,
+                                    const float mean_d1,
+                                    const float mean_d2,
+                                    const float std_d0,
+                                    const float std_d1,
+                                    const float std_d2) {
+    // We process one image per thread block.
+    const int n = blockIdx.x;
+
+    float mean = mean_d0;
+    float std = std_d0;
+    for (int c = 0; c < C; ++c) {
+        switch (c) {
+            case 0 : mean = mean_d0;
+                     std = std_d0;
+                     break;
+            case 1 : mean = mean_d1;
+                     std = std_d1;
+                     break;
+            case 2 : mean = mean_d2;
+                     std = std_d2;
+                     break;
+        }
+        for (int h = threadIdx.y; h < H; h += blockDim.y) {
+            for (int w = threadIdx.x; w < W; w += blockDim.x) {
+                KERNEL_ASSIGN(output[n][c][h][w], req,
+                              (input[n][c][h][w] -  mean) / std);
+            }
+        }
+    }
+}
+
+template<typename DType, typename T>
+void NormalizeImplCUDA(mshadow::Stream<gpu> *s,
+                       const T input,
+                       const T output,
+                       const int req,
+                       const float mean_d0,
+                       const float mean_d1,
+                       const float mean_d2,
+                       const float std_d0,
+                       const float std_d1,
+                       const float std_d2) {
+    int blocks, H, W, C, N;
+    cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+    if (std::is_same<T, Tensor<gpu, 3, DType>>::value) {
+        // 3D Input - (C, H, W)
+        N = 0;
+        C = input.size(0);
+        H = input.size(1);
+        W = input.size(2);
+        blocks = 1;
+    } else {
+        // 4D Input - (N, C, H, W)
+        N = input.size(0);
+        C = input.size(1);
+        H = input.size(2);
+        W = input.size(3);
+        blocks = N > 0 ? N : 1;
+    }
+    // One block per image.
+    // Number of threads = (16, 16) is optimal, because,
 
 Review comment:
   the common choice is to fully occupy each SM, so you will use max threads as 
https://github.com/dmlc/mshadow/blob/master/mshadow/cuda/tensor_gpu-inl.cuh#L30 
and calculate number of blocks accordingly.  The benchmark results are only 
valid for certain configurations.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to