ptrendx commented on a change in pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#discussion_r592660682



##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const 
InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : 
static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % 
param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / 
param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), 
sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + 
i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t 
num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;

Review comment:
       Should be, I will remove that.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to