This is an automated email from the ASF dual-hosted git repository.

moisesher pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c692770  [PERF] Moving GPU softmax to RTC and optimizations (#19905)
c692770 is described below

commit c6927704d8ceb4597f33a45037277d462ee67045
Author: Przemyslaw Tredak <[email protected]>
AuthorDate: Mon Apr 26 09:16:09 2021 -0700

    [PERF] Moving GPU softmax to RTC and optimizations (#19905)
    
    * Moving softmax to RTC
    
    * Fix from rebase
    
    * Fixes from review
    
    * Fix
    
    * Fix FPE in the softmax grad.
---
 src/common/cuda/rtc/backward_functions-inl.h     | 161 +++--
 src/common/cuda/rtc/forward_functions-inl.h      | 126 ++--
 src/common/cuda/rtc/reducer-inl.h                | 399 ++++++++++++
 src/common/cuda/rtc/util-inl.h                   | 113 +++-
 src/common/cuda/rtc/vectorization-inl.h          |  24 +-
 src/common/cuda/utils.cc                         |  20 +-
 src/common/cuda/utils.h                          |   8 +
 src/common/utils.h                               |  21 +
 src/operator/mxnet_op.h                          |  17 +-
 src/operator/nn/log_softmax.cu                   |   5 +-
 src/operator/nn/softmax-inl.h                    | 357 +----------
 src/operator/nn/softmax.cu                       | 781 ++++++++++++++++++++++-
 src/operator/nn/softmin.cu                       |   5 +-
 src/operator/tensor/elemwise_binary_scalar_op.cc |  12 +-
 14 files changed, 1484 insertions(+), 565 deletions(-)

diff --git a/src/common/cuda/rtc/backward_functions-inl.h 
b/src/common/cuda/rtc/backward_functions-inl.h
index 168dc68..50f0c67 100644
--- a/src/common/cuda/rtc/backward_functions-inl.h
+++ b/src/common/cuda/rtc/backward_functions-inl.h
@@ -32,217 +32,217 @@ const char backward_function_definitions[] = R"code(
 namespace op {
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_relu(const DTypeGrad grad, const DType val) {
   if (isnan(val)) return val;
   return val > 0 ? grad : 0;
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_sigmoid(const DTypeGrad grad, const DType out) {
   return grad * out * (1 - out);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_softrelu(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   return grad * sigmoid(v);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_softsign(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   const auto ap1 = 1 + op::abs(v);
   return grad / (ap1 * ap1);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_abs(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   return grad * op::sign(v);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_exp(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   return grad * op::exp(v);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_expm1(const DTypeGrad grad, const DType val) {
   return backward_exp(grad, val);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_log(const DTypeGrad grad, const DType val) {
   return grad / val;
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_log10(const DTypeGrad grad, const DType val) {
   return grad / (val * op::log(static_cast<DTypeGrad>(10)));
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_log2(const DTypeGrad grad, const DType val) {
   return grad / (val * op::log(static_cast<DTypeGrad>(2)));
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_log1p(const DTypeGrad grad, const DType val) {
   return grad / (1 + val);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_sin(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   return grad * op::cos(v);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_cos(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   return -grad * op::sin(v);
 }
 
 // Uses output from tan
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_tan(const DTypeGrad grad, const DType out) {
   return grad * (out * out + 1);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_arcsin(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   return grad / op::sqrt(1 - v*v);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_arccos(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   return -grad / op::sqrt(1 - v*v);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_arctan(const DTypeGrad grad, const DType val) {
   return grad / (1 + val*val);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_degrees(const DTypeGrad grad, const DType /* val */) {
   return op::degrees(grad);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_radians(const DTypeGrad grad, const DType /* val */) {
   return op::radians(grad);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_sinh(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   return grad * op::cosh(v);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_cosh(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   return grad * op::sinh(v);
 }
 
 // Uses tanh output
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_tanh(const DTypeGrad grad, const DType out) {
   return grad * (1 - out * out);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_arcsinh(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   return grad / op::sqrt(v * v + 1);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_arccosh(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   return grad / op::sqrt(v * v - 1);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_arctanh(const DTypeGrad grad, const DType val) {
   return grad / (1 - val * val);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_sqrt(const DTypeGrad grad, const DType out) {
   return 0.5 * grad / out;
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_rsqrt(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   const auto inv = 1 / v;
   return -0.5 * grad * op::sqrt(inv) * inv;
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_cbrt(const DTypeGrad grad, const DType out) {
   return grad / (3.0f * out * out);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_rcbrt(const DTypeGrad grad, const DType val) {
-  const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   const auto inv = 1 / v;
   return -1.f/3.f * grad * op::cbrt(inv) * inv;
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_square(const DTypeGrad grad, const DType val) {
   return 2 * val * grad;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 rdiv_grad(const DType val,
           const DType2 val2) {
   return -val2 / (val * val);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 div_grad(const DType val,
          const DType2 val2) {
-  const typename type_util::mixed_type<DType, DType2>::type temp = val2;
+  const mixed_type<DType, DType2> temp = val2;
   return op::reciprocal(temp);
 }
 
@@ -283,87 +283,87 @@ __device__ inline DType rmod_grad(const DType val,
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 power_grad(const DType val,
            const DType2 val2) {
   return op::power(val, val2 - 1.f) * val2;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 power_rgrad(const DType val,
             const DType2 val2) {
-  const typename type_util::mixed_type<DType, DType2>::type temp = val;
+  const mixed_type<DType, DType2> temp = val;
   return op::power(val, val2) * op::log(temp);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 rpower_grad(const DType val,
             const DType2 val2) {
-  const typename type_util::mixed_type<DType, DType2>::type temp = val2;
+  const mixed_type<DType, DType2> temp = val2;
   return val * op::log(temp);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 hypot_grad_left(const DType val,
                 const DType2 val2) {
   return val / op::hypot(val, val2);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 hypot_grad_right(const DType val,
                  const DType2 val2) {
   return val2 / op::hypot(val, val2);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 copysign_grad(const DType val,
               const DType2 val2) {
   return (val >= 0 && val2 >= 0) || (val < 0 && val2 < 0) ? 1 : -1;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 arctan2_grad(const DType val,
              const DType2 val2) {
   return val2 / (val * val + val2 * val2);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 rarctan2_grad(const DType val,
               const DType2 val2) {
   return val / (val * val + val2 * val2);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 arctan2_rgrad(const DType val,
               const DType2 val2) {
   return -rarctan2_grad(val, val2);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 ldexp_grad(const DType val,
            const DType2 val2) {
   return op::power(static_cast<DType>(2), val2);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 rldexp_grad(const DType val,
             const DType2 val2) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  return val2 * op::power(static_cast<mixed_type>(2), val) * 
op::log(static_cast<mixed_type>(2));
+  using type = mixed_type<DType, DType2>;
+  return val2 * op::power(static_cast<type>(2), val) * 
op::log(static_cast<type>(2));
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_clip(const DTypeGrad grad, const DType val,
               const float a_min, const float a_max) {
   if (val > a_max || val < a_min) {
@@ -374,35 +374,32 @@ backward_clip(const DTypeGrad grad, const DType val,
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_reciprocal(const DTypeGrad grad, const DType val) {
   return -grad / (val * val);
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_erf(const DTypeGrad grad, const DType val) {
-  using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
-  const mixed_type v = val;
-  constexpr mixed_type my_pi = pi;
+  const mixed_type<DTypeGrad, DType> v = val;
+  constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
   return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_erfinv(const DTypeGrad grad, const DType val) {
-  using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
-  constexpr mixed_type my_pi = pi;
-  const mixed_type g = grad;
-  const mixed_type v = val;
+  constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
+  const mixed_type<DTypeGrad, DType> g = grad;
+  const mixed_type<DTypeGrad, DType> v = val;
   return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_gamma(const DTypeGrad grad, const DType val) {
-  using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
-  const mixed_type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   if (type_util::is_same<DTypeGrad, double>::value) {
     return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
   } else {
@@ -411,10 +408,9 @@ backward_gamma(const DTypeGrad grad, const DType val) {
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_gammaln(const DTypeGrad grad, const DType val) {
-  using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
-  const mixed_type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   if (type_util::is_same<DTypeGrad, double>::value) {
     return grad * op::special_functions::cephes::psi<double>(v);
   } else {
@@ -423,10 +419,9 @@ backward_gammaln(const DTypeGrad grad, const DType val) {
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_digamma(const DTypeGrad grad, const DType val) {
-  using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
-  const mixed_type v = val;
+  const mixed_type<DTypeGrad, DType> v = val;
   if (type_util::is_same<DTypeGrad, double>::value) {
     return grad * op::special_functions::trigamma<double>(v);
   } else {
@@ -435,7 +430,7 @@ backward_digamma(const DTypeGrad grad, const DType val) {
 }
 
 template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_gelu(const DTypeGrad grad, const DType val) {
   return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
                  val * backward_erf(grad, val / op::sqrt(2.0f)) / 
op::sqrt(2.0f));
diff --git a/src/common/cuda/rtc/forward_functions-inl.h 
b/src/common/cuda/rtc/forward_functions-inl.h
index f85916f..f4d08e6 100644
--- a/src/common/cuda/rtc/forward_functions-inl.h
+++ b/src/common/cuda/rtc/forward_functions-inl.h
@@ -32,6 +32,7 @@ const char function_definitions_util[] = R"code(
 #define INT_MAX (2147483647)
 
 namespace op {
+using type_util::mixed_type;
 
 template <typename DType>
 struct LoadType {
@@ -241,44 +242,44 @@ __device__ inline bool_t isfinite(const DType val) {
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 add(const DType a, const DType2 b) {
   return a + b;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 sub(const DType a, const DType2 b) {
   return a - b;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 rsub(const DType a, const DType2 b) {
   return b - a;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 mul(const DType a, const DType2 b) {
   return a * b;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 div(const DType a, const DType2 b) {
   return a / b;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 rdiv(const DType a, const DType2 b) {
   return b / a;
 }
 
 #define DEFINE_BINARY_MATH_FUNC(name, double_version, float_version) \
 template <typename DType, typename DType2> \
-__device__ inline typename type_util::mixed_type<DType, DType2>::type \
+__device__ inline mixed_type<DType, DType2> \
 name (const DType a, const DType2 b) { \
   if (type_util::has_double_or_integral<DType, DType2>::value) { \
     return double_version ((double)a, (double)b); \
@@ -288,7 +289,7 @@ name (const DType a, const DType2 b) { \
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 power (const DType a, const DType2 b) {
   if (type_util::has_double<DType, DType2>::value) {
     return ::pow ((double)a, (double)b); \
@@ -298,34 +299,34 @@ power (const DType a, const DType2 b) {
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 rpow(const DType a, const DType2 b) {
   return power(b, a);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 max(const DType a, const DType2 b) {
   if (isnan(a)) return a;
   return a > b ? a : b;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 fmax(const DType a, const DType2 b) {
   if (isnan(b)) return a;
   return a > b ? a : b;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 min(const DType a, const DType2 b) {
   if (isnan(a)) return a;
   return a < b ? a : b;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 fmin(const DType a, const DType2 b) {
   if (isnan(b)) return a;
   return a < b ? a : b;
@@ -334,7 +335,7 @@ fmin(const DType a, const DType2 b) {
 DEFINE_BINARY_MATH_FUNC(hypot, ::hypot, ::hypotf)
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 mod(const DType a, const DType2 b) {
   if (b == 0) {
     return 0;
@@ -359,7 +360,7 @@ mod(const DType a, const DType2 b) {
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 fmod(const DType a, const DType2 b) {
   if (b == 0) {
     return 0;
@@ -368,110 +369,98 @@ fmod(const DType a, const DType2 b) {
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 rmod(const DType a, const DType2 b) {
   return op::mod(b, a);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 rfmod(const DType a, const DType2 b) {
   return op::fmod(b, a);
 }
 
 template <typename DType, typename DType2>
 __device__ inline DType equal(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a == real_b ? 1 : 0;
 }
 
 template <typename DType, typename DType2>
 __device__ inline DType not_equal(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a != real_b ? 1 : 0;
 }
 
 template <typename DType, typename DType2>
 __device__ inline DType greater(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a > real_b ? 1 : 0;
 }
 
 template <typename DType, typename DType2>
 __device__ inline DType greater_equal(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a >= real_b ? 1 : 0;
 }
 
 template <typename DType, typename DType2>
 __device__ inline DType less(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a < real_b ? 1 : 0;
 }
 
 template <typename DType, typename DType2>
 __device__ inline DType less_equal(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a <= real_b ? 1 : 0;
 }
 
 template <typename DType, typename DType2>
 __device__ inline bool_t np_equal(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a == real_b ? true : false;
 }
 
 template <typename DType, typename DType2>
 __device__ inline bool_t np_not_equal(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a != real_b ? true : false;
 }
 
 template <typename DType, typename DType2>
 __device__ inline bool_t np_greater(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a > real_b ? true : false;
 }
 
 template <typename DType, typename DType2>
 __device__ inline bool_t np_greater_equal(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a >= real_b ? true : false;
 }
 
 template <typename DType, typename DType2>
 __device__ inline bool_t np_less(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a < real_b ? true : false;
 }
 
 template <typename DType, typename DType2>
 __device__ inline bool_t np_less_equal(const DType a, const DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a <= real_b ? true : false;
 }
 
@@ -501,7 +490,7 @@ __device__ inline DType2 rcopysign(const DType a, const 
DType2 b) {
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 lcm(const DType a, const DType2 b) {
   if (type_util::is_integral<DType>::value &&
       type_util::is_integral<DType2>::value) {
@@ -542,7 +531,7 @@ lcm(const DType a, const DType2 b) {
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 gcd(const DType a, const DType2 b) {
   if (type_util::is_integral<DType>::value &&
       type_util::is_integral<DType2>::value) {
@@ -585,42 +574,39 @@ gcd(const DType a, const DType2 b) {
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type 
bitwise_xor(const DType a,
+__device__ inline mixed_type<DType, DType2> bitwise_xor(const DType a,
                                                                        const 
DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a ^ real_b;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type 
bitwise_or(const DType a,
+__device__ inline mixed_type<DType, DType2> bitwise_or(const DType a,
                                                                        const 
DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a | real_b;
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type 
bitwise_and(const DType a,
+__device__ inline mixed_type<DType, DType2> bitwise_and(const DType a,
                                                                        const 
DType2 b) {
-  using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
-  const mixed_type real_a = a;
-  const mixed_type real_b = b;
+  const mixed_type<DType, DType2> real_a = a;
+  const mixed_type<DType, DType2> real_b = b;
   return real_a & real_b;
 }
 
 DEFINE_BINARY_MATH_FUNC(arctan2, ::atan2, ::atan2f)
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 rarctan2(const DType a, const DType2 b) {
   return arctan2(b, a);
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 ldexp(const DType a, const DType2 b) {
   if (type_util::has_double_or_integral<DType, DType2>::value) {
     return a * ::pow(2.0, static_cast<double>(b));
@@ -630,7 +616,7 @@ ldexp(const DType a, const DType2 b) {
 }
 
 template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
 rldexp(const DType a, const DType2 b) {
   return ldexp(b, a);
 }
diff --git a/src/common/cuda/rtc/reducer-inl.h 
b/src/common/cuda/rtc/reducer-inl.h
index 93b7027..259d0e0 100644
--- a/src/common/cuda/rtc/reducer-inl.h
+++ b/src/common/cuda/rtc/reducer-inl.h
@@ -94,6 +94,405 @@ struct sum {
     residual = 0;
   }
 };
+
+/*! \brief maximum reducer */
+struct maximum {
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 
src) { // NOLINT(*)
+    if (!util::isnan(dst)) {
+      if (!(dst >= src)) dst = src;
+    }
+  }
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 
src,
+                                       volatile DType& none) {
+    Reduce(dst, src);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
src_val) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
dst_residual,
+                                      volatile DType& src_val, volatile DType& 
src_residual) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst) {}
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst, volatile DType& 
none) {}
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv) {
+    initv = -2*DBL_MAX;
+  }
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv, DType &none) {
+    SetInitValue(initv);
+  }
+};
+
+/*! \brief minimum reducer */
+struct minimum {
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 
src) {
+    if (!util::isnan(dst)) {
+      if (!(dst <= src)) dst = src;
+    }
+  }
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 
src,
+                                       volatile DType& none) {
+    Reduce(dst, src);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
src_val) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
dst_residual,
+                                      volatile DType& src_val, volatile DType& 
src_residual) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst) {}
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst, volatile DType& 
none) {}
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv) {
+    initv = 2*DBL_MAX;
+  }
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv, DType &none) {
+    SetInitValue(initv);
+  }
+};
+
+/*! \brief product reducer */
+struct product {
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 
src) {
+    dst = op::mul(dst, src);
+  }
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 
src,
+                                       volatile DType& none) {
+    Reduce(dst, src);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
src_val) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
dst_residual,
+                                      volatile DType& src_val, volatile DType& 
src_residual) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst) {}
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst, volatile DType& 
none) {}
+  /*!
+  *\brief set the initial value during reduction
+  */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv) {
+    initv = 1;
+  }
+  /*!
+  *\brief set the initial value during reduction
+  */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv, DType &none) {
+    SetInitValue(initv);
+  }
+};
+
+/*! \brief sum reducer that ignores NaN values in the input */
+struct nansum {
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 
src) {
+    if (util::isnan(src)) return;
+    dst = op::add(dst, src);
+  }
+  /*! \brief do reduction into dst */
+  template<typename DType>
+  __device__ inline static void Reduce(volatile DType& dst, volatile DType src,
+                                       volatile DType& residual) {
+    if (util::isnan(src)) return;
+    DType y = src - residual;
+    DType t = dst + y;
+    residual = (t - dst) - y;
+    dst = t;
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
src_val) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
dst_residual,
+                                      volatile DType& src_val, volatile DType& 
src_residual) {
+    DType t1 = dst_val + src_val;
+    DType e = t1 - src_val;
+    DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + 
src_residual;
+    dst_val = t1 + t2;
+    dst_residual = t2 - (dst_val - t1);
+  }
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst) {}
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst, volatile DType& 
none) {}
+  /*!
+  *\brief set the initial value during reduction
+  */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType & initv) {
+      initv = 0;
+  }
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv, DType &residual) {
+    SetInitValue(initv);
+    residual = 0;
+  }
+};
+
+/*! \brief product reducer that ignores NaN values in the input */
+struct nanprod {
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 
src) {
+    if (util::isnan(src)) return;
+    dst = op::mul(dst, src);
+  }
+  /*! \brief do reduction into dst */
+  template<typename DType>
+  __device__ inline static void Reduce(volatile DType& dst, volatile DType src,
+                                       volatile DType& none) {
+    Reduce(dst, src);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
src_val) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
dst_residual,
+                                      volatile DType& src_val, volatile DType& 
src_residual) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief finalize reduction */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst) {}
+  /*! \brief finalize reduction */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst, volatile DType& 
none) {}
+  /*!
+  *\brief set the initial value during reduction
+  */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType & initv) {
+    initv = 1;
+  }
+  /*!
+  *\brief set the initial value during reduction
+  */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv, DType &none) {
+    SetInitValue(initv);
+  }
+};
+
+struct nrm2 {
+  /*! \brief do reduction into dst */
+  template<typename AType, typename DType>
+  __device__ inline static void Reduce(volatile AType& sum_of_squares, 
volatile DType src) {
+    sum_of_squares = op::add(sum_of_square, src * src);
+  }
+  /*! \brief do stable reduction into dst */
+  template<typename AType, typename DType>
+  __device__ inline static void Reduce(volatile AType& sum_of_squares,
+                                       volatile DType src, volatile DType& 
scale) {
+    if (src != 0) {
+      DType abs = op::abs(src);
+      if (scale < abs) {
+        sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs);
+        scale = abs;
+      } else {
+        sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale);
+      }
+    }
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
src_val) {
+    dst_val = op::add(dst_val, src_val);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_ssq, volatile DType& 
dst_scale,
+                                      volatile DType& src_ssq, volatile DType& 
src_scale) {
+    if (dst_scale != 0 && dst_scale >= src_scale) {
+      dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / 
dst_scale);
+    } else if (src_scale != 0 && dst_scale < src_scale) {
+      dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale / 
src_scale);
+      dst_scale = src_scale;
+    }
+  }
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& sum_of_squares) {
+    sum_of_squares = op::sqrt(sum_of_squares);
+  }
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& sum_of_squares, 
volatile DType& scale) {
+    sum_of_squares = scale * op::sqrt(sum_of_squares);
+  }
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &sum_of_squares) {
+    sum_of_squares = 0;
+  }
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &sum_of_squares, DType 
&scale) {
+    SetInitValue(sum_of_squares);
+    scale = 0;
+  }
+};
+
+struct nrmlp {
+  double lp;
+  /* \brief power for Lp norm */
+  __device__ inline static double lp_power(volatile double src, volatile 
double p) {
+    if (p != 0.0) {
+      if (src == 0.0) {
+        return src;
+      } else {
+        return op::power(src, p);
+      }
+    } else {  // 0-norm, sparsity
+      return static_cast<double>(src != 0);
+    }
+  }
+
+  /*! \brief do reduction into dst */
+  template<typename AType, typename DType>
+  __device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType 
src) {
+    if (src != 0) {
+      sum_of_powers += AType(lp_power(static_cast<double>(src), lp));
+    }
+  }
+
+  /*! \brief do stable reduction into dst */
+  template<typename AType, typename DType>
+  __device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType 
src,
+                                volatile DType& scale) {
+    if (src != 0) {
+      DType src_abs = op::abs(src);
+      if (scale < src_abs) {
+        sum_of_powers = sum_of_powers * 
AType(lp_power(static_cast<double>(scale / src_abs), lp));
+        sum_of_powers = sum_of_powers + 1;
+        scale = src_abs;
+      } else {
+        sum_of_powers = sum_of_powers + 
AType(lp_power(static_cast<double>(src_abs / scale), lp));
+      }
+    }
+  }
+
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& 
src_val) {
+    dst_val = dst_val + src_val;
+  }
+
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_ssq, volatile DType& 
dst_scale,
+                                      volatile DType& src_ssq, volatile DType& 
src_scale) {
+    if (dst_scale != 0 && dst_scale >= src_scale) {
+      dst_ssq = dst_ssq + src_ssq * 
DType(lp_power(static_cast<double>(src_scale / dst_scale), 2));
+    } else if (src_scale != 0 && dst_scale < src_scale) {
+      dst_ssq = src_ssq + dst_ssq * 
DType(lp_power(static_cast<double>(dst_scale / src_scale), 2));
+      dst_scale = src_scale;
+    }
+  }
+
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline void Finalize(volatile DType& sum_of_powers) {
+    if (lp != 0.0) {
+      sum_of_powers = DType(lp_power(static_cast<double>(sum_of_powers), 1.0 / 
lp));
+    }
+  }
+
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline void Finalize(volatile DType& sum_of_powers, volatile 
DType& scale) {
+    if (lp != 0.0) {
+      sum_of_powers = scale * 
DType(lp_power(static_cast<double>(sum_of_powers), 1.0 / lp));
+    }
+  }
+
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &sum_of_powers) {
+    sum_of_powers = 0;
+  }
+
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &sum_of_powers, DType 
&scale) {
+    SetInitValue(sum_of_powers);
+    scale = 0;
+  }
+};
 }  // namespace red
 
 )code";
diff --git a/src/common/cuda/rtc/util-inl.h b/src/common/cuda/rtc/util-inl.h
index 372390f..b426603 100644
--- a/src/common/cuda/rtc/util-inl.h
+++ b/src/common/cuda/rtc/util-inl.h
@@ -174,74 +174,97 @@ struct enable_if<true> {
 };
 
 template <typename T, typename U, class Enable = void>
-struct mixed_type;
+struct mixed_type_helper;
 
 template <typename T>
-struct mixed_type<T, float64, typename enable_if<!is_same<float64, 
T>::value>::type> {
+struct mixed_type_helper<T, float64, typename enable_if<!is_same<float64, 
T>::value>::type> {
   using type = float64;
 };
 
 template <typename T>
-struct mixed_type<float64, T> {
+struct mixed_type_helper<float64, T> {
   using type = float64;
 };
 
 template <typename T>
-struct mixed_type<T, float32, typename enable_if<!is_same<float64, T>::value &&
-                                                 !is_same<float32, 
T>::value>::type> {
+struct mixed_type_helper<T, float32, typename enable_if<!is_same<float64, 
T>::value &&
+                                                        !is_same<float32, 
T>::value>::type> {
   using type = float32;
 };
 
 template <typename T>
-struct mixed_type<float32, T, typename enable_if<!is_same<float64, 
T>::value>::type> {
+struct mixed_type_helper<float32, T, typename enable_if<!is_same<float64, 
T>::value>::type> {
   using type = float32;
 };
 
 template <typename T>
-struct mixed_type<T, float16, typename enable_if<is_same<float16, T>::value ||
-                                                 is_integral<T>::value>::type> 
{
+struct mixed_type_helper<T, float16, typename enable_if<is_same<float16, 
T>::value ||
+                                                        
is_integral<T>::value>::type> {
   using type = float16;
 };
 
 template <typename T>
-struct mixed_type<float16, T, typename enable_if<is_integral<T>::value>::type> 
{
+struct mixed_type_helper<float16, T, typename 
enable_if<is_integral<T>::value>::type> {
   using type = float16;
 };
 
 template <typename T, typename U>
-struct mixed_type<T, U, typename enable_if<is_integral<T>::value &&
-                                           is_integral<U>::value &&
-                                           !is_same<U, bool_t>::value &&
-                                           sizeof(T) <= sizeof(U)>::type> {
+struct mixed_type_helper<T, U, typename enable_if<is_integral<T>::value &&
+                                                  is_integral<U>::value &&
+                                                  !is_same<U, bool_t>::value &&
+                                                  sizeof(T) <= 
sizeof(U)>::type> {
   using type = U;
 };
 
 template <typename T, typename U>
-struct mixed_type<U, T, typename enable_if<is_integral<T>::value &&
-                                           is_integral<U>::value &&
-                                           !is_same<U, bool_t>::value &&
-                                           sizeof(T) < sizeof(U)>::type> {
+struct mixed_type_helper<U, T, typename enable_if<is_integral<T>::value &&
+                                                  is_integral<U>::value &&
+                                                  !is_same<U, bool_t>::value &&
+                                                  sizeof(T) < 
sizeof(U)>::type> {
   using type = U;
 };
 
 template <typename T>
-struct mixed_type<T, bool_t, typename enable_if<is_integral<T>::value &&
-                                                sizeof(T) < 
sizeof(bool_t)>::type> {
+struct mixed_type_helper<T, bool_t, typename enable_if<is_integral<T>::value &&
+                                                       sizeof(T) < 
sizeof(bool_t)>::type> {
   using type = index_t;
 };
 
 template <typename T>
-struct mixed_type<bool_t, T, typename enable_if<is_integral<T>::value &&
-                                                sizeof(T) < 
sizeof(bool_t)>::type> {
+struct mixed_type_helper<bool_t, T, typename enable_if<is_integral<T>::value &&
+                                                       sizeof(T) < 
sizeof(bool_t)>::type> {
   using type = index_t;
 };
 
 template <typename T>
-struct mixed_type<T, bool_t, typename enable_if<is_integral<T>::value &&
-                                                sizeof(T) == 
sizeof(bool_t)>::type> {
+struct mixed_type_helper<T, bool_t, typename enable_if<is_integral<T>::value &&
+                                                       sizeof(T) == 
sizeof(bool_t)>::type> {
   using type = T;
 };
 
+template <typename... Ts>
+struct multi_mixed_type_helper;
+
+template <>
+struct multi_mixed_type_helper<> {
+    using type = void;
+};
+
+template <typename T>
+struct multi_mixed_type_helper<T> {
+    using type = T;
+};
+
+template <typename T, typename U, typename... Ts>
+struct multi_mixed_type_helper<T, U, Ts...> {
+    using type = typename mixed_type_helper<T,
+                                            typename multi_mixed_type_helper<U,
+                                                                             
Ts...>::type>::type;
+};
+
+template <typename... Ts>
+using mixed_type = typename multi_mixed_type_helper<Ts...>::type;
+
 }  // namespace type_util
 )code";
 
@@ -254,6 +277,7 @@ enum class OpReqType {
 };
 
 constexpr int kRTCMaxThreadsPerBlock = 512;
+constexpr int warp_size = 32;
 
 namespace util {
 
@@ -377,6 +401,49 @@ __device__ inline bool isnan(volatile const float16 &val) {
   return ::isnan(__half2float(const_cast<const float16&>(val)));
 }
 
+template <int NVALUES = warp_size, typename OP, typename T>
+__device__ inline T warp_reduce(T value, OP redfun) {
+#pragma unroll
+  for (int i = warp_size / 2; i >= 1; i /= 2) {
+    if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, 
i));
+  }
+  return value;
+}
+
+template <typename OP, typename T>
+__device__ inline T grouped_warp_reduce(T value, OP redfun, const int 
group_size) {
+  for (int i = 1; i < group_size; i *= 2) {
+    value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
+  }
+  return value;
+}
+
+template <typename OP, typename T>
+__device__ inline T grouped_warp_allreduce(T value, OP redfun, const int 
group_size) {
+  value = grouped_warp_reduce(value, redfun, group_size);
+  return __shfl_sync(0xffffffff, value, 0, group_size);
+}
+
+template <typename OP, typename T>
+__device__ inline T strided_grouped_warp_reduce(T value, OP redfun, const int 
group_size) {
+  for (int i = warp_size / 2; i >= group_size; i /= 2) {
+    value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
+  }
+  return value;
+}
+
+template <typename OP, typename T>
+__device__ inline T strided_grouped_warp_allreduce(T value, OP redfun, const 
int group_size) {
+  value = strided_grouped_warp_reduce(value, redfun, group_size);
+  for (int i = group_size; i < warp_size; i *= 2) {
+    T tmp = __shfl_up_sync(0xffffffff, value, i);
+    if (threadIdx.x % warp_size >= i) {
+      value = tmp;
+    }
+  }
+  return value;
+}
+
 }  // namespace util
 )code";
 }  // namespace rtc
diff --git a/src/common/cuda/rtc/vectorization-inl.h 
b/src/common/cuda/rtc/vectorization-inl.h
index 5cbc459..96205fc 100644
--- a/src/common/cuda/rtc/vectorization-inl.h
+++ b/src/common/cuda/rtc/vectorization-inl.h
@@ -41,6 +41,8 @@ const char vectorization_support_string[] = R"code(
 
 namespace vector {
 
+constexpr int vectorized_kernel_thread_num = 512;
+
 template <int size>
 struct VectorType {
     static_assert(size <= 32, "VectorType needs to have size of at most 32B");
@@ -166,7 +168,7 @@ class VectorizedAccessor {
     if (aligned) {
       alignment_ = 0;
       aligned_ptr_ = reinterpret_cast<LType*>(ptr);
-      n_elems_ = (size + nvec- 1) / nvec;
+      n_elems_ = (size + nvec - 1) / nvec;
     } else {
       size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
       alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType);
@@ -360,6 +362,8 @@ constexpr int vectorized_kernel_thread_num = 512;
  *  \param lead_input_num number of input to use for checking alignment
  *                        (in case only a subset of inputs is used vectorized).
  *                        Default is 0.
+ *  \param blocks if provided and not 0, will launch the specified number of 
thread blocks.
+ *                Default is 0.
  */
 template <typename Params>
 void VectorizedKernelRTCLauncher(const std::string &parameters,
@@ -373,7 +377,8 @@ void VectorizedKernelRTCLauncher(const std::string 
&parameters,
                                  const std::vector<TBlob> &inputs,
                                  const std::vector<TBlob> &outputs,
                                  const int dev_id,
-                                 const int lead_input_num = 0) {
+                                 const int lead_input_num = 0,
+                                 const index_t blocks = 0) {
   const index_t N = lead_dim * other_dim;
   nvec = std::min(nvec, 4);  // Use at most 4-wide vectors
   if (N != 0) {
@@ -435,11 +440,16 @@ void VectorizedKernelRTCLauncher(const std::string 
&parameters,
                                     lead_dim, nvec,
                                     common::mshadow_type_info(
                                       inputs[lead_input_num].type_flag_).size);
-    size_t num_elements = other_dim * num_aligned_elements;
     constexpr int threads = vectorized_kernel_thread_num;
-    constexpr int max_blocks = 65535;
-    index_t blocks = std::min(static_cast<int>((num_elements + threads - 1) / 
threads),
-                              max_blocks);
+    index_t num_blocks;
+    if (blocks != 0) {
+      num_blocks = blocks;
+    } else {
+      size_t num_elements = other_dim * num_aligned_elements;
+      num_blocks = (num_elements + threads - 1) / threads;
+      constexpr int max_blocks = 65535;
+      num_blocks = std::min(static_cast<int>(num_blocks), max_blocks);
+    }
     std::vector<const void*> args = {&params, &lead_dim, &other_dim,
                                      &N, &num_aligned_elements};
     auto function = common::cuda::rtc::get_function(kernel_builder,
@@ -448,7 +458,7 @@ void VectorizedKernelRTCLauncher(const std::string 
&parameters,
                                                     dev_id);
 
     common::cuda::rtc::launch(function,
-                              {static_cast<unsigned int>(blocks), 1, 1},
+                              {static_cast<unsigned int>(num_blocks), 1, 1},
                               {static_cast<unsigned int>(threads), 1, 1},
                               0, s, &args);
   }
diff --git a/src/common/cuda/utils.cc b/src/common/cuda/utils.cc
index b87c393..7aa936d 100644
--- a/src/common/cuda/utils.cc
+++ b/src/common/cuda/utils.cc
@@ -29,6 +29,7 @@
 #include <algorithm>
 
 #include "utils.h"
+#include "../utils.h"
 
 #if MXNET_USE_CUDA
 
@@ -36,25 +37,6 @@ namespace mxnet {
 namespace common {
 namespace cuda {
 
-namespace {
-  bool IsPower2(size_t N) {
-    return ((N & (N - 1)) == 0) && N != 0;
-  }
-
-  size_t RoundToPower2(size_t N) {
-    size_t ret = 1;
-    size_t copyN = N;
-    while (N >= 2) {
-      ret *= 2;
-      N /= 2;
-    }
-    if (ret < copyN) {
-      ret *= 2;
-    }
-    return ret;
-  }
-}  // namespace
-
 int get_load_type(size_t N) {
   using namespace mshadow;
   if (N % 8 == 0) {
diff --git a/src/common/cuda/utils.h b/src/common/cuda/utils.h
index fc4d40c..a203ba5 100644
--- a/src/common/cuda/utils.h
+++ b/src/common/cuda/utils.h
@@ -811,6 +811,14 @@ __device__ inline T warp_reduce(T value, OP redfun) {
   return value;
 }
 
+template <typename OP, typename T>
+__device__ inline T grouped_warp_allreduce(T value, OP redfun, const int 
group_size) {
+  for (int i = 1; i < group_size; i *= 2) {
+    value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
+  }
+  return __shfl_sync(0xffffffff, value, 0, group_size);
+}
+
 template <int NValues = warp_size, typename OP>
 __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t 
value, OP redfun) {
   float v = static_cast<float>(value);
diff --git a/src/common/utils.h b/src/common/utils.h
index dfd32ac..40376e9 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -977,6 +977,27 @@ inline void AlignedMemFree(void* ptr) {
 }
 
 
+inline index_t div_round(const index_t a, const index_t b) {
+  return (a + b - 1) / b;
+}
+
+inline bool IsPower2(size_t N) {
+  return ((N & (N - 1)) == 0) && N != 0;
+}
+
+inline size_t RoundToPower2(size_t N) {
+  size_t ret = 1;
+  size_t copyN = N;
+  while (N >= 2) {
+    ret *= 2;
+    N /= 2;
+  }
+  if (ret < copyN) {
+    ret *= 2;
+  }
+  return ret;
+}
+
 }  // namespace common
 }  // namespace mxnet
 #endif  // MXNET_COMMON_UTILS_H_
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 05c8b9a..2251ff8 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -31,6 +31,7 @@
 #include <mxnet/engine.h>
 #include <mxnet/op_attr_types.h>
 #include <algorithm>
+#include <limits>
 #include "./operator_tune.h"
 #include "../engine/openmp.h"
 
@@ -367,40 +368,30 @@ struct AccType<mshadow::half::half_t> {
     break;                                                 \
   case mshadow::kUint8:                                    \
     {                                                      \
-      typedef uint8_t DType;                               \
-      typedef uint8_t AType;                               \
       LOG(FATAL) << "This operation only support "         \
                     "floating point types not uint8";      \
     }                                                      \
     break;                                                 \
   case mshadow::kInt8:                                     \
     {                                                      \
-      typedef int8_t DType;                                \
-      typedef int8_t AType;                                \
       LOG(FATAL) << "This operation only support "         \
                     "floating point types not int8";       \
     }                                                      \
     break;                                                 \
   case mshadow::kInt32:                                    \
     {                                                      \
-      typedef int32_t DType;                               \
-      typedef int32_t AType;                               \
       LOG(FATAL) << "This operation only support "         \
                     "floating point types, not int32";     \
     }                                                      \
     break;                                                 \
   case mshadow::kInt64:                                    \
     {                                                      \
-      typedef int64_t DType;                               \
-      typedef int64_t AType;                               \
       LOG(FATAL) << "This operation only support "         \
                     "floating point types, not int64";     \
     }                                                      \
     break;                                                 \
   case mshadow::kBool:                                     \
     {                                                      \
-      typedef bool DType;                                  \
-      typedef int64_t AType;                               \
       LOG(FATAL) << "This operation only support "         \
                     "floating point types, not bool";      \
     }                                                      \
@@ -475,21 +466,18 @@ struct AccType<mshadow::half::half_t> {
   switch (type) {                                          \
   case mshadow::kFloat32:                                  \
     {                                                      \
-      typedef float DType;                                 \
       LOG(FATAL) << "This operation only support "         \
                     "integer types, not float32";          \
     }                                                      \
     break;                                                 \
   case mshadow::kFloat64:                                  \
     {                                                      \
-      typedef double DType;                                \
       LOG(FATAL) << "This operation only support "         \
                     "integer types, not float64";          \
     }                                                      \
     break;                                                 \
   case mshadow::kFloat16:                                  \
     {                                                      \
-      typedef mshadow::half::half_t DType;                 \
       LOG(FATAL) << "This operation only support "         \
                     "integer types, not float16";          \
     }                                                      \
@@ -532,21 +520,18 @@ struct AccType<mshadow::half::half_t> {
   switch (type) {                                          \
   case mshadow::kFloat32:                                  \
     {                                                      \
-      typedef float DType;                                 \
       LOG(FATAL) << "This operation only support "         \
                     "integer types, not float32";          \
     }                                                      \
     break;                                                 \
   case mshadow::kFloat64:                                  \
     {                                                      \
-      typedef double DType;                                \
       LOG(FATAL) << "This operation only support "         \
                     "integer types, not float64";          \
     }                                                      \
     break;                                                 \
   case mshadow::kFloat16:                                  \
     {                                                      \
-      typedef mshadow::half::half_t DType;                 \
       LOG(FATAL) << "This operation only support "         \
                     "integer types, not float16";          \
     }                                                      \
diff --git a/src/operator/nn/log_softmax.cu b/src/operator/nn/log_softmax.cu
index 396a4e8..485290d 100644
--- a/src/operator/nn/log_softmax.cu
+++ b/src/operator/nn/log_softmax.cu
@@ -29,11 +29,10 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(log_softmax)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu, 
mxnet_op::log_softmax_fwd>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCCompute{"log_softmax_fwd"});
 
 NNVM_REGISTER_OP(_backward_log_softmax)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, mshadow_op::left,
-                                                        
mxnet_op::log_softmax_bwd>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCGradCompute{"op::left", 
"log_softmax_bwd"});
 
 NNVM_REGISTER_OP(masked_log_softmax)
 .set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, 
mxnet_op::log_softmax_fwd,
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 3f037f9..59af426 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -34,7 +34,6 @@
 #include "../mxnet_op.h"
 #include "../operator_common.h"
 #include "../tensor/broadcast_reduce_op.h"
-#include "../../common/cuda/utils.h"
 
 using mshadow::red::limits::MinValue;
 
@@ -325,143 +324,8 @@ inline void MaskedSoftmaxGrad(Stream<cpu> *s, DType *out, 
DType *ograd,
 }
 
 #ifdef __CUDACC__
-template<int x_bits, typename OP, bool negate, typename AType, int ndim,
-         typename DType, typename OType, typename IType>
-__global__ void softmax_compute_kernel(DType *in, OType *out, IType *length,
-                                       index_t M, int axis, Shape<ndim> sshape,
-                                       Shape<ndim> stride, const double 
temperature) {
-  const unsigned x_size = 1 << x_bits;
-  __shared__ AType smem[x_size];
-  index_t sa = stride[axis];
-  index_t base = unravel_dot(blockIdx.x, sshape, stride);
-  index_t x = threadIdx.x;
-  const index_t len = length == nullptr ? M : 
static_cast<index_t>(length[blockIdx.x]);
-
-  red::maximum::SetInitValue(smem[x]);
-  for (index_t i = x; i < len; i += x_size) {
-    smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
-  }
-  __syncthreads();
-  cuda::Reduce1D<red::maximum, x_bits>(smem);
-  __syncthreads();
-  DType smax = smem[0];
-  __syncthreads();
-
-  red::sum::SetInitValue(smem[x]);
-  DType val;
-  for (index_t i = x; i < len; i += x_size) {
-    val = negate ? -in[base + i*sa]:in[base + i*sa];
-    smem[x] += static_cast<AType>(expf((val - smax) / 
static_cast<AType>(temperature)));
-  }
-  __syncthreads();
-  cuda::Reduce1D<red::sum, x_bits>(smem);
-  __syncthreads();
-  AType ssum = smem[0];
-  __syncthreads();
-
-  for (index_t i = x; i < M; i += x_size) {
-    val = negate ? -in[base + i*sa] : in[base + i*sa];
-    out[base + i*sa] =
-      (i < len) ? OType(OP::Map((val - smax)/static_cast<DType>(temperature), 
ssum)) : OType(0.0f);
-  }
-}
-
 const int softmax_threads_per_block = 512;
 
-template<typename OP, bool negate, typename AType, typename LType,
-  typename DType, typename OType, typename IType>
-__global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, 
IType *length,
-                                               const index_t M, const double 
temperature,
-                                               const int rows_per_block, const 
index_t total_rows) {
-  __shared__ AType scratch[softmax_threads_per_block];
-  __shared__ LType persistent_storage[20 * 1024 / sizeof(LType)];
-  const int warp_size = 32;
-  const int threads_per_row = softmax_threads_per_block / rows_per_block;
-  const int my_local_row = threadIdx.x / threads_per_row;
-  const int my_row = blockIdx.x * rows_per_block + my_local_row;
-  if (my_row >= total_rows) return;
-  const int my_id = threadIdx.x % threads_per_row;
-  const int entries_per_load = sizeof(LType)/sizeof(DType);
-  const index_t len = length == nullptr ? M : 
static_cast<index_t>(length[my_row]);
-  // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating
-  // kernels where sizeof(LType) may be less than sizeof(DType),
-  // resulting in entries_per_load being 0.
-  // This is not a valid combination and is being checked against
-  // in the launcher code. This switch here is just to silence
-  // the division by zero warning generated for such invalid cases.
-  const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
-
-  const LType* in_aligned = reinterpret_cast<const LType*>(in);
-  size_t base = my_row * row_length;
-
-  for (index_t i = my_id; i < row_length; i += threads_per_row) {
-    persistent_storage[my_local_row * row_length + i] = in_aligned[base + i];
-  }
-  DType * row = reinterpret_cast<DType *>(persistent_storage + my_local_row * 
row_length);
-  __syncthreads();
-
-  DType my_max_value;
-  red::maximum::SetInitValue(my_max_value);
-
-  for (index_t i = my_id; i < len; i += threads_per_row) {
-    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
-  }
-  scratch[threadIdx.x] = my_max_value;
-  __syncthreads();
-  for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
-    if (my_id < size) {
-      scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + 
size]);
-    }
-    __syncthreads();
-  }
-  if (my_id < warp_size) {
-    AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
-                                 [](AType x, AType y) { return ::max(x, y); });
-    scratch[threadIdx.x] = my_value;
-  }
-  __syncthreads();
-  DType smax = scratch[threadIdx.x - threadIdx.x % threads_per_row];
-  __syncthreads();
-
-  AType my_sum;
-  red::sum::SetInitValue(my_sum);
-
-  for (index_t i = my_id; i < len; i += threads_per_row) {
-    const DType val = negate ? -row[i] : row[i];
-    my_sum += static_cast<AType>(expf((val - smax) / 
static_cast<AType>(temperature)));
-  }
-  scratch[threadIdx.x] = my_sum;
-  __syncthreads();
-  for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
-    if (my_id < size) {
-      scratch[threadIdx.x] += scratch[threadIdx.x + size];
-    }
-    __syncthreads();
-  }
-  if (my_id < warp_size) {
-    AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
-                                 [](AType x, AType y) { return x + y;});
-    scratch[threadIdx.x] = my_value;
-  }
-  __syncthreads();
-
-  AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
-  __syncthreads();
-
-  for (index_t i = my_id; i < M; i += threads_per_row) {
-    const DType val = negate ? -row[i] : row[i];
-    row[i] = (i < len) ? DType(OP::Map((val - 
smax)/static_cast<DType>(temperature), ssum)) :
-                         DType(0.0f);
-  }
-  __syncthreads();
-
-  LType* out_aligned = reinterpret_cast<LType*>(out);
-
-  for (index_t i = my_id; i < row_length; i += threads_per_row) {
-    out_aligned[base + i] = persistent_storage[my_local_row * row_length + i];
-  }
-}
-
 template<int ndim>
 MSHADOW_XINLINE index_t get_mask_position(const index_t idx, const 
Shape<ndim>& data_shape,
   const Shape<ndim>& mask_shape, int axis, index_t* stride_axis) {
@@ -665,45 +529,6 @@ __global__ void masked_softmax_stride1_kernel(const DType 
*in, DType *out, bool
   }
 }
 
-template<typename OP, bool negate, typename AType, typename DType, typename 
OType,
-         typename IType, int ndim>
-inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
-                    Shape<ndim> shape, int axis, const double temperature) {
-  const int x_bits = 7;
-  const int x_size = 1 << x_bits;
-  index_t M = shape[axis];
-  if (M == 0 || shape.Size() == 0) return;
-  index_t N = shape.Size()/M;
-  Shape<ndim> stride = calc_stride(shape);
-  Shape<ndim> sshape = shape;
-  sshape[axis] = 1;
-
-  const size_t DSize = sizeof(DType);
-  // Using 20 kB of shared memory for persistent storage in the optimized case
-  const size_t max_opt_M = 20 * 1024 / DSize;
-  if (stride[axis] == 1 &&
-      static_cast<size_t>(M) <= max_opt_M &&
-      std::is_same<DType, OType>::value) {
-    int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType));
-    MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
-      int rows_per_block = mxnet::common::cuda::get_rows_per_block(M *
-                                                                   
sizeof(DType) / sizeof(LType),
-                                                                   
softmax_threads_per_block);
-      int nblocks = (N + rows_per_block - 1) / rows_per_block;
-      CHECK_LE(sizeof(DType), sizeof(LType));
-      softmax_stride1_compute_kernel<OP, negate, AType, LType>
-        <<<nblocks, softmax_threads_per_block, 0, 
mshadow::Stream<gpu>::GetStream(s)>>>(
-          in, out, length, M, temperature, rows_per_block, N);
-    });
-    MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_compute_kernel);
-  } else {
-    softmax_compute_kernel<x_bits, OP, negate, AType, ndim>
-      <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
-        in, out, length, M, axis, sshape, stride, temperature);
-    MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
-  }
-}
-
 template<typename OP, bool masked_neg_inf, bool negate,
          typename AType, typename DType, typename OType, int ndim>
 inline void MaskedSoftmax(Stream<gpu> *s, DType *in, OType *out, bool *mask,
@@ -785,120 +610,6 @@ inline void MaskedSoftmax(Stream<gpu> *s, DType *in, 
OType *out, bool *mask,
 }
 
 template<typename OP1, typename OP2, int Req, bool negate, typename AType, 
typename LType,
-  typename DType, typename OType, typename IType>
-__global__ void softmax_stride1_grad_kernel(const OType *out, const OType 
*ograd,
-                                            DType *igrad, const IType *length,
-                                            const index_t M,
-                                            const double temperature,
-                                            const int rows_per_block,
-                                            const index_t total_rows) {
-  __shared__ AType scratch[softmax_threads_per_block];
-  __shared__ LType persistent_storage[20 * 1024 / sizeof(LType)];
-  const int warp_size = 32;
-  const int threads_per_row = softmax_threads_per_block / rows_per_block;
-  const int my_local_row = threadIdx.x / threads_per_row;
-  const int my_row = blockIdx.x * rows_per_block + my_local_row;
-  if (my_row >= total_rows) return;
-  const int my_id = threadIdx.x % threads_per_row;
-  const int entries_per_load = sizeof(LType)/sizeof(DType);
-  const index_t len = length == nullptr ? M : 
static_cast<index_t>(length[my_row]);
-  // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating
-  // kernels where sizeof(LType) may be less than sizeof(DType),
-  // resulting in entries_per_load being 0.
-  // This is not a valid combination and is being checked against
-  // in the launcher code. This switch here is just to silence
-  // the division by zero warning generated for such invalid cases.
-  const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
-
-  const LType* out_aligned = reinterpret_cast<const LType*>(out);
-  const LType* ograd_aligned = reinterpret_cast<const LType*>(ograd);
-  size_t base = my_row * row_length;
-
-  for (index_t i = my_id; i < row_length; i += threads_per_row) {
-    persistent_storage[my_local_row * row_length * 2 + i] = out_aligned[base + 
i];
-    persistent_storage[my_local_row * row_length * 2 + row_length + i] = 
ograd_aligned[base + i];
-  }
-  DType * row = reinterpret_cast<DType *>(persistent_storage + my_local_row * 
row_length * 2);
-  __syncthreads();
-
-  AType my_sum_value;
-  red::sum::SetInitValue(my_sum_value);
-
-  for (index_t i = my_id; i < len; i += threads_per_row) {
-    my_sum_value += OP1::Map(row[i + M], row[i]);
-  }
-  scratch[threadIdx.x] = my_sum_value;
-  __syncthreads();
-  for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
-    if (my_id < size) {
-      scratch[threadIdx.x] = scratch[threadIdx.x] + scratch[threadIdx.x + 
size];
-    }
-    __syncthreads();
-  }
-  if (my_id < warp_size) {
-    AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
-                                 [](AType x, AType y) { return x + y; });
-    scratch[threadIdx.x] = my_value;
-  }
-  __syncthreads();
-  AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
-  __syncthreads();
-
-  for (index_t i = my_id; i < M; i += threads_per_row) {
-    const DType val =
-      negate ?
-      -OP2::Map(row[i + M], row[i], ssum) :
-      OP2::Map(row[i + M], row[i], ssum);
-    row[i] = (i < len) ? DType(val / static_cast<DType>(temperature)) :
-                         DType(0.0f);
-    if (Req == kAddTo) {
-      row[i] += igrad[my_row * M + i];
-    }
-  }
-  __syncthreads();
-
-  LType* igrad_aligned = reinterpret_cast<LType*>(igrad);
-
-  for (index_t i = my_id; i < row_length; i += threads_per_row) {
-    igrad_aligned[base + i] = persistent_storage[my_local_row * row_length * 2 
+ i];
-  }
-}
-
-template<int x_bits, typename OP1, typename OP2, int Req, bool negate, 
typename AType, int ndim,
-         typename DType, typename OType, typename IType>
-__global__ void softmax_grad_kernel(OType *out, OType *ograd, DType *igrad,
-                                    const IType *length, index_t M, int axis,
-                                    Shape<ndim> sshape, Shape<ndim> stride,
-                                    const double temperature) {
-  const unsigned x_size = 1 << x_bits;
-  __shared__ AType smem[x_size];
-  index_t sa = stride[axis];
-  index_t base = unravel_dot(blockIdx.x, sshape, stride);
-  index_t x = threadIdx.x;
-  index_t len = length != nullptr ? static_cast<index_t>(length[blockIdx.x]) : 
M;
-
-  red::sum::SetInitValue(smem[x]);
-  for (index_t i = x; i < len; i += x_size) {
-    smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]);
-  }
-  __syncthreads();
-  cuda::Reduce1D<red::sum, x_bits>(smem);
-  __syncthreads();
-  AType ssum = smem[0];
-  __syncthreads();
-
-  DType final_result;
-  for (index_t i = x; i < M; i += x_size) {
-    final_result =
-      negate ?
-      -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) :
-      OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum);
-    final_result = (i < len) ? final_result : DType(0.0f);
-    KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result / 
static_cast<DType>(temperature));
-  }
-}
-
-template<typename OP1, typename OP2, int Req, bool negate, typename AType, 
typename LType,
   typename LTypeMask, typename DType, typename OType, int ndim>
 __global__ void masked_softmax_stride1_grad_kernel(const OType *out, const 
OType *ograd,
                                                    DType *igrad, const bool 
*in_mask,
@@ -1042,48 +753,6 @@ __global__ void masked_softmax_grad_kernel(OType *out, 
OType *ograd, DType *igra
 }
 
 template<typename OP1, typename OP2, int Req, bool negate, typename AType, int 
ndim,
-         typename DType, typename OType, typename IType>
-inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
-                        DType *igrad, IType *length, Shape<ndim> shape, int 
axis,
-                        const double temperature) {
-  const int x_bits = 7;
-  const int x_size = 1 << x_bits;
-  index_t M = shape[axis];
-  if (M == 0 || shape.Size() == 0) return;
-  index_t N = shape.Size()/M;
-  Shape<ndim> stride = calc_stride(shape);
-  Shape<ndim> sshape = shape;
-  sshape[axis] = 1;
-
-  const size_t DSize = sizeof(DType);
-  // Using 20 kB of shared memory for persistent storage in the optimized case
-  // Need to store both out and ograd, so M can be only half compared to
-  // forward pass.
-  const size_t max_opt_M = 20 * 1024 / DSize / 2;
-  if (stride[axis] == 1 &&
-      static_cast<size_t>(M) <= max_opt_M &&
-      std::is_same<DType, OType>::value) {
-    int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType));
-    MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
-      int rows_per_block = mxnet::common::cuda::get_rows_per_block(M *
-                                                                   
sizeof(DType) / sizeof(LType),
-                                                                   
softmax_threads_per_block);
-      int nblocks = (N + rows_per_block - 1) / rows_per_block;
-      CHECK_LE(sizeof(DType), sizeof(LType));
-      softmax_stride1_grad_kernel<OP1, OP2, Req, negate, AType, LType>
-        <<<nblocks, softmax_threads_per_block, 0, 
mshadow::Stream<gpu>::GetStream(s)>>>(
-          out, ograd, igrad, length, M, temperature, rows_per_block, N);
-    });
-    MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_grad_kernel);
-  } else {
-    softmax_grad_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
-      <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
-        out, ograd, igrad, length, M, axis, sshape, stride, temperature);
-    MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_grad_kernel);
-  }
-}
-
-template<typename OP1, typename OP2, int Req, bool negate, typename AType, int 
ndim,
          typename DType, typename OType>
 inline void MaskedSoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
                               DType *igrad, bool *mask, Shape<ndim> data_shape,
@@ -1554,6 +1223,32 @@ void MaskedSoftmaxCompute(const nnvm::NodeAttrs& attrs,
   });
 }
 
+#if MXNET_USE_CUDA
+
+struct SoftmaxRTCCompute {
+  std::string OP;
+  bool negate = false;
+
+  void operator()(const nnvm::NodeAttrs& attrs,
+                  const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs);
+};
+
+struct SoftmaxRTCGradCompute {
+  std::string OP1;
+  std::string OP2;
+  bool negate = false;
+
+  void operator()(const nnvm::NodeAttrs& attrs,
+                  const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs);
+};
+
+#endif
 
 template<typename xpu, typename OP1, typename OP2, bool negate = false>
 void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/softmax.cu b/src/operator/nn/softmax.cu
index c75f543..c8a05b8 100644
--- a/src/operator/nn/softmax.cu
+++ b/src/operator/nn/softmax.cu
@@ -22,18 +22,791 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const 
InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : 
static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % 
param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / 
param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, 
y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), 
sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + 
i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t 
num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int threads_per_row = vectorized_kernel_thread_num / 
param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : 
LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * 
param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + 
loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * 
param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * 
param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + 
loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x 
+ size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { 
return op::max(x, y); },
+                                                    min(threads_per_row, 
warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return 
op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { 
return x + y;},
+                                                    min(threads_per_row, 
warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + 
y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - 
smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * 
param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + 
j],
+                                                   
OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * 
param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * 
param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + 
j],
+                                                   
OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+int get_rows_per_block(const index_t row_size, const int nvec,
+                       const index_t max_storage, const int 
num_threads_per_block,
+                       const index_t total_rows, const int dev_id) {
+  CHECK(common::IsPower2(num_threads_per_block))
+    << "Number of threads in a block must be power of 2 to use 
get_rows_per_block function";
+  // How many read instructions should 1 thread at least do
+  const int read_instructions = 16;
+  const size_t row_size_in_vec = (row_size + nvec - 1) / nvec;
+  int desired_num_threads_per_row = (row_size_in_vec + read_instructions - 1) 
/ read_instructions;
+  desired_num_threads_per_row = 
common::RoundToPower2(desired_num_threads_per_row);
+  desired_num_threads_per_row = std::min(desired_num_threads_per_row, 
num_threads_per_block);
+  const int desired_rows_per_block = num_threads_per_block / 
desired_num_threads_per_row;
+  int actual_rows_per_block = desired_rows_per_block;
+  int num_sms = MultiprocessorCount(dev_id);
+  while (actual_rows_per_block > 1 &&
+         ((max_storage != -1 && max_storage < row_size * 
actual_rows_per_block) ||
+          (total_rows + actual_rows_per_block - 1) / actual_rows_per_block < 
num_sms)) {
+    actual_rows_per_block /= 2;
+  }
+  return actual_rows_per_block;
+}
+
+}  // namespace
+
+void SoftmaxRTCCompute::operator()(const nnvm::NodeAttrs& attrs,
+                                   const OpContext& ctx,
+                                   const std::vector<TBlob>& inputs,
+                                   const std::vector<OpReqType>& req,
+                                   const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  using common::mshadow_type_info;
+  using namespace common::cuda::rtc;
+  using common::div_round;
+  if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+  int axis = CheckAxis(param.axis, inputs[0].ndim());
+  const double temperature = param.temperature.has_value() ?
+                             param.temperature.value() : 1.0;
+  mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
+
+  void* length_ptr = nullptr;
+  std::string length_typename = "int";
+  if (param.use_length.value()) {
+    CHECK(inputs.size() > 1)
+      << "Mask needs to be provided when using softmax with use_length=True.";
+    length_ptr = inputs[1].dptr_;
+    length_typename = mshadow_type_info(inputs[1].type_flag_).name;
+  }
+  CHECK_EQ(outputs.size(), 1);
+  index_t M = shape[axis];
+  if (M == 0 || shape.Size() == 0) return;
+  index_t stride = 1;
+  if (axis == shape.ndim() - 2) {
+    stride = shape[shape.ndim() - 1];
+  }
+  const index_t N = shape.Size() / M;
+  softmax_params params = {{inputs[0].dptr_, length_ptr, nullptr},
+                           {outputs[0].dptr_},
+                           stride, M,
+                           temperature, 1, N};
+  std::string code = "#define OP " + OP + "\n"
+                     "const OpReqType req = " + util::to_string(req[0]) + ";\n"
+                     "const bool negate = " + std::to_string(negate) + ";\n"
+                     "using InputType1 = " + length_typename + ";\n";
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+
+  constexpr int nvec = 2;
+  // Using 20 kB of shared memory for persistent storage in the optimized case
+  const size_t acc_type_size = 
std::max(mshadow_type_info(inputs[0].type_flag_).acc_size,
+                                        
mshadow_type_info(outputs[0].type_flag_).acc_size);
+  const size_t max_opt_M = 20 * 1024 / acc_type_size;
+  int rows_per_block = get_rows_per_block(M, nvec, max_opt_M,
+                                          vectorized_kernel_thread_num,
+                                          N, ctx.run_ctx.ctx.dev_id);
+  constexpr int warp_size = common::cuda::warp_size;
+  if (stride == 1 &&
+      static_cast<size_t>(M * rows_per_block) <= max_opt_M) {
+    code += "const bool only_full_blocks = " + std::to_string(N % 
rows_per_block == 0) + ";\n"
+            "const bool reduction_inside_warp = " +
+            std::to_string(vectorized_kernel_thread_num / rows_per_block <= 
warp_size) + ";\n";
+    params.rows_per_block = rows_per_block;
+    int nblocks = (N + rows_per_block - 1) / rows_per_block;
+    VectorizedKernelRTCLauncher(code + softmax_common_functions, 
"softmax_stride1_compute_kernel",
+                                softmax_stride1_kernel_fwd, nvec,
+                                M * rows_per_block, N / rows_per_block, s, 
params,
+                                inputs, outputs,
+                                ctx.run_ctx.ctx.dev_id, 0, nblocks);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_compute_kernel);
+  } else {
+    code += "using InputType0 = " + 
mshadow_type_info(inputs[0].type_flag_).name + ";\n"
+            "using OutputType0 = " + 
mshadow_type_info(outputs[0].type_flag_).name + ";\n";
+    std::vector<const void*> args;
+    args.emplace_back(&params);
+    args.emplace_back(&M);
+    int num_threads = std::min(static_cast<size_t>(128),
+                               common::RoundToPower2(div_round(M, warp_size) * 
warp_size));
+    if (stride != 1) {
+      const int num_sms = MultiprocessorCount(ctx.run_ctx.ctx.dev_id);
+      const index_t rows_per_sm = div_round(N, (512 / num_threads) * num_sms);
+      params.rows_per_block = std::min(static_cast<size_t>(warp_size),
+                                       common::RoundToPower2(rows_per_sm));
+    }
+    const auto& kernel_func = get_function(code + softmax_common_functions,
+                                           "simple_softmax_kernel",
+                                           simple_softmax_kernel_fwd,
+                                           ctx.run_ctx.ctx.dev_id);
+    launch(kernel_func, div_round(N, params.rows_per_block), num_threads, 0, 
s, &args);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(simple_softmax_kernel);
+  }
+}
+
+const char simple_softmax_kernel_bwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_grad_kernel(const softmax_params param,
+                                           const index_t lead_dim) {
+  using LengthType = AccType<InputType2>;
+  const InputType0* out = reinterpret_cast<const InputType0*>(param.inputs[0]);
+  const InputType1* ograd = reinterpret_cast<const 
InputType1*>(param.inputs[1]);
+  const InputType2* length = reinterpret_cast<const 
InputType2*>(param.inputs[2]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : 
static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % 
param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / 
param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType0 = AccType<InputType0>;
+  using IType1 = AccType<InputType1>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType0::type,
+                                      typename IType1::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto out_val = IType0::from(out[base + i * param.stride]);
+    auto ograd_val = IType1::from(ograd[base + i * param.stride]);
+    sum += OP1(ograd_val, out_val);
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = smem[threadIdx.x] + smem[threadIdx.x + size];
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y) { 
return x + y; },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* igrad = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto out_val = IType0::from(out[base + i * param.stride]);
+    auto ograd_val = IType1::from(ograd[base + i * param.stride]);
+    auto val = (i < len) ? OP2(ograd_val, out_val, sum) / 
static_cast<AType>(param.temperature) : 0;
+    val = negate ? -val : val;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        igrad[base + i * param.stride] = OType::to(val +
+                                                   OType::from(igrad[base + i 
* param.stride]));
+      }
+    } else {
+        igrad[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_bwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_grad_kernel(const softmax_params param,
+                                                    const index_t total_length,
+                                                    const index_t other_dim,
+                                                    const index_t N,
+                                                    const index_t 
num_aligned_elements) {
+  using namespace vector;
+  using IType0 = AccType<InputType0>;
+  using IType1 = AccType<InputType1>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType2>;
+  const InputType2* length = reinterpret_cast<const 
InputType2*>(param.inputs[2]);
+  using AType = type_util::mixed_type<typename IType0::type,
+                                      typename IType1::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType output_persistent_storage[10 * 1024 / sizeof(AType)];
+  __shared__ AType ograd_persistent_storage[10 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / 
param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : 
LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* output_row;
+  AType* ograd_row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> output_loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * 
param.num_elements,
+      total_length);
+    VectorizedLoader<InputType1, nvec, aligned> ograd_loader(
+      reinterpret_cast<const InputType1*>(param.inputs[1]) + base_row * 
param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      output_loader.load(i, total_length);
+      ograd_loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        output_persistent_storage[i*nvec + j] = 
IType0::from(output_loader.separate()[j]);
+        ograd_persistent_storage[i*nvec + j] = 
IType1::from(ograd_loader.separate()[j]);
+      }
+    }
+    output_row = output_persistent_storage +
+                 my_local_row * param.num_elements +
+                 output_loader.alignment();
+    ograd_row = ograd_persistent_storage +
+                my_local_row * param.num_elements +
+                ograd_loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * 
param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> output_loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * 
param.num_elements,
+      real_length);
+    VectorizedLoader<InputType1, nvec, false> ograd_loader(
+      reinterpret_cast<const InputType1*>(param.inputs[1]) + base_row * 
param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      output_loader.load(i, real_length);
+      ograd_loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        output_persistent_storage[i*nvec + j] = 
IType0::from(output_loader.separate()[j]);
+        ograd_persistent_storage[i*nvec + j] = 
IType1::from(ograd_loader.separate()[j]);
+      }
+    }
+    output_row = output_persistent_storage +
+                 my_local_row * param.num_elements +
+                 output_loader.alignment();
+    ograd_row = ograd_persistent_storage +
+                my_local_row * param.num_elements +
+                ograd_loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = OP1(ograd_row[i], output_row[i]);
+    my_sum += val;
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { 
return x + y;},
+                                                    min(threads_per_row, 
warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + 
y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    AType val = (i < len)
+                ? OP2(ograd_row[i], output_row[i], ssum) / 
static_cast<AType>(param.temperature)
+                : 0;
+    output_row[i] = negate ? -val : val;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * 
param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = 
OType::to(op::add(output_persistent_storage[i*nvec + j],
+                                                   
OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(output_persistent_storage[i*nvec + 
j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * 
param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * 
param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = 
OType::to(op::add(output_persistent_storage[i*nvec + j],
+                                                   
OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(output_persistent_storage[i*nvec + 
j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+void SoftmaxRTCGradCompute::operator()(const nnvm::NodeAttrs& attrs,
+                                       const OpContext& ctx,
+                                       const std::vector<TBlob>& inputs,
+                                       const std::vector<OpReqType>& req,
+                                       const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  using common::mshadow_type_info;
+  using namespace common::cuda::rtc;
+  using common::div_round;
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (softmax_use_length(attrs)) {
+    if (req[1] != kNullOp) {
+      cudaMemsetAsync(outputs[1].dptr_, 0,
+                      outputs[1].Size() * 
mshadow_type_info(outputs[1].type_flag_).size,
+                      Stream<gpu>::GetStream(s));
+    }
+  }
+  if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+  int axis = CheckAxis(param.axis, inputs[0].ndim());
+  const double temperature = param.temperature.has_value() ?
+                             param.temperature.value() : 1.0;
+  mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
+
+  int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1;
+  out_idx = softmax_use_length(attrs) ? 3 : out_idx;
+
+  void* length_ptr = nullptr;
+  std::string length_typename = "int";
+  if (softmax_use_length(attrs)) {
+    length_ptr = inputs[2].dptr_;
+    length_typename = mshadow_type_info(inputs[2].type_flag_).name;
+  }
+  index_t M = shape[axis];
+  if (M == 0 || shape.Size() == 0) return;
+  index_t stride = 1;
+  if (axis == shape.ndim() - 2) {
+    stride = shape[shape.ndim() - 1];
+  }
+  const index_t N = shape.Size() / M;
+  softmax_params params = {{inputs[out_idx].dptr_, inputs[0].dptr_, 
length_ptr},
+                           {outputs[0].dptr_},
+                           stride, M,
+                           temperature, 1, N};
+  std::string code = "#define OP1 " + OP1 + "\n"
+                     "#define OP2 " + OP2 + "\n"
+                     "const OpReqType req = " + util::to_string(req[0]) + ";\n"
+                     "const bool negate = " + std::to_string(negate) + ";\n"
+                     "using InputType2 = " + length_typename + ";\n";
+
+  constexpr int nvec = 2;
+  // Using 20 kB of shared memory for persistent storage in the optimized case
+  const size_t acc_type_size = 
std::max(mshadow_type_info(inputs[0].type_flag_).acc_size,
+                                        
mshadow_type_info(outputs[0].type_flag_).acc_size);
+  const size_t max_opt_M = 10 * 1024 / acc_type_size;
+  int rows_per_block = get_rows_per_block(M, nvec, max_opt_M,
+                                          vectorized_kernel_thread_num,
+                                          N, ctx.run_ctx.ctx.dev_id);
+  params.rows_per_block = rows_per_block;
+  bool debug_softmax = dmlc::GetEnv("DEBUG_SOFTMAX_GRAD", false);
+  if (!debug_softmax && stride == 1 &&
+      static_cast<size_t>(M * rows_per_block) <= max_opt_M) {
+    const int warp_size = 32;
+    code += "const bool only_full_blocks = " + std::to_string(N % 
rows_per_block == 0) + ";\n"
+            "const bool reduction_inside_warp = " +
+            std::to_string(vectorized_kernel_thread_num / rows_per_block <= 
warp_size) + ";\n";
+    int nblocks = div_round(N, rows_per_block);
+    std::vector<TBlob> new_inputs = {inputs[out_idx], inputs[0]};
+    if (softmax_use_length(attrs)) {
+      new_inputs.emplace_back(inputs[2]);
+    }
+    std::vector<TBlob> new_outputs = {outputs[0]};
+    VectorizedKernelRTCLauncher(code + softmax_common_functions,
+                                "softmax_stride1_compute_grad_kernel",
+                                softmax_stride1_kernel_bwd, nvec,
+                                M * rows_per_block, N / rows_per_block, s, 
params,
+                                new_inputs, new_outputs,
+                                ctx.run_ctx.ctx.dev_id, 0, nblocks);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_compute_grad_kernel);
+  } else {
+    code += "using InputType0 = " + 
mshadow_type_info(inputs[out_idx].type_flag_).name + ";\n"
+            "using InputType1 = " + 
mshadow_type_info(inputs[0].type_flag_).name + ";\n"
+            "using OutputType0 = " + 
mshadow_type_info(outputs[0].type_flag_).name + ";\n";
+    std::vector<const void*> args;
+    args.emplace_back(&params);
+    args.emplace_back(&M);
+    const int warp_size = 32;
+    int num_threads = std::min(static_cast<size_t>(128),
+                               common::RoundToPower2(div_round(M, warp_size) * 
warp_size));
+    if (stride != 1) {
+      const int num_sms = MultiprocessorCount(ctx.run_ctx.ctx.dev_id);
+      const index_t rows_per_sm = div_round(N, (512 / num_threads) * num_sms);
+      params.rows_per_block = std::min(static_cast<size_t>(warp_size),
+                                       common::RoundToPower2(rows_per_sm));
+    }
+    const auto& kernel_func = get_function(code + softmax_common_functions,
+                                           "simple_softmax_grad_kernel",
+                                           simple_softmax_kernel_bwd,
+                                           ctx.run_ctx.ctx.dev_id);
+    launch(kernel_func, div_round(N, params.rows_per_block), num_threads, 0, 
s, &args);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(simple_softmax_grad_kernel);
+  }
+}
+
 NNVM_REGISTER_OP(softmax)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu, 
mxnet_op::softmax_fwd>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCCompute{"softmax_fwd"});
 
 NNVM_REGISTER_OP(_backward_softmax)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, 
op::mshadow_op::mul,
-                                                        
mxnet_op::softmax_bwd>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCGradCompute{"op::mul", 
"softmax_bwd"});
+
 NNVM_REGISTER_OP(masked_softmax)
 .set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, 
mxnet_op::softmax_fwd,
                                                           false>);
diff --git a/src/operator/nn/softmin.cu b/src/operator/nn/softmin.cu
index d00d0bd..b6f56ce 100644
--- a/src/operator/nn/softmin.cu
+++ b/src/operator/nn/softmin.cu
@@ -29,11 +29,10 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(softmin)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu, 
mxnet_op::softmax_fwd, true>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCCompute{"softmax_fwd", true});
 
 NNVM_REGISTER_OP(_backward_softmin)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, 
op::mshadow_op::mul,
-                                                        mxnet_op::softmax_bwd, 
true>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCGradCompute{"op::mul", 
"softmax_bwd", true});
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.cc 
b/src/operator/tensor/elemwise_binary_scalar_op.cc
index f09bf21..d7fde47 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op.cc
+++ b/src/operator/tensor/elemwise_binary_scalar_op.cc
@@ -50,6 +50,7 @@ __global__ void binary_scalar_kernel(const 
binary_scalar_kernel_params params,
                                      const index_t N,
                                      const index_t num_aligned_elements) {
   using namespace vector;
+  using type_util::mixed_type;
   VectorizedLoader<InputType0, nvec, aligned> loader(
     reinterpret_cast<const InputType0*>(params.inputs[0]), N);
   VectorizedStorer<OutputType0, nvec, aligned> storer(
@@ -72,9 +73,8 @@ __global__ void binary_scalar_kernel(const 
binary_scalar_kernel_params params,
       const auto input = IType::from(loader.separate()[i]);
       // enables returning different type
       const auto temp = OP(input,
-                           static_cast<typename type_util::mixed_type<typename 
IType::type,
-                                                                      typename 
OType::type>::type>
-                             (params.scalar));
+                           static_cast<mixed_type<typename IType::type,
+                                                  typename 
OType::type>>(params.scalar));
 
       if (req == OpReqType::kAddTo) {
         // temp2 may have a wider type than either temp
@@ -171,6 +171,7 @@ __global__ void binary_scalar_kernel_bwd(const 
binary_scalar_kernel_params param
                                          const index_t N,
                                          const index_t num_aligned_elements) {
   using namespace vector;
+  using type_util::mixed_type;
   VectorizedLoader<InputType0, nvec, aligned> ograd_loader(
     reinterpret_cast<const InputType0*>(params.inputs[0]), N);
   VectorizedLoader<InputType1, nvec, aligned> input_loader(
@@ -199,9 +200,8 @@ __global__ void binary_scalar_kernel_bwd(const 
binary_scalar_kernel_params param
       // enables returning different type
       const auto temp = op::mul(ograd,
                                 OP(input,
-                                   static_cast<typename 
type_util::mixed_type<typename IType::type,
-                                                                              
typename OType::type>
-                                               ::type>(params.scalar)));
+                                   static_cast<mixed_type<typename IType::type,
+                                                          typename 
OType::type>>(params.scalar)));
 
       if (req == OpReqType::kAddTo) {
         // temp2 may have a wider type than either temp

Reply via email to