This is an automated email from the ASF dual-hosted git repository.
moisesher pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c692770 [PERF] Moving GPU softmax to RTC and optimizations (#19905)
c692770 is described below
commit c6927704d8ceb4597f33a45037277d462ee67045
Author: Przemyslaw Tredak <[email protected]>
AuthorDate: Mon Apr 26 09:16:09 2021 -0700
[PERF] Moving GPU softmax to RTC and optimizations (#19905)
* Moving softmax to RTC
* Fix from rebase
* Fixes from review
* Fix
* Fix FPE in the softmax grad.
---
src/common/cuda/rtc/backward_functions-inl.h | 161 +++--
src/common/cuda/rtc/forward_functions-inl.h | 126 ++--
src/common/cuda/rtc/reducer-inl.h | 399 ++++++++++++
src/common/cuda/rtc/util-inl.h | 113 +++-
src/common/cuda/rtc/vectorization-inl.h | 24 +-
src/common/cuda/utils.cc | 20 +-
src/common/cuda/utils.h | 8 +
src/common/utils.h | 21 +
src/operator/mxnet_op.h | 17 +-
src/operator/nn/log_softmax.cu | 5 +-
src/operator/nn/softmax-inl.h | 357 +----------
src/operator/nn/softmax.cu | 781 ++++++++++++++++++++++-
src/operator/nn/softmin.cu | 5 +-
src/operator/tensor/elemwise_binary_scalar_op.cc | 12 +-
14 files changed, 1484 insertions(+), 565 deletions(-)
diff --git a/src/common/cuda/rtc/backward_functions-inl.h
b/src/common/cuda/rtc/backward_functions-inl.h
index 168dc68..50f0c67 100644
--- a/src/common/cuda/rtc/backward_functions-inl.h
+++ b/src/common/cuda/rtc/backward_functions-inl.h
@@ -32,217 +32,217 @@ const char backward_function_definitions[] = R"code(
namespace op {
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_relu(const DTypeGrad grad, const DType val) {
if (isnan(val)) return val;
return val > 0 ? grad : 0;
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_sigmoid(const DTypeGrad grad, const DType out) {
return grad * out * (1 - out);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_softrelu(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
return grad * sigmoid(v);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_softsign(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
const auto ap1 = 1 + op::abs(v);
return grad / (ap1 * ap1);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_abs(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
return grad * op::sign(v);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_exp(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
return grad * op::exp(v);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_expm1(const DTypeGrad grad, const DType val) {
return backward_exp(grad, val);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_log(const DTypeGrad grad, const DType val) {
return grad / val;
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_log10(const DTypeGrad grad, const DType val) {
return grad / (val * op::log(static_cast<DTypeGrad>(10)));
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_log2(const DTypeGrad grad, const DType val) {
return grad / (val * op::log(static_cast<DTypeGrad>(2)));
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_log1p(const DTypeGrad grad, const DType val) {
return grad / (1 + val);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_sin(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
return grad * op::cos(v);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_cos(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
return -grad * op::sin(v);
}
// Uses output from tan
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_tan(const DTypeGrad grad, const DType out) {
return grad * (out * out + 1);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_arcsin(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
return grad / op::sqrt(1 - v*v);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_arccos(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
return -grad / op::sqrt(1 - v*v);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_arctan(const DTypeGrad grad, const DType val) {
return grad / (1 + val*val);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_degrees(const DTypeGrad grad, const DType /* val */) {
return op::degrees(grad);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_radians(const DTypeGrad grad, const DType /* val */) {
return op::radians(grad);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_sinh(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
return grad * op::cosh(v);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_cosh(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
return grad * op::sinh(v);
}
// Uses tanh output
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_tanh(const DTypeGrad grad, const DType out) {
return grad * (1 - out * out);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_arcsinh(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
return grad / op::sqrt(v * v + 1);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_arccosh(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
return grad / op::sqrt(v * v - 1);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_arctanh(const DTypeGrad grad, const DType val) {
return grad / (1 - val * val);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_sqrt(const DTypeGrad grad, const DType out) {
return 0.5 * grad / out;
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_rsqrt(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
const auto inv = 1 / v;
return -0.5 * grad * op::sqrt(inv) * inv;
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_cbrt(const DTypeGrad grad, const DType out) {
return grad / (3.0f * out * out);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_rcbrt(const DTypeGrad grad, const DType val) {
- const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
const auto inv = 1 / v;
return -1.f/3.f * grad * op::cbrt(inv) * inv;
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_square(const DTypeGrad grad, const DType val) {
return 2 * val * grad;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
rdiv_grad(const DType val,
const DType2 val2) {
return -val2 / (val * val);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
div_grad(const DType val,
const DType2 val2) {
- const typename type_util::mixed_type<DType, DType2>::type temp = val2;
+ const mixed_type<DType, DType2> temp = val2;
return op::reciprocal(temp);
}
@@ -283,87 +283,87 @@ __device__ inline DType rmod_grad(const DType val,
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
power_grad(const DType val,
const DType2 val2) {
return op::power(val, val2 - 1.f) * val2;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
power_rgrad(const DType val,
const DType2 val2) {
- const typename type_util::mixed_type<DType, DType2>::type temp = val;
+ const mixed_type<DType, DType2> temp = val;
return op::power(val, val2) * op::log(temp);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
rpower_grad(const DType val,
const DType2 val2) {
- const typename type_util::mixed_type<DType, DType2>::type temp = val2;
+ const mixed_type<DType, DType2> temp = val2;
return val * op::log(temp);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
hypot_grad_left(const DType val,
const DType2 val2) {
return val / op::hypot(val, val2);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
hypot_grad_right(const DType val,
const DType2 val2) {
return val2 / op::hypot(val, val2);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
copysign_grad(const DType val,
const DType2 val2) {
return (val >= 0 && val2 >= 0) || (val < 0 && val2 < 0) ? 1 : -1;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
arctan2_grad(const DType val,
const DType2 val2) {
return val2 / (val * val + val2 * val2);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
rarctan2_grad(const DType val,
const DType2 val2) {
return val / (val * val + val2 * val2);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
arctan2_rgrad(const DType val,
const DType2 val2) {
return -rarctan2_grad(val, val2);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
ldexp_grad(const DType val,
const DType2 val2) {
return op::power(static_cast<DType>(2), val2);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
rldexp_grad(const DType val,
const DType2 val2) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- return val2 * op::power(static_cast<mixed_type>(2), val) *
op::log(static_cast<mixed_type>(2));
+ using type = mixed_type<DType, DType2>;
+ return val2 * op::power(static_cast<type>(2), val) *
op::log(static_cast<type>(2));
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_clip(const DTypeGrad grad, const DType val,
const float a_min, const float a_max) {
if (val > a_max || val < a_min) {
@@ -374,35 +374,32 @@ backward_clip(const DTypeGrad grad, const DType val,
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_reciprocal(const DTypeGrad grad, const DType val) {
return -grad / (val * val);
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_erf(const DTypeGrad grad, const DType val) {
- using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
- const mixed_type v = val;
- constexpr mixed_type my_pi = pi;
+ const mixed_type<DTypeGrad, DType> v = val;
+ constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_erfinv(const DTypeGrad grad, const DType val) {
- using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
- constexpr mixed_type my_pi = pi;
- const mixed_type g = grad;
- const mixed_type v = val;
+ constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
+ const mixed_type<DTypeGrad, DType> g = grad;
+ const mixed_type<DTypeGrad, DType> v = val;
return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_gamma(const DTypeGrad grad, const DType val) {
- using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
- const mixed_type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
} else {
@@ -411,10 +408,9 @@ backward_gamma(const DTypeGrad grad, const DType val) {
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_gammaln(const DTypeGrad grad, const DType val) {
- using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
- const mixed_type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::special_functions::cephes::psi<double>(v);
} else {
@@ -423,10 +419,9 @@ backward_gammaln(const DTypeGrad grad, const DType val) {
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_digamma(const DTypeGrad grad, const DType val) {
- using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
- const mixed_type v = val;
+ const mixed_type<DTypeGrad, DType> v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::special_functions::trigamma<double>(v);
} else {
@@ -435,7 +430,7 @@ backward_digamma(const DTypeGrad grad, const DType val) {
}
template <typename DType, typename DTypeGrad>
-__device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
+__device__ inline mixed_type<DTypeGrad, DType>
backward_gelu(const DTypeGrad grad, const DType val) {
return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
val * backward_erf(grad, val / op::sqrt(2.0f)) /
op::sqrt(2.0f));
diff --git a/src/common/cuda/rtc/forward_functions-inl.h
b/src/common/cuda/rtc/forward_functions-inl.h
index f85916f..f4d08e6 100644
--- a/src/common/cuda/rtc/forward_functions-inl.h
+++ b/src/common/cuda/rtc/forward_functions-inl.h
@@ -32,6 +32,7 @@ const char function_definitions_util[] = R"code(
#define INT_MAX (2147483647)
namespace op {
+using type_util::mixed_type;
template <typename DType>
struct LoadType {
@@ -241,44 +242,44 @@ __device__ inline bool_t isfinite(const DType val) {
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
add(const DType a, const DType2 b) {
return a + b;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
sub(const DType a, const DType2 b) {
return a - b;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
rsub(const DType a, const DType2 b) {
return b - a;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
mul(const DType a, const DType2 b) {
return a * b;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
div(const DType a, const DType2 b) {
return a / b;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
rdiv(const DType a, const DType2 b) {
return b / a;
}
#define DEFINE_BINARY_MATH_FUNC(name, double_version, float_version) \
template <typename DType, typename DType2> \
-__device__ inline typename type_util::mixed_type<DType, DType2>::type \
+__device__ inline mixed_type<DType, DType2> \
name (const DType a, const DType2 b) { \
if (type_util::has_double_or_integral<DType, DType2>::value) { \
return double_version ((double)a, (double)b); \
@@ -288,7 +289,7 @@ name (const DType a, const DType2 b) { \
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
power (const DType a, const DType2 b) {
if (type_util::has_double<DType, DType2>::value) {
return ::pow ((double)a, (double)b); \
@@ -298,34 +299,34 @@ power (const DType a, const DType2 b) {
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
rpow(const DType a, const DType2 b) {
return power(b, a);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
max(const DType a, const DType2 b) {
if (isnan(a)) return a;
return a > b ? a : b;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
fmax(const DType a, const DType2 b) {
if (isnan(b)) return a;
return a > b ? a : b;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
min(const DType a, const DType2 b) {
if (isnan(a)) return a;
return a < b ? a : b;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
fmin(const DType a, const DType2 b) {
if (isnan(b)) return a;
return a < b ? a : b;
@@ -334,7 +335,7 @@ fmin(const DType a, const DType2 b) {
DEFINE_BINARY_MATH_FUNC(hypot, ::hypot, ::hypotf)
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
mod(const DType a, const DType2 b) {
if (b == 0) {
return 0;
@@ -359,7 +360,7 @@ mod(const DType a, const DType2 b) {
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
fmod(const DType a, const DType2 b) {
if (b == 0) {
return 0;
@@ -368,110 +369,98 @@ fmod(const DType a, const DType2 b) {
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
rmod(const DType a, const DType2 b) {
return op::mod(b, a);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
rfmod(const DType a, const DType2 b) {
return op::fmod(b, a);
}
template <typename DType, typename DType2>
__device__ inline DType equal(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a == real_b ? 1 : 0;
}
template <typename DType, typename DType2>
__device__ inline DType not_equal(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a != real_b ? 1 : 0;
}
template <typename DType, typename DType2>
__device__ inline DType greater(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a > real_b ? 1 : 0;
}
template <typename DType, typename DType2>
__device__ inline DType greater_equal(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a >= real_b ? 1 : 0;
}
template <typename DType, typename DType2>
__device__ inline DType less(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a < real_b ? 1 : 0;
}
template <typename DType, typename DType2>
__device__ inline DType less_equal(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a <= real_b ? 1 : 0;
}
template <typename DType, typename DType2>
__device__ inline bool_t np_equal(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a == real_b ? true : false;
}
template <typename DType, typename DType2>
__device__ inline bool_t np_not_equal(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a != real_b ? true : false;
}
template <typename DType, typename DType2>
__device__ inline bool_t np_greater(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a > real_b ? true : false;
}
template <typename DType, typename DType2>
__device__ inline bool_t np_greater_equal(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a >= real_b ? true : false;
}
template <typename DType, typename DType2>
__device__ inline bool_t np_less(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a < real_b ? true : false;
}
template <typename DType, typename DType2>
__device__ inline bool_t np_less_equal(const DType a, const DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a <= real_b ? true : false;
}
@@ -501,7 +490,7 @@ __device__ inline DType2 rcopysign(const DType a, const
DType2 b) {
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
lcm(const DType a, const DType2 b) {
if (type_util::is_integral<DType>::value &&
type_util::is_integral<DType2>::value) {
@@ -542,7 +531,7 @@ lcm(const DType a, const DType2 b) {
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
gcd(const DType a, const DType2 b) {
if (type_util::is_integral<DType>::value &&
type_util::is_integral<DType2>::value) {
@@ -585,42 +574,39 @@ gcd(const DType a, const DType2 b) {
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
bitwise_xor(const DType a,
+__device__ inline mixed_type<DType, DType2> bitwise_xor(const DType a,
const
DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a ^ real_b;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
bitwise_or(const DType a,
+__device__ inline mixed_type<DType, DType2> bitwise_or(const DType a,
const
DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a | real_b;
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
bitwise_and(const DType a,
+__device__ inline mixed_type<DType, DType2> bitwise_and(const DType a,
const
DType2 b) {
- using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
- const mixed_type real_a = a;
- const mixed_type real_b = b;
+ const mixed_type<DType, DType2> real_a = a;
+ const mixed_type<DType, DType2> real_b = b;
return real_a & real_b;
}
DEFINE_BINARY_MATH_FUNC(arctan2, ::atan2, ::atan2f)
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
rarctan2(const DType a, const DType2 b) {
return arctan2(b, a);
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
ldexp(const DType a, const DType2 b) {
if (type_util::has_double_or_integral<DType, DType2>::value) {
return a * ::pow(2.0, static_cast<double>(b));
@@ -630,7 +616,7 @@ ldexp(const DType a, const DType2 b) {
}
template <typename DType, typename DType2>
-__device__ inline typename type_util::mixed_type<DType, DType2>::type
+__device__ inline mixed_type<DType, DType2>
rldexp(const DType a, const DType2 b) {
return ldexp(b, a);
}
diff --git a/src/common/cuda/rtc/reducer-inl.h
b/src/common/cuda/rtc/reducer-inl.h
index 93b7027..259d0e0 100644
--- a/src/common/cuda/rtc/reducer-inl.h
+++ b/src/common/cuda/rtc/reducer-inl.h
@@ -94,6 +94,405 @@ struct sum {
residual = 0;
}
};
+
+/*! \brief maximum reducer */
+struct maximum {
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2
src) { // NOLINT(*)
+ if (!util::isnan(dst)) {
+ if (!(dst >= src)) dst = src;
+ }
+ }
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2
src,
+ volatile DType& none) {
+ Reduce(dst, src);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
src_val) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
dst_residual,
+ volatile DType& src_val, volatile DType&
src_residual) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst) {}
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst, volatile DType&
none) {}
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv) {
+ initv = -2*DBL_MAX;
+ }
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv, DType &none) {
+ SetInitValue(initv);
+ }
+};
+
+/*! \brief minimum reducer */
+struct minimum {
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2
src) {
+ if (!util::isnan(dst)) {
+ if (!(dst <= src)) dst = src;
+ }
+ }
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2
src,
+ volatile DType& none) {
+ Reduce(dst, src);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
src_val) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
dst_residual,
+ volatile DType& src_val, volatile DType&
src_residual) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst) {}
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst, volatile DType&
none) {}
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv) {
+ initv = 2*DBL_MAX;
+ }
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv, DType &none) {
+ SetInitValue(initv);
+ }
+};
+
+/*! \brief product reducer */
+struct product {
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2
src) {
+ dst = op::mul(dst, src);
+ }
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2
src,
+ volatile DType& none) {
+ Reduce(dst, src);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
src_val) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
dst_residual,
+ volatile DType& src_val, volatile DType&
src_residual) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst) {}
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst, volatile DType&
none) {}
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv) {
+ initv = 1;
+ }
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv, DType &none) {
+ SetInitValue(initv);
+ }
+};
+
+/*! \brief sum reducer that ignores NaN values in the input */
+struct nansum {
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2
src) {
+ if (util::isnan(src)) return;
+ dst = op::add(dst, src);
+ }
+ /*! \brief do reduction into dst */
+ template<typename DType>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType src,
+ volatile DType& residual) {
+ if (util::isnan(src)) return;
+ DType y = src - residual;
+ DType t = dst + y;
+ residual = (t - dst) - y;
+ dst = t;
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
src_val) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
dst_residual,
+ volatile DType& src_val, volatile DType&
src_residual) {
+ DType t1 = dst_val + src_val;
+ DType e = t1 - src_val;
+ DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual +
src_residual;
+ dst_val = t1 + t2;
+ dst_residual = t2 - (dst_val - t1);
+ }
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst) {}
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst, volatile DType&
none) {}
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType & initv) {
+ initv = 0;
+ }
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv, DType &residual) {
+ SetInitValue(initv);
+ residual = 0;
+ }
+};
+
+/*! \brief product reducer that ignores NaN values in the input */
+struct nanprod {
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2
src) {
+ if (util::isnan(src)) return;
+ dst = op::mul(dst, src);
+ }
+ /*! \brief do reduction into dst */
+ template<typename DType>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType src,
+ volatile DType& none) {
+ Reduce(dst, src);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
src_val) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
dst_residual,
+ volatile DType& src_val, volatile DType&
src_residual) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief finalize reduction */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst) {}
+ /*! \brief finalize reduction */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst, volatile DType&
none) {}
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType & initv) {
+ initv = 1;
+ }
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv, DType &none) {
+ SetInitValue(initv);
+ }
+};
+
+struct nrm2 {
+ /*! \brief do reduction into dst */
+ template<typename AType, typename DType>
+ __device__ inline static void Reduce(volatile AType& sum_of_squares,
volatile DType src) {
+ sum_of_squares = op::add(sum_of_square, src * src);
+ }
+ /*! \brief do stable reduction into dst */
+ template<typename AType, typename DType>
+ __device__ inline static void Reduce(volatile AType& sum_of_squares,
+ volatile DType src, volatile DType&
scale) {
+ if (src != 0) {
+ DType abs = op::abs(src);
+ if (scale < abs) {
+ sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs);
+ scale = abs;
+ } else {
+ sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale);
+ }
+ }
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
src_val) {
+ dst_val = op::add(dst_val, src_val);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_ssq, volatile DType&
dst_scale,
+ volatile DType& src_ssq, volatile DType&
src_scale) {
+ if (dst_scale != 0 && dst_scale >= src_scale) {
+ dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale /
dst_scale);
+ } else if (src_scale != 0 && dst_scale < src_scale) {
+ dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale /
src_scale);
+ dst_scale = src_scale;
+ }
+ }
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& sum_of_squares) {
+ sum_of_squares = op::sqrt(sum_of_squares);
+ }
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& sum_of_squares,
volatile DType& scale) {
+ sum_of_squares = scale * op::sqrt(sum_of_squares);
+ }
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &sum_of_squares) {
+ sum_of_squares = 0;
+ }
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &sum_of_squares, DType
&scale) {
+ SetInitValue(sum_of_squares);
+ scale = 0;
+ }
+};
+
+struct nrmlp {
+ double lp;
+ /* \brief power for Lp norm */
+ __device__ inline static double lp_power(volatile double src, volatile
double p) {
+ if (p != 0.0) {
+ if (src == 0.0) {
+ return src;
+ } else {
+ return op::power(src, p);
+ }
+ } else { // 0-norm, sparsity
+ return static_cast<double>(src != 0);
+ }
+ }
+
+ /*! \brief do reduction into dst */
+ template<typename AType, typename DType>
+ __device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType
src) {
+ if (src != 0) {
+ sum_of_powers += AType(lp_power(static_cast<double>(src), lp));
+ }
+ }
+
+ /*! \brief do stable reduction into dst */
+ template<typename AType, typename DType>
+ __device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType
src,
+ volatile DType& scale) {
+ if (src != 0) {
+ DType src_abs = op::abs(src);
+ if (scale < src_abs) {
+ sum_of_powers = sum_of_powers *
AType(lp_power(static_cast<double>(scale / src_abs), lp));
+ sum_of_powers = sum_of_powers + 1;
+ scale = src_abs;
+ } else {
+ sum_of_powers = sum_of_powers +
AType(lp_power(static_cast<double>(src_abs / scale), lp));
+ }
+ }
+ }
+
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType&
src_val) {
+ dst_val = dst_val + src_val;
+ }
+
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_ssq, volatile DType&
dst_scale,
+ volatile DType& src_ssq, volatile DType&
src_scale) {
+ if (dst_scale != 0 && dst_scale >= src_scale) {
+ dst_ssq = dst_ssq + src_ssq *
DType(lp_power(static_cast<double>(src_scale / dst_scale), 2));
+ } else if (src_scale != 0 && dst_scale < src_scale) {
+ dst_ssq = src_ssq + dst_ssq *
DType(lp_power(static_cast<double>(dst_scale / src_scale), 2));
+ dst_scale = src_scale;
+ }
+ }
+
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline void Finalize(volatile DType& sum_of_powers) {
+ if (lp != 0.0) {
+ sum_of_powers = DType(lp_power(static_cast<double>(sum_of_powers), 1.0 /
lp));
+ }
+ }
+
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline void Finalize(volatile DType& sum_of_powers, volatile
DType& scale) {
+ if (lp != 0.0) {
+ sum_of_powers = scale *
DType(lp_power(static_cast<double>(sum_of_powers), 1.0 / lp));
+ }
+ }
+
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &sum_of_powers) {
+ sum_of_powers = 0;
+ }
+
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &sum_of_powers, DType
&scale) {
+ SetInitValue(sum_of_powers);
+ scale = 0;
+ }
+};
} // namespace red
)code";
diff --git a/src/common/cuda/rtc/util-inl.h b/src/common/cuda/rtc/util-inl.h
index 372390f..b426603 100644
--- a/src/common/cuda/rtc/util-inl.h
+++ b/src/common/cuda/rtc/util-inl.h
@@ -174,74 +174,97 @@ struct enable_if<true> {
};
template <typename T, typename U, class Enable = void>
-struct mixed_type;
+struct mixed_type_helper;
template <typename T>
-struct mixed_type<T, float64, typename enable_if<!is_same<float64,
T>::value>::type> {
+struct mixed_type_helper<T, float64, typename enable_if<!is_same<float64,
T>::value>::type> {
using type = float64;
};
template <typename T>
-struct mixed_type<float64, T> {
+struct mixed_type_helper<float64, T> {
using type = float64;
};
template <typename T>
-struct mixed_type<T, float32, typename enable_if<!is_same<float64, T>::value &&
- !is_same<float32,
T>::value>::type> {
+struct mixed_type_helper<T, float32, typename enable_if<!is_same<float64,
T>::value &&
+ !is_same<float32,
T>::value>::type> {
using type = float32;
};
template <typename T>
-struct mixed_type<float32, T, typename enable_if<!is_same<float64,
T>::value>::type> {
+struct mixed_type_helper<float32, T, typename enable_if<!is_same<float64,
T>::value>::type> {
using type = float32;
};
template <typename T>
-struct mixed_type<T, float16, typename enable_if<is_same<float16, T>::value ||
- is_integral<T>::value>::type>
{
+struct mixed_type_helper<T, float16, typename enable_if<is_same<float16,
T>::value ||
+
is_integral<T>::value>::type> {
using type = float16;
};
template <typename T>
-struct mixed_type<float16, T, typename enable_if<is_integral<T>::value>::type>
{
+struct mixed_type_helper<float16, T, typename
enable_if<is_integral<T>::value>::type> {
using type = float16;
};
template <typename T, typename U>
-struct mixed_type<T, U, typename enable_if<is_integral<T>::value &&
- is_integral<U>::value &&
- !is_same<U, bool_t>::value &&
- sizeof(T) <= sizeof(U)>::type> {
+struct mixed_type_helper<T, U, typename enable_if<is_integral<T>::value &&
+ is_integral<U>::value &&
+ !is_same<U, bool_t>::value &&
+ sizeof(T) <=
sizeof(U)>::type> {
using type = U;
};
template <typename T, typename U>
-struct mixed_type<U, T, typename enable_if<is_integral<T>::value &&
- is_integral<U>::value &&
- !is_same<U, bool_t>::value &&
- sizeof(T) < sizeof(U)>::type> {
+struct mixed_type_helper<U, T, typename enable_if<is_integral<T>::value &&
+ is_integral<U>::value &&
+ !is_same<U, bool_t>::value &&
+ sizeof(T) <
sizeof(U)>::type> {
using type = U;
};
template <typename T>
-struct mixed_type<T, bool_t, typename enable_if<is_integral<T>::value &&
- sizeof(T) <
sizeof(bool_t)>::type> {
+struct mixed_type_helper<T, bool_t, typename enable_if<is_integral<T>::value &&
+ sizeof(T) <
sizeof(bool_t)>::type> {
using type = index_t;
};
template <typename T>
-struct mixed_type<bool_t, T, typename enable_if<is_integral<T>::value &&
- sizeof(T) <
sizeof(bool_t)>::type> {
+struct mixed_type_helper<bool_t, T, typename enable_if<is_integral<T>::value &&
+ sizeof(T) <
sizeof(bool_t)>::type> {
using type = index_t;
};
template <typename T>
-struct mixed_type<T, bool_t, typename enable_if<is_integral<T>::value &&
- sizeof(T) ==
sizeof(bool_t)>::type> {
+struct mixed_type_helper<T, bool_t, typename enable_if<is_integral<T>::value &&
+ sizeof(T) ==
sizeof(bool_t)>::type> {
using type = T;
};
+template <typename... Ts>
+struct multi_mixed_type_helper;
+
+template <>
+struct multi_mixed_type_helper<> {
+ using type = void;
+};
+
+template <typename T>
+struct multi_mixed_type_helper<T> {
+ using type = T;
+};
+
+template <typename T, typename U, typename... Ts>
+struct multi_mixed_type_helper<T, U, Ts...> {
+ using type = typename mixed_type_helper<T,
+ typename multi_mixed_type_helper<U,
+
Ts...>::type>::type;
+};
+
+template <typename... Ts>
+using mixed_type = typename multi_mixed_type_helper<Ts...>::type;
+
} // namespace type_util
)code";
@@ -254,6 +277,7 @@ enum class OpReqType {
};
constexpr int kRTCMaxThreadsPerBlock = 512;
+constexpr int warp_size = 32;
namespace util {
@@ -377,6 +401,49 @@ __device__ inline bool isnan(volatile const float16 &val) {
return ::isnan(__half2float(const_cast<const float16&>(val)));
}
+template <int NVALUES = warp_size, typename OP, typename T>
+__device__ inline T warp_reduce(T value, OP redfun) {
+#pragma unroll
+ for (int i = warp_size / 2; i >= 1; i /= 2) {
+ if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value,
i));
+ }
+ return value;
+}
+
+template <typename OP, typename T>
+__device__ inline T grouped_warp_reduce(T value, OP redfun, const int
group_size) {
+ for (int i = 1; i < group_size; i *= 2) {
+ value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
+ }
+ return value;
+}
+
+template <typename OP, typename T>
+__device__ inline T grouped_warp_allreduce(T value, OP redfun, const int
group_size) {
+ value = grouped_warp_reduce(value, redfun, group_size);
+ return __shfl_sync(0xffffffff, value, 0, group_size);
+}
+
+template <typename OP, typename T>
+__device__ inline T strided_grouped_warp_reduce(T value, OP redfun, const int
group_size) {
+ for (int i = warp_size / 2; i >= group_size; i /= 2) {
+ value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
+ }
+ return value;
+}
+
+template <typename OP, typename T>
+__device__ inline T strided_grouped_warp_allreduce(T value, OP redfun, const
int group_size) {
+ value = strided_grouped_warp_reduce(value, redfun, group_size);
+ for (int i = group_size; i < warp_size; i *= 2) {
+ T tmp = __shfl_up_sync(0xffffffff, value, i);
+ if (threadIdx.x % warp_size >= i) {
+ value = tmp;
+ }
+ }
+ return value;
+}
+
} // namespace util
)code";
} // namespace rtc
diff --git a/src/common/cuda/rtc/vectorization-inl.h
b/src/common/cuda/rtc/vectorization-inl.h
index 5cbc459..96205fc 100644
--- a/src/common/cuda/rtc/vectorization-inl.h
+++ b/src/common/cuda/rtc/vectorization-inl.h
@@ -41,6 +41,8 @@ const char vectorization_support_string[] = R"code(
namespace vector {
+constexpr int vectorized_kernel_thread_num = 512;
+
template <int size>
struct VectorType {
static_assert(size <= 32, "VectorType needs to have size of at most 32B");
@@ -166,7 +168,7 @@ class VectorizedAccessor {
if (aligned) {
alignment_ = 0;
aligned_ptr_ = reinterpret_cast<LType*>(ptr);
- n_elems_ = (size + nvec- 1) / nvec;
+ n_elems_ = (size + nvec - 1) / nvec;
} else {
size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType);
@@ -360,6 +362,8 @@ constexpr int vectorized_kernel_thread_num = 512;
* \param lead_input_num number of input to use for checking alignment
* (in case only a subset of inputs is used vectorized).
* Default is 0.
+ * \param blocks if provided and not 0, will launch the specified number of
thread blocks.
+ * Default is 0.
*/
template <typename Params>
void VectorizedKernelRTCLauncher(const std::string ¶meters,
@@ -373,7 +377,8 @@ void VectorizedKernelRTCLauncher(const std::string
¶meters,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const int dev_id,
- const int lead_input_num = 0) {
+ const int lead_input_num = 0,
+ const index_t blocks = 0) {
const index_t N = lead_dim * other_dim;
nvec = std::min(nvec, 4); // Use at most 4-wide vectors
if (N != 0) {
@@ -435,11 +440,16 @@ void VectorizedKernelRTCLauncher(const std::string
¶meters,
lead_dim, nvec,
common::mshadow_type_info(
inputs[lead_input_num].type_flag_).size);
- size_t num_elements = other_dim * num_aligned_elements;
constexpr int threads = vectorized_kernel_thread_num;
- constexpr int max_blocks = 65535;
- index_t blocks = std::min(static_cast<int>((num_elements + threads - 1) /
threads),
- max_blocks);
+ index_t num_blocks;
+ if (blocks != 0) {
+ num_blocks = blocks;
+ } else {
+ size_t num_elements = other_dim * num_aligned_elements;
+ num_blocks = (num_elements + threads - 1) / threads;
+ constexpr int max_blocks = 65535;
+ num_blocks = std::min(static_cast<int>(num_blocks), max_blocks);
+ }
std::vector<const void*> args = {¶ms, &lead_dim, &other_dim,
&N, &num_aligned_elements};
auto function = common::cuda::rtc::get_function(kernel_builder,
@@ -448,7 +458,7 @@ void VectorizedKernelRTCLauncher(const std::string
¶meters,
dev_id);
common::cuda::rtc::launch(function,
- {static_cast<unsigned int>(blocks), 1, 1},
+ {static_cast<unsigned int>(num_blocks), 1, 1},
{static_cast<unsigned int>(threads), 1, 1},
0, s, &args);
}
diff --git a/src/common/cuda/utils.cc b/src/common/cuda/utils.cc
index b87c393..7aa936d 100644
--- a/src/common/cuda/utils.cc
+++ b/src/common/cuda/utils.cc
@@ -29,6 +29,7 @@
#include <algorithm>
#include "utils.h"
+#include "../utils.h"
#if MXNET_USE_CUDA
@@ -36,25 +37,6 @@ namespace mxnet {
namespace common {
namespace cuda {
-namespace {
- bool IsPower2(size_t N) {
- return ((N & (N - 1)) == 0) && N != 0;
- }
-
- size_t RoundToPower2(size_t N) {
- size_t ret = 1;
- size_t copyN = N;
- while (N >= 2) {
- ret *= 2;
- N /= 2;
- }
- if (ret < copyN) {
- ret *= 2;
- }
- return ret;
- }
-} // namespace
-
int get_load_type(size_t N) {
using namespace mshadow;
if (N % 8 == 0) {
diff --git a/src/common/cuda/utils.h b/src/common/cuda/utils.h
index fc4d40c..a203ba5 100644
--- a/src/common/cuda/utils.h
+++ b/src/common/cuda/utils.h
@@ -811,6 +811,14 @@ __device__ inline T warp_reduce(T value, OP redfun) {
return value;
}
+template <typename OP, typename T>
+__device__ inline T grouped_warp_allreduce(T value, OP redfun, const int
group_size) {
+ for (int i = 1; i < group_size; i *= 2) {
+ value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
+ }
+ return __shfl_sync(0xffffffff, value, 0, group_size);
+}
+
template <int NValues = warp_size, typename OP>
__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t
value, OP redfun) {
float v = static_cast<float>(value);
diff --git a/src/common/utils.h b/src/common/utils.h
index dfd32ac..40376e9 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -977,6 +977,27 @@ inline void AlignedMemFree(void* ptr) {
}
+inline index_t div_round(const index_t a, const index_t b) {
+ return (a + b - 1) / b;
+}
+
+inline bool IsPower2(size_t N) {
+ return ((N & (N - 1)) == 0) && N != 0;
+}
+
+inline size_t RoundToPower2(size_t N) {
+ size_t ret = 1;
+ size_t copyN = N;
+ while (N >= 2) {
+ ret *= 2;
+ N /= 2;
+ }
+ if (ret < copyN) {
+ ret *= 2;
+ }
+ return ret;
+}
+
} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 05c8b9a..2251ff8 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -31,6 +31,7 @@
#include <mxnet/engine.h>
#include <mxnet/op_attr_types.h>
#include <algorithm>
+#include <limits>
#include "./operator_tune.h"
#include "../engine/openmp.h"
@@ -367,40 +368,30 @@ struct AccType<mshadow::half::half_t> {
break; \
case mshadow::kUint8: \
{ \
- typedef uint8_t DType; \
- typedef uint8_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
} \
break; \
case mshadow::kInt8: \
{ \
- typedef int8_t DType; \
- typedef int8_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types not int8"; \
} \
break; \
case mshadow::kInt32: \
{ \
- typedef int32_t DType; \
- typedef int32_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types, not int32"; \
} \
break; \
case mshadow::kInt64: \
{ \
- typedef int64_t DType; \
- typedef int64_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types, not int64"; \
} \
break; \
case mshadow::kBool: \
{ \
- typedef bool DType; \
- typedef int64_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types, not bool"; \
} \
@@ -475,21 +466,18 @@ struct AccType<mshadow::half::half_t> {
switch (type) { \
case mshadow::kFloat32: \
{ \
- typedef float DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float32"; \
} \
break; \
case mshadow::kFloat64: \
{ \
- typedef double DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float64"; \
} \
break; \
case mshadow::kFloat16: \
{ \
- typedef mshadow::half::half_t DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float16"; \
} \
@@ -532,21 +520,18 @@ struct AccType<mshadow::half::half_t> {
switch (type) { \
case mshadow::kFloat32: \
{ \
- typedef float DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float32"; \
} \
break; \
case mshadow::kFloat64: \
{ \
- typedef double DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float64"; \
} \
break; \
case mshadow::kFloat16: \
{ \
- typedef mshadow::half::half_t DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float16"; \
} \
diff --git a/src/operator/nn/log_softmax.cu b/src/operator/nn/log_softmax.cu
index 396a4e8..485290d 100644
--- a/src/operator/nn/log_softmax.cu
+++ b/src/operator/nn/log_softmax.cu
@@ -29,11 +29,10 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(log_softmax)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu,
mxnet_op::log_softmax_fwd>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCCompute{"log_softmax_fwd"});
NNVM_REGISTER_OP(_backward_log_softmax)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, mshadow_op::left,
-
mxnet_op::log_softmax_bwd>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCGradCompute{"op::left",
"log_softmax_bwd"});
NNVM_REGISTER_OP(masked_log_softmax)
.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu,
mxnet_op::log_softmax_fwd,
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 3f037f9..59af426 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -34,7 +34,6 @@
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../tensor/broadcast_reduce_op.h"
-#include "../../common/cuda/utils.h"
using mshadow::red::limits::MinValue;
@@ -325,143 +324,8 @@ inline void MaskedSoftmaxGrad(Stream<cpu> *s, DType *out,
DType *ograd,
}
#ifdef __CUDACC__
-template<int x_bits, typename OP, bool negate, typename AType, int ndim,
- typename DType, typename OType, typename IType>
-__global__ void softmax_compute_kernel(DType *in, OType *out, IType *length,
- index_t M, int axis, Shape<ndim> sshape,
- Shape<ndim> stride, const double
temperature) {
- const unsigned x_size = 1 << x_bits;
- __shared__ AType smem[x_size];
- index_t sa = stride[axis];
- index_t base = unravel_dot(blockIdx.x, sshape, stride);
- index_t x = threadIdx.x;
- const index_t len = length == nullptr ? M :
static_cast<index_t>(length[blockIdx.x]);
-
- red::maximum::SetInitValue(smem[x]);
- for (index_t i = x; i < len; i += x_size) {
- smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
- }
- __syncthreads();
- cuda::Reduce1D<red::maximum, x_bits>(smem);
- __syncthreads();
- DType smax = smem[0];
- __syncthreads();
-
- red::sum::SetInitValue(smem[x]);
- DType val;
- for (index_t i = x; i < len; i += x_size) {
- val = negate ? -in[base + i*sa]:in[base + i*sa];
- smem[x] += static_cast<AType>(expf((val - smax) /
static_cast<AType>(temperature)));
- }
- __syncthreads();
- cuda::Reduce1D<red::sum, x_bits>(smem);
- __syncthreads();
- AType ssum = smem[0];
- __syncthreads();
-
- for (index_t i = x; i < M; i += x_size) {
- val = negate ? -in[base + i*sa] : in[base + i*sa];
- out[base + i*sa] =
- (i < len) ? OType(OP::Map((val - smax)/static_cast<DType>(temperature),
ssum)) : OType(0.0f);
- }
-}
-
const int softmax_threads_per_block = 512;
-template<typename OP, bool negate, typename AType, typename LType,
- typename DType, typename OType, typename IType>
-__global__ void softmax_stride1_compute_kernel(const DType *in, OType *out,
IType *length,
- const index_t M, const double
temperature,
- const int rows_per_block, const
index_t total_rows) {
- __shared__ AType scratch[softmax_threads_per_block];
- __shared__ LType persistent_storage[20 * 1024 / sizeof(LType)];
- const int warp_size = 32;
- const int threads_per_row = softmax_threads_per_block / rows_per_block;
- const int my_local_row = threadIdx.x / threads_per_row;
- const int my_row = blockIdx.x * rows_per_block + my_local_row;
- if (my_row >= total_rows) return;
- const int my_id = threadIdx.x % threads_per_row;
- const int entries_per_load = sizeof(LType)/sizeof(DType);
- const index_t len = length == nullptr ? M :
static_cast<index_t>(length[my_row]);
- // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating
- // kernels where sizeof(LType) may be less than sizeof(DType),
- // resulting in entries_per_load being 0.
- // This is not a valid combination and is being checked against
- // in the launcher code. This switch here is just to silence
- // the division by zero warning generated for such invalid cases.
- const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
-
- const LType* in_aligned = reinterpret_cast<const LType*>(in);
- size_t base = my_row * row_length;
-
- for (index_t i = my_id; i < row_length; i += threads_per_row) {
- persistent_storage[my_local_row * row_length + i] = in_aligned[base + i];
- }
- DType * row = reinterpret_cast<DType *>(persistent_storage + my_local_row *
row_length);
- __syncthreads();
-
- DType my_max_value;
- red::maximum::SetInitValue(my_max_value);
-
- for (index_t i = my_id; i < len; i += threads_per_row) {
- my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
- }
- scratch[threadIdx.x] = my_max_value;
- __syncthreads();
- for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
- if (my_id < size) {
- scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x +
size]);
- }
- __syncthreads();
- }
- if (my_id < warp_size) {
- AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
- [](AType x, AType y) { return ::max(x, y); });
- scratch[threadIdx.x] = my_value;
- }
- __syncthreads();
- DType smax = scratch[threadIdx.x - threadIdx.x % threads_per_row];
- __syncthreads();
-
- AType my_sum;
- red::sum::SetInitValue(my_sum);
-
- for (index_t i = my_id; i < len; i += threads_per_row) {
- const DType val = negate ? -row[i] : row[i];
- my_sum += static_cast<AType>(expf((val - smax) /
static_cast<AType>(temperature)));
- }
- scratch[threadIdx.x] = my_sum;
- __syncthreads();
- for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
- if (my_id < size) {
- scratch[threadIdx.x] += scratch[threadIdx.x + size];
- }
- __syncthreads();
- }
- if (my_id < warp_size) {
- AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
- [](AType x, AType y) { return x + y;});
- scratch[threadIdx.x] = my_value;
- }
- __syncthreads();
-
- AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
- __syncthreads();
-
- for (index_t i = my_id; i < M; i += threads_per_row) {
- const DType val = negate ? -row[i] : row[i];
- row[i] = (i < len) ? DType(OP::Map((val -
smax)/static_cast<DType>(temperature), ssum)) :
- DType(0.0f);
- }
- __syncthreads();
-
- LType* out_aligned = reinterpret_cast<LType*>(out);
-
- for (index_t i = my_id; i < row_length; i += threads_per_row) {
- out_aligned[base + i] = persistent_storage[my_local_row * row_length + i];
- }
-}
-
template<int ndim>
MSHADOW_XINLINE index_t get_mask_position(const index_t idx, const
Shape<ndim>& data_shape,
const Shape<ndim>& mask_shape, int axis, index_t* stride_axis) {
@@ -665,45 +529,6 @@ __global__ void masked_softmax_stride1_kernel(const DType
*in, DType *out, bool
}
}
-template<typename OP, bool negate, typename AType, typename DType, typename
OType,
- typename IType, int ndim>
-inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
- Shape<ndim> shape, int axis, const double temperature) {
- const int x_bits = 7;
- const int x_size = 1 << x_bits;
- index_t M = shape[axis];
- if (M == 0 || shape.Size() == 0) return;
- index_t N = shape.Size()/M;
- Shape<ndim> stride = calc_stride(shape);
- Shape<ndim> sshape = shape;
- sshape[axis] = 1;
-
- const size_t DSize = sizeof(DType);
- // Using 20 kB of shared memory for persistent storage in the optimized case
- const size_t max_opt_M = 20 * 1024 / DSize;
- if (stride[axis] == 1 &&
- static_cast<size_t>(M) <= max_opt_M &&
- std::is_same<DType, OType>::value) {
- int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType));
- MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
- int rows_per_block = mxnet::common::cuda::get_rows_per_block(M *
-
sizeof(DType) / sizeof(LType),
-
softmax_threads_per_block);
- int nblocks = (N + rows_per_block - 1) / rows_per_block;
- CHECK_LE(sizeof(DType), sizeof(LType));
- softmax_stride1_compute_kernel<OP, negate, AType, LType>
- <<<nblocks, softmax_threads_per_block, 0,
mshadow::Stream<gpu>::GetStream(s)>>>(
- in, out, length, M, temperature, rows_per_block, N);
- });
- MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_compute_kernel);
- } else {
- softmax_compute_kernel<x_bits, OP, negate, AType, ndim>
- <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
- in, out, length, M, axis, sshape, stride, temperature);
- MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
- }
-}
-
template<typename OP, bool masked_neg_inf, bool negate,
typename AType, typename DType, typename OType, int ndim>
inline void MaskedSoftmax(Stream<gpu> *s, DType *in, OType *out, bool *mask,
@@ -785,120 +610,6 @@ inline void MaskedSoftmax(Stream<gpu> *s, DType *in,
OType *out, bool *mask,
}
template<typename OP1, typename OP2, int Req, bool negate, typename AType,
typename LType,
- typename DType, typename OType, typename IType>
-__global__ void softmax_stride1_grad_kernel(const OType *out, const OType
*ograd,
- DType *igrad, const IType *length,
- const index_t M,
- const double temperature,
- const int rows_per_block,
- const index_t total_rows) {
- __shared__ AType scratch[softmax_threads_per_block];
- __shared__ LType persistent_storage[20 * 1024 / sizeof(LType)];
- const int warp_size = 32;
- const int threads_per_row = softmax_threads_per_block / rows_per_block;
- const int my_local_row = threadIdx.x / threads_per_row;
- const int my_row = blockIdx.x * rows_per_block + my_local_row;
- if (my_row >= total_rows) return;
- const int my_id = threadIdx.x % threads_per_row;
- const int entries_per_load = sizeof(LType)/sizeof(DType);
- const index_t len = length == nullptr ? M :
static_cast<index_t>(length[my_row]);
- // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating
- // kernels where sizeof(LType) may be less than sizeof(DType),
- // resulting in entries_per_load being 0.
- // This is not a valid combination and is being checked against
- // in the launcher code. This switch here is just to silence
- // the division by zero warning generated for such invalid cases.
- const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
-
- const LType* out_aligned = reinterpret_cast<const LType*>(out);
- const LType* ograd_aligned = reinterpret_cast<const LType*>(ograd);
- size_t base = my_row * row_length;
-
- for (index_t i = my_id; i < row_length; i += threads_per_row) {
- persistent_storage[my_local_row * row_length * 2 + i] = out_aligned[base +
i];
- persistent_storage[my_local_row * row_length * 2 + row_length + i] =
ograd_aligned[base + i];
- }
- DType * row = reinterpret_cast<DType *>(persistent_storage + my_local_row *
row_length * 2);
- __syncthreads();
-
- AType my_sum_value;
- red::sum::SetInitValue(my_sum_value);
-
- for (index_t i = my_id; i < len; i += threads_per_row) {
- my_sum_value += OP1::Map(row[i + M], row[i]);
- }
- scratch[threadIdx.x] = my_sum_value;
- __syncthreads();
- for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
- if (my_id < size) {
- scratch[threadIdx.x] = scratch[threadIdx.x] + scratch[threadIdx.x +
size];
- }
- __syncthreads();
- }
- if (my_id < warp_size) {
- AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
- [](AType x, AType y) { return x + y; });
- scratch[threadIdx.x] = my_value;
- }
- __syncthreads();
- AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
- __syncthreads();
-
- for (index_t i = my_id; i < M; i += threads_per_row) {
- const DType val =
- negate ?
- -OP2::Map(row[i + M], row[i], ssum) :
- OP2::Map(row[i + M], row[i], ssum);
- row[i] = (i < len) ? DType(val / static_cast<DType>(temperature)) :
- DType(0.0f);
- if (Req == kAddTo) {
- row[i] += igrad[my_row * M + i];
- }
- }
- __syncthreads();
-
- LType* igrad_aligned = reinterpret_cast<LType*>(igrad);
-
- for (index_t i = my_id; i < row_length; i += threads_per_row) {
- igrad_aligned[base + i] = persistent_storage[my_local_row * row_length * 2
+ i];
- }
-}
-
-template<int x_bits, typename OP1, typename OP2, int Req, bool negate,
typename AType, int ndim,
- typename DType, typename OType, typename IType>
-__global__ void softmax_grad_kernel(OType *out, OType *ograd, DType *igrad,
- const IType *length, index_t M, int axis,
- Shape<ndim> sshape, Shape<ndim> stride,
- const double temperature) {
- const unsigned x_size = 1 << x_bits;
- __shared__ AType smem[x_size];
- index_t sa = stride[axis];
- index_t base = unravel_dot(blockIdx.x, sshape, stride);
- index_t x = threadIdx.x;
- index_t len = length != nullptr ? static_cast<index_t>(length[blockIdx.x]) :
M;
-
- red::sum::SetInitValue(smem[x]);
- for (index_t i = x; i < len; i += x_size) {
- smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]);
- }
- __syncthreads();
- cuda::Reduce1D<red::sum, x_bits>(smem);
- __syncthreads();
- AType ssum = smem[0];
- __syncthreads();
-
- DType final_result;
- for (index_t i = x; i < M; i += x_size) {
- final_result =
- negate ?
- -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) :
- OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum);
- final_result = (i < len) ? final_result : DType(0.0f);
- KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result /
static_cast<DType>(temperature));
- }
-}
-
-template<typename OP1, typename OP2, int Req, bool negate, typename AType,
typename LType,
typename LTypeMask, typename DType, typename OType, int ndim>
__global__ void masked_softmax_stride1_grad_kernel(const OType *out, const
OType *ograd,
DType *igrad, const bool
*in_mask,
@@ -1042,48 +753,6 @@ __global__ void masked_softmax_grad_kernel(OType *out,
OType *ograd, DType *igra
}
template<typename OP1, typename OP2, int Req, bool negate, typename AType, int
ndim,
- typename DType, typename OType, typename IType>
-inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
- DType *igrad, IType *length, Shape<ndim> shape, int
axis,
- const double temperature) {
- const int x_bits = 7;
- const int x_size = 1 << x_bits;
- index_t M = shape[axis];
- if (M == 0 || shape.Size() == 0) return;
- index_t N = shape.Size()/M;
- Shape<ndim> stride = calc_stride(shape);
- Shape<ndim> sshape = shape;
- sshape[axis] = 1;
-
- const size_t DSize = sizeof(DType);
- // Using 20 kB of shared memory for persistent storage in the optimized case
- // Need to store both out and ograd, so M can be only half compared to
- // forward pass.
- const size_t max_opt_M = 20 * 1024 / DSize / 2;
- if (stride[axis] == 1 &&
- static_cast<size_t>(M) <= max_opt_M &&
- std::is_same<DType, OType>::value) {
- int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType));
- MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
- int rows_per_block = mxnet::common::cuda::get_rows_per_block(M *
-
sizeof(DType) / sizeof(LType),
-
softmax_threads_per_block);
- int nblocks = (N + rows_per_block - 1) / rows_per_block;
- CHECK_LE(sizeof(DType), sizeof(LType));
- softmax_stride1_grad_kernel<OP1, OP2, Req, negate, AType, LType>
- <<<nblocks, softmax_threads_per_block, 0,
mshadow::Stream<gpu>::GetStream(s)>>>(
- out, ograd, igrad, length, M, temperature, rows_per_block, N);
- });
- MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_grad_kernel);
- } else {
- softmax_grad_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
- <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
- out, ograd, igrad, length, M, axis, sshape, stride, temperature);
- MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_grad_kernel);
- }
-}
-
-template<typename OP1, typename OP2, int Req, bool negate, typename AType, int
ndim,
typename DType, typename OType>
inline void MaskedSoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
DType *igrad, bool *mask, Shape<ndim> data_shape,
@@ -1554,6 +1223,32 @@ void MaskedSoftmaxCompute(const nnvm::NodeAttrs& attrs,
});
}
+#if MXNET_USE_CUDA
+
+struct SoftmaxRTCCompute {
+ std::string OP;
+ bool negate = false;
+
+ void operator()(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs);
+};
+
+struct SoftmaxRTCGradCompute {
+ std::string OP1;
+ std::string OP2;
+ bool negate = false;
+
+ void operator()(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs);
+};
+
+#endif
template<typename xpu, typename OP1, typename OP2, bool negate = false>
void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/softmax.cu b/src/operator/nn/softmax.cu
index c75f543..c8a05b8 100644
--- a/src/operator/nn/softmax.cu
+++ b/src/operator/nn/softmax.cu
@@ -22,18 +22,791 @@
* \file softmax.cu
* \brief GPU Implementation of softmax
*/
+#include <string>
#include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
namespace mxnet {
namespace op {
+namespace {
+
+struct softmax_params {
+ const void* inputs[3];
+ void* outputs[1];
+ index_t stride;
+ index_t num_elements;
+ double temperature;
+ int rows_per_block;
+ index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+ const void* inputs[3];
+ void* outputs[1];
+ index_t stride;
+ index_t num_elements;
+ double temperature;
+ int rows_per_block;
+ index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+ return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+ return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+ return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+ return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+ const index_t lead_dim) {
+ using LengthType = AccType<InputType1>;
+ const InputType0* input = reinterpret_cast<const
InputType0*>(param.inputs[0]);
+ const InputType1* length = reinterpret_cast<const
InputType1*>(param.inputs[1]);
+ const index_t len = length == nullptr
+ ? lead_dim
+ :
static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+ const int my_row = threadIdx.x % param.rows_per_block;
+ const int my_id = threadIdx.x / param.rows_per_block;
+ const int threads_per_row = blockDim.x / param.rows_per_block;
+ const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) %
param.stride;
+ const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) /
param.stride;
+ const index_t base = base_x + param.stride * lead_dim * base_n;
+ if (base >= param.num_elements * param.total_rows) return;
+ using IType = AccType<InputType0>;
+ using OType = AccType<OutputType0>;
+ using AType = type_util::mixed_type<typename IType::type,
+ typename OType::type>;
+ __shared__ AType smem[kRTCMaxThreadsPerBlock];
+ AType max;
+ red::maximum::SetInitValue(max);
+ for (index_t i = my_id; i < len; i += threads_per_row) {
+ auto val = IType::from(input[base + i * param.stride]);
+ max = op::max(max, negate ? -val : val);
+ }
+ smem[threadIdx.x] = max;
+ __syncthreads();
+ for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+ if (threadIdx.x < size) {
+ smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+ }
+ __syncthreads();
+ }
+ if (threadIdx.x < warp_size) {
+ AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+ [](AType x, AType y)
+ { return op::max(x,
y); },
+ param.rows_per_block);
+ smem[threadIdx.x] = my_value;
+ }
+ __syncthreads();
+ AType smax = smem[my_row];
+ __syncthreads();
+
+ AType sum;
+ red::sum::SetInitValue(sum);
+ for (index_t i = my_id; i < len; i += threads_per_row) {
+ auto val = IType::from(input[base + i * param.stride]);
+ val = negate ? -val :val;
+ sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+ }
+ smem[threadIdx.x] = sum;
+ __syncthreads();
+ for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+ if (threadIdx.x < size) {
+ smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+ }
+ __syncthreads();
+ }
+ if (threadIdx.x < warp_size) {
+ AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+ [](AType x, AType y)
+ { return op::add(x,
y); },
+ param.rows_per_block);
+ smem[threadIdx.x] = my_value;
+ }
+ __syncthreads();
+ sum = smem[my_row];
+ __syncthreads();
+
+ OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+ for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+ auto val = IType::from(input[base + i * param.stride]);
+ val = negate ? -val : val;
+ val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature),
sum) : 0;
+ if (req == OpReqType::kAddTo) {
+ if (i < len) {
+ output[base + i * param.stride] = OType::to(val +
+ OType::from(output[base +
i * param.stride]));
+ }
+ } else {
+ output[base + i * param.stride] = OType::to(val);
+ }
+ }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+ const index_t total_length,
+ const index_t other_dim,
+ const index_t N,
+ const index_t
num_aligned_elements) {
+ using namespace vector;
+ using IType = AccType<InputType0>;
+ using OType = AccType<OutputType0>;
+ using LengthType = AccType<InputType1>;
+ const InputType1* length = reinterpret_cast<const
InputType1*>(param.inputs[1]);
+ using AType = type_util::mixed_type<typename IType::type,
+ typename OType::type>;
+ __shared__ AType scratch[vectorized_kernel_thread_num];
+ __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+ const int threads_per_row = vectorized_kernel_thread_num /
param.rows_per_block;
+ const int my_local_row = threadIdx.x / threads_per_row;
+ const int base_row = blockIdx.x * param.rows_per_block;
+ const int my_row = base_row + my_local_row;
+ const index_t len = (length == nullptr ||
+ my_row >= param.total_rows) ? param.num_elements
+ :
LengthType::from(length[my_row]);
+ const int my_id = threadIdx.x % threads_per_row;
+
+ AType* row;
+ if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+ // full rows_per_block rows to compute
+ VectorizedLoader<InputType0, nvec, aligned> loader(
+ reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row *
param.num_elements,
+ total_length);
+ for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+ loader.load(i, total_length);
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+ }
+ }
+ row = persistent_storage + my_local_row * param.num_elements +
loader.alignment();
+ } else {
+ // less than rows_per_block rows to compute
+ const index_t real_length = min(total_length,
+ (param.total_rows - base_row) *
param.num_elements);
+ VectorizedLoader<InputType0, nvec, false> loader(
+ reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row *
param.num_elements,
+ real_length);
+ for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+ loader.load(i, real_length);
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+ }
+ }
+ row = persistent_storage + my_local_row * param.num_elements +
loader.alignment();
+ }
+ __syncthreads();
+
+ AType my_max_value;
+ red::maximum::SetInitValue(my_max_value);
+
+ for (index_t i = my_id; i < len; i += threads_per_row) {
+ my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+ }
+ AType smax;
+ if (!reduction_inside_warp) {
+ scratch[threadIdx.x] = my_max_value;
+ __syncthreads();
+ for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+ if (my_id < size) {
+ scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x
+ size]);
+ }
+ __syncthreads();
+ }
+ if (my_id < warp_size) {
+ AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+ [](AType x, AType y) {
return op::max(x, y); },
+ min(threads_per_row,
warp_size));
+ scratch[threadIdx.x] = my_value;
+ }
+ __syncthreads();
+ smax = scratch[threadIdx.x - my_id];
+ __syncthreads();
+ } else {
+ smax = util::grouped_warp_allreduce(my_max_value,
+ [](AType x, AType y) { return
op::max(x, y); },
+ threads_per_row);
+ }
+
+ AType my_sum;
+ red::sum::SetInitValue(my_sum);
+
+ for (index_t i = my_id; i < len; i += threads_per_row) {
+ const AType val = negate ? -row[i] : row[i];
+ my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+ }
+ AType ssum;
+ if (!reduction_inside_warp) {
+ scratch[threadIdx.x] = my_sum;
+ __syncthreads();
+ for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+ if (my_id < size) {
+ scratch[threadIdx.x] += scratch[threadIdx.x + size];
+ }
+ __syncthreads();
+ }
+ if (my_id < warp_size) {
+ AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+ [](AType x, AType y) {
return x + y;},
+ min(threads_per_row,
warp_size));
+ scratch[threadIdx.x] = my_value;
+ }
+ __syncthreads();
+
+ ssum = scratch[threadIdx.x - my_id];
+ __syncthreads();
+ } else {
+ ssum = util::grouped_warp_allreduce(my_sum,
+ [](AType x, AType y) { return x +
y;},
+ threads_per_row);
+ }
+
+ for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+ const AType val = negate ? -row[i] : row[i];
+ row[i] = (i < len) ? OP((val -
smax)/static_cast<AType>(param.temperature), ssum) :
+ 0;
+ }
+ __syncthreads();
+
+ if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+ VectorizedStorer<OutputType0, nvec, aligned> storer(
+ reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row *
param.num_elements,
+ total_length);
+
+ for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+ if (req == OpReqType::kAddTo) {
+ storer.load(i, total_length);
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec +
j],
+
OType::from(storer.separate()[j])));
+ }
+ } else {
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+ }
+ }
+ storer.store(i, total_length);
+ }
+ } else {
+ const index_t real_length = min(total_length,
+ (param.total_rows - base_row) *
param.num_elements);
+ VectorizedStorer<OutputType0, nvec, false> storer(
+ reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row *
param.num_elements,
+ real_length);
+
+ for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+ if (req == OpReqType::kAddTo) {
+ storer.load(i, real_length);
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec +
j],
+
OType::from(storer.separate()[j])));
+ }
+ } else {
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+ }
+ }
+ storer.store(i, real_length);
+ }
+ }
+}
+)code";
+
+int get_rows_per_block(const index_t row_size, const int nvec,
+ const index_t max_storage, const int
num_threads_per_block,
+ const index_t total_rows, const int dev_id) {
+ CHECK(common::IsPower2(num_threads_per_block))
+ << "Number of threads in a block must be power of 2 to use
get_rows_per_block function";
+ // How many read instructions should 1 thread at least do
+ const int read_instructions = 16;
+ const size_t row_size_in_vec = (row_size + nvec - 1) / nvec;
+ int desired_num_threads_per_row = (row_size_in_vec + read_instructions - 1)
/ read_instructions;
+ desired_num_threads_per_row =
common::RoundToPower2(desired_num_threads_per_row);
+ desired_num_threads_per_row = std::min(desired_num_threads_per_row,
num_threads_per_block);
+ const int desired_rows_per_block = num_threads_per_block /
desired_num_threads_per_row;
+ int actual_rows_per_block = desired_rows_per_block;
+ int num_sms = MultiprocessorCount(dev_id);
+ while (actual_rows_per_block > 1 &&
+ ((max_storage != -1 && max_storage < row_size *
actual_rows_per_block) ||
+ (total_rows + actual_rows_per_block - 1) / actual_rows_per_block <
num_sms)) {
+ actual_rows_per_block /= 2;
+ }
+ return actual_rows_per_block;
+}
+
+} // namespace
+
+void SoftmaxRTCCompute::operator()(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mxnet_op;
+ using common::mshadow_type_info;
+ using namespace common::cuda::rtc;
+ using common::div_round;
+ if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
+ const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+ int axis = CheckAxis(param.axis, inputs[0].ndim());
+ const double temperature = param.temperature.has_value() ?
+ param.temperature.value() : 1.0;
+ mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
+
+ void* length_ptr = nullptr;
+ std::string length_typename = "int";
+ if (param.use_length.value()) {
+ CHECK(inputs.size() > 1)
+ << "Mask needs to be provided when using softmax with use_length=True.";
+ length_ptr = inputs[1].dptr_;
+ length_typename = mshadow_type_info(inputs[1].type_flag_).name;
+ }
+ CHECK_EQ(outputs.size(), 1);
+ index_t M = shape[axis];
+ if (M == 0 || shape.Size() == 0) return;
+ index_t stride = 1;
+ if (axis == shape.ndim() - 2) {
+ stride = shape[shape.ndim() - 1];
+ }
+ const index_t N = shape.Size() / M;
+ softmax_params params = {{inputs[0].dptr_, length_ptr, nullptr},
+ {outputs[0].dptr_},
+ stride, M,
+ temperature, 1, N};
+ std::string code = "#define OP " + OP + "\n"
+ "const OpReqType req = " + util::to_string(req[0]) + ";\n"
+ "const bool negate = " + std::to_string(negate) + ";\n"
+ "using InputType1 = " + length_typename + ";\n";
+ Stream<gpu>* s = ctx.get_stream<gpu>();
+
+ constexpr int nvec = 2;
+ // Using 20 kB of shared memory for persistent storage in the optimized case
+ const size_t acc_type_size =
std::max(mshadow_type_info(inputs[0].type_flag_).acc_size,
+
mshadow_type_info(outputs[0].type_flag_).acc_size);
+ const size_t max_opt_M = 20 * 1024 / acc_type_size;
+ int rows_per_block = get_rows_per_block(M, nvec, max_opt_M,
+ vectorized_kernel_thread_num,
+ N, ctx.run_ctx.ctx.dev_id);
+ constexpr int warp_size = common::cuda::warp_size;
+ if (stride == 1 &&
+ static_cast<size_t>(M * rows_per_block) <= max_opt_M) {
+ code += "const bool only_full_blocks = " + std::to_string(N %
rows_per_block == 0) + ";\n"
+ "const bool reduction_inside_warp = " +
+ std::to_string(vectorized_kernel_thread_num / rows_per_block <=
warp_size) + ";\n";
+ params.rows_per_block = rows_per_block;
+ int nblocks = (N + rows_per_block - 1) / rows_per_block;
+ VectorizedKernelRTCLauncher(code + softmax_common_functions,
"softmax_stride1_compute_kernel",
+ softmax_stride1_kernel_fwd, nvec,
+ M * rows_per_block, N / rows_per_block, s,
params,
+ inputs, outputs,
+ ctx.run_ctx.ctx.dev_id, 0, nblocks);
+ MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_compute_kernel);
+ } else {
+ code += "using InputType0 = " +
mshadow_type_info(inputs[0].type_flag_).name + ";\n"
+ "using OutputType0 = " +
mshadow_type_info(outputs[0].type_flag_).name + ";\n";
+ std::vector<const void*> args;
+ args.emplace_back(¶ms);
+ args.emplace_back(&M);
+ int num_threads = std::min(static_cast<size_t>(128),
+ common::RoundToPower2(div_round(M, warp_size) *
warp_size));
+ if (stride != 1) {
+ const int num_sms = MultiprocessorCount(ctx.run_ctx.ctx.dev_id);
+ const index_t rows_per_sm = div_round(N, (512 / num_threads) * num_sms);
+ params.rows_per_block = std::min(static_cast<size_t>(warp_size),
+ common::RoundToPower2(rows_per_sm));
+ }
+ const auto& kernel_func = get_function(code + softmax_common_functions,
+ "simple_softmax_kernel",
+ simple_softmax_kernel_fwd,
+ ctx.run_ctx.ctx.dev_id);
+ launch(kernel_func, div_round(N, params.rows_per_block), num_threads, 0,
s, &args);
+ MSHADOW_CUDA_POST_KERNEL_CHECK(simple_softmax_kernel);
+ }
+}
+
+const char simple_softmax_kernel_bwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_grad_kernel(const softmax_params param,
+ const index_t lead_dim) {
+ using LengthType = AccType<InputType2>;
+ const InputType0* out = reinterpret_cast<const InputType0*>(param.inputs[0]);
+ const InputType1* ograd = reinterpret_cast<const
InputType1*>(param.inputs[1]);
+ const InputType2* length = reinterpret_cast<const
InputType2*>(param.inputs[2]);
+ const index_t len = length == nullptr
+ ? lead_dim
+ :
static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+ const int my_row = threadIdx.x % param.rows_per_block;
+ const int my_id = threadIdx.x / param.rows_per_block;
+ const int threads_per_row = blockDim.x / param.rows_per_block;
+ const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) %
param.stride;
+ const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) /
param.stride;
+ const index_t base = base_x + param.stride * lead_dim * base_n;
+ if (base >= param.num_elements * param.total_rows) return;
+ using IType0 = AccType<InputType0>;
+ using IType1 = AccType<InputType1>;
+ using OType = AccType<OutputType0>;
+ using AType = type_util::mixed_type<typename IType0::type,
+ typename IType1::type,
+ typename OType::type>;
+ __shared__ AType smem[kRTCMaxThreadsPerBlock];
+ AType sum;
+ red::sum::SetInitValue(sum);
+ for (index_t i = my_id; i < len; i += threads_per_row) {
+ auto out_val = IType0::from(out[base + i * param.stride]);
+ auto ograd_val = IType1::from(ograd[base + i * param.stride]);
+ sum += OP1(ograd_val, out_val);
+ }
+ smem[threadIdx.x] = sum;
+ __syncthreads();
+ for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+ if (threadIdx.x < size) {
+ smem[threadIdx.x] = smem[threadIdx.x] + smem[threadIdx.x + size];
+ }
+ __syncthreads();
+ }
+ if (threadIdx.x < warp_size) {
+ AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+ [](AType x, AType y) {
return x + y; },
+ param.rows_per_block);
+ smem[threadIdx.x] = my_value;
+ }
+ __syncthreads();
+ sum = smem[my_row];
+ __syncthreads();
+
+ OutputType0* igrad = reinterpret_cast<OutputType0*>(param.outputs[0]);
+ for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+ auto out_val = IType0::from(out[base + i * param.stride]);
+ auto ograd_val = IType1::from(ograd[base + i * param.stride]);
+ auto val = (i < len) ? OP2(ograd_val, out_val, sum) /
static_cast<AType>(param.temperature) : 0;
+ val = negate ? -val : val;
+ if (req == OpReqType::kAddTo) {
+ if (i < len) {
+ igrad[base + i * param.stride] = OType::to(val +
+ OType::from(igrad[base + i
* param.stride]));
+ }
+ } else {
+ igrad[base + i * param.stride] = OType::to(val);
+ }
+ }
+}
+)code";
+
+const char softmax_stride1_kernel_bwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_grad_kernel(const softmax_params param,
+ const index_t total_length,
+ const index_t other_dim,
+ const index_t N,
+ const index_t
num_aligned_elements) {
+ using namespace vector;
+ using IType0 = AccType<InputType0>;
+ using IType1 = AccType<InputType1>;
+ using OType = AccType<OutputType0>;
+ using LengthType = AccType<InputType2>;
+ const InputType2* length = reinterpret_cast<const
InputType2*>(param.inputs[2]);
+ using AType = type_util::mixed_type<typename IType0::type,
+ typename IType1::type,
+ typename OType::type>;
+ __shared__ AType scratch[vectorized_kernel_thread_num];
+ __shared__ AType output_persistent_storage[10 * 1024 / sizeof(AType)];
+ __shared__ AType ograd_persistent_storage[10 * 1024 / sizeof(AType)];
+ const int warp_size = 32;
+ const int threads_per_row = vectorized_kernel_thread_num /
param.rows_per_block;
+ const int my_local_row = threadIdx.x / threads_per_row;
+ const int base_row = blockIdx.x * param.rows_per_block;
+ const int my_row = base_row + my_local_row;
+ const index_t len = (length == nullptr ||
+ my_row >= param.total_rows) ? param.num_elements
+ :
LengthType::from(length[my_row]);
+ const int my_id = threadIdx.x % threads_per_row;
+
+ AType* output_row;
+ AType* ograd_row;
+ if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+ // full rows_per_block rows to compute
+ VectorizedLoader<InputType0, nvec, aligned> output_loader(
+ reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row *
param.num_elements,
+ total_length);
+ VectorizedLoader<InputType1, nvec, aligned> ograd_loader(
+ reinterpret_cast<const InputType1*>(param.inputs[1]) + base_row *
param.num_elements,
+ total_length);
+ for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+ output_loader.load(i, total_length);
+ ograd_loader.load(i, total_length);
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ output_persistent_storage[i*nvec + j] =
IType0::from(output_loader.separate()[j]);
+ ograd_persistent_storage[i*nvec + j] =
IType1::from(ograd_loader.separate()[j]);
+ }
+ }
+ output_row = output_persistent_storage +
+ my_local_row * param.num_elements +
+ output_loader.alignment();
+ ograd_row = ograd_persistent_storage +
+ my_local_row * param.num_elements +
+ ograd_loader.alignment();
+ } else {
+ // less than rows_per_block rows to compute
+ const index_t real_length = min(total_length,
+ (param.total_rows - base_row) *
param.num_elements);
+ VectorizedLoader<InputType0, nvec, false> output_loader(
+ reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row *
param.num_elements,
+ real_length);
+ VectorizedLoader<InputType1, nvec, false> ograd_loader(
+ reinterpret_cast<const InputType1*>(param.inputs[1]) + base_row *
param.num_elements,
+ real_length);
+ for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+ output_loader.load(i, real_length);
+ ograd_loader.load(i, real_length);
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ output_persistent_storage[i*nvec + j] =
IType0::from(output_loader.separate()[j]);
+ ograd_persistent_storage[i*nvec + j] =
IType1::from(ograd_loader.separate()[j]);
+ }
+ }
+ output_row = output_persistent_storage +
+ my_local_row * param.num_elements +
+ output_loader.alignment();
+ ograd_row = ograd_persistent_storage +
+ my_local_row * param.num_elements +
+ ograd_loader.alignment();
+ }
+ __syncthreads();
+
+ AType my_sum;
+ red::sum::SetInitValue(my_sum);
+
+ for (index_t i = my_id; i < len; i += threads_per_row) {
+ const AType val = OP1(ograd_row[i], output_row[i]);
+ my_sum += val;
+ }
+ AType ssum;
+ if (!reduction_inside_warp) {
+ scratch[threadIdx.x] = my_sum;
+ __syncthreads();
+ for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+ if (my_id < size) {
+ scratch[threadIdx.x] += scratch[threadIdx.x + size];
+ }
+ __syncthreads();
+ }
+ if (my_id < warp_size) {
+ AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+ [](AType x, AType y) {
return x + y;},
+ min(threads_per_row,
warp_size));
+ scratch[threadIdx.x] = my_value;
+ }
+ __syncthreads();
+
+ ssum = scratch[threadIdx.x - my_id];
+ __syncthreads();
+ } else {
+ ssum = util::grouped_warp_allreduce(my_sum,
+ [](AType x, AType y) { return x +
y;},
+ threads_per_row);
+ }
+
+ for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+ AType val = (i < len)
+ ? OP2(ograd_row[i], output_row[i], ssum) /
static_cast<AType>(param.temperature)
+ : 0;
+ output_row[i] = negate ? -val : val;
+ }
+ __syncthreads();
+
+ if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+ VectorizedStorer<OutputType0, nvec, aligned> storer(
+ reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row *
param.num_elements,
+ total_length);
+
+ for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+ if (req == OpReqType::kAddTo) {
+ storer.load(i, total_length);
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ storer.separate()[j] =
OType::to(op::add(output_persistent_storage[i*nvec + j],
+
OType::from(storer.separate()[j])));
+ }
+ } else {
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ storer.separate()[j] = OType::to(output_persistent_storage[i*nvec +
j]);
+ }
+ }
+ storer.store(i, total_length);
+ }
+ } else {
+ const index_t real_length = min(total_length,
+ (param.total_rows - base_row) *
param.num_elements);
+ VectorizedStorer<OutputType0, nvec, false> storer(
+ reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row *
param.num_elements,
+ real_length);
+
+ for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+ if (req == OpReqType::kAddTo) {
+ storer.load(i, real_length);
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ storer.separate()[j] =
OType::to(op::add(output_persistent_storage[i*nvec + j],
+
OType::from(storer.separate()[j])));
+ }
+ } else {
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ storer.separate()[j] = OType::to(output_persistent_storage[i*nvec +
j]);
+ }
+ }
+ storer.store(i, real_length);
+ }
+ }
+}
+)code";
+
+void SoftmaxRTCGradCompute::operator()(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mxnet_op;
+ using common::mshadow_type_info;
+ using namespace common::cuda::rtc;
+ using common::div_round;
+ Stream<gpu>* s = ctx.get_stream<gpu>();
+ if (softmax_use_length(attrs)) {
+ if (req[1] != kNullOp) {
+ cudaMemsetAsync(outputs[1].dptr_, 0,
+ outputs[1].Size() *
mshadow_type_info(outputs[1].type_flag_).size,
+ Stream<gpu>::GetStream(s));
+ }
+ }
+ if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
+ const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+ int axis = CheckAxis(param.axis, inputs[0].ndim());
+ const double temperature = param.temperature.has_value() ?
+ param.temperature.value() : 1.0;
+ mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
+
+ int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1;
+ out_idx = softmax_use_length(attrs) ? 3 : out_idx;
+
+ void* length_ptr = nullptr;
+ std::string length_typename = "int";
+ if (softmax_use_length(attrs)) {
+ length_ptr = inputs[2].dptr_;
+ length_typename = mshadow_type_info(inputs[2].type_flag_).name;
+ }
+ index_t M = shape[axis];
+ if (M == 0 || shape.Size() == 0) return;
+ index_t stride = 1;
+ if (axis == shape.ndim() - 2) {
+ stride = shape[shape.ndim() - 1];
+ }
+ const index_t N = shape.Size() / M;
+ softmax_params params = {{inputs[out_idx].dptr_, inputs[0].dptr_,
length_ptr},
+ {outputs[0].dptr_},
+ stride, M,
+ temperature, 1, N};
+ std::string code = "#define OP1 " + OP1 + "\n"
+ "#define OP2 " + OP2 + "\n"
+ "const OpReqType req = " + util::to_string(req[0]) + ";\n"
+ "const bool negate = " + std::to_string(negate) + ";\n"
+ "using InputType2 = " + length_typename + ";\n";
+
+ constexpr int nvec = 2;
+ // Using 20 kB of shared memory for persistent storage in the optimized case
+ const size_t acc_type_size =
std::max(mshadow_type_info(inputs[0].type_flag_).acc_size,
+
mshadow_type_info(outputs[0].type_flag_).acc_size);
+ const size_t max_opt_M = 10 * 1024 / acc_type_size;
+ int rows_per_block = get_rows_per_block(M, nvec, max_opt_M,
+ vectorized_kernel_thread_num,
+ N, ctx.run_ctx.ctx.dev_id);
+ params.rows_per_block = rows_per_block;
+ bool debug_softmax = dmlc::GetEnv("DEBUG_SOFTMAX_GRAD", false);
+ if (!debug_softmax && stride == 1 &&
+ static_cast<size_t>(M * rows_per_block) <= max_opt_M) {
+ const int warp_size = 32;
+ code += "const bool only_full_blocks = " + std::to_string(N %
rows_per_block == 0) + ";\n"
+ "const bool reduction_inside_warp = " +
+ std::to_string(vectorized_kernel_thread_num / rows_per_block <=
warp_size) + ";\n";
+ int nblocks = div_round(N, rows_per_block);
+ std::vector<TBlob> new_inputs = {inputs[out_idx], inputs[0]};
+ if (softmax_use_length(attrs)) {
+ new_inputs.emplace_back(inputs[2]);
+ }
+ std::vector<TBlob> new_outputs = {outputs[0]};
+ VectorizedKernelRTCLauncher(code + softmax_common_functions,
+ "softmax_stride1_compute_grad_kernel",
+ softmax_stride1_kernel_bwd, nvec,
+ M * rows_per_block, N / rows_per_block, s,
params,
+ new_inputs, new_outputs,
+ ctx.run_ctx.ctx.dev_id, 0, nblocks);
+ MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_compute_grad_kernel);
+ } else {
+ code += "using InputType0 = " +
mshadow_type_info(inputs[out_idx].type_flag_).name + ";\n"
+ "using InputType1 = " +
mshadow_type_info(inputs[0].type_flag_).name + ";\n"
+ "using OutputType0 = " +
mshadow_type_info(outputs[0].type_flag_).name + ";\n";
+ std::vector<const void*> args;
+ args.emplace_back(¶ms);
+ args.emplace_back(&M);
+ const int warp_size = 32;
+ int num_threads = std::min(static_cast<size_t>(128),
+ common::RoundToPower2(div_round(M, warp_size) *
warp_size));
+ if (stride != 1) {
+ const int num_sms = MultiprocessorCount(ctx.run_ctx.ctx.dev_id);
+ const index_t rows_per_sm = div_round(N, (512 / num_threads) * num_sms);
+ params.rows_per_block = std::min(static_cast<size_t>(warp_size),
+ common::RoundToPower2(rows_per_sm));
+ }
+ const auto& kernel_func = get_function(code + softmax_common_functions,
+ "simple_softmax_grad_kernel",
+ simple_softmax_kernel_bwd,
+ ctx.run_ctx.ctx.dev_id);
+ launch(kernel_func, div_round(N, params.rows_per_block), num_threads, 0,
s, &args);
+ MSHADOW_CUDA_POST_KERNEL_CHECK(simple_softmax_grad_kernel);
+ }
+}
+
NNVM_REGISTER_OP(softmax)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu,
mxnet_op::softmax_fwd>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCCompute{"softmax_fwd"});
NNVM_REGISTER_OP(_backward_softmax)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu,
op::mshadow_op::mul,
-
mxnet_op::softmax_bwd>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCGradCompute{"op::mul",
"softmax_bwd"});
+
NNVM_REGISTER_OP(masked_softmax)
.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu,
mxnet_op::softmax_fwd,
false>);
diff --git a/src/operator/nn/softmin.cu b/src/operator/nn/softmin.cu
index d00d0bd..b6f56ce 100644
--- a/src/operator/nn/softmin.cu
+++ b/src/operator/nn/softmin.cu
@@ -29,11 +29,10 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(softmin)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu,
mxnet_op::softmax_fwd, true>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCCompute{"softmax_fwd", true});
NNVM_REGISTER_OP(_backward_softmin)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu,
op::mshadow_op::mul,
- mxnet_op::softmax_bwd,
true>);
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxRTCGradCompute{"op::mul",
"softmax_bwd", true});
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.cc
b/src/operator/tensor/elemwise_binary_scalar_op.cc
index f09bf21..d7fde47 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op.cc
+++ b/src/operator/tensor/elemwise_binary_scalar_op.cc
@@ -50,6 +50,7 @@ __global__ void binary_scalar_kernel(const
binary_scalar_kernel_params params,
const index_t N,
const index_t num_aligned_elements) {
using namespace vector;
+ using type_util::mixed_type;
VectorizedLoader<InputType0, nvec, aligned> loader(
reinterpret_cast<const InputType0*>(params.inputs[0]), N);
VectorizedStorer<OutputType0, nvec, aligned> storer(
@@ -72,9 +73,8 @@ __global__ void binary_scalar_kernel(const
binary_scalar_kernel_params params,
const auto input = IType::from(loader.separate()[i]);
// enables returning different type
const auto temp = OP(input,
- static_cast<typename type_util::mixed_type<typename
IType::type,
- typename
OType::type>::type>
- (params.scalar));
+ static_cast<mixed_type<typename IType::type,
+ typename
OType::type>>(params.scalar));
if (req == OpReqType::kAddTo) {
// temp2 may have a wider type than either temp
@@ -171,6 +171,7 @@ __global__ void binary_scalar_kernel_bwd(const
binary_scalar_kernel_params param
const index_t N,
const index_t num_aligned_elements) {
using namespace vector;
+ using type_util::mixed_type;
VectorizedLoader<InputType0, nvec, aligned> ograd_loader(
reinterpret_cast<const InputType0*>(params.inputs[0]), N);
VectorizedLoader<InputType1, nvec, aligned> input_loader(
@@ -199,9 +200,8 @@ __global__ void binary_scalar_kernel_bwd(const
binary_scalar_kernel_params param
// enables returning different type
const auto temp = op::mul(ograd,
OP(input,
- static_cast<typename
type_util::mixed_type<typename IType::type,
-
typename OType::type>
- ::type>(params.scalar)));
+ static_cast<mixed_type<typename IType::type,
+ typename
OType::type>>(params.scalar)));
if (req == OpReqType::kAddTo) {
// temp2 may have a wider type than either temp