This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 344587f  Safe accumulation for computing gradient in Embedding & Take 
(#18385)
344587f is described below

commit 344587f295666e4375042d054cd5a134fdeaf517
Author: MoisesHer <[email protected]>
AuthorDate: Thu Aug 13 22:18:26 2020 -0700

    Safe accumulation for computing gradient in Embedding & Take (#18385)
    
    * Safe accumulation for computing gradient in Embedding & Take
    
    * Fix bug in TakeGrad: initialize temporal storage for safe_accumulation
    
    * fix lint
    
    * make MXNET_SAFE_ACCUMULATION compatible with Windows
    
    * Increase test coverage: small inputs & SAFE_ACCUMULATION
---
 3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh |  77 ++++++++++
 3rdparty/mshadow/mshadow/tensor.h                |  26 ++++
 3rdparty/mshadow/mshadow/tensor_cpu-inl.h        |  33 ++++
 3rdparty/mshadow/mshadow/tensor_gpu-inl.h        |   8 +
 src/operator/tensor/indexing_op.cu               |  84 ++++++----
 src/operator/tensor/indexing_op.h                | 186 +++++++++++++++++++----
 tests/python/gpu/test_operator_gpu.py            | 156 ++++++++++++-------
 7 files changed, 453 insertions(+), 117 deletions(-)

diff --git a/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh 
b/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh
index 02a74b2..a00aade 100644
--- a/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh
+++ b/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh
@@ -641,6 +641,43 @@ __global__ void AddTakeGradKernel(DstPlan dst,
   }
 }
 
+template<bool clip, int x_bits, typename DstPlan, typename ATypePlan,
+         typename SrcPlan1, typename SrcPlan2>
+__global__ void AddTakeGradKernel(DstPlan dst,
+                                  ATypePlan temp,
+                                  SrcPlan1 index, SrcPlan2 src,
+                                  index_t ymax, index_t xmax, const int K) {
+  const unsigned x_size = 1 << x_bits;
+  const int xindex = blockIdx.x * x_size + threadIdx.x;
+  __shared__ int ptr;
+  if (xindex < xmax) {
+    for (unsigned y = 0; y < K; ++y) {
+      temp.REval(y, xindex) = dst.Eval(y, xindex);
+    }
+  }
+  for (unsigned y = 0; y < ymax; ++y) {
+    if (threadIdx.x == 0) {
+      ptr = index.Eval(0, y);
+      if (clip) {
+        if (ptr <= 0) ptr = 0;
+        else if (ptr >= K) ptr = K - 1;
+      } else {
+        ptr %= K;
+        if (ptr < 0) ptr += K;
+      }
+    }
+    __syncthreads();
+    if (xindex < xmax) {
+      temp.REval(ptr, xindex) += src.Eval(y, xindex);
+    }
+  }
+  if (xindex < xmax) {
+    for (unsigned y = 0; y < K; ++y) {
+      dst.REval(y, xindex) = temp.Eval(y, xindex);
+    }
+  }
+}
+
 template<int warp_bits, int SZ, typename DType, typename IdxType>
 __global__ void AddTakeGradLargeBatchKernel(DType* dst,
                                             const IdxType *sorted, const 
IdxType *index,
@@ -733,6 +770,46 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
   MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel);
 }
 
+template<bool clip = true, typename IndexType, typename DType, typename AType>
+inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
+                        Tensor<gpu, 2, AType> temp,
+                        const Tensor<gpu, 1, IndexType>& index,
+                        const Tensor<gpu, 2, DType> &src) {
+  CHECK_EQ(dst.CheckContiguous(), true);
+  CHECK_EQ(index.CheckContiguous(), true);
+  CHECK_EQ(src.CheckContiguous(), true);
+  const int kUnitBits = kMemUnitBits + 1;
+  dim3 dimBlock(1 << kUnitBits);
+  dim3 dimGrid((dst.size(1) + (1 << kUnitBits) - 1) >> kUnitBits);
+
+  CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGrad: shape mismatch";
+  CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGrad: shape mismatch";
+  CheckLaunchParam(dimGrid, dimBlock, "AddTakeGrad");
+  cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
+  const int K = dst.shape_[0];
+
+  if (clip) {
+    AddTakeGradKernel<true, kUnitBits>
+      <<<dimGrid, dimBlock, 0, stream>>>
+      (expr::MakePlan(dst),
+       expr::MakePlan(temp),
+       expr::MakePlan(index),
+       expr::MakePlan(src),
+       src.size(0),
+       src.size(1), K);
+  } else {
+    AddTakeGradKernel<false, kUnitBits>
+      <<<dimGrid, dimBlock, 0, stream>>>
+      (expr::MakePlan(dst),
+       expr::MakePlan(temp),
+       expr::MakePlan(index),
+       expr::MakePlan(src),
+       src.size(0),
+       src.size(1), K);
+  }
+  MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel);
+}
+
 template<typename IndexType, typename DType>
 inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
                                   const Tensor<gpu, 1, IndexType>& sorted,
diff --git a/3rdparty/mshadow/mshadow/tensor.h 
b/3rdparty/mshadow/mshadow/tensor.h
index 8dd57e2..c92bf8d 100644
--- a/3rdparty/mshadow/mshadow/tensor.h
+++ b/3rdparty/mshadow/mshadow/tensor.h
@@ -848,6 +848,19 @@ inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
  * \param index index to take
  * \param src source output
  */
+template<bool clip = true, typename IndexType, typename DType, typename AType>
+inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
+                        Tensor<cpu, 2, AType> temp,
+                        const Tensor<cpu, 1, IndexType>& index,
+                        const Tensor<cpu, 2, DType> &src);
+/*!
+ * \brief CPU/GPU: Gradient accumulate of embedding matrix with safe 
accumulation.
+                   dst[index[i]] += src[i]
+ * \param dst destination
+ * \temp temporal storage for safe accumulation
+ * \param index index to take
+ * \param src source output
+ */
 template<bool clip = true, typename IndexType, typename DType>
 inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
                         const Tensor<gpu, 1, IndexType>& index,
@@ -861,6 +874,19 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
  * \param index original index of the sorted indices
  * \param src source output
  */
+template<bool clip = true, typename IndexType, typename DType, typename AType>
+inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
+                        Tensor<gpu, 2, AType> temp,
+                        const Tensor<gpu, 1, IndexType>& index,
+                        const Tensor<gpu, 2, DType> &src);
+/*!
+ * \brief CPU/GPU: Gradient accumulate of embedding matrix with safe 
accumulation.
+                   dst[index[i]] += src[i]
+ * \param dst destination
+ * \temp temporal storage for safe accumulation
+ * \param index index to take
+ * \param src source output
+ */
 template<typename IndexType, typename DType>
 inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
                                   const Tensor<cpu, 1, IndexType>& sorted,
diff --git a/3rdparty/mshadow/mshadow/tensor_cpu-inl.h 
b/3rdparty/mshadow/mshadow/tensor_cpu-inl.h
index 2d00220..5f05f0a 100644
--- a/3rdparty/mshadow/mshadow/tensor_cpu-inl.h
+++ b/3rdparty/mshadow/mshadow/tensor_cpu-inl.h
@@ -539,6 +539,39 @@ inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
   }
 }
 
+// safe accumulation
+template<bool clip, typename IndexType, typename DType, typename AType>
+inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
+                        Tensor<cpu, 2, AType> temp,
+                        const Tensor<cpu, 1, IndexType>& index,
+                        const Tensor<cpu, 2, DType> &src) {
+  const index_t K = dst.shape_[0];
+  const index_t C = dst.shape_[1];
+  for (index_t j = 0; j < K; ++j) {
+    for (index_t i = 0; i < C; ++i) {
+      temp[j][i] = dst[j][i];
+    }
+  }
+  for (index_t y = 0; y < index.size(0); ++y) {
+    index_t j = index[y];
+    if (clip) {
+      if (j <= 0) j = 0;
+      else if (j >= K) j = K - 1;
+    } else {
+      j %= K;
+      if (j < 0) j += K;
+    }
+    for (index_t i = 0; i < C; ++i) {
+      temp[j][i] += src[y][i];
+    }
+  }
+  for (index_t j = 0; j < K; ++j) {
+    for (index_t i = 0; i < C; ++i) {
+      dst[j][i] = temp[j][i];
+    }
+  }
+}
+
 template<typename IndexType, typename DType>
 inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
                                   const Tensor<cpu, 1, IndexType>& sorted,
diff --git a/3rdparty/mshadow/mshadow/tensor_gpu-inl.h 
b/3rdparty/mshadow/mshadow/tensor_gpu-inl.h
index e7dde27..3140259 100644
--- a/3rdparty/mshadow/mshadow/tensor_gpu-inl.h
+++ b/3rdparty/mshadow/mshadow/tensor_gpu-inl.h
@@ -239,6 +239,14 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
   cuda::AddTakeGrad<clip, IndexType, DType>(dst, index, src);
 }
 
+template<bool clip, typename IndexType, typename DType, typename AType>
+inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
+                        Tensor<gpu, 2, AType> temp,
+                        const Tensor<gpu, 1, IndexType>& index,
+                        const Tensor<gpu, 2, DType> &src) {
+  cuda::AddTakeGrad<clip, IndexType, DType>(dst, temp, index, src);
+}
+
 template<typename IndexType, typename DType>
 inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
                                   const Tensor<gpu, 1, IndexType>& sorted,
diff --git a/src/operator/tensor/indexing_op.cu 
b/src/operator/tensor/indexing_op.cu
index e3c8a78..f9d7a19 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -684,27 +684,25 @@ __global__ void EmbeddingFindBounds(const IType 
*sorted_data,
  * \param grad_out output gradient data
  * \param embbedding_dim dimension of the dense embedding
  * \param vocab_dim maximum number of unique indices in the data array: tokens 
vocabulary size
+ * \param nelems_per_load number of elements per each load based on (LType / 
DType)
  * \param req write/add/null
  */
-template <typename LType, typename DType, typename IType>
+template <typename AType, typename LType, typename DType, typename IType>
 __global__ void EmbeddingGradKernel(DType *grad_in,
-                                      const IType *original_index,
-                                      const IType *index_bounds,
-                                      const DType *grad_out,
-                                      const index_t embbedding_dim,
-                                      const index_t vocab_dim,
-                                      const int req) {
+                                    const IType *original_index,
+                                    const IType *index_bounds,
+                                    const DType *grad_out,
+                                    const index_t embbedding_dim,
+                                    const index_t vocab_dim,
+                                    const int nelems_per_load,
+                                    const int req) {
   extern __shared__ int sharedmem[];
-  LType* grad_in_row =  reinterpret_cast<LType *>(sharedmem);
-
-  // LType has to be bigger than DType, guarded in the launcher code
-  const int n_val = sizeof(DType) < sizeof(LType) ? sizeof(LType) / 
sizeof(DType) : 1;
+  AType* grad_in_row =  reinterpret_cast<AType *>(sharedmem);
   const LType *aligned_grad_out = reinterpret_cast<const LType *>(grad_out);
   LType *aligned_grad_in = reinterpret_cast<LType *>(grad_in);
-  const index_t aligned_emb_dim = embbedding_dim / n_val;
-  DType *my_grad_in_row = reinterpret_cast<DType *>(&grad_in_row[threadIdx.x]);
-  LType Lvalue[1];
-  DType* Dvalues = reinterpret_cast<DType*>(Lvalue);
+  const index_t aligned_emb_dim = embbedding_dim / nelems_per_load;
+  LType load_value[1];
+  DType* data_values = reinterpret_cast<DType*>(load_value);
 
   IType my_row = blockIdx.x;
   if (my_row < vocab_dim) {
@@ -716,29 +714,37 @@ __global__ void EmbeddingGradKernel(DType *grad_in,
     for (index_t emb_id=threadIdx.x; emb_id < aligned_emb_dim; emb_id += 
blockDim.x) {
       // Initialize grad_in
       if (req == kAddTo) {
-        grad_in_row[threadIdx.x] = aligned_grad_in[my_row * aligned_emb_dim + 
emb_id];
+        *load_value = aligned_grad_in[my_row * aligned_emb_dim + emb_id];
+        for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
+          grad_in_row[val_id * blockDim.x + threadIdx.x] = 
static_cast<AType>(data_values[val_id]);
+        }
       } else {
-        grad_in_row[threadIdx.x] = 0.0;
+        for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
+          grad_in_row[val_id * blockDim.x + threadIdx.x] = 
static_cast<AType>(0.0);
+        }
       }
       // Add all rows from grad_out according to indices in data
       for (index_t data_idx=lower_bound; data_idx < (lower_bound + 
nOccurrences); ++data_idx) {
-        *Lvalue = aligned_grad_out[original_index[data_idx] * aligned_emb_dim 
+ emb_id];
-        for (index_t val_id = 0; val_id < n_val; val_id++) {
-          my_grad_in_row[val_id] += Dvalues[val_id];
+        *load_value = aligned_grad_out[original_index[data_idx] * 
aligned_emb_dim + emb_id];
+        for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
+          grad_in_row[val_id * blockDim.x + threadIdx.x] += 
static_cast<AType>(data_values[val_id]);
         }
       }
       // Save results
-      aligned_grad_in[my_row * aligned_emb_dim + emb_id] = 
grad_in_row[threadIdx.x];
+      for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
+        data_values[val_id] = static_cast<DType>(grad_in_row[val_id * 
blockDim.x + threadIdx.x]);
+      }
+      aligned_grad_in[my_row * aligned_emb_dim + emb_id] = *load_value;
     }
   }
 }
 
-template<typename gpu, typename IType, typename DType>
+template<typename AType, typename IType, typename DType>
 void EmbeddingGradKernelCaller(const OpContext& ctx,
-                                mshadow::Tensor<gpu, 2, DType> grad_in,
-                                const mshadow::Tensor<gpu, 1, IType>& index,
-                                const mshadow::Tensor<gpu, 2, DType> &grad_out,
-                                const std::vector<OpReqType>& req) {
+                               mshadow::Tensor<gpu, 2, DType> grad_in,
+                               const mshadow::Tensor<gpu, 1, IType>& index,
+                               const mshadow::Tensor<gpu, 2, DType> &grad_out,
+                               const std::vector<OpReqType>& req) {
   using namespace mxnet_op;
   using namespace mshadow::expr;
 
@@ -792,20 +798,23 @@ void EmbeddingGradKernelCaller(const OpContext& ctx,
 
   // Compute Gradient
   int ltype = mxnet::common::cuda::get_load_type(embbedding_dim * 
sizeof(DType));
+
   MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
-    int nelems_per_thread = sizeof(LType) / sizeof(DType);
+    CHECK_LE(sizeof(DType), sizeof(LType));
+    int nelems_per_load = sizeof(LType) / sizeof(DType);
     int threads_block_grad = 32;
     int maxThreads = 1024;
-    while (threads_block_grad < (embbedding_dim/nelems_per_thread) &&
+    while (threads_block_grad < (embbedding_dim/nelems_per_load) &&
           (threads_block_grad < maxThreads))
       threads_block_grad += 32;
-    size_t required_shared = threads_block_grad * sizeof(LType);
+    size_t required_shared = threads_block_grad * nelems_per_load * 
sizeof(AType);
     dim3 blocks(vocab_dim, 1);
-    EmbeddingGradKernel<LType><<<blocks, threads_block_grad, required_shared,
+    EmbeddingGradKernel<AType, LType><<<blocks, threads_block_grad, 
required_shared,
                   Stream<gpu>::GetStream(s)>>>(
                   grad_in.dptr_, original_index.dptr_,
                   bounds_index.dptr_, grad_out.dptr_,
                   embbedding_dim, vocab_dim,
+                  nelems_per_load,
                   req[embedding::kWeight]);
   });
 }
@@ -831,9 +840,17 @@ void EmbeddingOpBackward<gpu>(const nnvm::NodeAttrs& attrs,
   const mxnet::TShape& oshape = inputs[0].shape_;
 
   Stream<gpu> *s = ctx.get_stream<gpu>();
+
   CHECK_NE(req[embedding::kWeight], kWriteInplace)
     << "Backward of Embedding does not support writing in place.";
-  MSHADOW_TYPE_SWITCH(outputs[1].type_flag_, DType, {
+  bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
+  if (!safe_acc && outputs[1].type_flag_ == mshadow::kFloat16) {
+    common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for 
EmbeddingOpBackward "
+                    "with float16 inputs. "
+                    "See https://mxnet.apache.org/api/faq/env_var "
+                    "for more details.");
+  }
+  MXNET_REAL_ACC_TYPE_SWITCH(outputs[1].type_flag_, DType, AType, {
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {
       Tensor < gpu, 1, IType > data = inputs[1].get_with_shape<gpu, 1, IType>(
         Shape1(ishape.ProdShape(0, ishape.ndim())), s);
@@ -842,7 +859,10 @@ void EmbeddingOpBackward<gpu>(const nnvm::NodeAttrs& attrs,
       Tensor<gpu, 2, DType> grad_in = outputs[1].get<gpu, 2, DType>(s);
 
       if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] == 
kAddTo) {
-        EmbeddingGradKernelCaller(ctx, grad_in, data, grad_out, req);
+        if (safe_acc)
+            EmbeddingGradKernelCaller<AType>(ctx, grad_in, data, grad_out, 
req);
+        else
+            EmbeddingGradKernelCaller<DType>(ctx, grad_in, data, grad_out, 
req);
       } else {
         LOG(FATAL) << "wrong req";
       }
diff --git a/src/operator/tensor/indexing_op.h 
b/src/operator/tensor/indexing_op.h
index 5454900..7f0f0fa 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -400,20 +400,38 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs,
   const mxnet::TShape& oshape = inputs[0].shape_;
 
   Stream<xpu> *s = ctx.get_stream<xpu>();
-  MSHADOW_TYPE_SWITCH(outputs[1].type_flag_, DType, {
+
+  bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
+  if (!safe_acc && outputs[1].type_flag_ == mshadow::kFloat16) {
+    common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for 
EmbeddingOpBackward "
+                    "with float16 inputs. "
+                    "See https://mxnet.apache.org/api/faq/env_var "
+                    "for more details.");
+  }
+  MXNET_REAL_ACC_TYPE_SWITCH(outputs[1].type_flag_, DType, AType, {
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {
       Tensor < xpu, 1, IType > data = inputs[1].get_with_shape<xpu, 1, IType>(
         Shape1(ishape.ProdShape(0, ishape.ndim())), s);
       Tensor<xpu, 2, DType> grad_out = inputs[0].get_with_shape<xpu, 2, DType>(
-      Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), 
s);
+        Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), 
s);
       Tensor<xpu, 2, DType> grad_in = outputs[1].get<xpu, 2, DType>(s);
 
-
       if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] == 
kAddTo) {
         if (req[embedding::kWeight] == kWriteTo) {
           grad_in = scalar<DType>(0.0f);
         }
-        AddTakeGrad(grad_in, data, grad_out);
+        if (safe_acc) {
+          // Temporary storage for safe accumulation
+          size_t temp_space_size = grad_in.size(0) * grad_in.size(1) * 
sizeof(AType);
+          Tensor<xpu, 1, char> temp_space =
+            ctx.requested[embedding::kTempSpace].get_space_typed<xpu, 1, char>(
+              Shape1(temp_space_size), s);
+          Tensor<xpu, 2, AType> 
temp_grad_in(reinterpret_cast<AType*>(temp_space.dptr_),
+              grad_in.shape_, s);
+          AddTakeGrad(grad_in, temp_grad_in, data, grad_out);
+        } else {
+          AddTakeGrad(grad_in, data, grad_out);
+        }
       } else {
         LOG(FATAL) << "wrong req";
       }
@@ -696,7 +714,48 @@ struct TakeGradGeneralKernel {
   }
 };
 
-template<bool clip = true>
+struct TakeGradGeneralKernelSafeAccumulation {
+  /*!
+   * \brief Map function for general case of take grad
+   * \param tid           global thread id
+   * \param arr_grad      ptr to in_grad
+   * \param temp          ptr to temporal space to perform accumulation
+   * \param ograd         ptr to out_grad
+   * \param src_indptr    ptr to indptr to src indices
+   * \param original_idx  ptr to original indices of the inputs
+   * \param in_strides    strides of inputs
+   * \param out_strides   strides of outputs
+   * \param in_ndims      # of dims of input tensor
+   * \param out_ndims     # of dims of output tensor
+   * \param idx_ndims     # of dims of indices tensor
+   * \param axis_dim      dim size of the axis dimension
+   * \param axis          axis id
+   */
+  template<typename DType, typename IType, typename AType>
+  MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, AType* temp,
+                                  const DType* ograd,
+                                  const IType* src_indptr, const IType* 
original_idx,
+                                  mshadow::Shape<10> in_strides, 
mshadow::Shape<10> out_strides,
+                                  const int in_ndims, const int out_ndims, 
const int idx_ndims,
+                                  const int axis, const int K) {
+    const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1];
+    const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1];
+    const int in_mid_index = in_rest_index / in_strides[axis];
+    const int in_tail_index = (axis == in_ndims - 1) ?
+                              0 : (in_rest_index % in_strides[axis]);
+    temp[tid] = static_cast<AType>(arr_grad[tid]);
+    for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; 
++i) {
+      int out_mid_index = original_idx[i];
+      out_mid_index = (out_mid_index < 0) ? out_mid_index + K : out_mid_index;
+      int target = in_tail_index + out_mid_index * in_strides[axis];
+      target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1];
+      temp[tid] += ograd[target];
+    }
+    arr_grad[tid] = temp[tid];
+  }
+};
+
+template<bool clip = true, bool safe_acc = false, typename AType>
 void TakeOpBackwardImpl(mshadow::Stream<cpu>* s,
                         const OpContext& ctx,
                         const TBlob& arr,
@@ -715,14 +774,23 @@ void TakeOpBackwardImpl(mshadow::Stream<cpu>* s,
     size_t temp_storage_bytes = SortByKeyWorkspaceSize<int, int, 
cpu>(idxshape.Size());
     size_t original_idx_bytes = idxshape.Size() * sizeof(int);
     size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int);
-    size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + 
temp_storage_bytes;
+    size_t temp_accumulation_arrgrad_bytes = 0;
+    if (safe_acc) {
+        temp_accumulation_arrgrad_bytes = arr.Size() * sizeof(AType);
+    }
+    size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes +
+        temp_storage_bytes + temp_accumulation_arrgrad_bytes;
     Tensor<cpu, 1, char> workspace =
       ctx.requested[0].get_space_typed<cpu, 1, char>(Shape1(workspace_bytes), 
s);
-    int* sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_);
-    int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + 
original_idx_bytes);
-    src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 * 
original_idx_bytes);
+    AType* temp_accum_arrgrad_ptr = reinterpret_cast<AType*>(workspace.dptr_);
+    int* sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + 
temp_accumulation_arrgrad_bytes);
+    int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + 
original_idx_bytes +
+                                                   
temp_accumulation_arrgrad_bytes);
+    src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 * 
original_idx_bytes +
+                                            temp_accumulation_arrgrad_bytes);
     Tensor<cpu, 1, char> temp_storage(
-      workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes, 
Shape1(temp_storage_bytes), s);
+      workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes + 
temp_accumulation_arrgrad_bytes,
+      Shape1(temp_storage_bytes), s);
     // Reset indptr to zero
     Kernel<set_zero, cpu>::Launch(s, arrshape[axis] + 1, src_indptr_ptr);
     // Fill original_idx
@@ -759,16 +827,23 @@ void TakeOpBackwardImpl(mshadow::Stream<cpu>* s,
       out_strides[i] = stride;
     }
     MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, {
-      Kernel<TakeGradGeneralKernel, cpu>::Launch(
-        s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(), 
src_indptr_ptr,
-        original_idx_ptr, in_strides, out_strides,
-        arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, 
static_cast<int>(arrshape[axis]));
+      if (safe_acc) {
+        Kernel<TakeGradGeneralKernelSafeAccumulation, cpu>::Launch(
+          s, arrshape.Size(), arr.dptr<DType>(), temp_accum_arrgrad_ptr, 
ograd.dptr<DType>(),
+          src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
+          arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, 
static_cast<int>(arrshape[axis]));
+      } else {
+        Kernel<TakeGradGeneralKernel, cpu>::Launch(
+          s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(),
+          src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
+          arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, 
static_cast<int>(arrshape[axis]));
+      }
     });
   });
 }
 
 #ifdef __CUDACC__
-template<bool clip = true>
+template<bool clip = true, bool safe_acc = false, typename AType>
 void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
                         const OpContext& ctx,
                         const TBlob& arr,
@@ -808,13 +883,23 @@ void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
     temp_storage_bytes = max(temp_storage_bytes, histo_temp_storage_bytes);
     size_t original_idx_bytes = idxshape.Size() * sizeof(int);
     size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int);
-    size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + 
temp_storage_bytes;
+    size_t temp_accumulation_igrad_bytes = 0;
+    if (safe_acc) {
+        temp_accumulation_igrad_bytes = arr.Size() * sizeof(AType);
+    }
+    size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes +
+        temp_storage_bytes + temp_accumulation_igrad_bytes;
     Tensor<gpu, 1, char> workspace =
       ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_bytes), 
s);
-    sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_);
-    int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + 
original_idx_bytes);
-    src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 * 
original_idx_bytes);
-    temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + 
src_indptr_bytes;
+    AType* temp_accum_igrad_ptr = reinterpret_cast<AType*>(workspace.dptr_);
+    sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + 
temp_accumulation_igrad_bytes);
+    int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + 
original_idx_bytes +
+                                                   
temp_accumulation_igrad_bytes);
+    src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 * 
original_idx_bytes +
+                                            temp_accumulation_igrad_bytes);
+    temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + 
src_indptr_bytes +
+                       temp_accumulation_igrad_bytes;
+
     // Reset indptr to zero
     Kernel<set_zero, gpu>::Launch(s, arrshape[axis] + 1, src_indptr_ptr);
     // Fill original_idx
@@ -863,10 +948,19 @@ void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
       out_strides[i] = stride;
     }
     MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, {
-      Kernel<TakeGradGeneralKernel, gpu>::Launch(
-        s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(),
-        src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
-        arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, 
static_cast<int>(arrshape[axis]));
+      if (safe_acc) {
+          Kernel<TakeGradGeneralKernelSafeAccumulation, gpu>::Launch(
+            s, arrshape.Size(), arr.dptr<DType>(), temp_accum_igrad_ptr, 
ograd.dptr<DType>(),
+            src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
+            arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis,
+            static_cast<int>(arrshape[axis]));
+      } else {
+          Kernel<TakeGradGeneralKernel, gpu>::Launch(
+            s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(),
+            src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
+            arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis,
+            static_cast<int>(arrshape[axis]));
+      }
     });
   });
 }
@@ -891,7 +985,14 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
   // grad_in is the gradient of the inputs in the feed-forward
   Stream<xpu> *s = ctx.get_stream<xpu>();
 
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {  // output data type
+  bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
+  if (!safe_acc && outputs[0].type_flag_ == mshadow::kFloat16) {
+    common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for 
TakeOpBackward "
+                    "with float16 inputs. "
+                    "See https://mxnet.apache.org/api/faq/env_var "
+                    "for more details.");
+  }
+  MXNET_REAL_ACC_TYPE_SWITCH(outputs[0].type_flag_, DType, AType, {
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // index data type
       // inputs are specified in the .cc file, which are the gradients from
       // the upper layer and the input index
@@ -925,10 +1026,25 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
           if (req[take_::kArr] == kWriteTo) {
             grad_in = scalar<DType>(0.0f);
           }
-          if (param.mode == take_::kClip) {
-            AddTakeGrad(grad_in, idx, grad_out);
+          if (safe_acc) {
+            // Temporary storage for safe accumulation
+            size_t temp_space_size = grad_in.size(0) * grad_in.size(1) * 
sizeof(AType);
+            Tensor<xpu, 1, char> temp_space =
+              ctx.requested[take_::kTempSpace].get_space_typed<xpu, 1, char>(
+                Shape1(temp_space_size), s);
+            Tensor<xpu, 2, AType> 
temp_grad_in(reinterpret_cast<AType*>(temp_space.dptr_),
+                grad_in.shape_, s);
+            if (param.mode == take_::kClip) {
+              AddTakeGrad(grad_in, temp_grad_in, idx, grad_out);
+            } else {
+              AddTakeGrad<false>(grad_in, temp_grad_in, idx, grad_out);
+            }
           } else {
-            AddTakeGrad<false>(grad_in, idx, grad_out);
+            if (param.mode == take_::kClip) {
+              AddTakeGrad(grad_in, idx, grad_out);
+            } else {
+              AddTakeGrad<false>(grad_in, idx, grad_out);
+            }
           }
         } else {
           LOG(FATAL) << "wrong req";
@@ -939,10 +1055,18 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
         const TBlob& arr = outputs[0];
         const TBlob& ograd = inputs[0];
 
-        if (param.mode == take_::kClip) {
-          TakeOpBackwardImpl<true>(s, ctx, arr, idx, ograd, actual_axis);
+        if (safe_acc) {
+          if (param.mode == take_::kClip) {
+            TakeOpBackwardImpl<true, true, AType>(s, ctx, arr, idx, ograd, 
actual_axis);
+          } else {
+            TakeOpBackwardImpl<false, true, AType>(s, ctx, arr, idx, ograd, 
actual_axis);
+          }
         } else {
-          TakeOpBackwardImpl<false>(s, ctx, arr, idx, ograd, actual_axis);
+          if (param.mode == take_::kClip) {
+            TakeOpBackwardImpl<true, false, AType>(s, ctx, arr, idx, ograd, 
actual_axis);
+          } else {
+            TakeOpBackwardImpl<false, false, AType>(s, ctx, arr, idx, ograd, 
actual_axis);
+          }
         }
       }
     });
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index 84cfe9c..519c02f 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1522,19 +1522,23 @@ def test_lrn():
                     reason="Testing with naive engine consistently triggers 
illegal memory access. Tracked in #17713")
 def test_embedding_with_type():
     def test_embedding_helper(data_types, weight_types, low_pad, high_pad):
-        NVD = [[20, 10, 20], [200, 10, 300]]
-        for N, V, D in NVD:
-            sym = mx.sym.Embedding(name='embedding', input_dim=V, output_dim=D)
-            ctx_list = []
-            for data_type in data_types:
-                for weight_type in weight_types:
-                    ctx_list.append({'ctx': mx.gpu(0), 'embedding_data': (N,),
-                        'type_dict': {'embedding_data': data_type, 
'embedding_weight': weight_type}})
-                    ctx_list.append({'ctx': mx.cpu(0), 'embedding_data': (N,),
-                        'type_dict': {'embedding_data': data_type, 
'embedding_weight': weight_type}})
-            arg_params = {'embedding_data': np.random.randint(low=-low_pad, 
high=V+high_pad, size=(N,))}
-            check_consistency(sym, ctx_list, grad_req={'embedding_data': 
'null','embedding_weight': 'write'},
-                              arg_params=arg_params, scale=0.1)
+        NVD = [[20, 10, 20], [200, 10, 300], [10000, 4, 20]]
+        for safe_accumulation in ['0', '1', None]:
+            for N, V, D in NVD:
+                with environment('MXNET_SAFE_ACCUMULATION', safe_accumulation):
+                    if N > 1000 and safe_accumulation != '1':
+                        break
+                    sym = mx.sym.Embedding(name='embedding', input_dim=V, 
output_dim=D)
+                    ctx_list = []
+                    for data_type in data_types:
+                        for weight_type in weight_types:
+                            ctx_list.append({'ctx': mx.gpu(0), 
'embedding_data': (N,),
+                                'type_dict': {'embedding_data': data_type, 
'embedding_weight': weight_type}})
+                            ctx_list.append({'ctx': mx.cpu(0), 
'embedding_data': (N,),
+                                'type_dict': {'embedding_data': data_type, 
'embedding_weight': weight_type}})
+                    arg_params = {'embedding_data': 
np.random.randint(low=-low_pad, high=V+high_pad, size=(N,))}
+                    check_consistency(sym, ctx_list, 
grad_req={'embedding_data': 'null','embedding_weight': 'write'},
+                                      arg_params=arg_params, scale=0.1)
 
     data_types = [np.float16, np.float32, np.float64, np.int32]
     weight_types = [np.float16, np.float32, np.float64]
@@ -1547,47 +1551,91 @@ def test_embedding_with_type():
 @with_seed()
 def test_take_with_type():
     sym = mx.sym.take(name='take')
-    for data_ndim in range(2, 5):
-        for idx_ndim in range(1, 4):
-            data_shape = ()
-            for _ in range(data_ndim):
-                data_shape += (np.random.randint(low=3, high=6), )
-            idx_shape = ()
-            for _ in range(idx_ndim):
-                idx_shape += (np.random.randint(low=3, high=5), )
-            ctx_list = [{'ctx': mx.gpu(0), 'take_indices': idx_shape,
-                         'take_a': data_shape,
-                         'type_dict': {'take_indices': np.float64,
-                                       'take_a': np.float64}},
-                        {'ctx': mx.gpu(0), 'take_indices': idx_shape,
-                         'take_a': data_shape,
-                         'type_dict': {'take_indices': np.float32,
-                                       'take_a': np.float32}},
-                        {'ctx': mx.gpu(0), 'take_indices': idx_shape,
-                         'take_a': data_shape,
-                         'type_dict': {'take_indices': np.float16,
-                                       'take_a': np.float16}},
-                        {'ctx': mx.cpu(0), 'take_indices': idx_shape,
-                         'take_a': data_shape,
-                         'type_dict': {'take_indices': np.float64,
-                                       'take_a': np.float64}},
-                        {'ctx': mx.cpu(0), 'take_indices': idx_shape,
-                         'take_a': data_shape,
-                         'type_dict': {'take_indices': np.float32,
-                                       'take_a': np.float32}},
-                        {'ctx': mx.cpu(0), 'take_indices': idx_shape,
-                         'take_a': data_shape,
-                         'type_dict': {'take_indices': np.float16,
-                                       'take_a': np.float16}}]
-            arg_params = {'take_indices': np.random.randint(low=0,
-                                                            high=data_shape[0],
-                                                            size=idx_shape),
-                          'take_a': np.random.normal(size=data_shape)}
-            check_consistency(sym, ctx_list,
-                              grad_req={'take_indices': 'null',
-                                        'take_a': 'write'},
-                              arg_params=arg_params)
-
+    for safe_accumulation in ['0', '1', None]:
+        for data_ndim in range(2, 5):
+            for idx_ndim in range(1, 4):
+                data_shape = ()
+                for _ in range(data_ndim):
+                    data_shape += (np.random.randint(low=3, high=6), )
+                idx_shape = ()
+                for _ in range(idx_ndim):
+                    idx_shape += (np.random.randint(low=3, high=5), )
+                ctx_list = [{'ctx': mx.gpu(0), 'take_indices': idx_shape,
+                             'take_a': data_shape,
+                             'type_dict': {'take_indices': np.float64,
+                                           'take_a': np.float64}},
+                            {'ctx': mx.gpu(0), 'take_indices': idx_shape,
+                             'take_a': data_shape,
+                             'type_dict': {'take_indices': np.float32,
+                                           'take_a': np.float32}},
+                            {'ctx': mx.gpu(0), 'take_indices': idx_shape,
+                             'take_a': data_shape,
+                             'type_dict': {'take_indices': np.float16,
+                                           'take_a': np.float16}},
+                            {'ctx': mx.cpu(0), 'take_indices': idx_shape,
+                             'take_a': data_shape,
+                             'type_dict': {'take_indices': np.float64,
+                                           'take_a': np.float64}},
+                            {'ctx': mx.cpu(0), 'take_indices': idx_shape,
+                             'take_a': data_shape,
+                             'type_dict': {'take_indices': np.float32,
+                                           'take_a': np.float32}},
+                            {'ctx': mx.cpu(0), 'take_indices': idx_shape,
+                             'take_a': data_shape,
+                             'type_dict': {'take_indices': np.float16,
+                                           'take_a': np.float16}}]
+                arg_params = {'take_indices': np.random.randint(low=0,
+                                                                
high=data_shape[0],
+                                                                
size=idx_shape),
+                              'take_a': np.random.normal(size=data_shape)}
+                with environment('MXNET_SAFE_ACCUMULATION', safe_accumulation):
+                    check_consistency(sym, ctx_list,
+                                      grad_req={'take_indices': 'null',
+                                                'take_a': 'write'},
+                                      arg_params=arg_params)
+
+    # check a large num of indices: may underflow calculating gradient in FP16,
+    # if MXNET_SAFE_ACCUMULATION is not activated
+    with environment('MXNET_SAFE_ACCUMULATION', '1'):
+        data_size = 4
+        indices_size = 10000
+        out_dim = 20
+        data_types = [np.float16, np.float32, np.float64]
+        indices_types = [np.float16, np.float32, np.float64, np.int32]
+        # axis 0
+        sym = mx.sym.take(name='take', axis=0)
+        ctx_list = []
+        for data_type in data_types:
+            for index_type in indices_types:
+                ctx_list.append({'ctx': mx.cpu(0), 'take_indices': 
(indices_size,),
+                    'take_a': (data_size, out_dim),
+                    'type_dict': {'take_indices': index_type, 'take_a': 
data_type}})
+                ctx_list.append({'ctx': mx.gpu(0), 'take_indices': 
(indices_size,),
+                    'take_a': (data_size, out_dim),
+                    'type_dict': {'take_indices': index_type, 'take_a': 
data_type}})
+                arg_params = {'take_indices': np.random.randint(0, data_size,
+                              size=(indices_size,)),
+                              'take_a': np.random.normal(size=(data_size, 
out_dim))}
+                check_consistency(sym, ctx_list,
+                                  grad_req={'take_indices': 'null','take_a': 
'write'},
+                                  arg_params=arg_params)
+        # axis 1
+        sym = mx.sym.take(name='take', axis=1)
+        ctx_list = []
+        for data_type in data_types:
+            for index_type in indices_types:
+                ctx_list.append({'ctx': mx.cpu(0), 'take_indices': 
(indices_size,),
+                    'take_a': (data_size, out_dim),
+                    'type_dict': {'take_indices': index_type, 'take_a': 
data_type}})
+                ctx_list.append({'ctx': mx.gpu(0), 'take_indices': 
(indices_size,),
+                    'take_a': (data_size, out_dim),
+                    'type_dict': {'take_indices': index_type, 'take_a': 
data_type}})
+                arg_params = {'take_indices': np.random.randint(0, data_size,
+                              size=(indices_size,)),
+                              'take_a': np.random.normal(size=(data_size, 
out_dim))}
+                check_consistency(sym, ctx_list,
+                                  grad_req={'take_indices': 'null','take_a': 
'write'},
+                                  arg_params=arg_params)
 
 @with_seed()
 @pytest.mark.serial

Reply via email to