MoisesHer commented on a change in pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#discussion_r585267506



##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {

Review comment:
       Is there a way to avoid duplicated code for this structure?

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const 
InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : 
static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % 
param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / 
param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), 
sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + 
i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t 
num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;

Review comment:
       is the one defined in 
https://github.com/apache/incubator-mxnet/blob/ba2c3b411e953b650fd919db59b42b5535d80e83/src/common/cuda/rtc/util-inl.h#L280
 visible here?

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const 
InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : 
static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % 
param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / 
param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), 
sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + 
i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t 
num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / 
param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : 
LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * 
param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + 
loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * 
param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * 
param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + 
loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x 
+ size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { 
return op::max(x, y); },
+                                                    min(threads_per_row, 
warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return 
op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { 
return x + y;},
+                                                    min(threads_per_row, 
warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + 
y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - 
smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * 
param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + 
j],
+                                                   
OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * 
param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * 
param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + 
j],
+                                                   
OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+bool IsPower2(size_t N) {
+  return ((N & (N - 1)) == 0) && N != 0;
+}
+
+index_t RoundToPower2(index_t N) {

Review comment:
        move this to util-inl.h if you think this can be reused by other ops?

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const 
InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : 
static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % 
param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / 
param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), 
sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + 
i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t 
num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / 
param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : 
LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * 
param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + 
loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * 
param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * 
param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + 
loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x 
+ size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { 
return op::max(x, y); },
+                                                    min(threads_per_row, 
warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return 
op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { 
return x + y;},
+                                                    min(threads_per_row, 
warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + 
y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - 
smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * 
param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + 
j],
+                                                   
OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * 
param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * 
param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + 
j],
+                                                   
OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+bool IsPower2(size_t N) {
+  return ((N & (N - 1)) == 0) && N != 0;
+}
+
+index_t RoundToPower2(index_t N) {
+  size_t ret = 1;
+  size_t copyN = N;
+  while (N >= 2) {
+    ret *= 2;
+    N /= 2;
+  }
+  if (ret < copyN) {
+    ret *= 2;
+  }
+  return ret;
+}
+
+int get_rows_per_block(const index_t row_size, const int nvec,
+                       const index_t max_storage, const int 
num_threads_per_block,
+                       const index_t total_rows, const int dev_id) {
+  CHECK(IsPower2(num_threads_per_block))
+    << "Number of threads in a block must be power of 2 to use 
get_rows_per_block function";
+  // How many read instructions should 1 thread at least do
+  const int read_instructions = 16;
+  const size_t row_size_in_vec = (row_size + nvec - 1) / nvec;
+  int desired_num_threads_per_row = (row_size_in_vec + read_instructions - 1) 
/ read_instructions;
+  desired_num_threads_per_row = RoundToPower2(desired_num_threads_per_row);
+  desired_num_threads_per_row = std::min(desired_num_threads_per_row, 
num_threads_per_block);
+  const int desired_rows_per_block = num_threads_per_block / 
desired_num_threads_per_row;
+  int actual_rows_per_block = desired_rows_per_block;
+  int num_sms = MultiprocessorCount(dev_id);
+  while (actual_rows_per_block > 1 &&
+         ((max_storage != -1 && max_storage < row_size * 
actual_rows_per_block) ||
+          (total_rows + actual_rows_per_block - 1) / actual_rows_per_block < 
num_sms)) {
+    actual_rows_per_block /= 2;
+  }
+  return actual_rows_per_block;
+}
+
+}  // namespace
+
+void SoftmaxRTCCompute::operator()(const nnvm::NodeAttrs& attrs,
+                                   const OpContext& ctx,
+                                   const std::vector<TBlob>& inputs,
+                                   const std::vector<OpReqType>& req,
+                                   const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  using common::mshadow_type_info;
+  using namespace common::cuda::rtc;
+  using common::div_round;
+  if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+  int axis = CheckAxis(param.axis, inputs[0].ndim());
+  const double temperature = param.temperature.has_value() ?
+                             param.temperature.value() : 1.0;
+  mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
+
+  void* length_ptr = nullptr;
+  std::string length_typename = "int";
+  if (param.use_length.value()) {
+    CHECK(inputs.size() > 1)
+      << "Mask needs to be provided when using softmax with use_length=True.";
+    length_ptr = inputs[1].dptr_;
+    length_typename = mshadow_type_info(inputs[1].type_flag_).name;
+  }
+  CHECK_EQ(outputs.size(), 1);
+  index_t M = shape[axis];
+  if (M == 0 || shape.Size() == 0) return;
+  index_t stride = 1;
+  if (axis == shape.ndim() - 2) {
+    stride = shape[shape.ndim() - 1];
+  }
+  const index_t N = shape.Size() / M;
+  softmax_params params = {{inputs[0].dptr_, length_ptr, nullptr},
+                           {outputs[0].dptr_},
+                           stride, M,
+                           temperature, 1, N};
+  std::string code = "#define OP " + OP + "\n"
+                     "const OpReqType req = " + util::to_string(req[0]) + ";\n"
+                     "const bool negate = " + std::to_string(negate) + ";\n"
+                     "using InputType1 = " + length_typename + ";\n";
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+
+  constexpr int nvec = 2;
+  // Using 20 kB of shared memory for persistent storage in the optimized case
+  const size_t acc_type_size = 
std::max(mshadow_type_info(inputs[0].type_flag_).acc_size,
+                                        
mshadow_type_info(outputs[0].type_flag_).acc_size);
+  const size_t max_opt_M = 20 * 1024 / acc_type_size;
+  int rows_per_block = get_rows_per_block(M, nvec, max_opt_M,
+                                          vectorized_kernel_thread_num,
+                                          N, ctx.run_ctx.ctx.dev_id);
+  if (stride == 1 &&
+      static_cast<size_t>(M * rows_per_block) <= max_opt_M) {
+    const int warp_size = 32;

Review comment:
       maybe you can use 
https://github.com/apache/incubator-mxnet/blob/ba2c3b411e953b650fd919db59b42b5535d80e83/src/common/cuda/rtc/util-inl.h#L280?

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const 
InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : 
static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % 
param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / 
param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), 
sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + 
i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t 
num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / 
param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : 
LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * 
param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + 
loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * 
param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * 
param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + 
loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x 
+ size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { 
return op::max(x, y); },
+                                                    min(threads_per_row, 
warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return 
op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { 
return x + y;},
+                                                    min(threads_per_row, 
warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + 
y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - 
smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * 
param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + 
j],
+                                                   
OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * 
param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * 
param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + 
j],
+                                                   
OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+bool IsPower2(size_t N) {

Review comment:
       maybe we could move this to util-inl.h if you think this can be reused 
by other ops?

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];

Review comment:
       why 3 inputs?
   seems you only make use of 2




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to