This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 344587f Safe accumulation for computing gradient in Embedding & Take
(#18385)
344587f is described below
commit 344587f295666e4375042d054cd5a134fdeaf517
Author: MoisesHer <[email protected]>
AuthorDate: Thu Aug 13 22:18:26 2020 -0700
Safe accumulation for computing gradient in Embedding & Take (#18385)
* Safe accumulation for computing gradient in Embedding & Take
* Fix bug in TakeGrad: initialize temporal storage for safe_accumulation
* fix lint
* make MXNET_SAFE_ACCUMULATION compatible with Windows
* Increase test coverage: small inputs & SAFE_ACCUMULATION
---
3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh | 77 ++++++++++
3rdparty/mshadow/mshadow/tensor.h | 26 ++++
3rdparty/mshadow/mshadow/tensor_cpu-inl.h | 33 ++++
3rdparty/mshadow/mshadow/tensor_gpu-inl.h | 8 +
src/operator/tensor/indexing_op.cu | 84 ++++++----
src/operator/tensor/indexing_op.h | 186 +++++++++++++++++++----
tests/python/gpu/test_operator_gpu.py | 156 ++++++++++++-------
7 files changed, 453 insertions(+), 117 deletions(-)
diff --git a/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh
b/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh
index 02a74b2..a00aade 100644
--- a/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh
+++ b/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh
@@ -641,6 +641,43 @@ __global__ void AddTakeGradKernel(DstPlan dst,
}
}
+template<bool clip, int x_bits, typename DstPlan, typename ATypePlan,
+ typename SrcPlan1, typename SrcPlan2>
+__global__ void AddTakeGradKernel(DstPlan dst,
+ ATypePlan temp,
+ SrcPlan1 index, SrcPlan2 src,
+ index_t ymax, index_t xmax, const int K) {
+ const unsigned x_size = 1 << x_bits;
+ const int xindex = blockIdx.x * x_size + threadIdx.x;
+ __shared__ int ptr;
+ if (xindex < xmax) {
+ for (unsigned y = 0; y < K; ++y) {
+ temp.REval(y, xindex) = dst.Eval(y, xindex);
+ }
+ }
+ for (unsigned y = 0; y < ymax; ++y) {
+ if (threadIdx.x == 0) {
+ ptr = index.Eval(0, y);
+ if (clip) {
+ if (ptr <= 0) ptr = 0;
+ else if (ptr >= K) ptr = K - 1;
+ } else {
+ ptr %= K;
+ if (ptr < 0) ptr += K;
+ }
+ }
+ __syncthreads();
+ if (xindex < xmax) {
+ temp.REval(ptr, xindex) += src.Eval(y, xindex);
+ }
+ }
+ if (xindex < xmax) {
+ for (unsigned y = 0; y < K; ++y) {
+ dst.REval(y, xindex) = temp.Eval(y, xindex);
+ }
+ }
+}
+
template<int warp_bits, int SZ, typename DType, typename IdxType>
__global__ void AddTakeGradLargeBatchKernel(DType* dst,
const IdxType *sorted, const
IdxType *index,
@@ -733,6 +770,46 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel);
}
+template<bool clip = true, typename IndexType, typename DType, typename AType>
+inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
+ Tensor<gpu, 2, AType> temp,
+ const Tensor<gpu, 1, IndexType>& index,
+ const Tensor<gpu, 2, DType> &src) {
+ CHECK_EQ(dst.CheckContiguous(), true);
+ CHECK_EQ(index.CheckContiguous(), true);
+ CHECK_EQ(src.CheckContiguous(), true);
+ const int kUnitBits = kMemUnitBits + 1;
+ dim3 dimBlock(1 << kUnitBits);
+ dim3 dimGrid((dst.size(1) + (1 << kUnitBits) - 1) >> kUnitBits);
+
+ CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGrad: shape mismatch";
+ CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGrad: shape mismatch";
+ CheckLaunchParam(dimGrid, dimBlock, "AddTakeGrad");
+ cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
+ const int K = dst.shape_[0];
+
+ if (clip) {
+ AddTakeGradKernel<true, kUnitBits>
+ <<<dimGrid, dimBlock, 0, stream>>>
+ (expr::MakePlan(dst),
+ expr::MakePlan(temp),
+ expr::MakePlan(index),
+ expr::MakePlan(src),
+ src.size(0),
+ src.size(1), K);
+ } else {
+ AddTakeGradKernel<false, kUnitBits>
+ <<<dimGrid, dimBlock, 0, stream>>>
+ (expr::MakePlan(dst),
+ expr::MakePlan(temp),
+ expr::MakePlan(index),
+ expr::MakePlan(src),
+ src.size(0),
+ src.size(1), K);
+ }
+ MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel);
+}
+
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& sorted,
diff --git a/3rdparty/mshadow/mshadow/tensor.h
b/3rdparty/mshadow/mshadow/tensor.h
index 8dd57e2..c92bf8d 100644
--- a/3rdparty/mshadow/mshadow/tensor.h
+++ b/3rdparty/mshadow/mshadow/tensor.h
@@ -848,6 +848,19 @@ inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
* \param index index to take
* \param src source output
*/
+template<bool clip = true, typename IndexType, typename DType, typename AType>
+inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
+ Tensor<cpu, 2, AType> temp,
+ const Tensor<cpu, 1, IndexType>& index,
+ const Tensor<cpu, 2, DType> &src);
+/*!
+ * \brief CPU/GPU: Gradient accumulate of embedding matrix with safe
accumulation.
+ dst[index[i]] += src[i]
+ * \param dst destination
+ * \temp temporal storage for safe accumulation
+ * \param index index to take
+ * \param src source output
+ */
template<bool clip = true, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
@@ -861,6 +874,19 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
* \param index original index of the sorted indices
* \param src source output
*/
+template<bool clip = true, typename IndexType, typename DType, typename AType>
+inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
+ Tensor<gpu, 2, AType> temp,
+ const Tensor<gpu, 1, IndexType>& index,
+ const Tensor<gpu, 2, DType> &src);
+/*!
+ * \brief CPU/GPU: Gradient accumulate of embedding matrix with safe
accumulation.
+ dst[index[i]] += src[i]
+ * \param dst destination
+ * \temp temporal storage for safe accumulation
+ * \param index index to take
+ * \param src source output
+ */
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& sorted,
diff --git a/3rdparty/mshadow/mshadow/tensor_cpu-inl.h
b/3rdparty/mshadow/mshadow/tensor_cpu-inl.h
index 2d00220..5f05f0a 100644
--- a/3rdparty/mshadow/mshadow/tensor_cpu-inl.h
+++ b/3rdparty/mshadow/mshadow/tensor_cpu-inl.h
@@ -539,6 +539,39 @@ inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
}
}
+// safe accumulation
+template<bool clip, typename IndexType, typename DType, typename AType>
+inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
+ Tensor<cpu, 2, AType> temp,
+ const Tensor<cpu, 1, IndexType>& index,
+ const Tensor<cpu, 2, DType> &src) {
+ const index_t K = dst.shape_[0];
+ const index_t C = dst.shape_[1];
+ for (index_t j = 0; j < K; ++j) {
+ for (index_t i = 0; i < C; ++i) {
+ temp[j][i] = dst[j][i];
+ }
+ }
+ for (index_t y = 0; y < index.size(0); ++y) {
+ index_t j = index[y];
+ if (clip) {
+ if (j <= 0) j = 0;
+ else if (j >= K) j = K - 1;
+ } else {
+ j %= K;
+ if (j < 0) j += K;
+ }
+ for (index_t i = 0; i < C; ++i) {
+ temp[j][i] += src[y][i];
+ }
+ }
+ for (index_t j = 0; j < K; ++j) {
+ for (index_t i = 0; i < C; ++i) {
+ dst[j][i] = temp[j][i];
+ }
+ }
+}
+
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& sorted,
diff --git a/3rdparty/mshadow/mshadow/tensor_gpu-inl.h
b/3rdparty/mshadow/mshadow/tensor_gpu-inl.h
index e7dde27..3140259 100644
--- a/3rdparty/mshadow/mshadow/tensor_gpu-inl.h
+++ b/3rdparty/mshadow/mshadow/tensor_gpu-inl.h
@@ -239,6 +239,14 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
cuda::AddTakeGrad<clip, IndexType, DType>(dst, index, src);
}
+template<bool clip, typename IndexType, typename DType, typename AType>
+inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
+ Tensor<gpu, 2, AType> temp,
+ const Tensor<gpu, 1, IndexType>& index,
+ const Tensor<gpu, 2, DType> &src) {
+ cuda::AddTakeGrad<clip, IndexType, DType>(dst, temp, index, src);
+}
+
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& sorted,
diff --git a/src/operator/tensor/indexing_op.cu
b/src/operator/tensor/indexing_op.cu
index e3c8a78..f9d7a19 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -684,27 +684,25 @@ __global__ void EmbeddingFindBounds(const IType
*sorted_data,
* \param grad_out output gradient data
* \param embbedding_dim dimension of the dense embedding
* \param vocab_dim maximum number of unique indices in the data array: tokens
vocabulary size
+ * \param nelems_per_load number of elements per each load based on (LType /
DType)
* \param req write/add/null
*/
-template <typename LType, typename DType, typename IType>
+template <typename AType, typename LType, typename DType, typename IType>
__global__ void EmbeddingGradKernel(DType *grad_in,
- const IType *original_index,
- const IType *index_bounds,
- const DType *grad_out,
- const index_t embbedding_dim,
- const index_t vocab_dim,
- const int req) {
+ const IType *original_index,
+ const IType *index_bounds,
+ const DType *grad_out,
+ const index_t embbedding_dim,
+ const index_t vocab_dim,
+ const int nelems_per_load,
+ const int req) {
extern __shared__ int sharedmem[];
- LType* grad_in_row = reinterpret_cast<LType *>(sharedmem);
-
- // LType has to be bigger than DType, guarded in the launcher code
- const int n_val = sizeof(DType) < sizeof(LType) ? sizeof(LType) /
sizeof(DType) : 1;
+ AType* grad_in_row = reinterpret_cast<AType *>(sharedmem);
const LType *aligned_grad_out = reinterpret_cast<const LType *>(grad_out);
LType *aligned_grad_in = reinterpret_cast<LType *>(grad_in);
- const index_t aligned_emb_dim = embbedding_dim / n_val;
- DType *my_grad_in_row = reinterpret_cast<DType *>(&grad_in_row[threadIdx.x]);
- LType Lvalue[1];
- DType* Dvalues = reinterpret_cast<DType*>(Lvalue);
+ const index_t aligned_emb_dim = embbedding_dim / nelems_per_load;
+ LType load_value[1];
+ DType* data_values = reinterpret_cast<DType*>(load_value);
IType my_row = blockIdx.x;
if (my_row < vocab_dim) {
@@ -716,29 +714,37 @@ __global__ void EmbeddingGradKernel(DType *grad_in,
for (index_t emb_id=threadIdx.x; emb_id < aligned_emb_dim; emb_id +=
blockDim.x) {
// Initialize grad_in
if (req == kAddTo) {
- grad_in_row[threadIdx.x] = aligned_grad_in[my_row * aligned_emb_dim +
emb_id];
+ *load_value = aligned_grad_in[my_row * aligned_emb_dim + emb_id];
+ for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
+ grad_in_row[val_id * blockDim.x + threadIdx.x] =
static_cast<AType>(data_values[val_id]);
+ }
} else {
- grad_in_row[threadIdx.x] = 0.0;
+ for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
+ grad_in_row[val_id * blockDim.x + threadIdx.x] =
static_cast<AType>(0.0);
+ }
}
// Add all rows from grad_out according to indices in data
for (index_t data_idx=lower_bound; data_idx < (lower_bound +
nOccurrences); ++data_idx) {
- *Lvalue = aligned_grad_out[original_index[data_idx] * aligned_emb_dim
+ emb_id];
- for (index_t val_id = 0; val_id < n_val; val_id++) {
- my_grad_in_row[val_id] += Dvalues[val_id];
+ *load_value = aligned_grad_out[original_index[data_idx] *
aligned_emb_dim + emb_id];
+ for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
+ grad_in_row[val_id * blockDim.x + threadIdx.x] +=
static_cast<AType>(data_values[val_id]);
}
}
// Save results
- aligned_grad_in[my_row * aligned_emb_dim + emb_id] =
grad_in_row[threadIdx.x];
+ for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
+ data_values[val_id] = static_cast<DType>(grad_in_row[val_id *
blockDim.x + threadIdx.x]);
+ }
+ aligned_grad_in[my_row * aligned_emb_dim + emb_id] = *load_value;
}
}
}
-template<typename gpu, typename IType, typename DType>
+template<typename AType, typename IType, typename DType>
void EmbeddingGradKernelCaller(const OpContext& ctx,
- mshadow::Tensor<gpu, 2, DType> grad_in,
- const mshadow::Tensor<gpu, 1, IType>& index,
- const mshadow::Tensor<gpu, 2, DType> &grad_out,
- const std::vector<OpReqType>& req) {
+ mshadow::Tensor<gpu, 2, DType> grad_in,
+ const mshadow::Tensor<gpu, 1, IType>& index,
+ const mshadow::Tensor<gpu, 2, DType> &grad_out,
+ const std::vector<OpReqType>& req) {
using namespace mxnet_op;
using namespace mshadow::expr;
@@ -792,20 +798,23 @@ void EmbeddingGradKernelCaller(const OpContext& ctx,
// Compute Gradient
int ltype = mxnet::common::cuda::get_load_type(embbedding_dim *
sizeof(DType));
+
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
- int nelems_per_thread = sizeof(LType) / sizeof(DType);
+ CHECK_LE(sizeof(DType), sizeof(LType));
+ int nelems_per_load = sizeof(LType) / sizeof(DType);
int threads_block_grad = 32;
int maxThreads = 1024;
- while (threads_block_grad < (embbedding_dim/nelems_per_thread) &&
+ while (threads_block_grad < (embbedding_dim/nelems_per_load) &&
(threads_block_grad < maxThreads))
threads_block_grad += 32;
- size_t required_shared = threads_block_grad * sizeof(LType);
+ size_t required_shared = threads_block_grad * nelems_per_load *
sizeof(AType);
dim3 blocks(vocab_dim, 1);
- EmbeddingGradKernel<LType><<<blocks, threads_block_grad, required_shared,
+ EmbeddingGradKernel<AType, LType><<<blocks, threads_block_grad,
required_shared,
Stream<gpu>::GetStream(s)>>>(
grad_in.dptr_, original_index.dptr_,
bounds_index.dptr_, grad_out.dptr_,
embbedding_dim, vocab_dim,
+ nelems_per_load,
req[embedding::kWeight]);
});
}
@@ -831,9 +840,17 @@ void EmbeddingOpBackward<gpu>(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& oshape = inputs[0].shape_;
Stream<gpu> *s = ctx.get_stream<gpu>();
+
CHECK_NE(req[embedding::kWeight], kWriteInplace)
<< "Backward of Embedding does not support writing in place.";
- MSHADOW_TYPE_SWITCH(outputs[1].type_flag_, DType, {
+ bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
+ if (!safe_acc && outputs[1].type_flag_ == mshadow::kFloat16) {
+ common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for
EmbeddingOpBackward "
+ "with float16 inputs. "
+ "See https://mxnet.apache.org/api/faq/env_var "
+ "for more details.");
+ }
+ MXNET_REAL_ACC_TYPE_SWITCH(outputs[1].type_flag_, DType, AType, {
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {
Tensor < gpu, 1, IType > data = inputs[1].get_with_shape<gpu, 1, IType>(
Shape1(ishape.ProdShape(0, ishape.ndim())), s);
@@ -842,7 +859,10 @@ void EmbeddingOpBackward<gpu>(const nnvm::NodeAttrs& attrs,
Tensor<gpu, 2, DType> grad_in = outputs[1].get<gpu, 2, DType>(s);
if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] ==
kAddTo) {
- EmbeddingGradKernelCaller(ctx, grad_in, data, grad_out, req);
+ if (safe_acc)
+ EmbeddingGradKernelCaller<AType>(ctx, grad_in, data, grad_out,
req);
+ else
+ EmbeddingGradKernelCaller<DType>(ctx, grad_in, data, grad_out,
req);
} else {
LOG(FATAL) << "wrong req";
}
diff --git a/src/operator/tensor/indexing_op.h
b/src/operator/tensor/indexing_op.h
index 5454900..7f0f0fa 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -400,20 +400,38 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& oshape = inputs[0].shape_;
Stream<xpu> *s = ctx.get_stream<xpu>();
- MSHADOW_TYPE_SWITCH(outputs[1].type_flag_, DType, {
+
+ bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
+ if (!safe_acc && outputs[1].type_flag_ == mshadow::kFloat16) {
+ common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for
EmbeddingOpBackward "
+ "with float16 inputs. "
+ "See https://mxnet.apache.org/api/faq/env_var "
+ "for more details.");
+ }
+ MXNET_REAL_ACC_TYPE_SWITCH(outputs[1].type_flag_, DType, AType, {
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {
Tensor < xpu, 1, IType > data = inputs[1].get_with_shape<xpu, 1, IType>(
Shape1(ishape.ProdShape(0, ishape.ndim())), s);
Tensor<xpu, 2, DType> grad_out = inputs[0].get_with_shape<xpu, 2, DType>(
- Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]),
s);
+ Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]),
s);
Tensor<xpu, 2, DType> grad_in = outputs[1].get<xpu, 2, DType>(s);
-
if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] ==
kAddTo) {
if (req[embedding::kWeight] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
- AddTakeGrad(grad_in, data, grad_out);
+ if (safe_acc) {
+ // Temporary storage for safe accumulation
+ size_t temp_space_size = grad_in.size(0) * grad_in.size(1) *
sizeof(AType);
+ Tensor<xpu, 1, char> temp_space =
+ ctx.requested[embedding::kTempSpace].get_space_typed<xpu, 1, char>(
+ Shape1(temp_space_size), s);
+ Tensor<xpu, 2, AType>
temp_grad_in(reinterpret_cast<AType*>(temp_space.dptr_),
+ grad_in.shape_, s);
+ AddTakeGrad(grad_in, temp_grad_in, data, grad_out);
+ } else {
+ AddTakeGrad(grad_in, data, grad_out);
+ }
} else {
LOG(FATAL) << "wrong req";
}
@@ -696,7 +714,48 @@ struct TakeGradGeneralKernel {
}
};
-template<bool clip = true>
+struct TakeGradGeneralKernelSafeAccumulation {
+ /*!
+ * \brief Map function for general case of take grad
+ * \param tid global thread id
+ * \param arr_grad ptr to in_grad
+ * \param temp ptr to temporal space to perform accumulation
+ * \param ograd ptr to out_grad
+ * \param src_indptr ptr to indptr to src indices
+ * \param original_idx ptr to original indices of the inputs
+ * \param in_strides strides of inputs
+ * \param out_strides strides of outputs
+ * \param in_ndims # of dims of input tensor
+ * \param out_ndims # of dims of output tensor
+ * \param idx_ndims # of dims of indices tensor
+ * \param axis_dim dim size of the axis dimension
+ * \param axis axis id
+ */
+ template<typename DType, typename IType, typename AType>
+ MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, AType* temp,
+ const DType* ograd,
+ const IType* src_indptr, const IType*
original_idx,
+ mshadow::Shape<10> in_strides,
mshadow::Shape<10> out_strides,
+ const int in_ndims, const int out_ndims,
const int idx_ndims,
+ const int axis, const int K) {
+ const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1];
+ const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1];
+ const int in_mid_index = in_rest_index / in_strides[axis];
+ const int in_tail_index = (axis == in_ndims - 1) ?
+ 0 : (in_rest_index % in_strides[axis]);
+ temp[tid] = static_cast<AType>(arr_grad[tid]);
+ for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1];
++i) {
+ int out_mid_index = original_idx[i];
+ out_mid_index = (out_mid_index < 0) ? out_mid_index + K : out_mid_index;
+ int target = in_tail_index + out_mid_index * in_strides[axis];
+ target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1];
+ temp[tid] += ograd[target];
+ }
+ arr_grad[tid] = temp[tid];
+ }
+};
+
+template<bool clip = true, bool safe_acc = false, typename AType>
void TakeOpBackwardImpl(mshadow::Stream<cpu>* s,
const OpContext& ctx,
const TBlob& arr,
@@ -715,14 +774,23 @@ void TakeOpBackwardImpl(mshadow::Stream<cpu>* s,
size_t temp_storage_bytes = SortByKeyWorkspaceSize<int, int,
cpu>(idxshape.Size());
size_t original_idx_bytes = idxshape.Size() * sizeof(int);
size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int);
- size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes +
temp_storage_bytes;
+ size_t temp_accumulation_arrgrad_bytes = 0;
+ if (safe_acc) {
+ temp_accumulation_arrgrad_bytes = arr.Size() * sizeof(AType);
+ }
+ size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes +
+ temp_storage_bytes + temp_accumulation_arrgrad_bytes;
Tensor<cpu, 1, char> workspace =
ctx.requested[0].get_space_typed<cpu, 1, char>(Shape1(workspace_bytes),
s);
- int* sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_);
- int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ +
original_idx_bytes);
- src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 *
original_idx_bytes);
+ AType* temp_accum_arrgrad_ptr = reinterpret_cast<AType*>(workspace.dptr_);
+ int* sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ +
temp_accumulation_arrgrad_bytes);
+ int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ +
original_idx_bytes +
+
temp_accumulation_arrgrad_bytes);
+ src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 *
original_idx_bytes +
+ temp_accumulation_arrgrad_bytes);
Tensor<cpu, 1, char> temp_storage(
- workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes,
Shape1(temp_storage_bytes), s);
+ workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes +
temp_accumulation_arrgrad_bytes,
+ Shape1(temp_storage_bytes), s);
// Reset indptr to zero
Kernel<set_zero, cpu>::Launch(s, arrshape[axis] + 1, src_indptr_ptr);
// Fill original_idx
@@ -759,16 +827,23 @@ void TakeOpBackwardImpl(mshadow::Stream<cpu>* s,
out_strides[i] = stride;
}
MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, {
- Kernel<TakeGradGeneralKernel, cpu>::Launch(
- s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(),
src_indptr_ptr,
- original_idx_ptr, in_strides, out_strides,
- arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis,
static_cast<int>(arrshape[axis]));
+ if (safe_acc) {
+ Kernel<TakeGradGeneralKernelSafeAccumulation, cpu>::Launch(
+ s, arrshape.Size(), arr.dptr<DType>(), temp_accum_arrgrad_ptr,
ograd.dptr<DType>(),
+ src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
+ arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis,
static_cast<int>(arrshape[axis]));
+ } else {
+ Kernel<TakeGradGeneralKernel, cpu>::Launch(
+ s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(),
+ src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
+ arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis,
static_cast<int>(arrshape[axis]));
+ }
});
});
}
#ifdef __CUDACC__
-template<bool clip = true>
+template<bool clip = true, bool safe_acc = false, typename AType>
void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
const OpContext& ctx,
const TBlob& arr,
@@ -808,13 +883,23 @@ void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
temp_storage_bytes = max(temp_storage_bytes, histo_temp_storage_bytes);
size_t original_idx_bytes = idxshape.Size() * sizeof(int);
size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int);
- size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes +
temp_storage_bytes;
+ size_t temp_accumulation_igrad_bytes = 0;
+ if (safe_acc) {
+ temp_accumulation_igrad_bytes = arr.Size() * sizeof(AType);
+ }
+ size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes +
+ temp_storage_bytes + temp_accumulation_igrad_bytes;
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_bytes),
s);
- sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_);
- int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ +
original_idx_bytes);
- src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 *
original_idx_bytes);
- temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes +
src_indptr_bytes;
+ AType* temp_accum_igrad_ptr = reinterpret_cast<AType*>(workspace.dptr_);
+ sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ +
temp_accumulation_igrad_bytes);
+ int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ +
original_idx_bytes +
+
temp_accumulation_igrad_bytes);
+ src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 *
original_idx_bytes +
+ temp_accumulation_igrad_bytes);
+ temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes +
src_indptr_bytes +
+ temp_accumulation_igrad_bytes;
+
// Reset indptr to zero
Kernel<set_zero, gpu>::Launch(s, arrshape[axis] + 1, src_indptr_ptr);
// Fill original_idx
@@ -863,10 +948,19 @@ void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
out_strides[i] = stride;
}
MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, {
- Kernel<TakeGradGeneralKernel, gpu>::Launch(
- s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(),
- src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
- arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis,
static_cast<int>(arrshape[axis]));
+ if (safe_acc) {
+ Kernel<TakeGradGeneralKernelSafeAccumulation, gpu>::Launch(
+ s, arrshape.Size(), arr.dptr<DType>(), temp_accum_igrad_ptr,
ograd.dptr<DType>(),
+ src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
+ arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis,
+ static_cast<int>(arrshape[axis]));
+ } else {
+ Kernel<TakeGradGeneralKernel, gpu>::Launch(
+ s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(),
+ src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
+ arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis,
+ static_cast<int>(arrshape[axis]));
+ }
});
});
}
@@ -891,7 +985,14 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
// grad_in is the gradient of the inputs in the feed-forward
Stream<xpu> *s = ctx.get_stream<xpu>();
- MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type
+ bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
+ if (!safe_acc && outputs[0].type_flag_ == mshadow::kFloat16) {
+ common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for
TakeOpBackward "
+ "with float16 inputs. "
+ "See https://mxnet.apache.org/api/faq/env_var "
+ "for more details.");
+ }
+ MXNET_REAL_ACC_TYPE_SWITCH(outputs[0].type_flag_, DType, AType, {
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type
// inputs are specified in the .cc file, which are the gradients from
// the upper layer and the input index
@@ -925,10 +1026,25 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
if (req[take_::kArr] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
- if (param.mode == take_::kClip) {
- AddTakeGrad(grad_in, idx, grad_out);
+ if (safe_acc) {
+ // Temporary storage for safe accumulation
+ size_t temp_space_size = grad_in.size(0) * grad_in.size(1) *
sizeof(AType);
+ Tensor<xpu, 1, char> temp_space =
+ ctx.requested[take_::kTempSpace].get_space_typed<xpu, 1, char>(
+ Shape1(temp_space_size), s);
+ Tensor<xpu, 2, AType>
temp_grad_in(reinterpret_cast<AType*>(temp_space.dptr_),
+ grad_in.shape_, s);
+ if (param.mode == take_::kClip) {
+ AddTakeGrad(grad_in, temp_grad_in, idx, grad_out);
+ } else {
+ AddTakeGrad<false>(grad_in, temp_grad_in, idx, grad_out);
+ }
} else {
- AddTakeGrad<false>(grad_in, idx, grad_out);
+ if (param.mode == take_::kClip) {
+ AddTakeGrad(grad_in, idx, grad_out);
+ } else {
+ AddTakeGrad<false>(grad_in, idx, grad_out);
+ }
}
} else {
LOG(FATAL) << "wrong req";
@@ -939,10 +1055,18 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
const TBlob& arr = outputs[0];
const TBlob& ograd = inputs[0];
- if (param.mode == take_::kClip) {
- TakeOpBackwardImpl<true>(s, ctx, arr, idx, ograd, actual_axis);
+ if (safe_acc) {
+ if (param.mode == take_::kClip) {
+ TakeOpBackwardImpl<true, true, AType>(s, ctx, arr, idx, ograd,
actual_axis);
+ } else {
+ TakeOpBackwardImpl<false, true, AType>(s, ctx, arr, idx, ograd,
actual_axis);
+ }
} else {
- TakeOpBackwardImpl<false>(s, ctx, arr, idx, ograd, actual_axis);
+ if (param.mode == take_::kClip) {
+ TakeOpBackwardImpl<true, false, AType>(s, ctx, arr, idx, ograd,
actual_axis);
+ } else {
+ TakeOpBackwardImpl<false, false, AType>(s, ctx, arr, idx, ograd,
actual_axis);
+ }
}
}
});
diff --git a/tests/python/gpu/test_operator_gpu.py
b/tests/python/gpu/test_operator_gpu.py
index 84cfe9c..519c02f 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1522,19 +1522,23 @@ def test_lrn():
reason="Testing with naive engine consistently triggers
illegal memory access. Tracked in #17713")
def test_embedding_with_type():
def test_embedding_helper(data_types, weight_types, low_pad, high_pad):
- NVD = [[20, 10, 20], [200, 10, 300]]
- for N, V, D in NVD:
- sym = mx.sym.Embedding(name='embedding', input_dim=V, output_dim=D)
- ctx_list = []
- for data_type in data_types:
- for weight_type in weight_types:
- ctx_list.append({'ctx': mx.gpu(0), 'embedding_data': (N,),
- 'type_dict': {'embedding_data': data_type,
'embedding_weight': weight_type}})
- ctx_list.append({'ctx': mx.cpu(0), 'embedding_data': (N,),
- 'type_dict': {'embedding_data': data_type,
'embedding_weight': weight_type}})
- arg_params = {'embedding_data': np.random.randint(low=-low_pad,
high=V+high_pad, size=(N,))}
- check_consistency(sym, ctx_list, grad_req={'embedding_data':
'null','embedding_weight': 'write'},
- arg_params=arg_params, scale=0.1)
+ NVD = [[20, 10, 20], [200, 10, 300], [10000, 4, 20]]
+ for safe_accumulation in ['0', '1', None]:
+ for N, V, D in NVD:
+ with environment('MXNET_SAFE_ACCUMULATION', safe_accumulation):
+ if N > 1000 and safe_accumulation != '1':
+ break
+ sym = mx.sym.Embedding(name='embedding', input_dim=V,
output_dim=D)
+ ctx_list = []
+ for data_type in data_types:
+ for weight_type in weight_types:
+ ctx_list.append({'ctx': mx.gpu(0),
'embedding_data': (N,),
+ 'type_dict': {'embedding_data': data_type,
'embedding_weight': weight_type}})
+ ctx_list.append({'ctx': mx.cpu(0),
'embedding_data': (N,),
+ 'type_dict': {'embedding_data': data_type,
'embedding_weight': weight_type}})
+ arg_params = {'embedding_data':
np.random.randint(low=-low_pad, high=V+high_pad, size=(N,))}
+ check_consistency(sym, ctx_list,
grad_req={'embedding_data': 'null','embedding_weight': 'write'},
+ arg_params=arg_params, scale=0.1)
data_types = [np.float16, np.float32, np.float64, np.int32]
weight_types = [np.float16, np.float32, np.float64]
@@ -1547,47 +1551,91 @@ def test_embedding_with_type():
@with_seed()
def test_take_with_type():
sym = mx.sym.take(name='take')
- for data_ndim in range(2, 5):
- for idx_ndim in range(1, 4):
- data_shape = ()
- for _ in range(data_ndim):
- data_shape += (np.random.randint(low=3, high=6), )
- idx_shape = ()
- for _ in range(idx_ndim):
- idx_shape += (np.random.randint(low=3, high=5), )
- ctx_list = [{'ctx': mx.gpu(0), 'take_indices': idx_shape,
- 'take_a': data_shape,
- 'type_dict': {'take_indices': np.float64,
- 'take_a': np.float64}},
- {'ctx': mx.gpu(0), 'take_indices': idx_shape,
- 'take_a': data_shape,
- 'type_dict': {'take_indices': np.float32,
- 'take_a': np.float32}},
- {'ctx': mx.gpu(0), 'take_indices': idx_shape,
- 'take_a': data_shape,
- 'type_dict': {'take_indices': np.float16,
- 'take_a': np.float16}},
- {'ctx': mx.cpu(0), 'take_indices': idx_shape,
- 'take_a': data_shape,
- 'type_dict': {'take_indices': np.float64,
- 'take_a': np.float64}},
- {'ctx': mx.cpu(0), 'take_indices': idx_shape,
- 'take_a': data_shape,
- 'type_dict': {'take_indices': np.float32,
- 'take_a': np.float32}},
- {'ctx': mx.cpu(0), 'take_indices': idx_shape,
- 'take_a': data_shape,
- 'type_dict': {'take_indices': np.float16,
- 'take_a': np.float16}}]
- arg_params = {'take_indices': np.random.randint(low=0,
- high=data_shape[0],
- size=idx_shape),
- 'take_a': np.random.normal(size=data_shape)}
- check_consistency(sym, ctx_list,
- grad_req={'take_indices': 'null',
- 'take_a': 'write'},
- arg_params=arg_params)
-
+ for safe_accumulation in ['0', '1', None]:
+ for data_ndim in range(2, 5):
+ for idx_ndim in range(1, 4):
+ data_shape = ()
+ for _ in range(data_ndim):
+ data_shape += (np.random.randint(low=3, high=6), )
+ idx_shape = ()
+ for _ in range(idx_ndim):
+ idx_shape += (np.random.randint(low=3, high=5), )
+ ctx_list = [{'ctx': mx.gpu(0), 'take_indices': idx_shape,
+ 'take_a': data_shape,
+ 'type_dict': {'take_indices': np.float64,
+ 'take_a': np.float64}},
+ {'ctx': mx.gpu(0), 'take_indices': idx_shape,
+ 'take_a': data_shape,
+ 'type_dict': {'take_indices': np.float32,
+ 'take_a': np.float32}},
+ {'ctx': mx.gpu(0), 'take_indices': idx_shape,
+ 'take_a': data_shape,
+ 'type_dict': {'take_indices': np.float16,
+ 'take_a': np.float16}},
+ {'ctx': mx.cpu(0), 'take_indices': idx_shape,
+ 'take_a': data_shape,
+ 'type_dict': {'take_indices': np.float64,
+ 'take_a': np.float64}},
+ {'ctx': mx.cpu(0), 'take_indices': idx_shape,
+ 'take_a': data_shape,
+ 'type_dict': {'take_indices': np.float32,
+ 'take_a': np.float32}},
+ {'ctx': mx.cpu(0), 'take_indices': idx_shape,
+ 'take_a': data_shape,
+ 'type_dict': {'take_indices': np.float16,
+ 'take_a': np.float16}}]
+ arg_params = {'take_indices': np.random.randint(low=0,
+
high=data_shape[0],
+
size=idx_shape),
+ 'take_a': np.random.normal(size=data_shape)}
+ with environment('MXNET_SAFE_ACCUMULATION', safe_accumulation):
+ check_consistency(sym, ctx_list,
+ grad_req={'take_indices': 'null',
+ 'take_a': 'write'},
+ arg_params=arg_params)
+
+ # check a large num of indices: may underflow calculating gradient in FP16,
+ # if MXNET_SAFE_ACCUMULATION is not activated
+ with environment('MXNET_SAFE_ACCUMULATION', '1'):
+ data_size = 4
+ indices_size = 10000
+ out_dim = 20
+ data_types = [np.float16, np.float32, np.float64]
+ indices_types = [np.float16, np.float32, np.float64, np.int32]
+ # axis 0
+ sym = mx.sym.take(name='take', axis=0)
+ ctx_list = []
+ for data_type in data_types:
+ for index_type in indices_types:
+ ctx_list.append({'ctx': mx.cpu(0), 'take_indices':
(indices_size,),
+ 'take_a': (data_size, out_dim),
+ 'type_dict': {'take_indices': index_type, 'take_a':
data_type}})
+ ctx_list.append({'ctx': mx.gpu(0), 'take_indices':
(indices_size,),
+ 'take_a': (data_size, out_dim),
+ 'type_dict': {'take_indices': index_type, 'take_a':
data_type}})
+ arg_params = {'take_indices': np.random.randint(0, data_size,
+ size=(indices_size,)),
+ 'take_a': np.random.normal(size=(data_size,
out_dim))}
+ check_consistency(sym, ctx_list,
+ grad_req={'take_indices': 'null','take_a':
'write'},
+ arg_params=arg_params)
+ # axis 1
+ sym = mx.sym.take(name='take', axis=1)
+ ctx_list = []
+ for data_type in data_types:
+ for index_type in indices_types:
+ ctx_list.append({'ctx': mx.cpu(0), 'take_indices':
(indices_size,),
+ 'take_a': (data_size, out_dim),
+ 'type_dict': {'take_indices': index_type, 'take_a':
data_type}})
+ ctx_list.append({'ctx': mx.gpu(0), 'take_indices':
(indices_size,),
+ 'take_a': (data_size, out_dim),
+ 'type_dict': {'take_indices': index_type, 'take_a':
data_type}})
+ arg_params = {'take_indices': np.random.randint(0, data_size,
+ size=(indices_size,)),
+ 'take_a': np.random.normal(size=(data_size,
out_dim))}
+ check_consistency(sym, ctx_list,
+ grad_req={'take_indices': 'null','take_a':
'write'},
+ arg_params=arg_params)
@with_seed()
@pytest.mark.serial