haojin2 commented on a change in pull request #17084: [numpy] add op median
URL: https://github.com/apache/incubator-mxnet/pull/17084#discussion_r359127664
##########
File path: src/operator/numpy/np_broadcast_reduce_op.h
##########
@@ -843,6 +846,354 @@ void NumpyBroadcastToForward(const nnvm::NodeAttrs&
attrs,
req, outputs, expanded_ishape);
}
+struct NumpyMedianParam : public dmlc::Parameter<NumpyMedianParam> {
+ dmlc::optional<mxnet::Tuple<int>> axis;
+ bool keepdims;
+ DMLC_DECLARE_PARAMETER(NumpyMedianParam) {
+ DMLC_DECLARE_FIELD(axis)
+ .set_default(dmlc::optional<mxnet::Tuple<int>>())
+ .describe("Axis or axes along which the medians are computed. "
+ "The default is to compute the "
+ "median along a flattened version of the array.");
+ DMLC_DECLARE_FIELD(keepdims)
+ .set_default(false)
+ .describe("If this is set to `True`, the reduced axes are left "
+ "in the result as dimension with size one.");
+ }
+};
+
+template<int NDim>
+struct median_forward {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i,
+ DType* out,
+ const DType* a_sort,
+ mshadow::Shape<NDim> t_shape,
+ mshadow::Shape<NDim> r_shape) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ using namespace std;
+
+ auto r_coord = unravel(i, r_shape);
+
+ Shape<NDim> t_coord(t_shape);
+
+ for (int j = 0; j < NDim-1; ++j) {
+ t_coord[j] = r_coord[j+1];
+ }
+
+ float idx = 0.5 * (t_shape[NDim-1]-1);
+
+ if (floor(idx) == ceil(idx)) {
+ int idx_below = floor(idx);
+ t_coord[NDim-1] = idx_below;
+ size_t t_idx1 = ravel(t_coord, t_shape);
+ out[i] = a_sort[t_idx1];
+ } else{
+ int idx_below = floor(idx);
+ t_coord[NDim-1] = idx_below;
+ size_t t_idx1 = ravel(t_coord, t_shape);
+ size_t t_idx2 = t_idx1 + 1;
+ DType x1 = a_sort[t_idx1];
+ DType x2 = a_sort[t_idx2];
+ out[i] = (x1 + x2) / 2;}
+ }
+};
+
+template<typename xpu>
+void NumpyMedianForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ if (req[0] == kNullOp)
+ return;
Review comment:
```c++
if (req[0] == kNullOp) return;
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services