ptrendx commented on a change in pull request #15545: Softmax optimization for 
GPU
URL: https://github.com/apache/incubator-mxnet/pull/15545#discussion_r316289235
 
 

 ##########
 File path: src/operator/nn/softmax-inl.h
 ##########
 @@ -313,71 +294,134 @@ __global__ void softmax_compute_kernel(DType *in, OType 
*out, index_t M, int axi
 
   for (index_t i = x; i < M; i += x_size) {
     val = negate ? -in[base + i*sa] : in[base + i*sa];
-    out[base + i*sa] = OP::Map((val - smax)/static_cast<DType>(temperature), 
ssum);
+    out[base + i*sa] =
+      (i < len) ? OType(OP::Map((val - smax)/static_cast<DType>(temperature), 
ssum)) : OType(0.0f);
   }
 }
 
-template<typename OP, bool negate, typename AType, typename DType, typename 
OType, int ndim>
-inline void Softmax(Stream<gpu> *s, DType *in, OType *out,
-                    Shape<ndim> shape, int axis, const double temperature) {
-  const int x_bits = 7;
-  const int x_size = 1 << x_bits;
-  index_t M = shape[axis];
-  index_t N = shape.Size()/M;
-  Shape<ndim> stride = calc_stride(shape);
-  Shape<ndim> sshape = shape;
-  sshape[axis] = 1;
+const int softmax_threads_per_block = 512;
+
+template<typename OP, bool negate, typename AType, typename LType,
+  typename DType, typename OType, typename IType>
+__global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, 
IType *length,
+                                               const index_t M, const double 
temperature,
+                                               const int rows_per_block, const 
index_t total_rows) {
+  __shared__ AType scratch[softmax_threads_per_block];
+  __shared__ LType persistent_storage[20 * 1024 / sizeof(LType)];
+  const int warp_size = 32;
+  const int threads_per_row = softmax_threads_per_block / rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int my_row = blockIdx.x * rows_per_block + my_local_row;
+  if (my_row >= total_rows) return;
+  const int my_id = threadIdx.x % threads_per_row;
+  const int entries_per_load = sizeof(LType)/sizeof(DType);
+  const index_t len = length == nullptr ? M : 
static_cast<index_t>(length[my_row]);
+  // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating
+  // kernels where sizeof(LType) may be less than sizeof(DType),
+  // resulting in entries_per_load being 0.
+  // This is not a valid combination and is being checked against
+  // in the launcher code. This switch here is just to silence
+  // the division by zero warning generated for such invalid cases.
+  const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
+
+  const LType* in_aligned = reinterpret_cast<const LType*>(in);
+  size_t base = my_row * row_length;
+
+  for (index_t i = my_id; i < row_length; i += threads_per_row) {
+    persistent_storage[my_local_row * row_length + i] = in_aligned[base + i];
+  }
+  DType * row = reinterpret_cast<DType *>(persistent_storage + my_local_row * 
row_length);
+  __syncthreads();
 
-  softmax_compute_kernel<x_bits, OP, negate, AType, ndim>
-    <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
-      in, out, M, axis, sshape, stride, temperature);
-  MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
-}
+  DType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
 
-template<int x_bits, typename OP, bool negate, typename AType, int ndim,
-         typename DType, typename OType, typename IType>
-__global__ void softmax_with_length_kernel(DType *in, OType *out, IType 
*length,
-                                           index_t M, int axis, Shape<ndim> 
sshape,
-                                           Shape<ndim> stride, const double 
temperature) {
-  const unsigned x_size = 1 << x_bits;
-  __shared__ AType smem[x_size];
-  index_t sa = stride[axis];
-  index_t base = unravel_dot(blockIdx.x, sshape, stride);
-  index_t x = threadIdx.x;
-  index_t len = static_cast<index_t>(length[blockIdx.x]);
-
-  red::maximum::SetInitValue(smem[x]);
-  for (index_t i = x; i < len; i += x_size) {
-    smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
   }
+  scratch[threadIdx.x] = my_max_value;
   __syncthreads();
-  cuda::Reduce1D<red::maximum, x_bits>(smem);
+  for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+    if (my_id < size) {
+      scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + 
size]);
+    }
+    __syncthreads();
+  }
+  if (my_id < warp_size) {
+    AType my_value = warp_reduce(scratch[threadIdx.x],
+                                 [](AType x, AType y) { return ::max(x, y); });
+    scratch[threadIdx.x] = my_value;
+  }
   __syncthreads();
-  DType smax = smem[0];
+  DType smax = scratch[threadIdx.x - threadIdx.x % threads_per_row];
   __syncthreads();
 
-  red::sum::SetInitValue(smem[x]);
-  DType val;
-  for (index_t i = x; i < len; i += x_size) {
-    val = negate ? -in[base + i*sa]:in[base + i*sa];
-    smem[x] += static_cast<AType>(expf((val - smax) / 
static_cast<AType>(temperature)));
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const DType val = negate ? -row[i] : row[i];
+    my_sum += static_cast<AType>(expf((val - smax) / 
static_cast<AType>(temperature)));
   }
+  scratch[threadIdx.x] = my_sum;
   __syncthreads();
-  cuda::Reduce1D<red::sum, x_bits>(smem);
+  for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+    if (my_id < size) {
+      scratch[threadIdx.x] += scratch[threadIdx.x + size];
+    }
+    __syncthreads();
+  }
+  if (my_id < warp_size) {
+    AType my_value = warp_reduce(scratch[threadIdx.x],
+                                 [](AType x, AType y) { return x + y;});
+    scratch[threadIdx.x] = my_value;
+  }
   __syncthreads();
-  AType ssum = smem[0];
+
+  AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
   __syncthreads();
 
-  for (index_t i = x; i < M; i += x_size) {
-    val = negate ? -in[base + i*sa] : in[base + i*sa];
-    out[base + i*sa] =
-      (i < len) ? OType(OP::Map((val - smax)/static_cast<DType>(temperature), 
ssum)) : OType(0.0f);
+  for (index_t i = my_id; i < M; i += threads_per_row) {
+    const DType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? DType(OP::Map((val - 
smax)/static_cast<DType>(temperature), ssum)) :
+                         DType(0.0f);
+  }
+  __syncthreads();
+
+  LType* out_aligned = reinterpret_cast<LType*>(out);
+
+  for (index_t i = my_id; i < row_length; i += threads_per_row) {
+    out_aligned[base + i] = persistent_storage[my_local_row * row_length + i];
   }
 }
 
+namespace {
+
+int get_rows_per_block(size_t N) {
+  const int warp_size = 32;
+  // How many read instructions should 1 thread at least do
+  const int read_instructions = 2;
+  const int num_threads = (N + read_instructions - 1) / read_instructions;
+  int num_warps = (num_threads + warp_size - 1) / warp_size;
+  // num_warps needs to be power of 2
+  int used_num_warps = 1;
+  num_warps = std::min(num_warps, softmax_threads_per_block / warp_size);
+  int tmp = num_warps;
 
 Review comment:
   Done.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to