haojin2 commented on a change in pull request #17084: [numpy] add op median
URL: https://github.com/apache/incubator-mxnet/pull/17084#discussion_r359132646
##########
File path: src/operator/numpy/np_broadcast_reduce_op.h
##########
@@ -843,6 +846,354 @@ void NumpyBroadcastToForward(const nnvm::NodeAttrs&
attrs,
req, outputs, expanded_ishape);
}
+struct NumpyMedianParam : public dmlc::Parameter<NumpyMedianParam> {
+ dmlc::optional<mxnet::Tuple<int>> axis;
+ bool keepdims;
+ DMLC_DECLARE_PARAMETER(NumpyMedianParam) {
+ DMLC_DECLARE_FIELD(axis)
+ .set_default(dmlc::optional<mxnet::Tuple<int>>())
+ .describe("Axis or axes along which the medians are computed. "
+ "The default is to compute the "
+ "median along a flattened version of the array.");
+ DMLC_DECLARE_FIELD(keepdims)
+ .set_default(false)
+ .describe("If this is set to `True`, the reduced axes are left "
+ "in the result as dimension with size one.");
+ }
+};
+
+template<int NDim>
+struct median_forward {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i,
+ DType* out,
+ const DType* a_sort,
+ mshadow::Shape<NDim> t_shape,
+ mshadow::Shape<NDim> r_shape) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ using namespace std;
+
+ auto r_coord = unravel(i, r_shape);
+
+ Shape<NDim> t_coord(t_shape);
+
+ for (int j = 0; j < NDim-1; ++j) {
+ t_coord[j] = r_coord[j+1];
+ }
+
+ float idx = 0.5 * (t_shape[NDim-1]-1);
+
+ if (floor(idx) == ceil(idx)) {
+ int idx_below = floor(idx);
+ t_coord[NDim-1] = idx_below;
+ size_t t_idx1 = ravel(t_coord, t_shape);
+ out[i] = a_sort[t_idx1];
+ } else{
+ int idx_below = floor(idx);
+ t_coord[NDim-1] = idx_below;
+ size_t t_idx1 = ravel(t_coord, t_shape);
+ size_t t_idx2 = t_idx1 + 1;
+ DType x1 = a_sort[t_idx1];
+ DType x2 = a_sort[t_idx2];
+ out[i] = (x1 + x2) / 2;}
+ }
+};
+
+template<typename xpu>
+void NumpyMedianForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ if (req[0] == kNullOp)
+ return;
+
+ using namespace mxnet;
+ using namespace mxnet_op;
+ CHECK_EQ(inputs.size(), 1U);
+ CHECK_EQ(outputs.size(), 1U);
+
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+
+ NumpyMedianParam param = nnvm::get<NumpyMedianParam>(attrs.parsed);
+ const TBlob& a = inputs[0];
+ const TBlob& r = outputs[0];
+
+ auto small = NumpyReduceAxesShapeImpl(a.shape_, param.axis, false);
+
+ dmlc::optional<mxnet::Tuple<int>> axis = param.axis;
+
+ TShape r_shape;
+ r_shape = TShape(small.ndim()+1, 1);
+ for (int i = 1; i < r_shape.ndim(); ++i) {
+ r_shape[i] = small[i-1];
+ }
+
+ TShape axes;
+ if (!axis.has_value()) {
+ axes = TShape(a.shape_.ndim(), 1);
+ for (int i = 0; i < a.shape_.ndim(); ++i) {
+ axes[i] = i;
+ }
+ } else {
+ auto axis_tuple = axis.value();
+ axes = TShape(axis_tuple.ndim(), 1);
+ for (int i = 0; i < axis_tuple.ndim(); ++i) {
+ axes[i] = axis_tuple[i];
+ }
+ }
+
+ TShape t_axes(a.shape_.ndim(), 1);
+ int j = 0;
+ for (int i = 0; i < t_axes.ndim(); ++i) {
+ bool red = false;
+ for (int k = 0; k < axes.ndim(); ++k) {
+ if (axes[k] == i) {
+ red = true;
+ }
+ }
+ if (!red) {
+ t_axes[j] = i;
+ j++;
+ }
+ }
+ for (int jj = j; jj < t_axes.ndim(); ++jj) {
+ t_axes[jj] = axes[jj-j];
+ }
+
+ TShape t_shape(small.ndim()+1, 1);
+ for (int i = 0; i < small.ndim(); ++i) {
+ t_shape[i] = small[i];
+ }
+ size_t red_size = 1;
+ for (int i = 0; i < axes.ndim(); ++i) {
+ red_size *= a.shape_[axes[i]];
+ }
+ t_shape[t_shape.ndim()-1] = red_size;
+ TShape t_shape_ex(a.shape_.ndim(), 1);
+ for (int i = 0; i < small.ndim(); ++i) {
+ t_shape_ex[i] = small[i];
+ }
+ for (int i = small.ndim(); i < a.shape_.ndim(); ++i) {
+ t_shape_ex[i] = a.shape_[axes[i-small.ndim()]];
+ }
+
+ TopKParam topk_param = TopKParam();
+ topk_param.axis = dmlc::optional<int>(-1);
+ topk_param.is_ascend = true;
+ topk_param.k = 0;
+ topk_param.ret_typ = topk_enum::kReturnValue;
+
+ MSHADOW_TYPE_SWITCH(a.type_flag_, DType, {
+ using namespace mshadow::expr;
+ Tensor<xpu, 1, char> workspace;
+ Tensor<xpu, 1, char> temp_workspace;
+ Tensor<xpu, 1, DType> sorted_dat;
+ Tensor<xpu, 1, index_t> indices, sel_indices;
+ size_t batch_size = 0;
+ index_t element_num = 0; // number of batches + the size of each batch
+ int axis_topk = 0;
+ bool do_transpose = false;
+ bool is_ascend = false;
+ index_t k = 0;
+ size_t alignment = std::max(sizeof(DType), sizeof(index_t));
+ mxnet::TShape target_shape;
+
+ size_t temp_data_size = a.shape_.Size() * sizeof(DType);
+ size_t idx_size = a.shape_.Size() * sizeof(index_t);
+ size_t temp_mem_size = 2 * temp_data_size + idx_size;
+ size_t temp_size = std::max(
+ mxnet::op::SortByKeyWorkspaceSize<index_t, DType,
xpu>(a.shape_.Size()),
+ mxnet::op::SortByKeyWorkspaceSize<DType, index_t,
xpu>(a.shape_.Size()));
+
+ temp_size = std::max(temp_size,
+ mxnet::op::SortByKeyWorkspaceSize<index_t, index_t,
xpu>(a.shape_.Size()));
+ // Additional temp space for gpu full sorts for batch ids.
+ temp_size += PadBytes(sizeof(index_t) * a.shape_.Size(), alignment);
+ // Temp space for cpu sorts.
+ temp_size = std::max(temp_size, sizeof(DType) * a.shape_.Size());
+
+ size_t workspace_size = temp_size + PadBytes(sizeof(DType) *
a.shape_.Size(), alignment)
+ + PadBytes(sizeof(index_t) *
a.shape_.Size(), alignment);
+ workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment);
+ temp_mem_size += workspace_size * 2;
+
+ Tensor<xpu, 1, char> temp_mem =
+ ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size),
s);
+ DType* trans_ptr, *sort_ptr;
+ char* workspace_curr_ptr;
+ index_t* idx_ptr;
+ if (sizeof(DType) >= sizeof(index_t)) {
+ trans_ptr = reinterpret_cast<DType*>(temp_mem.dptr_);
+ sort_ptr = reinterpret_cast<DType*>(temp_mem.dptr_ + temp_data_size);
+ idx_ptr = reinterpret_cast<index_t*>(temp_mem.dptr_ + 2 *
temp_data_size);
+ } else {
+ idx_ptr = reinterpret_cast<index_t*>(temp_mem.dptr_);
+ trans_ptr = reinterpret_cast<DType*>(temp_mem.dptr_ + idx_size);
+ sort_ptr = reinterpret_cast<DType*>(temp_mem.dptr_ + temp_data_size +
idx_size);
+ }
+ workspace_curr_ptr = temp_mem.dptr_ + 2 * temp_data_size + idx_size;
+
+ TBlob a_trans = TBlob(trans_ptr, t_shape_ex, xpu::kDevMask);
+
+ TransposeImpl<xpu>(ctx.run_ctx, a, a_trans, t_axes);
+
+ TBlob a_sort = TBlob(sort_ptr, t_shape, xpu::kDevMask);
+ TBlob a_idx = TBlob(idx_ptr, t_shape, xpu::kDevMask);
+
+ std::vector<OpReqType> req_TopK = {kWriteTo, kNullOp};
+ TBlob src = a_trans.reshape(t_shape);
+ std::vector<TBlob> ret = {a_sort, a_idx};
+ TopKParam parameter = topk_param;
+
+ ParseTopKParam(src.shape_, parameter,
+ &target_shape, &batch_size, &element_num, &axis_topk,
+ &k, &do_transpose, &is_ascend);
+ CHECK_LE(element_num, mxnet::common::MaxIntegerValue<index_t>())
+ << "'index_t' does not have a sufficient precision to represent "
+ << "the indices of the input array. The total element_num is "
+ << element_num << ", but the selected index_t can only represent "
+ << mxnet::common::MaxIntegerValue<index_t>() << " elements";
+ Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis_topk, axis_topk,
s);
+
+ sorted_dat = Tensor<xpu, 1,
DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
+ Shape1(src.Size()), s); // contain sorted dat
+ workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
+ indices = Tensor<xpu, 1,
index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
+ Shape1(src.Size()), s); // indices in the original matrix
+ workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment);
+
+ if (parameter.ret_typ == topk_enum::kReturnMask) {
+ sel_indices = Tensor<xpu, 1,
index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
+ Shape1(batch_size * k), s);
+ workspace_curr_ptr += PadBytes(sizeof(index_t) * batch_size * k,
alignment);
+ CHECK_EQ(sel_indices.CheckContiguous(), true);
+ }
+
+ if (std::is_same<xpu, cpu>::value) {
+ Tensor<xpu, 1, DType> flattened_data;
+ if (do_transpose) {
+ flattened_data = Tensor<xpu, 1,
DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
+ Shape1(src.Size()), s);
+ workspace_curr_ptr += sizeof(DType) * src.Size();
+ flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)),
Shape1(src.Size()));
+ CHECK_EQ(flattened_data.CheckContiguous(), true);
+ } else {
+ flattened_data = src.FlatTo1D<xpu, DType>(s);
+ }
+ // `temp_workspace` stores the flattened data
+ temp_workspace = Tensor<xpu, 1,
char>(reinterpret_cast<char*>(flattened_data.dptr_),
+
Shape1(sizeof(DType)*src.Size()), s);
+ CHECK_EQ(temp_workspace.CheckContiguous(), true);
+ } else {
+ if (do_transpose) {
+ sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)),
Shape1(src.Size()));
+ } else {
+ sorted_dat = reshape(dat, Shape1(src.Size()));
+ }
+ CHECK_EQ(sorted_dat.CheckContiguous(), true);
+ temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr,
Shape1(temp_size), s);
+ workspace_curr_ptr += temp_size;
+ }
+
+ mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1,
index_t{0}, index_t{1},
+ kWriteTo, indices.dptr_);
+ CHECK_EQ(indices.CheckContiguous(), true);
+
+ // 2. Perform inplace batch sort.
+ TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend,
s);
+
+ // 3. Assign results to the ret blob
+ if (parameter.ret_typ == topk_enum::kReturnMask) {
+ Tensor<xpu, 1, DType> ret_mask = ret[0].FlatTo1D<xpu, DType>(s);
+ ret_mask = scalar<DType>(0);
+ sel_indices = reshape(slice<1>(
+ inplace_reshape(indices,
+ Shape2(batch_size,
+ element_num)), 0, k),
+ Shape1(batch_size * k));
+ if (do_transpose) {
+ mxnet::TShape src_shape = src.shape_.FlatTo3D(axis_topk);
+ CHECK_EQ(sel_indices.CheckContiguous(), true);
+ sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0],
src_shape[2],
+ src_shape[1]), Shape3(0, 2, 1));
+ }
+ if (req_TopK[0] == kNullOp) {
+ return;
+ } else if (req_TopK[0] == kWriteTo) {
+ mxnet_op::Kernel<fill_ind_to_one, xpu>::Launch(s, batch_size * k,
+ sel_indices.dptr_,
ret_mask.dptr_);
+ } else {
+ LOG(FATAL) << "req=" << req_TopK[0] << " is not supported yet.";
+ }
+ } else if (parameter.ret_typ == topk_enum::kReturnIndices) {
+ if (do_transpose) {
+ Tensor<xpu, 3, index_t> ret_indices = ret[0].FlatTo3D<xpu,
index_t>(axis_topk,
+
axis_topk, s);
+ ASSIGN_DISPATCH(ret_indices, req_TopK[0],
tcast<index_t>(F<mshadow_op::mod>(transpose(
+ slice<2>(inplace_reshape(indices,
+ Shape3(ret_indices.shape_[0],
+ ret_indices.shape_[2],
+ element_num)),
+ 0, k),
+ Shape3(0, 2, 1)), element_num)));
+ } else {
+ Tensor<xpu, 2, index_t> ret_indices =
+ ret[0].get_with_shape<xpu, 2, index_t>(Shape2(batch_size, k), s);
+ ASSIGN_DISPATCH(ret_indices, req_TopK[0],
tcast<index_t>(F<mshadow_op::mod>(slice<1>(
+ inplace_reshape(indices, Shape2(batch_size,
element_num)), 0, k),
+ element_num)));
+ }
+ } else {
+ if (do_transpose) {
+ Tensor<xpu, 3, DType> ret_value = ret[0].FlatTo3D<xpu,
DType>(axis_topk, axis_topk, s);
+ Tensor<xpu, 3, index_t> ret_indices = ret[1].FlatTo3D<xpu,
index_t>(axis_topk,
+
axis_topk, s);
+ ASSIGN_DISPATCH(ret_value, req_TopK[0], transpose(
+ slice<2>(inplace_reshape(sorted_dat,
+ Shape3(ret_value.shape_[0],
ret_value.shape_[2],
+ element_num)), 0, k), Shape3(0, 2, 1)));
+ ASSIGN_DISPATCH(ret_indices, req_TopK[1],
tcast<index_t>(F<mshadow_op::mod>(transpose(
+ slice<2>(inplace_reshape(indices,
+ Shape3(ret_indices.shape_[0],
+ ret_indices.shape_[2],
+ element_num)),
+ 0, k), Shape3(0, 2, 1)), element_num)));
+ } else {
+ Tensor<xpu, 2, DType> ret_value =
+ ret[0].get_with_shape<xpu, 2, DType>(Shape2(batch_size, k), s);
+ Tensor<xpu, 2, index_t> ret_indices =
+ ret[1].get_with_shape<xpu, 2, index_t>(Shape2(batch_size, k), s);
+ ASSIGN_DISPATCH(ret_value, req_TopK[0],
+ slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size,
element_num)), 0, k));
+ ASSIGN_DISPATCH(ret_indices, req_TopK[1],
tcast<index_t>(F<mshadow_op::mod>(slice<1>(
+ inplace_reshape(indices, Shape2(batch_size, element_num)),
0, k), element_num)));
+ }
+ }
+
+
Review comment:
get rid of extra blank lines
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services