eric-haibin-lin closed pull request #9846: [WIP] Fix non-determinism in sparse 
embedding
URL: https://github.com/apache/incubator-mxnet/pull/9846
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/indexing_op-inl.cuh 
b/src/operator/tensor/indexing_op-inl.cuh
index 4458151f178..4df1fd451ec 100644
--- a/src/operator/tensor/indexing_op-inl.cuh
+++ b/src/operator/tensor/indexing_op-inl.cuh
@@ -38,7 +38,7 @@ namespace mxnet {
 namespace op {
 const int kWarpSize = 32;
 
-template<int SZ, typename DType, typename IdxType>
+template<int SZ, bool lookup, typename DType, typename IdxType>
 __global__ void AddTakeGradLargeBatchKernel(DType* dst,
                                            // If idx_start == NULL, then 
in-kernel edge
                                            // detection is used
@@ -47,7 +47,9 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst,
                                            const int* idx_start_size_ptr,
                                            const IdxType *sorted, const 
IdxType *index,
                                            const DType *src,
-                                           int ymax, int xmax) {
+                                           int ymax, int xmax,
+                                           // table to look up positions of 
row_ids in dst
+                                           const nnvm::dim_t *lookup_table) {
   // Size of the shared memory is [blockDim.x*SZ*blockDim.y]*sizeof(DType)
   extern __shared__ char sh_grad_weight_char[];
   DType* sh_grad_weight = (DType*)sh_grad_weight_char;
@@ -125,7 +127,8 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst,
     }
 
     const int start_feature = threadIdx.x + blockIdx.x * blockDim.x * SZ;
-    const int dst_row = sorted_value * xmax;
+    // Lookup inclusive prefix sum table if necessary
+    const int dst_row = (lookup ? (lookup_table[sorted_value] - 1) : 
sorted_value) * xmax;
 
     int num_idx = idx_end - idx_begin;
     int idx0 = idx_begin + threadIdx.y*num_idx/blockDim.y;
@@ -179,7 +182,6 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst,
         }
       }
     }
-
   }
 }
 
@@ -199,6 +201,73 @@ AddTakeGradLargeBatchWorkspaceSize(size_t num_keys) {
   return (unique_bytes + counts_bytes + num_runs_bytes + temporary_bytes);
 }
 
+template<bool lookup, typename IndexType, typename DType>
+inline void AddTakeGradLargeBatchKernelLaunch(mshadow::Tensor<gpu, 2, DType> 
dst,
+                                              const mshadow::Tensor<gpu, 1, 
IndexType>& sorted,
+                                              const mshadow::Tensor<gpu, 1, 
IndexType>& index,
+                                              const mshadow::Tensor<gpu, 2, 
DType> &src,
+                                              IndexType* sum_counts_ptr,
+                                              int* num_runs_ptr,
+                                              const nnvm::dim_t* lookup_table) 
{
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(dst.stream_);
+  const int num_unique_est = min(dst.size(0), src.size(0));
+  const int max_nthread = 128;
+  const int num_y = max(src.size(0)/num_unique_est, 1);
+  const int block_dim_x = kWarpSize;
+  const int block_dim_y = min(num_y, max_nthread/block_dim_x);
+  const int SZ = min((src.size(1) + block_dim_x - 1) / block_dim_x, 4);
+  const int grid_dim_x = (src.size(1) + block_dim_x * SZ - 1) / (block_dim_x * 
SZ);
+  const int grid_dim_y = min(num_unique_est, mshadow::cuda::kBaseGridNum);
+  dim3 dimBlock(block_dim_x, block_dim_y);
+  dim3 dimGrid(grid_dim_x, grid_dim_y);
+  // Maximum shared memory usage: 128*4*sizeof(DType), which is 4K for 64bit 
DType elements
+  int shmem_size = dimBlock.x*SZ*dimBlock.y*sizeof(DType);
+
+  CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGradLargeBatch: shape 
mismatch";
+  CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGradLargeBatch: shape 
mismatch";
+  mshadow::cuda::CheckLaunchParam(dimGrid, dimBlock, "AddTakeGradLargeBatch");
+
+  switch (SZ) {
+    case 1:
+    AddTakeGradLargeBatchKernel<1, lookup, DType>
+        <<<dimGrid, dimBlock, shmem_size, stream>>>
+        (dst.dptr_, sum_counts_ptr, num_runs_ptr,
+         sorted.dptr_, index.dptr_, src.dptr_,
+         static_cast<int>(src.size(0)),
+         static_cast<int>(src.size(1)), lookup_table);
+    break;
+    case 2:
+    AddTakeGradLargeBatchKernel<2, lookup, DType>
+        <<<dimGrid, dimBlock, shmem_size, stream>>>
+        (dst.dptr_, sum_counts_ptr, num_runs_ptr,
+         sorted.dptr_, index.dptr_, src.dptr_,
+         static_cast<int>(src.size(0)),
+         static_cast<int>(src.size(1)), lookup_table);
+    break;
+    case 3:
+    AddTakeGradLargeBatchKernel<3, lookup, DType>
+        <<<dimGrid, dimBlock, shmem_size, stream>>>
+        (dst.dptr_, sum_counts_ptr, num_runs_ptr,
+         sorted.dptr_, index.dptr_, src.dptr_,
+         static_cast<int>(src.size(0)),
+         static_cast<int>(src.size(1)), lookup_table);
+    break;
+    case 4:
+    AddTakeGradLargeBatchKernel<4, lookup, DType>
+        <<<dimGrid, dimBlock, shmem_size, stream>>>
+        (dst.dptr_, sum_counts_ptr, num_runs_ptr,
+         sorted.dptr_, index.dptr_, src.dptr_,
+         static_cast<int>(src.size(0)),
+         static_cast<int>(src.size(1)), lookup_table);
+    break;
+    default:
+    LOG(FATAL) << "AddTakeGradLargeBatch, incorrect value SZ " << SZ;
+    break;
+  }
+  MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradLargeBatchKernel);
+}
+
+
 template<typename IndexType, typename DType>
 inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
                                   const mshadow::Tensor<gpu, 1, IndexType>& 
sorted,
@@ -249,62 +318,9 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, 
DType> dst,
     (temporary_storage, temporary_bytes, counts_out_ptr, sum_counts_ptr,
       sorted.size(0), stream);
   }
-
-  const int num_unique_est = min(dst.size(0), src.size(0));
-  const int max_nthread = 128;
-  const int num_y = max(src.size(0)/num_unique_est, 1);
-  const int block_dim_x = kWarpSize;
-  const int block_dim_y = min(num_y, max_nthread/block_dim_x);
-  const int SZ = min((src.size(1) + block_dim_x - 1) / block_dim_x, 4);
-  const int grid_dim_x = (src.size(1) + block_dim_x * SZ - 1) / (block_dim_x * 
SZ);
-  const int grid_dim_y = min(num_unique_est, mshadow::cuda::kBaseGridNum);
-  dim3 dimBlock(block_dim_x, block_dim_y);
-  dim3 dimGrid(grid_dim_x, grid_dim_y);
-  // Maximum shared memory usage: 128*4*sizeof(DType), which is 4K for 64bit 
DType elements
-  int shmem_size = dimBlock.x*SZ*dimBlock.y*sizeof(DType);
-
-  CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGradLargeBatch: shape 
mismatch";
-  CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGradLargeBatch: shape 
mismatch";
-  mshadow::cuda::CheckLaunchParam(dimGrid, dimBlock, "AddTakeGradLargeBatch");
-
-  switch (SZ) {
-    case 1:
-    AddTakeGradLargeBatchKernel<1, DType>
-        <<<dimGrid, dimBlock, shmem_size, stream>>>
-        (dst.dptr_, sum_counts_ptr, num_runs_ptr,
-         sorted.dptr_, index.dptr_, src.dptr_,
-         static_cast<int>(src.size(0)),
-         static_cast<int>(src.size(1)));
-    break;
-    case 2:
-    AddTakeGradLargeBatchKernel<2, DType>
-        <<<dimGrid, dimBlock, shmem_size, stream>>>
-        (dst.dptr_, sum_counts_ptr, num_runs_ptr,
-         sorted.dptr_, index.dptr_, src.dptr_,
-         static_cast<int>(src.size(0)),
-         static_cast<int>(src.size(1)));
-    break;
-    case 3:
-    AddTakeGradLargeBatchKernel<3, DType>
-        <<<dimGrid, dimBlock, shmem_size, stream>>>
-        (dst.dptr_, sum_counts_ptr, num_runs_ptr,
-         sorted.dptr_, index.dptr_, src.dptr_,
-         static_cast<int>(src.size(0)),
-         static_cast<int>(src.size(1)));
-    break;
-    case 4:
-    AddTakeGradLargeBatchKernel<4, DType>
-        <<<dimGrid, dimBlock, shmem_size, stream>>>
-        (dst.dptr_, sum_counts_ptr, num_runs_ptr,
-         sorted.dptr_, index.dptr_, src.dptr_,
-         static_cast<int>(src.size(0)),
-         static_cast<int>(src.size(1)));
-    break;
-    default:
-    LOG(FATAL) << "AddTakeGradLargeBatch, incorrect value SZ " << SZ;
-    break;
-  }
-  MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradLargeBatchKernel);
+  nnvm::dim_t* lookup_table = nullptr;
+  AddTakeGradLargeBatchKernelLaunch<false>(dst, sorted, index, src, 
sum_counts_ptr,
+                                           num_runs_ptr, lookup_table);
 }
 
 }  // namespace op
diff --git a/src/operator/tensor/indexing_op.cu 
b/src/operator/tensor/indexing_op.cu
index 762d8fd64c2..223a1230365 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -41,25 +41,6 @@ struct is_valid_check {
   }
 };
 
-
-struct AddTakeGradRspGPUKernel {
-  template<typename DType, typename IType>
-  __device__ __forceinline__ static void Map(int tid,
-                                             DType* out,
-                                             const nnvm::dim_t* prefix_sum,
-                                             const IType* data,
-                                             const DType* ograd,
-                                             const nnvm::dim_t row_length) {
-    using nnvm::dim_t;
-    const dim_t data_i = tid / row_length;
-    const dim_t grad_i = tid % row_length;
-    const dim_t irow = static_cast<dim_t>(data[data_i]);
-    const dim_t rsp_row = prefix_sum[irow] - 1;
-    const DType val = ograd[data_i * row_length + grad_i];
-    atomicAdd(static_cast<DType *>(&(out[rsp_row*row_length+grad_i])), val);
-  }
-};
-
 template<>
 void SparseEmbeddingOpForwardRspImpl<gpu>(const OpContext& ctx,
                                           const TBlob& data,
@@ -103,7 +84,6 @@ void SparseEmbeddingOpForwardRspImpl<gpu>(const OpContext& 
ctx,
   }
 }
 
-
 template<>
 inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
                                                   const TBlob& ograd,
@@ -125,55 +105,98 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const 
OpContext& ctx,
   dim_t row_length = output.shape()[1];
   dim_t data_size = static_cast<dim_t>(data.shape_.Size());
   dim_t num_threads;
-
+  if (data_size == 0) {
+    FillZerosRspImpl(s, output);
+    return;
+  }
   MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
-    MSHADOW_SGL_DBL_TYPE_SWITCH(ograd.type_flag_, DType, {
+    MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
       MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, {
         dim_t* prefix_sum = NULL;
-        void* d_temp_storage = NULL;
-        size_t temp_storage_bytes = 0;
-        cub::DeviceScan::InclusiveSum(d_temp_storage,
-                                      temp_storage_bytes,
+        void* temp_storage = NULL;
+        dim_t* sorted_data = NULL;
+        dim_t* original_idx = NULL;
+        // calculate resource bytes
+        size_t row_flg_storage_bytes = num_rows * sizeof(dim_t);
+        size_t sorted_data_storage_bytes = data_size * sizeof(dim_t);
+        size_t original_idx_storage_bytes = data_size * sizeof(dim_t);
+        size_t sum_workspace_bytes = 0;
+        size_t sort_workspace_size = SortByKeyWorkspaceSize<dim_t, dim_t, 
gpu>(data_size);
+        cub::DeviceScan::InclusiveSum(temp_storage,
+                                      sum_workspace_bytes,
                                       prefix_sum,
                                       prefix_sum,
                                       num_rows,
                                       Stream<gpu>::GetStream(s));
+        // temp_workspace is shared by inclusive sum and sort
+        size_t temp_workspace_bytes = std::max(sum_workspace_bytes, 
sort_workspace_size);
+        size_t total_storage_bytes = row_flg_storage_bytes + 
sorted_data_storage_bytes +
+                                     original_idx_storage_bytes + 
temp_workspace_bytes;
+
+        // request resource and split it. layout =
+        // row_flg/prefixsum, sorted_data, original_idx, temp_storage
         Tensor<gpu, 1, char> workspace = ctx.requested[0]
-            .get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(dim_t) +
-                                           temp_storage_bytes), s);
+            .get_space_typed<gpu, 1, char>(Shape1(total_storage_bytes), s);
         prefix_sum = reinterpret_cast<dim_t*>(workspace.dptr_);
-        d_temp_storage = workspace.dptr_ + num_rows*sizeof(dim_t);
+        sorted_data = reinterpret_cast<dim_t*>(workspace.dptr_ + 
row_flg_storage_bytes);
+        original_idx = reinterpret_cast<dim_t*>(workspace.dptr_ + 
row_flg_storage_bytes +
+                                                sorted_data_storage_bytes);
+        temp_storage = workspace.dptr_ + total_storage_bytes - 
temp_workspace_bytes;
+        // compute row flags and prefix sum
         num_threads = num_rows;
         Fill<false>(s, TBlob(prefix_sum, Shape1(num_threads), gpu::kDevMask), 
kWriteTo, 0);
         Kernel<MarkRowFlgKernel, gpu>::Launch(s, data_size, prefix_sum, 
data.dptr<IType>());
-
-        cub::DeviceScan::InclusiveSum(d_temp_storage,
-                                      temp_storage_bytes,
+        cub::DeviceScan::InclusiveSum(temp_storage,
+                                      temp_workspace_bytes,
                                       prefix_sum,
                                       prefix_sum,
                                       num_rows,
                                       mshadow::Stream<gpu>::GetStream(s));
+        // retrieve nnr and allocate output
         dim_t nnr = 0;
         CUDA_CALL(cudaMemcpy(&nnr, &prefix_sum[num_rows-1], sizeof(dim_t),
             cudaMemcpyDeviceToHost));
-
-        if (nnr == 0) {
-          FillZerosRspImpl(s, output);
-          return;
-        }
         output.CheckAndAlloc({Shape1(nnr)});
-        RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
         // fill row_idx array of output matrix, using the row_flg values
+        RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
         Kernel<FillRspRowIdxKernel, gpu>::Launch(s, num_rows,
             grad_row_idx, prefix_sum, num_rows);
-        // prefill with zeros
+
+        // make a copy of the data, to be sorted
+        TBlob sorted_data_blob(sorted_data, Shape1(data_size), gpu::kDevMask);
+        auto sorted_data_tensor = sorted_data_blob.FlatTo1D<gpu, dim_t>(s);
+        mxnet_op::copy(s, sorted_data_blob, data);
+
+        // generate original idx
+        Tensor<gpu, 1, dim_t> original_idx_tensor(original_idx, 
Shape1(data_size), s);
+        Kernel<range_fwd, gpu>::Launch(s, data_size, 1, static_cast<dim_t>(0),
+                                       static_cast<dim_t>(1), kWriteTo, 
original_idx);
+        // sort data with its original idx
+        int num_bits = ilog2(num_rows - 1);
+        char* temp_storage_ptr = reinterpret_cast<char*>(temp_storage);
+        Tensor<gpu, 1, char> temp_storage_tensor(temp_storage_ptr,
+                                                 Shape1(sort_workspace_size), 
s);
+        SortByKey(sorted_data_tensor, original_idx_tensor, true,
+                  &temp_storage_tensor, 0, num_bits);
+        // accumulate gradients
         DType* grad_data = output.data().dptr<DType>();
         Fill<false>(s, TBlob(grad_data, Shape1(nnr * row_length), 
gpu::kDevMask),
             kWriteTo, 0);
-        // add the final gradients
-        num_threads = row_length * data_size;
-        Kernel<AddTakeGradRspGPUKernel, gpu>::Launch(s, num_threads, 
grad_data, prefix_sum,
-            data.dptr<IType>(), ograd.dptr<DType>(), row_length);
+
+        // reuse dense op backward kernel
+        {
+          dim_t* sum_counts_ptr = NULL;
+          int* num_runs_ptr = NULL;
+          mshadow::Tensor<gpu, 2, DType> dst = output.data().get<gpu, 2, 
DType>(s);
+          mshadow::Tensor<gpu, 1, dim_t> sorted = sorted_data_tensor;
+          mshadow::Tensor<gpu, 1, dim_t> index = original_idx_tensor;
+          const auto oshape = ograd.shape_;
+          mshadow::Tensor<gpu, 2, DType> src = ograd.get_with_shape<gpu, 2, 
DType>(
+              Shape2(oshape.ProdShape(0, oshape.ndim()-1), 
oshape[oshape.ndim()-1]), s);
+          nnvm::dim_t* lookup_table = prefix_sum;
+          AddTakeGradLargeBatchKernelLaunch<true>(dst, sorted, index, src, 
sum_counts_ptr,
+                                                  num_runs_ptr, lookup_table);
+        }
       });
     });
   });
diff --git a/tests/python/unittest/test_sparse_operator.py 
b/tests/python/unittest/test_sparse_operator.py
index 54809d9419d..9441cc7f64e 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1634,7 +1634,7 @@ def check_sparse_elementwise_sum_with_shape(stype, shape, 
n):
 
 @with_seed()
 def test_sparse_embedding():
-    ''' test sparse embedding op on cpu '''
+    ''' test sparse embedding operator '''
     def check_sparse_embedding(executor, weight_ref, data_onehot, grad, 
density):
         # update weight based on density
         weight[:] = rand_ndarray(weight.shape, 'row_sparse', density=density)
@@ -1665,7 +1665,7 @@ def check_sparse_embedding(executor, weight_ref, 
data_onehot, grad, density):
     arg_map["data"][:] = np_data
     # init grad
     np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape)
-    grad = mx.nd.sparse.zeros('row_sparse', np_grad.shape)
+    grad = mx.nd.zeros(np_grad.shape)
     grad[:] = np_grad
     # weight
     weight = arg_map["embed_weight"]


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to