piiswrong closed pull request #10997: speed up of topk-operator
URL: https://github.com/apache/incubator-mxnet/pull/10997
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/ordering_op-inl.h 
b/src/operator/tensor/ordering_op-inl.h
index 606406dfe0b..105ee8b90db 100644
--- a/src/operator/tensor/ordering_op-inl.h
+++ b/src/operator/tensor/ordering_op-inl.h
@@ -152,6 +152,174 @@ inline void ParseTopKParam(const TShape& src_shape, const 
TopKParam& param, TSha
                                       << *element_num << ", get k = " << *k;
 }
 
+using namespace mshadow;
+
+template<typename xpu>
+void TopKSort(const Tensor<xpu, 1, real_t>& dat,
+              const Tensor<xpu, 1, int>& ind,
+              const Tensor<xpu, 1, char>& work,
+              int K, int N, bool is_ascend,
+              Stream<xpu> *s);
+
+template<>
+MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
+                                        const Tensor<cpu, 1, int>& ind,
+                                        const Tensor<cpu, 1, char>& work,
+                                        int K, int N, bool is_ascend,
+                                        Stream<cpu> *s) {
+  // Use full sort when K is relatively large.
+  const bool full_sort(K*8 > N);
+  // Batch size.
+  const int M(dat.size(0)/N);
+  const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < M; ++i) {
+    real_t *vals = dat.dptr_;
+    int *indices = ind.dptr_+i*N;
+    if (is_ascend) {
+      if (full_sort) {
+        std::sort(indices, indices+N,
+                  [&](const int& i1, const int& i2){ return vals[i1] < 
vals[i2]; });
+      } else {
+        std::partial_sort(indices, indices+K, indices+N,
+                          [&](const int& i1, const int& i2){ return vals[i1] < 
vals[i2]; });
+      }
+    } else {
+      if (full_sort) {
+        std::sort(indices, indices+N,
+                  [&](const int& i1, const int& i2){ return vals[i1] > 
vals[i2]; });
+      } else {
+        std::partial_sort(indices, indices+K, indices+N,
+                          [&](const int& i1, const int& i2){ return vals[i1] > 
vals[i2]; });
+      }
+    }
+    real_t *buff = reinterpret_cast<real_t*>(work.dptr_)+i*K;
+    for (int j = 0; j < K; ++j) {
+      buff[j] = vals[indices[j]];
+    }
+    std::copy(buff, buff+K, &vals[i*N]);
+  }
+}
+
+#ifdef __CUDACC__
+
+template<typename DType>
+MSHADOW_XINLINE bool TopKCompare(DType val1, int ind1, DType val2, int ind2, 
bool is_ascend) {
+  // Negative indices denote undefined values which are considered arbitrary 
small resp. large.
+  return (ind2 < 0) || (ind1 >= 0 && ((is_ascend && val1 < val2) || 
(!is_ascend && val1 > val2)));
+}
+
+template<typename DType>
+MSHADOW_XINLINE void MergeTopK(int K, DType *val1, int *ind1, DType *val2, int 
*ind2,
+                               bool is_ascend) {
+  // In-place merge of two sorted top-K lists into val1/ind1. First determine 
the intervals
+  // [0,..,i1], [0,..i2] of the two lists that will be part of the merged list.
+  int i1(K-1), i2(K-1);
+  for (int i = 0; i < K; ++i) {
+    if (TopKCompare(val1[i1], ind1[i1], val2[i2], ind2[i2], is_ascend)) {
+      --i2;
+    } else {
+      --i1;
+    }
+  }
+  // Now merge the lists from back to front.
+  for (int i = K; i--;) {
+    if (i2 < 0 || i1 >= 0 && TopKCompare(val2[i2], ind2[i2], val1[i1], 
ind1[i1], is_ascend)) {
+      val1[i] = val1[i1];
+      ind1[i] = ind1[i1];
+      --i1;
+    } else {
+      val1[i] = val2[i2];
+      ind1[i] = ind2[i2];
+      --i2;
+    }
+  }
+}
+
+template<typename DType>
+__global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool 
is_ascend) {
+  // Buffer for pairwise reduction.
+  extern __shared__ int buff[];
+  // Start of buffer sections associated with this thread.
+  const int offset(threadIdx.x*K);
+  int *ind_buff = &buff[offset];
+  DType *val_buff = reinterpret_cast<DType*>(&buff[blockDim.x*K])+offset;
+  // Initialize top-K values for this thread.
+  for (int i = 0; i < K; ++i) {
+    ind_buff[i] = -1;
+  }
+  // Range of values this thread cares about. Each thread block processes
+  // a different batch item (i.e. a different set of ind/val where we
+  // have to select the top-K elements). All threads within the same
+  // block work on the same batch item.
+  const int first(blockIdx.x*N+threadIdx.x), last((blockIdx.x+1)*N);
+  // Select top-K from this range and store it sorted in the buffer.
+  // We assume a small K, so linear insertion is o.k.
+  for (int i = first; i < last; i += blockDim.x) {
+    DType cur_val(val[i]);
+    int cur_ind(ind[i]);
+    for (int j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j], 
ind_buff[j], is_ascend); ) {
+      if (j+1 < K) {
+        val_buff[j+1] = val_buff[j];
+        ind_buff[j+1] = ind_buff[j];
+      }
+      val_buff[j] = cur_val;
+      ind_buff[j] = cur_ind;
+    }
+  }
+  // Recursive merge of sorted lists for this thread block. Note that 
blockDim.x is not
+  // necessary a power of two, therefore the additional checks for last_s.
+  for (unsigned int s = (blockDim.x+1)/2, last_s = blockDim.x;
+       last_s > 1; last_s = s, s = (s+1)/2) {
+    __syncthreads();
+    if (threadIdx.x < s && threadIdx.x+s < last_s) {
+      MergeTopK(K, val_buff, ind_buff, val_buff+s*K, ind_buff+s*K, is_ascend);
+    }
+  }
+  // Final updates on master thread.
+  if (threadIdx.x == 0) {
+    for (int i = 0; i < K; ++i) {
+      ind[blockIdx.x*N+i] = ind_buff[i];
+      val[blockIdx.x*N+i] = val_buff[i];
+    }
+  }
+}
+
+template<>
+MSHADOW_FORCE_INLINE void TopKSort<gpu>(const Tensor<gpu, 1, real_t>& dat,
+                                        const Tensor<gpu, 1, int>& ind,
+                                        const Tensor<gpu, 1, char>& work,
+                                        int K, int N, bool is_ascend,
+                                        Stream<gpu> *s) {
+  // Use full sort for all but very small K for which we
+  // can do a partial sort entirely within shared memory.
+  const bool full_sort(K > 5);
+  // Batch size.
+  const int M(dat.size(0)/N);
+  if (full_sort) {
+    // Divide workspace into two parts. The first one is needed to store batch 
ids.
+    const int id_size(sizeof(int)*ind.size(0));
+    Tensor<gpu, 1, int> batch_id(reinterpret_cast<int*>(work.dptr_), 
Shape1(ind.size(0)), s);
+    Tensor<gpu, 1, char> sort_work(work.dptr_+id_size, 
Shape1(work.size(0)-id_size), s);
+    mxnet::op::SortByKey(dat, ind, is_ascend, &sort_work);
+    if (M > 1) {
+      // Back to back sorting. Note that mxnet::op::SortByKey is a stable sort.
+      batch_id = ind / N;
+      mxnet::op::SortByKey(batch_id, dat, true, &sort_work);
+      batch_id = ind / N;
+      mxnet::op::SortByKey(batch_id, ind, true, &sort_work);
+    }
+  } else {
+    const int nthreads(mshadow::cuda::kBaseThreadNum);
+    PartialSortSmallK<<<M, nthreads, nthreads*K*(sizeof(int)+sizeof(real_t)),
+                        mshadow::Stream<gpu>::GetStream(s)>>>
+                        (K, N, dat.dptr_, ind.dptr_, is_ascend);
+  }
+}
+
+#endif
+
+
 /*!
    * \brief Implementation of the TopK operation
    *
@@ -180,7 +348,7 @@ void TopKImpl(RunContext ctx,
   Tensor<xpu, 1, char> workspace;
   Tensor<xpu, 1, char> temp_workspace;
   Tensor<xpu, 1, real_t> sorted_dat;
-  Tensor<xpu, 1, int> indices, batch_id, sel_indices;
+  Tensor<xpu, 1, int> indices, sel_indices;
   Tensor<xpu, 2, real_t> mask_val;
   int batch_size, element_num;  // number of batches + the size of each batch
   int axis = 0;
@@ -191,10 +359,16 @@ void TopKImpl(RunContext ctx,
   ParseTopKParam(src.shape_, param,
                  &target_shape, &batch_size, &element_num, &axis, &k, 
&do_transpose, &is_ascend);
   Tensor<xpu, 3, real_t> dat = src.FlatTo3D<xpu, real_t>(axis, axis, s);
-  size_t temp_size = mxnet::op::SortByKeyWorkspaceSize<int, int, 
xpu>(src.Size());
+  size_t temp_size = 0;
+  // Temp space needed by the gpu-based full sorts.
+  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, int, 
xpu>(src.Size()));
   temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, 
real_t, xpu>(src.Size()));
   temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<real_t, 
int, xpu>(src.Size()));
-  size_t workspace_size = temp_size + sizeof(real_t) * src.Size() + 
sizeof(int) * src.Size() * 2;
+  // Additional temp space for gpu full sorts for batch ids.
+  temp_size += sizeof(int) * src.Size();
+  // Temp space for cpu sorts.
+  temp_size = std::max(temp_size, sizeof(real_t) * src.Size());
+  size_t workspace_size = temp_size + sizeof(real_t) * src.Size() + 
sizeof(int) * src.Size();
   if (param.ret_typ == topk_enum::kReturnMask) {
     workspace_size += sizeof(int) * batch_size * k + sizeof(real_t) * 
batch_size * k;
   }
@@ -206,9 +380,6 @@ void TopKImpl(RunContext ctx,
   indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
                                 Shape1(src.Size()), s);  // indices in the 
original matrix
   workspace_curr_ptr += sizeof(int) * src.Size();
-  batch_id = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
-                                 Shape1(src.Size()), s);  // batch id in the 
original matrix
-  workspace_curr_ptr += sizeof(int) * src.Size();
   if (do_transpose) {
     sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
   } else {
@@ -232,19 +403,11 @@ void TopKImpl(RunContext ctx,
   }
   temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), 
s);  // temp space
   workspace_curr_ptr += temp_size;
-  // 2. Perform inplace batch sort using the `SortByKey` in MShadow
+
+  // 2. Perform inplace batch sort.
   // After sorting, each batch in `sorted_dat` will be sorted in the 
corresponding order
-  //   and the `indices` will contain the corresponding index in `sorted_dat`
-  // Sort the data and keep record of the correspondence to global indices.
-  mxnet::op::SortByKey(sorted_dat, indices, is_ascend, &temp_workspace);
-  // Calculate the corresponding batch indices of the elements
-  batch_id = indices / element_num;
-  // Since the SortByKey performs stable sort, the second SortByKey will 
reorder
-  //   the sorted_dat based on the order of the batch_id
-  mxnet::op::SortByKey(batch_id, sorted_dat, true, &temp_workspace);
-  // Reorder the indices
-  batch_id = indices / element_num;
-  mxnet::op::SortByKey(batch_id, indices, true, &temp_workspace);
+  // up to the k-th element and the `indices` will contain the corresponding 
index in `sorted_dat`
+  TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s);
 
   // 3. Assign results to the ret blob
   if (param.ret_typ == topk_enum::kReturnMask) {
@@ -264,7 +427,7 @@ void TopKImpl(RunContext ctx,
     }
     IndexFill(ret_mask, sel_indices, mask_val);
   } else if (param.ret_typ == topk_enum::kReturnIndices) {
-    indices -= batch_id * element_num;
+    indices = F<mshadow_op::mod>(indices, element_num);
     if (do_transpose) {
       Tensor<xpu, 3, real_t> ret_indices = ret[0].FlatTo3D<xpu, real_t>(axis, 
axis, s);
       ret_indices = tcast<real_t>(transpose(
@@ -281,7 +444,7 @@ void TopKImpl(RunContext ctx,
                       inplace_reshape(indices, Shape2(batch_size, 
element_num)), 0, k));
     }
   } else {
-    indices -= batch_id * element_num;
+    indices = F<mshadow_op::mod>(indices, element_num);
     if (do_transpose) {
       Tensor<xpu, 3, real_t> ret_value = ret[0].FlatTo3D<xpu, real_t>(axis, 
axis, s);
       Tensor<xpu, 3, real_t> ret_indices = ret[1].FlatTo3D<xpu, real_t>(axis, 
axis, s);


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to