haojin2 commented on a change in pull request #17234: Op Quantile/Percentile 
[Numpy]
URL: https://github.com/apache/incubator-mxnet/pull/17234#discussion_r364660646
 
 

 ##########
 File path: src/operator/tensor/ordering_op-inl.h
 ##########
 @@ -558,6 +558,202 @@ void TopKImpl(const RunContext &ctx,
   }
 }
 
+template<typename xpu, typename DType>
+size_t TopK_Workspace_Cal(const TBlob& src,
+                          const TopKParam& param,
+                          size_t *temp_size_ptr) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+
+  size_t batch_size = 0;
+  size_t temp_size;
+  index_t element_num = 0;  // number of batches + the size of each batch
+  int axis = 0;
+  bool do_transpose = false;
+  bool is_ascend = false;
+  index_t k = 0;
+  size_t alignment = std::max(sizeof(DType), sizeof(index_t));
+  mxnet::TShape target_shape;
+  ParseTopKParam(src.shape_, param,
+                 &target_shape, &batch_size, &element_num, &axis, &k, 
&do_transpose, &is_ascend);
+
+  // Temp space needed by the full sorts.
+  temp_size = std::max(
+      mxnet::op::SortByKeyWorkspaceSize<index_t, DType, xpu>(src.Size()),
+      mxnet::op::SortByKeyWorkspaceSize<DType, index_t, xpu>(src.Size()));
+
+  temp_size = std::max(temp_size,
+      mxnet::op::SortByKeyWorkspaceSize<index_t, index_t, xpu>(src.Size()));
+  // Additional temp space for gpu full sorts for batch ids.
+  temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment);
+  // Temp space for cpu sorts.
+  temp_size = std::max(temp_size, sizeof(DType) * src.Size());
+  *temp_size_ptr = temp_size;
+
+  size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), 
alignment)
+                                    + PadBytes(sizeof(index_t) * src.Size(), 
alignment);
+  if (param.ret_typ == topk_enum::kReturnMask) {
+    workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment);
+  }
+  return workspace_size;
+}
+
+template<typename xpu, typename DType, typename IDType>
+void TopK_Workspace_Impl(const RunContext &ctx,
 
 Review comment:
   `TopKImplwithWorkspace`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to