apeforest commented on a change in pull request #17368: Faster GPU frozen 
BatchNorm
URL: https://github.com/apache/incubator-mxnet/pull/17368#discussion_r369401124
 
 

 ##########
 File path: src/common/cuda_utils.h
 ##########
 @@ -764,27 +764,86 @@ __device__ inline DType ldg(const DType* address) {
 #endif
 }
 
-template <typename OP, typename T>
+namespace mxnet {
+namespace common {
+/*! \brief common utils for cuda */
+namespace cuda {
+
+static constexpr const int warp_size = 32;
+
+/*! \brief Reduction inside a warp.
+ * Template parameters:
+ * NVALUES - number of values to reduce (defaults to warp_size).
+ * \param value - values to be reduced.
+ * \param redfun - function used to perform reduction.
+ */
+template <int NVALUES = warp_size, typename OP, typename T>
 __device__ inline T warp_reduce(T value, OP redfun) {
-  value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
-  value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
-  value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
-  value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
-  value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
+#pragma unroll
+  for (int i = warp_size / 2; i >= 1; i /= 2) {
+    if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, 
i));
+  }
   return value;
 }
 
-template <typename OP>
+template <int NValues = warp_size, typename OP>
 __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t 
value, OP redfun) {
   float v = static_cast<float>(value);
-  v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
-  v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
-  v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
-  v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
-  v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
+#pragma unroll
+  for (int i = warp_size / 2; i >= 1; i /= 2) {
+    if (NValues > i) v = redfun(v, __shfl_down_sync(0xffffffff, v, i));
+  }
   return mshadow::half::half_t(v);
 }
 
+/*! \brief Reduction inside a block, requires all threads in a block to 
participate.
+ *         It uses a 2 step approach:
+ *          - all warps in a block perform intermediate reduction
+ *          - first warp reduces the intermediate results.
+ * Template parameters:
+ * NTHREADS - number of threads in a block.
+ * all_reduce - whether all threads need the result of the reduction. If set to
+ *              true, then all threads return with the same value. If set to
+ *              false, then only thread 0 has the valid result. Defaults to 
true.
+ * \param value - value from each thread to be reduced
+ * \param redfun - function used to perform reduction
+ */
+template <int NTHREADS, bool all_reduce = true, typename OP, typename T>
+__device__ inline T reduce(const T& value, OP redfun) {
+  static_assert(NTHREADS <= warp_size * warp_size,
+                "Number of threads too large for reduction");
+  __shared__ T scratch[NTHREADS / warp_size];
+  const int my_id = threadIdx.x % warp_size;
 
 Review comment:
   Could we use a more informative names other than my_*?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to