eric-haibin-lin commented on a change in pull request #15545: Softmax 
optimization for GPU
URL: https://github.com/apache/incubator-mxnet/pull/15545#discussion_r315936560
 
 

 ##########
 File path: src/operator/nn/softmax-inl.h
 ##########
 @@ -218,6 +219,157 @@ __global__ void softmax_compute_kernel(DType *in, OType 
*out, index_t M, int axi
   }
 }
 
+const int softmax_threads_per_block = 512;
+
+template <typename OP, typename T>
+__device__ inline T warp_reduce(T value, OP redfun) {
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
+  return value;
+}
+
+template <typename OP>
+__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t 
value, OP redfun) {
+  float v = static_cast<float>(value);
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
+  return mshadow::half::half_t(v);
+}
+
+template<typename OP, bool negate, typename AType, typename LType,
+  typename DType, typename OType>
+__global__ void softmax_compute_kernel2(const DType *in, OType *out, const 
index_t M,
+                                       const double temperature, int 
rows_per_block,
+                                       const index_t total_rows) {
+  __shared__ AType scratch[softmax_threads_per_block];
+  __shared__ LType persistent_storage[20*1024 / sizeof(LType)];
+  const int warp_size = 32;
+  const int threads_per_row = softmax_threads_per_block / rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int my_row = blockIdx.x * rows_per_block + my_local_row;
+  if (my_row >= total_rows) return;
+  const int my_id = threadIdx.x % threads_per_row;
+  const int entries_per_load = sizeof(LType)/sizeof(DType);
+  // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating
+  // kernels where sizeof(LType) may be less than sizeof(DType),
+  // resulting in entries_per_load being 0.
+  // This is not a valid combination and is being checked against
+  // in the launcher code. This switch here is just to silence
+  // the division by zero warning generated for such invalid cases.
+  const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
+
+  const LType * in_aligned = reinterpret_cast<const LType *>(in);
+  size_t base = my_row * row_length;
+
+  for (index_t i = my_id; i < row_length; i += threads_per_row) {
+    persistent_storage[my_local_row * row_length + i] = in_aligned[base + i];
+  }
+  DType * row = reinterpret_cast<DType *>(persistent_storage + my_local_row * 
row_length);
+  __syncthreads();
+
+  DType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < M; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  scratch[threadIdx.x] = my_max_value;
+  __syncthreads();
+  for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+    if (my_id < size) {
+      scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + 
size]);
+    }
+    __syncthreads();
+  }
+  if (my_id < warp_size) {
+    AType my_value = warp_reduce(scratch[threadIdx.x],
+                                 [](AType x, AType y) { return ::max(x, y); });
+    scratch[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  DType smax = scratch[threadIdx.x - threadIdx.x % threads_per_row];
+  __syncthreads();
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < M; i += threads_per_row) {
+    const DType val = negate ? -row[i] : row[i];
+    my_sum += static_cast<AType>(expf((val - smax) / 
static_cast<AType>(temperature)));
+  }
+  scratch[threadIdx.x] = my_sum;
+  __syncthreads();
+  for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+    if (my_id < size) {
+      scratch[threadIdx.x] += scratch[threadIdx.x + size];
+    }
+    __syncthreads();
+  }
+  if (my_id < warp_size) {
+    AType my_value = warp_reduce(scratch[threadIdx.x],
+                                 [](AType x, AType y) { return x + y;});
+    scratch[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+
+  AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
+  __syncthreads();
+
+  for (index_t i = my_id; i < M; i += threads_per_row) {
+    const DType val = negate ? -row[i] : row[i];
+    row[i] = OP::Map((val - smax)/static_cast<DType>(temperature), ssum);
+  }
+  __syncthreads();
+
+  LType * out_aligned = reinterpret_cast<LType *>(out);
+
+  for (index_t i = my_id; i < row_length; i += threads_per_row) {
+    out_aligned[base + i] = persistent_storage[my_local_row * row_length + i];
+  }
+}
+
+namespace {
+
+int get_load_type(size_t N) {
+  if (N % 8 == 0) {
+    return kFloat64;
+  } else if (N % 4 == 0) {
+    return kFloat32;
+  } else if (N % 2 == 0) {
+    return kFloat16;
+  } else {
+    return kInt8;
+  }
+}
+
+int get_rows_per_block(size_t N) {
 
 Review comment:
   @ptrendx could you address this? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to