This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b30949f  [MXNet-1211] Factor and "Like" modes in BilinearResize2D 
operator (#13226)
b30949f is described below

commit b30949f3bf3606477a89955dc37f58b6eb285d99
Author: Mikhail Lobanov <[email protected]>
AuthorDate: Sun May 5 12:53:14 2019 +0300

    [MXNet-1211] Factor and "Like" modes in BilinearResize2D operator (#13226)
    
    * Added "factor" and "like" modes into BilinearResize2D operator.
    Also added tests and some fixes to visualization needed due to added modes.
    
    * Lint fix
    
    * Test fix
    
    * retrigger CI
    
    * retrigger CI
    
    * Retrigger CI
    
    * retrigger CI
    
    * retrigger ci
    
    * retrigger ci again
---
 CONTRIBUTORS.md                            |   1 +
 python/mxnet/visualization.py              |   4 +
 src/operator/contrib/bilinear_resize-inl.h | 173 +++++++++++++++++++++++++----
 src/operator/contrib/bilinear_resize.cc    |  52 ++++++---
 src/operator/contrib/bilinear_resize.cu    |  37 +++++-
 tests/python/gpu/test_gluon_transforms.py  |   4 +-
 tests/python/unittest/test_operator.py     | 116 +++++++++++++++++++
 7 files changed, 349 insertions(+), 38 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index be497e5..7daccf0 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -218,6 +218,7 @@ List of Contributors
 * [Dang Trung Kien](https://github.com/kiendang)
 * [Zach Boldyga](https://github.com/zboldyga)
 * [Gordon Reid](https://github.com/gordon1992)
+* [Mikhail Lobanov](https://github.com/lobanov-m)
 * [Ming Yang](http://ufoym.com)
 * [Satya Krishna Gorti](https://github.com/satyakrishnagorti)
 * [Neo Chien](https://github.com/cchung100m)
diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py
index 4101f74..bc8309a 100644
--- a/python/mxnet/visualization.py
+++ b/python/mxnet/visualization.py
@@ -384,6 +384,10 @@ def plot_network(symbol, title="plot", save_format='pdf', 
shape=None, dtype=None
             continue
         else:
             inputs = node["inputs"]
+
+            if node['op'] == '_contrib_BilinearResize2D':
+                inputs = [inputs[0]]
+
             for item in inputs:
                 input_node = nodes[item[0]]
                 input_name = input_node["name"]
diff --git a/src/operator/contrib/bilinear_resize-inl.h 
b/src/operator/contrib/bilinear_resize-inl.h
index ce9c6c8..4da12cb 100644
--- a/src/operator/contrib/bilinear_resize-inl.h
+++ b/src/operator/contrib/bilinear_resize-inl.h
@@ -44,6 +44,12 @@
 #include "../mxnet_op.h"
 #include "../mshadow_op.h"
 
+namespace bilinear_resize {
+enum BilinearResizeOpMode{simple, odd_scale, like, to_even_down, to_even_up, 
to_odd_down,
+  to_odd_up};
+}  // namespace bilinear_resize
+
+
 namespace mxnet {
 namespace op {
 
@@ -52,15 +58,45 @@ struct BilinearSampleParam : public 
dmlc::Parameter<BilinearSampleParam> {
   int width;
   dmlc::optional<float> scale_height;
   dmlc::optional<float> scale_width;
+  int mode;
   DMLC_DECLARE_PARAMETER(BilinearSampleParam) {
     DMLC_DECLARE_FIELD(height).set_default(1).set_range(1, 10000)
-    .describe("output height (required, but ignored if scale_height is 
defined)");
+    .describe("output height (required, but ignored if scale_height is defined 
or mode is not "
+              "\"size\")");
     DMLC_DECLARE_FIELD(width).set_default(1).set_range(1, 10000)
-    .describe("output width (required, but ignored if scale_width is 
defined)");
+    .describe("output width (required, but ignored if scale_width is defined 
or mode is not "
+              "\"size\")");
     DMLC_DECLARE_FIELD(scale_height).set_default(dmlc::optional<float>())
-    .describe("sampling scale of the height (optional, ignores height if 
defined)");
+    .describe("sampling scale of the height (optional, used in modes \"scale\" 
and \"odd_scale\")");
     DMLC_DECLARE_FIELD(scale_width).set_default(dmlc::optional<float>())
-    .describe("sampling scale of the scale_width (optional, ignores width if 
defined)");
+    .describe("sampling scale of the width (optional, used in modes \"scale\" 
and \"odd_scale\")");
+    DMLC_DECLARE_FIELD(mode)
+    .add_enum("size", bilinear_resize::simple)
+    .add_enum("odd_scale", bilinear_resize::odd_scale)
+    .add_enum("like", bilinear_resize::like)
+    .add_enum("to_even_down", bilinear_resize::to_even_down)
+    .add_enum("to_even_up", bilinear_resize::to_even_up)
+    .add_enum("to_odd_down", bilinear_resize::to_odd_down)
+    .add_enum("to_odd_up", bilinear_resize::to_odd_up)
+    .set_default(bilinear_resize::simple)
+    .describe("resizing mode. \"simple\" - output height equals parameter 
\"height\" if "
+              "\"scale_height\" parameter is not defined or input height 
multiplied by "
+              "\"scale_height\" otherwise. Same for width;"
+              "\"odd_scale\" - if original height or width is odd, then result 
height is "
+              "calculated like result_h = (original_h - 1) * scale + 1; "
+              "for scale > 1 the result shape would be like if we did 
deconvolution with kernel "
+              "= (1, 1) and stride = (height_scale, width_scale); and for 
scale < 1 shape "
+              "would be like we did convolution with kernel = (1, 1) and "
+              "stride = (int(1 / height_scale), int( 1/ width_scale);"
+              "\"like\" - resize first input to the height and width of second 
input; "
+              "\"to_even_down\" - resize input to nearest lower even height 
and width "
+              "(if original height is odd then result height = original height 
- 1);"
+              "\"to_even_up\" - resize input to nearest bigger even height and 
width "
+              "(if original height is odd then result height = original height 
+ 1);"
+              "\"to_odd_down\" - resize input to nearest odd height and width "
+              "(if original height is odd then result height = original height 
- 1);"
+              "\"to_odd_up\" - resize input to nearest odd height and width "
+              "(if original height is odd then result height = original height 
+ 1);");
   }
 };
 
@@ -76,7 +112,8 @@ void 
SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<cpu> *s,
 template<typename xpu, typename DType, typename AccReal>
 void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
                                               const std::vector<TBlob> &input,
-                                              const std::vector<TBlob> 
&output);
+                                              const std::vector<TBlob> &output,
+                                              bool modeLike);
 
 #if MXNET_USE_CUDA
 template<typename xpu, typename DType, typename AccReal>
@@ -87,7 +124,8 @@ void 
SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<gpu> *s,
 template<typename xpu, typename DType, typename AccReal>
 void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<gpu> *s,
                                               const std::vector<TBlob> &input,
-                                              const std::vector<TBlob> 
&output);
+                                              const std::vector<TBlob> &output,
+                                              bool modeLike);
 #endif  // MXNET_USE_CUDA
 
 template <typename xpu>
@@ -96,7 +134,9 @@ inline void BilinearSampleOpForward(const nnvm::NodeAttrs& 
attrs,
                                     const std::vector<TBlob> &inputs,
                                     const std::vector<OpReqType> &req,
                                     const std::vector<TBlob> &outputs) {
-  CHECK_EQ(inputs.size(), 1U);
+  const BilinearSampleParam& param = 
nnvm::get<BilinearSampleParam>(attrs.parsed);
+  size_t expected = param.mode == bilinear_resize::like ? 2 : 1;
+  CHECK_EQ(inputs.size(), expected);
   CHECK_EQ(outputs.size(), 1U);
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
@@ -111,8 +151,11 @@ inline void BilinearSampleOpBackward(const 
nnvm::NodeAttrs& attrs,
                                      const std::vector<TBlob> &inputs,
                                      const std::vector<OpReqType> &req,
                                      const std::vector<TBlob> &outputs) {
+  const BilinearSampleParam& param = 
nnvm::get<BilinearSampleParam>(attrs.parsed);
   CHECK_EQ(inputs.size(), 1U);
-  CHECK_EQ(outputs.size(), 1U);
+  bool modeLike = param.mode == bilinear_resize::like;
+  size_t expected = modeLike ? 2 : 1;
+  CHECK_EQ(outputs.size(), expected);
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   if (IsWriting(req[0])) {
     // zero grad before backwarding
@@ -121,7 +164,7 @@ inline void BilinearSampleOpBackward(const nnvm::NodeAttrs& 
attrs,
     })
   }
   MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
-    SpatialUpSamplingBilinearUpdateGradInput<xpu, DType, AccReal>(s, inputs, 
outputs);
+    SpatialUpSamplingBilinearUpdateGradInput<xpu, DType, AccReal>(s, inputs, 
outputs, modeLike);
   });
 }
 
@@ -130,28 +173,120 @@ static bool BilinearSampleOpInferShape(const 
nnvm::NodeAttrs& attrs,
                                        mxnet::ShapeVector *in_shape,
                                        mxnet::ShapeVector *out_shape) {
   using namespace mshadow;
-  CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
   CHECK_EQ(out_shape->size(), 1U) << "Output:[data]";
   const BilinearSampleParam& param = 
nnvm::get<BilinearSampleParam>(attrs.parsed);
+  size_t expected = param.mode == bilinear_resize::like ? 2 : 1;
+  CHECK_EQ(in_shape->size(), expected);
   mxnet::TShape dshape(in_shape->at(0));
   if (mxnet::op::shape_is_none(dshape)) return false;
-  if (param.scale_height.has_value()) {
-    dshape[2] = static_cast<int>(param.scale_height.value() * 
in_shape->at(0)[2]);
-  } else {
-    dshape[2] = param.height;
+  int16_t new_height = -1;
+  int16_t new_width = -1;
+  switch (param.mode) {
+    case bilinear_resize::simple:
+    {
+      if (param.scale_height.has_value()) {
+        new_height = static_cast<int>(param.scale_height.value() * 
in_shape->at(0)[2]);
+      } else {
+        new_height = param.height;
+      }
+      if (param.scale_height.has_value()) {
+        new_width = static_cast<int>(param.scale_width.value() * 
in_shape->at(0)[3]);
+      } else {
+        new_width = param.width;
+      }
+      break;
+    }
+    case bilinear_resize::odd_scale:
+    {
+      new_height = ((dshape[2] % 2) == 0) ? (int16_t) (dshape[2] * 
param.scale_height.value()) :
+                   (int16_t) ((dshape[2] - 1) * param.scale_height.value()) + 
1;
+      new_width = ((dshape[3] % 2) == 0) ? (int16_t) (dshape[3] * 
param.scale_width.value()) :
+                  (int16_t) ((dshape[3] - 1) * param.scale_width.value()) + 1;
+      break;
+    }
+    case bilinear_resize::like:
+    {
+      TShape like_shape(in_shape->at(1));
+      if (dshape.ndim() == 0) return false;
+      new_height = like_shape[2];
+      new_width = like_shape[3];
+      break;
+    }
+    case bilinear_resize::to_even_down:
+    {
+      new_height = ((dshape[2] % 2) == 0) ? dshape[2] : dshape[2] - 1;
+      new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] - 1;
+      break;
+    }
+    case bilinear_resize::to_even_up:
+    {
+      new_height = ((dshape[2] % 2) == 0) ? dshape[2] : dshape[2] + 1;
+      new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] + 1;
+      break;
+    }
+    case bilinear_resize::to_odd_down:
+    {
+      new_height = ((dshape[2] % 2) == 1) ? dshape[2] : dshape[2] - 1;
+      new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] - 1;
+      break;
+    }
+    case bilinear_resize::to_odd_up:
+    {
+      new_height = ((dshape[2] % 2) == 1) ? dshape[2] : dshape[2] + 1;
+      new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] + 1;
+      break;
+    }
+    default:
+    {
+      LOG(FATAL) << "Invalid mode " << param.mode;
+    }
   }
 
-  if (param.scale_height.has_value()) {
-    dshape[3] = static_cast<int>(param.scale_width.value() * 
in_shape->at(0)[3]);
-  } else {
-    dshape[3] = param.width;
-  }
+  dshape[2] = new_height;
+  dshape[3] = new_width;
 
   out_shape->clear();
   out_shape->push_back(dshape);
   return true;
 }
 
+
+inline uint16_t BilinearSampleOpNumInputs(const NodeAttrs& attrs) {
+  auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
+  if (param.mode == bilinear_resize::like) {
+      return 2;
+  } else {
+    return 1;
+  }
+}
+
+inline uint16_t BilinearSampleOpNumBackwardInputs(const NodeAttrs& attrs) {
+  auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
+  if (param.mode == bilinear_resize::like) {
+    return 3;
+  } else {
+    return 1;
+  }
+}
+
+inline uint16_t BilinearSampleOpNumBackwardOutputs(const NodeAttrs& attrs) {
+  auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
+  if (param.mode == bilinear_resize::like) {
+    return 2;
+  } else {
+    return 1;
+  }
+}
+
+inline std::vector<std::string> BilinearSampleOpInputNames(const NodeAttrs& 
attrs) {
+  auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
+  if (param.mode == bilinear_resize::like) {
+    return std::vector<std::string>{"data", "like"};
+  } else {
+    return std::vector<std::string>{"data"};
+  }
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/contrib/bilinear_resize.cc 
b/src/operator/contrib/bilinear_resize.cc
index 1288e9d..441ea53 100644
--- a/src/operator/contrib/bilinear_resize.cc
+++ b/src/operator/contrib/bilinear_resize.cc
@@ -97,7 +97,8 @@ void 
SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<cpu> *s,
 template<typename xpu, typename DType, typename AccReal>
 void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
                                               const std::vector<TBlob> &input,
-                                              const std::vector<TBlob> 
&output) {
+                                              const std::vector<TBlob> &output,
+                                              bool modeLike) {
   Tensor<xpu, 4, DType> gradOutput = input[0].get<xpu, 4, DType>(s);
   Tensor<xpu, 4, DType> gradInput = output[0].get<xpu, 4, DType>(s);
 
@@ -108,8 +109,8 @@ void 
SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
   int inputHeight = gradInput.size(2);
   int inputWidth = gradInput.size(3);
 
-  DType *data1 = gradInput.dptr_;
-  DType *data2 = gradOutput.dptr_;
+  DType *dataInput = gradInput.dptr_;
+  DType *dataOutput = gradOutput.dptr_;
   channels = nbatch * channels;
 
   // special case: same-size matching grids
@@ -118,8 +119,8 @@ void 
SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
       const int h1 = h2;
       for (int w2 = 0; w2 < outputWidth; ++w2) {
         const int w1 = w2;
-        DType* pos1 = &data1[h1 * inputWidth + w1];
-        const DType* pos2 = &data2[h2 * outputWidth + w2];
+        DType* pos1 = &dataInput[h1 * inputWidth + w1];
+        const DType* pos2 = &dataOutput[h2 * outputWidth + w2];
         for (int c = 0; c < channels; ++c) {
           pos1[0] += pos2[0];
           pos1 += inputWidth * inputHeight;
@@ -145,15 +146,32 @@ void 
SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
       const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
       const DType w1lambda = w1r - w1;
       const DType w0lambda = (DType)1. - w1lambda;
-      DType* pos1 = &data1[h1 * inputWidth + w1];
-      const DType* pos2 = &data2[h2 * outputWidth + w2];
+      DType* posInput = &dataInput[h1 * inputWidth + w1];
+      const DType* posOutput = &dataOutput[h2 * outputWidth + w2];
       for (int c = 0; c < channels; ++c) {
-        pos1[0] += h0lambda * w0lambda * pos2[0];
-        pos1[w1p] += h0lambda * w1lambda * pos2[0];
-        pos1[h1p * inputWidth] += h1lambda * w0lambda * pos2[0];
-        pos1[h1p * inputWidth + w1p] += h1lambda * w1lambda * pos2[0];
-        pos1 += inputWidth * inputHeight;
-        pos2 += outputWidth * outputHeight;
+        posInput[0] += h0lambda * w0lambda * posOutput[0];
+        posInput[w1p] += h0lambda * w1lambda * posOutput[0];
+        posInput[h1p * inputWidth] += h1lambda * w0lambda * posOutput[0];
+        posInput[h1p * inputWidth + w1p] += h1lambda * w1lambda * posOutput[0];
+        posInput += inputWidth * inputHeight;
+        posOutput += outputWidth * outputHeight;
+      }
+    }
+  }
+
+  if (modeLike) {
+    Tensor<xpu, 4, DType> gradInputLike = output[1].get<xpu, 4, DType>(s);
+    int inputHeightLike = gradInputLike.size(2);
+    int inputWidthLike = gradInputLike.size(3);
+    DType *dataInputLike = gradInputLike.dptr_;
+    int channelsLike = nbatch * gradInputLike.size(1);
+    for (int h_like = 0; h_like < inputHeightLike; ++h_like) {
+      for (int w_like = 0; w_like < inputWidthLike; ++w_like) {
+        DType *posInput = &dataInputLike[h_like * inputWidthLike + w_like];
+        for (int c = 0; c < channelsLike; ++c) {
+          posInput[0] = 0;
+          posInput += inputWidthLike * inputHeightLike;
+        }
       }
     }
   }
@@ -174,19 +192,21 @@ first in one direction, and then again in the other 
direction. See the wikipedia
 for more details.
 )code" ADD_FILELINE)
 .set_attr_parser(ParamParser<BilinearSampleParam>)
-.set_num_inputs(1)
+.set_num_inputs(BilinearSampleOpNumInputs)
 .set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames", BilinearSampleOpInputNames)
 .set_attr<mxnet::FInferShape>("FInferShape", BilinearSampleOpInferShape)
 .set_attr<FCompute>("FCompute<cpu>", BilinearSampleOpForward<cpu>)
 .set_attr<nnvm::FGradient>("FGradient",
   ElemwiseGradUseNone{"_backward_contrib_BilinearResize2D"})
 .add_argument("data", "NDArray-or-Symbol", "Input data")
+.add_argument("like", "NDArray-or-Symbol", "Resize data to it's shape")
 .add_arguments(BilinearSampleParam::__FIELDS__());
 
 NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D)
 .set_attr_parser(ParamParser<BilinearSampleParam>)
-.set_num_inputs(1)
-.set_num_outputs(1)
+.set_num_inputs(BilinearSampleOpNumBackwardInputs)
+.set_num_outputs(BilinearSampleOpNumBackwardOutputs)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FCompute>("FCompute<cpu>", BilinearSampleOpBackward<cpu>);
 
diff --git a/src/operator/contrib/bilinear_resize.cu 
b/src/operator/contrib/bilinear_resize.cu
index b0a4c4b..0753c47 100644
--- a/src/operator/contrib/bilinear_resize.cu
+++ b/src/operator/contrib/bilinear_resize.cu
@@ -32,6 +32,26 @@ namespace op {
 
 using namespace mshadow;
 
+template<typename xpu, typename Dtype, typename Acctype>
+__global__ void like_mode_kernel_backward(const int n,
+    Tensor<xpu, 4, Dtype> dataLike) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  const int batchsize = dataLike.size(0);
+  const int channels = dataLike.size(1);
+  const int height = dataLike.size(2);
+  const int width = dataLike.size(3);
+  if (index < n) {
+    const int w = index % width;
+    const int h = index / width;
+    for (int n = 0; n < batchsize ; n++) {
+      for (int c = 0; c < channels; ++c) {
+        dataLike[n][c][h][w] = 0;
+      }
+    }
+    return;
+  }
+}
+
 // Backward (adjoint) operation 1 <- 2 (accumulates)
 template<typename xpu, typename Dtype, typename Acctype>
 __global__ void caffe_gpu_interp2_kernel_backward(const int n,
@@ -118,7 +138,8 @@ void 
SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<gpu> *s,
 template<typename xpu, typename DType, typename AccReal>
 void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<gpu> *s,
                                               const std::vector<TBlob> &input,
-                                              const std::vector<TBlob> 
&output) {
+                                              const std::vector<TBlob> &output,
+                                              bool modeLike) {
   Tensor<xpu, 4, DType> data1 = output[0].get<xpu, 4, DType>(s);
   Tensor<xpu, 4, DType> data2 = input[0].get<xpu, 4, DType>(s);
   int height1 = data1.size(2);
@@ -135,6 +156,20 @@ void 
SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<gpu> *s,
   caffe_gpu_interp2_kernel_backward<xpu, DType, AccReal>
   <<<blocks, threads, 0, stream>>>(
     num_kernels, rheight, rwidth, data1, data2);
+
+  if (modeLike) {
+    Tensor<xpu, 4, DType> dataLike = output[1].get<xpu, 4, DType>(s);
+    int heightLike = dataLike.size(2);
+    int widthLike = dataLike.size(3);
+    const int num_kernels_like = heightLike * widthLike;
+    const int num_threads_like = getNumThreads(num_kernels_like, false);
+    dim3 blocksLike(static_cast<int>(num_kernels_like / num_threads_like) + 1);
+    dim3 threadsLike(num_threads_like);
+    like_mode_kernel_backward<xpu, DType, AccReal>
+    <<<blocksLike, threadsLike, 0, stream>>>(
+      num_kernels_like, dataLike);
+  }
+
   MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateGradInput);
 }
 
diff --git a/tests/python/gpu/test_gluon_transforms.py 
b/tests/python/gpu/test_gluon_transforms.py
index 23b34d3..599a02c 100644
--- a/tests/python/gpu/test_gluon_transforms.py
+++ b/tests/python/gpu/test_gluon_transforms.py
@@ -96,14 +96,14 @@ def test_resize():
     data_in_3d = nd.random.uniform(0, 255, (300, 300, 3))
     out_nd_3d = transforms.Resize((100, 100))(data_in_3d)
     data_in_4d_nchw = nd.moveaxis(nd.expand_dims(data_in_3d, axis=0), 3, 1)
-    data_expected_3d = 
(nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, 100, 100), 1, 3))[0]
+    data_expected_3d = 
(nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, height=100, 
width=100), 1, 3))[0]
     assert_almost_equal(out_nd_3d.asnumpy(), data_expected_3d.asnumpy())
 
     # Test with normal case 4D input float type
     data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3))
     out_nd_4d = transforms.Resize((100, 100))(data_in_4d)
     data_in_4d_nchw = nd.moveaxis(data_in_4d, 3, 1)
-    data_expected_4d = 
nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, 100, 100), 1, 3)
+    data_expected_4d = 
nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, height=100, 
width=100), 1, 3)
     assert_almost_equal(out_nd_4d.asnumpy(), data_expected_4d.asnumpy())
 
     # Test invalid interp
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index e8bfaba..2406a1c 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -7164,6 +7164,45 @@ def test_bilinear_resize_op():
                             h1lambda*((1-w1lambda)*x[b][c][h1+h1p][w1] + \
                             w1lambda*x[b][c][h1+h1p][w1+w1p])
         return y
+    def py_bilinear_resize_backward(x, incoming_grads, mode='size'):
+        data1 = np.zeros_like(x)
+        data2 = incoming_grads
+        batchsize = data1.shape[0]
+        channels = data1.shape[1]
+        height1 = data1.shape[2]
+        width1 = data1.shape[3]
+        height2 = data2.shape[2]
+        width2 = data2.shape[3]
+        rheight = float(height1 - 1) / (height2 - 1) if (height2 > 1) else 0
+        rwidth = float(width1 - 1) / (width2 - 1) if (width2 > 1) else 0
+        # special case: just copy
+        if height1 == height2 and width1 == width2:
+            data1 += data2
+            return [data1]
+        for h2 in range(0, height2):
+            for w2 in range(0, width2):
+                h1r = rheight * h2
+                h1 = int(h1r)
+                h1p = 1 if (h1 < height1 - 1) else 0
+                h1lambda = h1r - h1
+                h0lambda = 1 - h1lambda
+                #
+                w1r = rwidth * w2
+                w1 = int(w1r)
+                w1p = 1 if (w1 < width1 - 1) else 0
+                w1lambda = w1r - w1
+                w0lambda = 1 - w1lambda
+                #
+                for n in range(0, batchsize):
+                    for c in range(0, channels):
+                        d2val = data2[n][c][h2][w2]
+                        data1[n][c][h1][w1] += h0lambda * w0lambda * d2val
+                        data1[n][c][h1][w1 + w1p] += h0lambda * w1lambda * 
d2val
+                        data1[n][c][h1 + h1p][w1] += h1lambda * w0lambda * 
d2val
+                        data1[n][c][h1 + h1p][w1 + w1p] += h1lambda * w1lambda 
* d2val
+        if mode == 'like':
+            return data1, np.zeros_like(incoming_grads)
+        return [data1]
     def check_bilinear_resize_op(shape, height, width):
         x = mx.nd.random.uniform(shape=shape)
         y = mx.nd.contrib.BilinearResize2D(x, height=height, width=width)
@@ -7173,12 +7212,89 @@ def test_bilinear_resize_op():
         y_scale = height / shape[-2]
         y = mx.nd.contrib.BilinearResize2D(x, scale_height=y_scale, 
scale_width=x_scale)
         assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), 
height, width))
+    def check_bilinear_resize_modes_op(shape, scale_height=None, 
scale_width=None, shape_1=None, mode=None):
+        x = mx.nd.random.uniform(shape=shape)
+        original_h = shape[2]
+        original_w = shape[3]
+        if mode == 'odd_scale':
+            assert scale_height is not None and scale_width is not None
+            new_h = int(original_h * scale_height) if (original_h % 2) == 0 
else \
+                int((original_h - 1) * scale_height) + 1
+            new_w = int(original_w * scale_width) if (original_w % 2) == 0 \
+                else int((original_w - 1) * scale_width) + 1
+            y = mx.nd.contrib.BilinearResize2D(x, scale_height=scale_height,
+                                               scale_width=scale_width,
+                                               mode='odd_scale')
+        elif mode == 'to_even_down':
+            new_h = original_h if (original_h % 2) == 0 else original_h - 1
+            new_w = original_w if (original_w % 2) == 0 else original_w - 1
+            y = mx.nd.contrib.BilinearResize2D(x, mode='to_even_down')
+        elif mode == 'to_even_up':
+            new_h = original_h if (original_h % 2) == 0 else original_h + 1
+            new_w = original_w if (original_w % 2) == 0 else original_w + 1
+            y = mx.nd.contrib.BilinearResize2D(x, mode='to_even_up')
+        elif mode == 'to_odd_down':
+            new_h = original_h if (original_h % 2) == 1 else original_h - 1
+            new_w = original_w if (original_w % 2) == 1 else original_w - 1
+            y = mx.nd.contrib.BilinearResize2D(x, mode='to_odd_down')
+        elif mode == 'to_odd_up':
+            new_h = original_h if (original_h % 2) == 1 else original_h + 1
+            new_w = original_w if (original_w % 2) == 1 else original_w + 1
+            y = mx.nd.contrib.BilinearResize2D(x, mode='to_odd_up')
+        elif mode == 'like':
+            x_1 = mx.nd.random.uniform(shape=shape_1)
+            new_h = x_1.shape[2]
+            new_w = x_1.shape[3]
+            y = mx.nd.contrib.BilinearResize2D(x, x_1, mode='like')
+        new_shape_desired = np.array([shape[0], shape[1], new_h, new_w], 
dtype='int')
+        new_shape_got = np.array(y.shape, dtype='int')
+        data_sym = mx.sym.var('data')
+        data_np = x.asnumpy()
+        expected = py_bilinear_resize(data_np, new_h, new_w)
+        out_grads = np.ones([shape[0], shape[1], new_h, new_w])
+        expected_backward = py_bilinear_resize_backward(data_np, out_grads, 
mode)
+        assert_array_equal(new_shape_desired, new_shape_got, "Desired and got 
shapes are not equal. {} vs {}".format(
+            str(new_shape_desired.tolist()), str(new_shape_got.tolist())))
+        assert_almost_equal(y.asnumpy(), expected, 1e-3, 0)
+        if mode != 'like':
+            resize_sym = mx.sym.contrib.BilinearResize2D(data_sym, None, 
scale_height=scale_height, scale_width=scale_width, mode=mode)
+            check_symbolic_forward(resize_sym, [data_np], [expected], 
rtol=1e-3)
+            check_symbolic_backward(resize_sym, [data_np], [out_grads], 
expected_backward, rtol=1e-3)
+            check_numeric_gradient(resize_sym, [data_np])
+        else:
+            data_sym_like = mx.sym.var('data_like')
+            resize_sym = mx.sym.contrib.BilinearResize2D(data_sym, 
data_sym_like, mode=mode)
+            date_np_like = x_1.asnumpy()
+            check_symbolic_forward(resize_sym, [data_np, date_np_like], 
[expected], rtol=1e-3)
+            check_symbolic_backward(resize_sym, [data_np, date_np_like], 
[out_grads], expected_backward, rtol=1e-3)
+            check_numeric_gradient(resize_sym, [data_np, date_np_like])
+
     shape = (2, 2, 10, 10)
     check_bilinear_resize_op(shape, 5, 5)
     check_bilinear_resize_op(shape, 10, 10)
     check_bilinear_resize_op(shape, 15, 15)
     check_bilinear_resize_op(shape, 3, 7)
     check_bilinear_resize_op(shape, 13, 17)
+    shape = (2, 2, 20, 20)
+    check_bilinear_resize_modes_op(shape, scale_height=0.5, scale_width=0.5, 
mode='odd_scale')
+    check_bilinear_resize_modes_op(shape, scale_height=5, scale_width=10, 
mode='odd_scale')
+    check_bilinear_resize_modes_op(shape, scale_height=0.1, scale_width=0.2, 
mode='odd_scale')
+    check_bilinear_resize_modes_op(shape, mode='to_even_down')
+    check_bilinear_resize_modes_op(shape, mode='to_even_up')
+    check_bilinear_resize_modes_op(shape, mode='to_odd_down')
+    check_bilinear_resize_modes_op(shape, mode='to_odd_up')
+    shape = (2, 2, 21, 21)
+    check_bilinear_resize_modes_op(shape, scale_height=0.5, scale_width=0.5, 
mode='odd_scale')
+    check_bilinear_resize_modes_op(shape, scale_height=5, scale_width=10, 
mode='odd_scale')
+    check_bilinear_resize_modes_op(shape, scale_height=0.1, scale_width=0.2, 
mode='odd_scale')
+    check_bilinear_resize_modes_op(shape, mode='to_even_down')
+    check_bilinear_resize_modes_op(shape, mode='to_even_up')
+    check_bilinear_resize_modes_op(shape, mode='to_odd_down')
+    check_bilinear_resize_modes_op(shape, mode='to_odd_up')
+    shape_0 = (2, 2, 21, 21)
+    shape_1 = (2, 2, 10, 10)
+    check_bilinear_resize_modes_op(shape_0, shape_1=shape_1, mode='like')
+    check_bilinear_resize_modes_op(shape_1, shape_1=shape_0, mode='like')
 
 def test_multi_proposal_op():
     # paramters

Reply via email to