piiswrong closed pull request #9688: [MXNET-108] Adding BilinearResize2D and 
AdaptiveAvgPool2d operators
URL: https://github.com/apache/incubator-mxnet/pull/9688
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/contrib.md 
b/docs/api/python/ndarray/contrib.md
index 3dcb6d18f95..25cabed808e 100644
--- a/docs/api/python/ndarray/contrib.md
+++ b/docs/api/python/ndarray/contrib.md
@@ -34,6 +34,8 @@ In the rest of this document, we list routines provided by 
the `ndarray.contrib`
 .. autosummary::
     :nosignatures:
 
+    AdaptiveAvgPooling2D
+    BilinearResize2D
     CTCLoss
     DeformableConvolution
     DeformablePSROIPooling
diff --git a/docs/api/python/symbol/contrib.md 
b/docs/api/python/symbol/contrib.md
index 7f5cc4bb3ff..1af18bbf86d 100644
--- a/docs/api/python/symbol/contrib.md
+++ b/docs/api/python/symbol/contrib.md
@@ -34,6 +34,8 @@ In the rest of this document, we list routines provided by 
the `symbol.contrib`
 .. autosummary::
     :nosignatures:
 
+    AdaptiveAvgPooling2D
+    BilinearResize2D
     CTCLoss
     DeformableConvolution
     DeformablePSROIPooling
diff --git a/src/operator/contrib/adaptive_avg_pooling-inl.h 
b/src/operator/contrib/adaptive_avg_pooling-inl.h
new file mode 100644
index 00000000000..7331c7bd47a
--- /dev/null
+++ b/src/operator/contrib/adaptive_avg_pooling-inl.h
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file adaptive_avg_pooling-inl.h
+ * \brief adaptive average pooling operator
+ * \author Hang Zhang
+*/
+#ifndef MXNET_OPERATOR_CONTRIB_ADAPTIVE_AVG_POOLING_INL_H_
+#define MXNET_OPERATOR_CONTRIB_ADAPTIVE_AVG_POOLING_INL_H_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mxnet/ndarray.h>
+#include <map>
+#include <vector>
+#include <string>
+#include <utility>
+/* contrib
+#include "../ndarray/ndarray_function.h"
+#include "./operator_common.h"
+#include "./mxnet_op.h"
+#include "./mshadow_op.h"
+*/
+#include "../../ndarray/ndarray_function.h"
+#include "../operator_common.h"
+#include "../mxnet_op.h"
+#include "../mshadow_op.h"
+
+namespace mxnet {
+namespace op {
+
+struct AdaptiveAvgPoolParam : public dmlc::Parameter<AdaptiveAvgPoolParam> {
+  TShape output_size;
+  DMLC_DECLARE_PARAMETER(AdaptiveAvgPoolParam) {
+    DMLC_DECLARE_FIELD(output_size).set_default(TShape())
+    .describe("int (output size) or a tuple of int for output (height, 
width).");
+  }
+};
+
+static inline bool IsWriting(const OpReqType ort) {
+  return ort == kWriteTo || ort == kWriteInplace;
+}
+
+template<typename xpu, typename DType, typename AccReal>
+void AdaptiveAvgPoolUpdateOutput(mshadow::Stream<cpu> *s,
+                                 const std::vector<TBlob> &input,
+                                 const std::vector<TBlob> &output);
+
+template<typename xpu, typename DType, typename AccReal>
+void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream<cpu> *s,
+                                    const std::vector<TBlob> &input,
+                                    const std::vector<TBlob> &output);
+
+#if MXNET_USE_CUDA
+template<typename xpu, typename DType, typename AccReal>
+void AdaptiveAvgPoolUpdateOutput(mshadow::Stream<gpu> *s,
+                                 const std::vector<TBlob> &input,
+                                 const std::vector<TBlob> &output);
+
+template<typename xpu, typename DType, typename AccReal>
+void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream<gpu> *s,
+                                    const std::vector<TBlob> &input,
+                                    const std::vector<TBlob> &output);
+#endif  // MXNET_USE_CUDA
+
+template <typename xpu>
+inline void AdaptiveAvgPoolOpForward(const nnvm::NodeAttrs& attrs,
+                                     const OpContext &ctx,
+                                     const std::vector<TBlob> &inputs,
+                                     const std::vector<OpReqType> &req,
+                                     const std::vector<TBlob> &outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
+    AdaptiveAvgPoolUpdateOutput<xpu, DType, AccReal>(s, inputs, outputs);
+  });
+}
+
+
+template <typename xpu>
+inline void AdaptiveAvgPoolOpBackward(const nnvm::NodeAttrs& attrs,
+                                      const OpContext &ctx,
+                                      const std::vector<TBlob> &inputs,
+                                      const std::vector<OpReqType> &req,
+                                      const std::vector<TBlob> &outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  if (IsWriting(req[0])) {
+    // zero grad before backwarding
+    MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      Fill<false>(s, outputs[0], kWriteTo, 0);
+    })
+  }
+  MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
+    AdaptiveAvgPoolUpdateGradInput<xpu, DType, AccReal>(s, inputs, outputs);
+  });
+}
+
+
+static bool AdaptiveAvgPoolOpInferShape(const nnvm::NodeAttrs& attrs,
+                                       std::vector<TShape> *in_shape,
+                                       std::vector<TShape> *out_shape) {
+  using namespace mshadow;
+  CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
+  CHECK_EQ(out_shape->size(), 1U) << "Output:[data]";
+  const AdaptiveAvgPoolParam& param = 
nnvm::get<AdaptiveAvgPoolParam>(attrs.parsed);
+  TShape dshape(in_shape->at(0));
+  if (dshape.ndim() == 0) return false;
+  if (param.output_size.ndim() == 0) {
+    dshape[2] = 1;
+    dshape[3] = 1;
+  } else if (param.output_size.ndim() == 1) {
+    dshape[2] = param.output_size[0];
+    dshape[3] = param.output_size[0];
+  } else if (param.output_size.ndim() == 2) {
+    dshape[2] = param.output_size[0];
+    dshape[3] = param.output_size[1];
+  } else {
+    dshape[2] = 1;
+    dshape[3] = 1;
+  }
+  out_shape->clear();
+  out_shape->push_back(dshape);
+  return true;
+}
+
+static bool AdaptiveAvgPoolOpInferType(const nnvm::NodeAttrs& attrs,
+                                       std::vector<int> *in_type,
+                                       std::vector<int> *out_type) {
+  using namespace mshadow;
+  CHECK_EQ(in_type->size(), 1U);
+  int dtype = (*in_type)[0];
+  CHECK_NE(dtype, -1) << "First input must have specified type";
+  // For float16 input type beta, gamma, mean, and average are stored in 
float32.
+  // For other input types, these parameters have the same type as input
+  // NOTE: This requirement is from cuDNN (v. 4 and 5)
+  int dtype_param = 0;
+  MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
+      dtype_param = mshadow::DataType<AccRealX>::kFlag; });
+  out_type->clear();
+  out_type->push_back(dtype_param);
+  return true;
+}
+
+static inline bool AdaptiveAvgPoolOpStorageType(const nnvm::NodeAttrs &attrs,
+                                                const int dev_mask,
+                                                DispatchMode *dispatch_mode,
+                                                std::vector<int> *in_attrs,
+                                                std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  *dispatch_mode = DispatchMode::kFCompute;
+  for (int& v : *in_attrs) {
+    if (v == - 1) v = kDefaultStorage;
+  }
+  for (size_t i = 0; i < out_attrs->size(); i++) {
+    (*out_attrs)[i] = kDefaultStorage;
+  }
+  return true;
+}
+
+using namespace mshadow;
+template<typename xpu, int Dim, typename DType>
+MSHADOW_XINLINE int get_stride(Tensor<xpu, Dim, DType> tensor, int idx) {
+  int stride = 1;
+  for (int i = Dim-2; i >= idx; --i) {
+    stride *= tensor.size(i+1);
+  }
+  return stride;
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_CONTRIB_ADAPTIVE_AVG_POOLING_INL_H_
diff --git a/src/operator/contrib/adaptive_avg_pooling.cc 
b/src/operator/contrib/adaptive_avg_pooling.cc
new file mode 100644
index 00000000000..079571177cb
--- /dev/null
+++ b/src/operator/contrib/adaptive_avg_pooling.cc
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file adaptive_avg_pooling.cc
+ * \brief adaptive average pooling operator
+ * \author Hang Zhang
+*/
+#include "adaptive_avg_pooling-inl.h"
+// #include "elemwise_op_common.h"
+#include "../elemwise_op_common.h"
+
+#define START_IND(a, b, c) static_cast<int>(floor(static_cast<float>(a * c) / 
b))
+#define END_IND(a, b, c) static_cast<int>(ceil(static_cast<float>((a + 1) * c) 
/ b))
+
+namespace mxnet {
+namespace op {
+
+using namespace mshadow;
+
+template<typename real>
+static void SpatialAdaptiveAveragePooling_updateOutput_frame(
+          real *input_p,
+          real *output_p,
+          int64_t sizeD,
+          int64_t isizeH,
+          int64_t isizeW,
+          int64_t osizeH,
+          int64_t osizeW,
+          int64_t istrideD,
+          int64_t istrideH,
+          int64_t istrideW) {
+  int64_t d;
+#pragma omp parallel for private(d) \
+num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+  for (d = 0; d < sizeD; d++) {
+    /* loop over output */
+    int64_t oh, ow, ih, iw;
+    int outOffset = d*osizeH*osizeW;
+    for (oh = 0; oh < osizeH; oh++) {
+      int istartH = START_IND(oh, osizeH, isizeH);
+      int startOffsetH = istartH * istrideH;
+      int outOffsetH = oh * osizeW;
+      int iendH   = END_IND(oh, osizeH, isizeH);
+      int kH = iendH - istartH;
+
+      for (ow = 0; ow < osizeW; ow++) {
+        int istartW = START_IND(ow, osizeW, isizeW);
+        int iendW   = END_IND(ow, osizeW, isizeW);
+        int kW = iendW - istartW;
+
+        /* local pointers */
+        real *ip = input_p   + d*istrideD + startOffsetH + istartW*istrideW;
+        real *op = output_p  + outOffset + outOffsetH + ow;
+
+        /* compute local average: */
+        real sum = 0;
+        for (ih = 0; ih < kH; ih++) {
+          int ihOffset = ih*istrideH;
+          for (iw = 0; iw < kW; iw++) {
+            real val = *(ip + ihOffset + iw*istrideW);
+            sum += val;
+          }
+        }
+
+        /* set output to local average */
+        *op = sum / kW / kH;
+      }
+    }
+  }
+}
+
+template<typename real>
+static void SpatialAdaptiveAveragePooling_updateGradInput_frame(
+          real *gradInput_p,
+          real *gradOutput_p,
+          int64_t sizeD,
+          int64_t isizeH,
+          int64_t isizeW,
+          int64_t osizeH,
+          int64_t osizeW) {
+  int64_t d;
+#pragma omp parallel for private(d) \
+num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+  for (d = 0; d < sizeD; d++) {
+    real *gradInput_p_d = gradInput_p + d*isizeW*isizeH;
+    real *gradOutput_p_d = gradOutput_p + d*osizeW*osizeH;
+
+    /* calculate average */
+    int64_t oh, ow;
+    for (oh = 0; oh < osizeH; oh++) {
+      int istartH = START_IND(oh, osizeH, isizeH);
+      int iendH   = END_IND(oh, osizeH, isizeH);
+      int kH = iendH - istartH;
+
+      for (ow = 0; ow < osizeW; ow++) {
+        int istartW = START_IND(ow, osizeW, isizeW);
+        int iendW   = END_IND(ow, osizeW, isizeW);
+        int kW = iendW - istartW;
+
+        real grad_delta = gradOutput_p_d[oh*osizeW +ow] / kH / kW;
+
+        int ih, iw;
+        for (ih = istartH; ih < iendH; ih++) {
+          for (iw = istartW; iw < iendW; iw++) {
+            /* update gradient */
+            gradInput_p_d[ih*isizeW + iw] += grad_delta;
+          }
+        }
+      }
+    }
+  }
+}
+
+
+template<typename xpu, typename DType, typename AccReal>
+void AdaptiveAvgPoolUpdateOutput(mshadow::Stream<cpu> *s,
+                                           const std::vector<TBlob> &input,
+                                           const std::vector<TBlob> &output) {
+  Tensor<xpu, 4, DType> itensor = input[0].get<xpu, 4, DType>(s);
+  Tensor<xpu, 4, DType> otensor = output[0].get<xpu, 4, DType>(s);
+
+  DType *input_data = itensor.dptr_;
+  DType *output_data = otensor.dptr_;
+
+  int64_t sizeB  = itensor.size(0);
+  int64_t sizeD  = itensor.size(1);
+  int64_t isizeH = itensor.size(2);
+  int64_t isizeW = itensor.size(3);
+
+  int64_t istrideB = get_stride<xpu, 4, DType>(itensor, 0);
+  int64_t istrideD = get_stride<xpu, 4, DType>(itensor, 1);
+  int64_t istrideH = get_stride<xpu, 4, DType>(itensor, 2);
+  int64_t istrideW = get_stride<xpu, 4, DType>(itensor, 3);
+
+  int64_t osizeH = otensor.size(2);
+  int64_t osizeW = otensor.size(3);
+
+  int64_t b;
+#pragma omp parallel for private(b) \
+num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+  for (b = 0; b < sizeB; b++) {
+    SpatialAdaptiveAveragePooling_updateOutput_frame<DType>(
+      input_data+b*istrideB, output_data+b*sizeD*osizeH*osizeW,
+      sizeD,
+      isizeH, isizeW,
+      osizeH, osizeW,
+      istrideD,
+      istrideH, istrideW);
+  }
+}
+
+
+template<typename xpu, typename DType, typename AccReal>
+void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream<cpu> *s,
+                                              const std::vector<TBlob> &input,
+                                              const std::vector<TBlob> 
&output) {
+  Tensor<xpu, 4, DType> gradOut = input[0].get<xpu, 4, DType>(s);
+  Tensor<xpu, 4, DType> gradIn = output[0].get<xpu, 4, DType>(s);
+
+  DType *gradOutput_data = gradOut.dptr_;
+  DType *gradInput_data = gradIn.dptr_;
+
+  int64_t sizeB  = gradIn.size(0);
+  int64_t sizeD  = gradIn.size(1);
+  int64_t isizeH = gradIn.size(2);
+  int64_t isizeW = gradIn.size(3);
+
+  int64_t osizeH = gradOut.size(2);
+  int64_t osizeW = gradOut.size(3);
+
+  int64_t b;
+#pragma omp parallel for private(b) \
+num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+  for (b = 0; b < sizeB; b++) {
+    SpatialAdaptiveAveragePooling_updateGradInput_frame<DType>(
+      gradInput_data+b*sizeD*isizeH*isizeW, 
gradOutput_data+b*sizeD*osizeH*osizeW,
+      sizeD,
+      isizeH, isizeW,
+      osizeH, osizeW);
+  }
+}
+
+
+DMLC_REGISTER_PARAMETER(AdaptiveAvgPoolParam);
+
+NNVM_REGISTER_OP(_contrib_AdaptiveAvgPooling2D)
+.describe(R"code(
+Applies a 2D adaptive average pooling over a 4D input with the shape of (NCHW).
+The pooling kernel and stride sizes are automatically chosen for desired 
output sizes.
+
+- If a single integer is provided for output_size, the output size is
+(N x C x output_size x output_size) for any input (NCHW).
+
+- If a tuple of integers (height, width) are provided for output_size, the 
output size is
+(N x C x height x width) for any input (NCHW).
+
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<AdaptiveAvgPoolParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FInferShape>("FInferShape", AdaptiveAvgPoolOpInferShape)
+.set_attr<nnvm::FInferType>("FInferType", AdaptiveAvgPoolOpInferType)
+.set_attr<FInferStorageType>("FInferStorageType", AdaptiveAvgPoolOpStorageType)
+.set_attr<FCompute>("FCompute<cpu>", AdaptiveAvgPoolOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient",
+  ElemwiseGradUseNone{"_backward_contrib_AdaptiveAvgPooling2D"})
+.add_argument("data", "NDArray-or-Symbol", "Input data")
+.add_arguments(AdaptiveAvgPoolParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_contrib_AdaptiveAvgPooling2D)
+.set_attr_parser(ParamParser<AdaptiveAvgPoolParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FInferStorageType>("FInferStorageType", AdaptiveAvgPoolOpStorageType)
+.set_attr<FCompute>("FCompute<cpu>", AdaptiveAvgPoolOpBackward<cpu>);
+
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/contrib/adaptive_avg_pooling.cu 
b/src/operator/contrib/adaptive_avg_pooling.cu
new file mode 100644
index 00000000000..375c420a044
--- /dev/null
+++ b/src/operator/contrib/adaptive_avg_pooling.cu
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file adaptive_avg_pooling.cu
+ * \brief adaptive average pooling operator
+ * \author Hang Zhang
+*/
+#include <cuda_runtime_api.h>
+#include <algorithm>
+#include "adaptive_avg_pooling-inl.h"
+
+#define START_IND(a, b, c) static_cast<int>(floor(static_cast<float>(a * c) / 
b))
+#define END_IND(a, b, c) static_cast<int>(ceil(static_cast<float>((a + 1) * c) 
/ b))
+#define CUDA_MAX_THREADS 1024   // this is safe, in reality 256 is our limit
+
+namespace mxnet {
+namespace op {
+
+using namespace mshadow;
+
+template<typename In, typename Out>
+struct ScalarConvert {
+  static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) 
v; }
+};
+
+/*
+ * Description:
+ *    this function adaptively average pools an input 4D tensor along 
dimensions 2 and 3
+ *    4D input, 4D output
+ */
+template <typename T>
+__global__ void adaptiveaveragepool(T *input, T *output,
+                        int isizeH, int isizeW,
+                        int osizeH, int osizeW,
+                        int64_t istrideD, int64_t istrideH, int64_t istrideW) {
+  // iterators on output pixels
+  int oh, ow;
+
+  // select input/output plane based on thread/block ID
+  int o_plane = blockIdx.x;
+  int i_plane = o_plane;
+
+  output = output + o_plane*osizeH*osizeW;
+  input = input + i_plane*istrideD;
+
+  int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
+  int oendH = osizeH;
+  const int ostepH = blockDim.y*gridDim.y;
+
+  int ostartW = threadIdx.x;
+  int oendW = osizeW;
+  const int ostepW = blockDim.x;
+
+  // For all output pixels...
+  for (oh = ostartH; oh < oendH; oh += ostepH) {
+    int istartH = START_IND(oh, osizeH, isizeH);
+    int iendH   = END_IND(oh, osizeH, isizeH);
+    int kH = iendH - istartH;
+
+    for (ow = ostartW; ow < oendW; ow += ostepW) {
+      int istartW = START_IND(ow, osizeW, isizeW);
+      int iendW   = END_IND(ow, osizeW, isizeW);
+      int kW = iendW - istartW;
+
+      // Compute the average pooling over corresponding input pixels
+      T *ptr_input = input + istartH*istrideH + istartW*istrideW;
+      T *ptr_output = output + oh*osizeW + ow;
+      T sum = ScalarConvert<int, T>::to(0);
+      int ih, iw;
+      for (ih = 0; ih < kH; ++ih) {
+        for (iw = 0; iw < kW; ++iw) {
+          T val = ptr_input[iw*istrideW];
+          sum += val;
+        }
+        ptr_input += istrideH;  // next input line
+      }
+      // Update output
+      *ptr_output = sum / kH / kW;
+    }
+  }
+}
+
+/*
+ * Description:
+ *    this function computes the gradInput from gradOutput
+ *    (uses atomic add)
+ */
+template <typename T>
+__global__ void atomicadaptiveaveragegradinput(
+  T *gradInput, T *gradOutput,
+  int isizeH, int isizeW, int osizeH, int osizeW
+) {
+  // iterators on output indices
+  int oh, ow;
+
+  // select input/output plane based on thread/block ID
+  int o_plane = blockIdx.x;
+  int i_plane = o_plane;
+
+  gradOutput = gradOutput + o_plane*osizeW*osizeH;
+  gradInput = gradInput + i_plane*isizeW*isizeH;
+
+  int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
+  int oendH = osizeH;
+  int ostepH = blockDim.y*gridDim.y;
+
+  int ostartW = threadIdx.x;
+  int oendW = osizeW;
+  int ostepW = blockDim.x;
+
+  // For all output pixels...
+  for (oh = ostartH; oh < oendH; oh += ostepH) {
+    int istartH = START_IND(oh, osizeH, isizeH);
+    int iendH   = END_IND(oh, osizeH, isizeH);
+    int kH = iendH - istartH;
+
+    for (ow = ostartW; ow < oendW; ow += ostepW) {
+      int istartW = START_IND(ow, osizeW, isizeW);
+      int iendW   = END_IND(ow, osizeW, isizeW);
+      int kW = iendW - istartW;
+
+      // Compute the gradients for over corresponding input pixels
+      T *ptr_gradInput = gradInput + istartH*isizeW + istartW;
+      T *ptr_gradOutput = gradOutput + oh*osizeW + ow;
+      T grad_delta = *ptr_gradOutput / kW / kH;
+
+      int ih, iw;
+      for (ih = 0; ih < kH; ++ih) {
+        for (iw = 0; iw < kW; ++iw) {
+          // atomic add since different threads could update same variable
+          atomicAdd(&(ptr_gradInput[iw]), grad_delta);
+        }
+        ptr_gradInput += isizeW;  // next input line
+      }
+    }
+  }
+}
+
+
+template<typename xpu, typename DType, typename AccReal>
+void AdaptiveAvgPoolUpdateOutput(mshadow::Stream<gpu> *s,
+                                 const std::vector<TBlob> &input,
+                                 const std::vector<TBlob> &output) {
+  Tensor<xpu, 4, DType> itensor = input[0].get<xpu, 4, DType>(s);
+  Tensor<xpu, 4, DType> otensor = output[0].get<xpu, 4, DType>(s);
+
+  DType *input_data = itensor.dptr_;
+  DType *output_data = otensor.dptr_;
+
+  int64_t sizeB  = itensor.size(0);
+  int64_t sizeD  = itensor.size(1);
+  int64_t isizeH = itensor.size(2);
+  int64_t isizeW = itensor.size(3);
+
+  int64_t istrideD = get_stride<xpu, 4, DType>(itensor, 1);
+  int64_t istrideH = get_stride<xpu, 4, DType>(itensor, 2);
+  int64_t istrideW = get_stride<xpu, 4, DType>(itensor, 3);
+
+  int64_t osizeH = otensor.size(2);
+  int64_t osizeW = otensor.size(3);
+
+  // cuda blocks & threads:
+  int blocksH = max(static_cast<int>(16L / sizeD), 1);
+  dim3 blocks(sizeB * sizeD, blocksH);
+  dim3 threads(32, 8);
+
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+  // run averagepool kernel
+  adaptiveaveragepool <<<blocks, threads, 0, stream>>> (
+    input_data, output_data, isizeH, isizeW, osizeH, osizeW,
+    istrideD, istrideH, istrideW);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(AdaptiveAvgPoolUpdateOutput);
+}
+
+template<typename xpu, typename DType, typename AccReal>
+void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream<gpu> *s,
+                                    const std::vector<TBlob> &input,
+                                    const std::vector<TBlob> &output) {
+  Tensor<xpu, 4, DType> gradOut = input[0].get<xpu, 4, DType>(s);
+  Tensor<xpu, 4, DType> gradIn = output[0].get<xpu, 4, DType>(s);
+
+  DType *gradOutput_data = gradOut.dptr_;
+  DType *gradInput_data = gradIn.dptr_;
+
+  int64_t sizeB  = gradIn.size(0);
+  int64_t sizeD  = gradIn.size(1);
+  int64_t isizeH = gradIn.size(2);
+  int64_t isizeW = gradIn.size(3);
+
+  int64_t osizeH = gradOut.size(2);
+  int64_t osizeW = gradOut.size(3);
+
+  // cuda blocks & threads:
+  int blocksH = max(static_cast<int>(16L / sizeD), 1);
+  dim3 blocks(sizeB * sizeD, blocksH);
+  dim3 threads(32, 8);
+
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+  // run updateGradInput kernel, accumulate gradients atomically
+  atomicadaptiveaveragegradinput <<<blocks, threads, 0, stream>>> (
+    gradInput_data, gradOutput_data, isizeH, isizeW, osizeH, osizeW);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(AdaptiveAvgPoolUpdateGradInput);
+}
+
+NNVM_REGISTER_OP(_contrib_AdaptiveAvgPooling2D)
+.set_attr<FCompute>("FCompute<gpu>", AdaptiveAvgPoolOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_contrib_AdaptiveAvgPooling2D)
+.set_attr<FCompute>("FCompute<gpu>", AdaptiveAvgPoolOpBackward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/contrib/bilinear_resize-inl.h 
b/src/operator/contrib/bilinear_resize-inl.h
new file mode 100644
index 00000000000..b73ead9eba5
--- /dev/null
+++ b/src/operator/contrib/bilinear_resize-inl.h
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file bilinear_resize-inl.h
+ * \brief bilinear resize operator
+ * \author Hang Zhang
+*/
+#ifndef MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_
+#define MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mxnet/ndarray.h>
+#include <map>
+#include <vector>
+#include <string>
+#include <utility>
+/* contrib
+#include "../ndarray/ndarray_function.h"
+#include "./operator_common.h"
+#include "./mxnet_op.h"
+#include "./mshadow_op.h"
+*/
+#include "../../ndarray/ndarray_function.h"
+#include "../operator_common.h"
+#include "../mxnet_op.h"
+#include "../mshadow_op.h"
+
+namespace mxnet {
+namespace op {
+
+struct BilinearSampleParam : public dmlc::Parameter<BilinearSampleParam> {
+  int height;
+  int width;
+  DMLC_DECLARE_PARAMETER(BilinearSampleParam) {
+    DMLC_DECLARE_FIELD(height).set_range(1, 1000)
+    .describe("output height (required)");
+    DMLC_DECLARE_FIELD(width).set_range(1, 1000)
+    .describe("output width (required)");
+  }
+};
+
+static inline bool IsWriting(const OpReqType ort) {
+  return ort == kWriteTo || ort == kWriteInplace;
+}
+
+template<typename xpu, typename DType, typename AccReal>
+void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<cpu> *s,
+                                           const std::vector<TBlob> &input,
+                                           const std::vector<TBlob> &output);
+
+template<typename xpu, typename DType, typename AccReal>
+void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
+                                              const std::vector<TBlob> &input,
+                                              const std::vector<TBlob> 
&output);
+
+#if MXNET_USE_CUDA
+template<typename xpu, typename DType, typename AccReal>
+void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<gpu> *s,
+                                           const std::vector<TBlob> &input,
+                                           const std::vector<TBlob> &output);
+
+template<typename xpu, typename DType, typename AccReal>
+void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<gpu> *s,
+                                              const std::vector<TBlob> &input,
+                                              const std::vector<TBlob> 
&output);
+#endif  // MXNET_USE_CUDA
+
+template <typename xpu>
+inline void BilinearSampleOpForward(const nnvm::NodeAttrs& attrs,
+                                    const OpContext &ctx,
+                                    const std::vector<TBlob> &inputs,
+                                    const std::vector<OpReqType> &req,
+                                    const std::vector<TBlob> &outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
+    SpatialUpSamplingBilinearUpdateOutput<xpu, DType, AccReal>(s, inputs, 
outputs);
+  });
+}
+
+
+template <typename xpu>
+inline void BilinearSampleOpBackward(const nnvm::NodeAttrs& attrs,
+                                     const OpContext &ctx,
+                                     const std::vector<TBlob> &inputs,
+                                     const std::vector<OpReqType> &req,
+                                     const std::vector<TBlob> &outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  if (IsWriting(req[0])) {
+    // zero grad before backwarding
+    MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      Fill<false>(s, outputs[0], kWriteTo, 0);
+    })
+  }
+  MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
+    SpatialUpSamplingBilinearUpdateGradInput<xpu, DType, AccReal>(s, inputs, 
outputs);
+  });
+}
+
+
+static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs,
+                                       std::vector<TShape> *in_shape,
+                                       std::vector<TShape> *out_shape) {
+  using namespace mshadow;
+  CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
+  CHECK_EQ(out_shape->size(), 1U) << "Output:[data]";
+  const BilinearSampleParam& param = 
nnvm::get<BilinearSampleParam>(attrs.parsed);
+  TShape dshape(in_shape->at(0));
+  if (dshape.ndim() == 0) return false;
+  dshape[2] = param.height;
+  dshape[3] = param.width;
+  out_shape->clear();
+  out_shape->push_back(dshape);
+  return true;
+}
+
+static bool BilinearSampleOpInferType(const nnvm::NodeAttrs& attrs,
+                                      std::vector<int> *in_type,
+                                      std::vector<int> *out_type) {
+  using namespace mshadow;
+  CHECK_EQ(in_type->size(), 1U);
+  int dtype = (*in_type)[0];
+  CHECK_NE(dtype, -1) << "First input must have specified type";
+  // For float16 input type beta, gamma, mean, and average are stored in 
float32.
+  // For other input types, these parameters have the same type as input
+  // NOTE: This requirement is from cuDNN (v. 4 and 5)
+  int dtype_param = 0;
+  MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
+      dtype_param = mshadow::DataType<AccRealX>::kFlag; });
+  out_type->clear();
+  out_type->push_back(dtype_param);
+  return true;
+}
+
+static inline bool BilinearSampleOpStorageType(const nnvm::NodeAttrs &attrs,
+                                               const int dev_mask,
+                                               DispatchMode *dispatch_mode,
+                                               std::vector<int> *in_attrs,
+                                               std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  *dispatch_mode = DispatchMode::kFCompute;
+  for (int& v : *in_attrs) {
+    if (v == - 1) v = kDefaultStorage;
+  }
+  for (size_t i = 0; i < out_attrs->size(); i++) {
+    (*out_attrs)[i] = kDefaultStorage;
+  }
+  return true;
+}
+
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_
diff --git a/src/operator/contrib/bilinear_resize.cc 
b/src/operator/contrib/bilinear_resize.cc
new file mode 100644
index 00000000000..e1248ce97bb
--- /dev/null
+++ b/src/operator/contrib/bilinear_resize.cc
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file bilinear_resize.cc
+ * \brief bilinear resize operator
+ * \author Hang Zhang
+*/
+#include "bilinear_resize-inl.h"
+// #include "elemwise_op_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+using namespace mshadow;
+
+template<typename xpu, typename DType, typename AccReal>
+void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<cpu> *s,
+                                           const std::vector<TBlob> &input,
+                                           const std::vector<TBlob> &output) {
+  Tensor<xpu, 4, DType> itensor = input[0].get<xpu, 4, DType>(s);
+  Tensor<xpu, 4, DType> otensor = output[0].get<xpu, 4, DType>(s);
+  int nbatch = otensor.size(0);
+  int channels = otensor.size(1);
+  int outputHeight = otensor.size(2);
+  int outputWidth = otensor.size(3);
+  int inputHeight = itensor.size(2);
+  int inputWidth = itensor.size(3);
+
+  DType *idata = itensor.dptr_;
+  DType *odata = otensor.dptr_;
+  channels = nbatch * channels;
+  // special case: just copy
+  if (inputHeight == outputHeight && inputWidth == outputWidth) {
+    for (int h2 = 0; h2 < outputHeight; ++h2) {
+      const int h1 = h2;
+      for (int w2 = 0; w2 < outputWidth; ++w2) {
+        const int w1 = w2;
+        const DType* pos1 = &idata[h1 * inputWidth + w1];
+        DType* pos2 = &odata[h2 * outputWidth + w2];
+        for (int c = 0; c < channels; ++c) {
+          pos2[0] = pos1[0];
+          pos1 += inputWidth * inputHeight;
+          pos2 += outputWidth * outputHeight;
+        }
+      }
+    }
+    return;
+  }
+  const float rheight =(outputHeight > 1) ? static_cast<float>(inputHeight - 
1)/
+                       (outputHeight - 1) : 0.f;
+  const float rwidth = (outputWidth > 1) ? static_cast<float>(inputWidth - 1) /
+                       (outputWidth - 1) : 0.f;
+  for (int h2 = 0; h2 < outputHeight; ++h2) {
+    const float h1r = rheight * h2;
+    const int h1 = h1r;
+    const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
+    const DType h1lambda = h1r - h1;
+    const DType h0lambda = (DType)1. - h1lambda;
+    for (int w2 = 0; w2 < outputWidth; ++w2) {
+      const float w1r = rwidth * w2;
+      const int w1 = w1r;
+      const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
+      const DType w1lambda = w1r - w1;
+      const DType w0lambda = (DType)1. - w1lambda;
+      const DType* pos1 = &idata[h1 * inputWidth + w1];
+      DType* pos2 = &odata[h2 * outputWidth + w2];
+      for (int c = 0; c < channels; ++c) {
+        pos2[0] = h0lambda * (w0lambda * pos1[0]+ w1lambda * pos1[w1p])
+                  + h1lambda * (w0lambda * pos1[h1p * inputWidth]
+                  + w1lambda * pos1[h1p * inputWidth + w1p]);
+        pos1 += inputWidth * inputHeight;
+        pos2 += outputWidth * outputHeight;
+      }
+    }
+  }
+}
+
+
+template<typename xpu, typename DType, typename AccReal>
+void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
+                                              const std::vector<TBlob> &input,
+                                              const std::vector<TBlob> 
&output) {
+  Tensor<xpu, 4, DType> gradOutput = input[0].get<xpu, 4, DType>(s);
+  Tensor<xpu, 4, DType> gradInput = output[0].get<xpu, 4, DType>(s);
+
+  int nbatch = gradInput.size(0);
+  int channels = gradInput.size(1);
+  int outputHeight = gradOutput.size(2);
+  int outputWidth = gradOutput.size(3);
+  int inputHeight = gradInput.size(2);
+  int inputWidth = gradInput.size(3);
+
+  DType *data1 = gradInput.dptr_;
+  DType *data2 = gradOutput.dptr_;
+  channels = nbatch * channels;
+
+  // special case: same-size matching grids
+  if (inputHeight == outputHeight && inputWidth == outputWidth) {
+    for (int h2 = 0; h2 < outputHeight; ++h2) {
+      const int h1 = h2;
+      for (int w2 = 0; w2 < outputWidth; ++w2) {
+        const int w1 = w2;
+        DType* pos1 = &data1[h1 * inputWidth + w1];
+        const DType* pos2 = &data2[h2 * outputWidth + w2];
+        for (int c = 0; c < channels; ++c) {
+          pos1[0] += pos2[0];
+          pos1 += inputWidth * inputHeight;
+          pos2 += outputWidth * outputHeight;
+        }
+      }
+    }
+    return;
+  }
+  const float rheight =(outputHeight > 1) ? static_cast<float>(inputHeight - 
1)/
+                       (outputHeight - 1) : 0.f;
+  const float rwidth = (outputWidth > 1) ? static_cast<float>(inputWidth - 1)/
+                       (outputWidth - 1) : 0.f;
+  for (int h2 = 0; h2 < outputHeight; ++h2) {
+    const float h1r = rheight * h2;
+    const int h1 = h1r;
+    const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
+    const DType h1lambda = h1r - h1;
+    const DType h0lambda = (DType)1. - h1lambda;
+    for (int w2 = 0; w2 < outputWidth; ++w2) {
+      const float w1r = rwidth * w2;
+      const int w1 = w1r;
+      const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
+      const DType w1lambda = w1r - w1;
+      const DType w0lambda = (DType)1. - w1lambda;
+      DType* pos1 = &data1[h1 * inputWidth + w1];
+      const DType* pos2 = &data2[h2 * outputWidth + w2];
+      for (int c = 0; c < channels; ++c) {
+        pos1[0] += h0lambda * w0lambda * pos2[0];
+        pos1[w1p] += h0lambda * w1lambda * pos2[0];
+        pos1[h1p * inputWidth] += h1lambda * w0lambda * pos2[0];
+        pos1[h1p * inputWidth + w1p] += h1lambda * w1lambda * pos2[0];
+        pos1 += inputWidth * inputHeight;
+        pos2 += outputWidth * outputHeight;
+      }
+    }
+  }
+}
+
+
+DMLC_REGISTER_PARAMETER(BilinearSampleParam);
+
+NNVM_REGISTER_OP(_contrib_BilinearResize2D)
+.describe(R"code(
+Perform 2D resizing (upsampling or downsampling) for 4D input using bilinear 
interpolation.
+
+Expected input is a 4 dimensional NDArray (NCHW) and the output
+with the shape of (N x C x height x width). 
+The key idea of bilinear interpolation is to perform linear interpolation
+first in one direction, and then again in the other direction. See the 
wikipedia of
+`Bilinear interpolation  
<https://en.wikipedia.org/wiki/Bilinear_interpolation>`_
+for more details.
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<BilinearSampleParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FInferShape>("FInferShape", BilinearSampleOpInferShape)
+.set_attr<nnvm::FInferType>("FInferType", BilinearSampleOpInferType)
+.set_attr<FInferStorageType>("FInferStorageType", BilinearSampleOpStorageType)
+.set_attr<FCompute>("FCompute<cpu>", BilinearSampleOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient",
+  ElemwiseGradUseNone{"_backward_contrib_BilinearResize2D"})
+.add_argument("data", "NDArray-or-Symbol", "Input data")
+.add_arguments(BilinearSampleParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D)
+.set_attr_parser(ParamParser<BilinearSampleParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FInferStorageType>("FInferStorageType", BilinearSampleOpStorageType)
+.set_attr<FCompute>("FCompute<cpu>", BilinearSampleOpBackward<cpu>);
+
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/contrib/bilinear_resize.cu 
b/src/operator/contrib/bilinear_resize.cu
new file mode 100644
index 00000000000..f01c9c2fa13
--- /dev/null
+++ b/src/operator/contrib/bilinear_resize.cu
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file bilinear_resize.cu
+ * \brief bilinear resize operator
+ * \author Hang Zhang
+*/
+#include <cuda_runtime_api.h>
+#include <algorithm>
+#include "bilinear_resize-inl.h"
+
+namespace mxnet {
+namespace op {
+
+using namespace mshadow;
+
+template<typename In, typename Out>
+struct ScalarConvert {
+  static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) 
v; }
+};
+
+
+// The maximum number of threads in a block
+static const unsigned MAX_BLOCK_SIZE = 512U;
+
+// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
+static unsigned getNumThreads(int nElem, const bool smaller) {
+  unsigned threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
+  const int maxi = smaller ? 4 : 5;
+  for (int i = 0; i != maxi; ++i) {
+    if (static_cast<unsigned>(nElem) <= threadSizes[i]) {
+      return threadSizes[i];
+    }
+  }
+  return smaller ? (MAX_BLOCK_SIZE >> 1) : MAX_BLOCK_SIZE;
+}
+
+template<typename xpu, typename Dtype, typename Acctype>
+__global__ void caffe_gpu_interp2_kernel(const int n,
+    const Acctype rheight, const Acctype rwidth,
+    const Tensor<xpu, 4, Dtype> data1,
+    Tensor<xpu, 4, Dtype> data2) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  const int batchsize = data1.size(0);
+  const int channels = data1.size(1);
+  const int height1 = data1.size(2);
+  const int width1 = data1.size(3);
+  const int height2 = data2.size(2);
+  const int width2 = data2.size(3);
+
+  if (index < n) {
+    const int w2 = index % width2;  // 0:width2-1
+    const int h2 = index / width2;  // 0:height2-1
+    // special case: just copy
+    if (height1 == height2 && width1 == width2) {
+      const int h1 = h2;
+      const int w1 = w2;
+      for (int n = 0; n < batchsize ; n++) {
+        for (int c = 0; c < channels; ++c) {
+          const Dtype val = data1[n][c][h1][w1];
+          data2[n][c][h2][w2] = val;
+        }
+      }
+      return;
+    }
+    //
+    const Acctype h1r = rheight * h2;
+    const int h1 = h1r;
+    const int h1p = (h1 < height1 - 1) ? 1 : 0;
+    const Acctype h1lambda = h1r - h1;
+    const Acctype h0lambda = Acctype(1) - h1lambda;
+    //
+    const Acctype w1r = rwidth * w2;
+    const int w1 = w1r;
+    const int w1p = (w1 < width1 - 1) ? 1 : 0;
+    const Acctype w1lambda = w1r - w1;
+    const Acctype w0lambda = Acctype(1) - w1lambda;
+    //
+    for (int n = 0; n < batchsize ; n++) {
+        for (int c = 0; c < channels; ++c) {
+        const Acctype val = h0lambda * (w0lambda * data1[n][c][h1][w1]
+                            + w1lambda * data1[n][c][h1][w1+w1p])
+                            + h1lambda * (w0lambda * data1[n][c][h1+h1p][w1]
+                            + w1lambda * data1[n][c][h1+h1p][w1+w1p]);
+        data2[n][c][h2][w2] = ScalarConvert<Acctype, Dtype>::to(val);
+      }
+    }
+  }
+}
+
+// Backward (adjoint) operation 1 <- 2 (accumulates)
+template<typename xpu, typename Dtype, typename Acctype>
+__global__ void caffe_gpu_interp2_kernel_backward(const int n,
+    const Acctype rheight, const Acctype rwidth,
+    Tensor<xpu, 4, Dtype> data1, const Tensor<xpu, 4, Dtype> data2) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  const int batchsize = data1.size(0);
+  const int channels = data1.size(1);
+  const int height1 = data1.size(2);
+  const int width1 = data1.size(3);
+  const int height2 = data2.size(2);
+  const int width2 = data2.size(3);
+  if (index < n) {
+    const int w2 = index % width2;  // 0:width2-1
+    const int h2 = index / width2;  // 0:height2-1
+    // special case: just copy
+    if (height1 == height2 && width1 == width2) {
+      const int h1 = h2;
+      const int w1 = w2;
+      for (int n = 0; n < batchsize ; n++) {
+        for (int c = 0; c < channels; ++c) {
+          const Dtype val = data2[n][c][h1][w1];
+          data1[n][c][h2][w2] += val;
+        }
+      }
+      return;
+    }
+    //
+    const Acctype h1r = rheight * h2;
+    const int h1 = h1r;
+    const int h1p = (h1 < height1 - 1) ? 1 : 0;
+    const Acctype h1lambda = h1r - h1;
+    const Acctype h0lambda = Acctype(1) - h1lambda;
+    //
+    const Acctype w1r = rwidth * w2;
+    const int w1 = w1r;
+    const int w1p = (w1 < width1 - 1) ? 1 : 0;
+    const Acctype w1lambda = w1r - w1;
+    const Acctype w0lambda = Acctype(1) - w1lambda;
+    //
+    for (int n = 0; n < batchsize ; n++) {
+      for (int c = 0; c < channels; ++c) {
+        const Dtype d2val = data2[n][c][h2][w2];
+        atomicAdd(&data1[n][c][h1][w1],
+                  ScalarConvert<Acctype, Dtype>::to(h0lambda * w0lambda * 
d2val));
+        atomicAdd(&data1[n][c][h1][w1+w1p],
+                  ScalarConvert<Acctype, Dtype>::to(h0lambda * w1lambda * 
d2val));
+        atomicAdd(&data1[n][c][h1+h1p][w1],
+                  ScalarConvert<Acctype, Dtype>::to(h1lambda * w0lambda * 
d2val));
+        atomicAdd(&data1[n][c][h1+h1p][w1+w1p],
+                  ScalarConvert<Acctype, Dtype>::to(h1lambda * w1lambda * 
d2val));
+      }
+    }
+  }
+}
+
+template<typename xpu, typename DType, typename AccReal>
+void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<gpu> *s,
+                                           const std::vector<TBlob> &input,
+                                           const std::vector<TBlob> &output) {
+  Tensor<xpu, 4, DType> idata = input[0].get<xpu, 4, DType>(s);
+  Tensor<xpu, 4, DType> odata = output[0].get<xpu, 4, DType>(s);
+  int outputHeight = odata.size(2);
+  int outputWidth = odata.size(3);
+  int inputHeight = idata.size(2);
+  int inputWidth = idata.size(3);
+
+  const AccReal rheight = (outputHeight > 1) ? (AccReal)(inputHeight - 1)/
+                         (outputHeight - 1) : AccReal(0);
+  const AccReal rwidth = (outputWidth > 1) ? (AccReal)(inputWidth - 1)/
+                         (outputWidth - 1) : AccReal(0);
+  const int num_kernels = outputHeight * outputWidth;
+  const int num_threads = getNumThreads(inputHeight*inputWidth, false);
+  dim3 blocks(static_cast<int>(num_kernels / num_threads) + 1);
+  dim3 threads(num_threads);
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+  caffe_gpu_interp2_kernel<xpu, DType, AccReal>
+  <<<blocks, threads , 0, stream>>>(
+    num_kernels, rheight, rwidth, idata, odata);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateOutput);
+}
+
+template<typename xpu, typename DType, typename AccReal>
+void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<gpu> *s,
+                                              const std::vector<TBlob> &input,
+                                              const std::vector<TBlob> 
&output) {
+  Tensor<xpu, 4, DType> data1 = output[0].get<xpu, 4, DType>(s);
+  Tensor<xpu, 4, DType> data2 = input[0].get<xpu, 4, DType>(s);
+  int height1 = data1.size(2);
+  int width1 = data1.size(3);
+  int height2 = data2.size(2);
+  int width2 = data2.size(3);
+  const AccReal rheight = (height2 > 1) ? (AccReal)(height1 - 1)/(height2 - 1) 
: AccReal(0);
+  const AccReal rwidth = (width2 > 1) ? (AccReal)(width1 - 1) / (width2 - 1) : 
AccReal(0);
+  const int num_kernels = height2 * width2;
+  const int num_threads = getNumThreads(height1*width1, false);
+  dim3 blocks(static_cast<int>(num_kernels / num_threads) + 1);
+  dim3 threads(num_threads);
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+  caffe_gpu_interp2_kernel_backward<xpu, DType, AccReal>
+  <<<blocks, threads, 0, stream>>>(
+    num_kernels, rheight, rwidth, data1, data2);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateGradInput);
+}
+
+NNVM_REGISTER_OP(_contrib_BilinearResize2D)
+.set_attr<FCompute>("FCompute<gpu>", BilinearSampleOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D)
+.set_attr<FCompute>("FCompute<gpu>", BilinearSampleOpBackward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 61b4478d4d1..8d66b2a74ac 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -5280,6 +5280,84 @@ def check_squeeze_op(shape, axis=None):
     check_numeric_gradient(test, [data_tmp])
 
 @with_seed()
+def test_adaptive_avg_pool_op():
+    def py_adaptive_avg_pool(x, height, width):
+        # 2D per frame adaptive avg pool
+        def adaptive_avg_pool_frame(x, y):
+            isizeH, isizeW = x.shape
+            osizeH, osizeW = y.shape
+            for oh in range(osizeH):
+                istartH = int(np.floor(1.0 * (oh * isizeH) / osizeH))
+                iendH = int(np.ceil(1.0 * (oh + 1) * isizeH / osizeH))
+                kH = iendH - istartH
+                for ow in range(osizeW):
+                    istartW = int(np.floor(1.0 * (ow * isizeW) / osizeW))
+                    iendW = int(np.ceil(1.0 * (ow + 1) * isizeW / osizeW))
+                    kW = iendW - istartW
+                    xsum = 0
+                    for ih in range(kH):
+                        for iw in range(kW):
+                            xsum += x[istartH+ih][istartW+iw]
+                    y[oh][ow] = xsum / kH / kW
+
+        B,C,_,_ = x.shape
+        y = np.empty([B,C,height, width], dtype=x.dtype)
+        for b in range(B):
+            for c in range(C):
+                adaptive_avg_pool_frame(x[b][c], y[b][c])
+        return y
+    def check_adaptive_avg_pool_op(shape, output_height, output_width=None):
+        x = mx.nd.random.uniform(shape=shape)
+        if output_width is None:
+            y = mx.nd.contrib.AdaptiveAvgPooling2D(x, 
output_size=output_height)
+            npy = py_adaptive_avg_pool(x.asnumpy(), output_height, 
output_height)
+        else:
+            y = mx.nd.contrib.AdaptiveAvgPooling2D(x, 
output_size=(output_height, output_width))
+            npy = py_adaptive_avg_pool(x.asnumpy(), output_height, 
output_width)
+        assert_almost_equal(y.asnumpy(), npy)
+    shape = (2, 2, 10, 10)
+    for i in range(1, 11):
+        check_adaptive_avg_pool_op(shape, i)
+        for j in range(1, 11):
+            check_adaptive_avg_pool_op(shape, i, j)
+
+@with_seed()
+def test_bilinear_resize_op():
+    def py_bilinear_resize(x, outputHeight, outputWidth):
+        batch, channel, inputHeight, inputWidth = x.shape
+        if outputHeight == inputHeight and outputWidth == inputWidth:
+            return x
+        y = np.empty([batch, channel, outputHeight, outputWidth]) 
+        rheight = 1.0 * (inputHeight - 1) / (outputHeight - 1) if outputHeight 
> 1 else 0.0
+        rwidth = 1.0 * (inputWidth - 1) / (outputWidth - 1) if outputWidth > 1 
else 0.0
+        for h2 in range(outputHeight):
+            h1r = 1.0 * h2 * rheight
+            h1 = int(np.floor(h1r))
+            h1lambda = h1r - h1
+            h1p = 1 if h1 < (inputHeight - 1) else 0
+            for w2 in range(outputWidth):
+                w1r = 1.0 * w2 * rwidth
+                w1 = int(np.floor(w1r))
+                w1lambda = w1r - w1
+                w1p = 1 if w1 < (inputHeight - 1) else 0
+                for b in range(batch):
+                    for c in range(channel):
+                        y[b][c][h2][w2] = 
(1-h1lambda)*((1-w1lambda)*x[b][c][h1][w1] + \
+                            w1lambda*x[b][c][h1][w1+w1p]) + \
+                            h1lambda*((1-w1lambda)*x[b][c][h1+h1p][w1] + \
+                            w1lambda*x[b][c][h1+h1p][w1+w1p])
+        return y
+    def check_bilinear_resize_op(shape, height, width):
+        x = mx.nd.random.uniform(shape=shape)
+        y = mx.nd.contrib.BilinearResize2D(x, height=height, width=width)
+        assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), 
height, width))
+    shape = (2, 2, 10, 10)
+    check_bilinear_resize_op(shape, 5, 5)
+    check_bilinear_resize_op(shape, 10, 10)
+    check_bilinear_resize_op(shape, 15, 15)
+    check_bilinear_resize_op(shape, 3, 7)
+    check_bilinear_resize_op(shape, 13, 17)
+
 def test_multi_proposal_op():
     # paramters
     feature_stride = 16


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to