This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cee48d8  fix consistency of cpu/gpu in stn (#7374)
cee48d8 is described below

commit cee48d8e5f9afa0c341eb6d53157853285b6e6dc
Author: Wei Wu <tornadom...@users.noreply.github.com>
AuthorDate: Sat Aug 12 08:08:36 2017 +0800

    fix consistency of cpu/gpu in stn (#7374)
    
    * fix consistency of cpu/gpu in stn
    
    * add consistent test of stn
    
    add consistent test of stn
    
    * add consistent test of stn
---
 src/operator/spatial_transformer.cc   | 72 ++++++++++++++++++++++------------
 src/operator/spatial_transformer.cu   | 74 +++++++++++++++++++++++------------
 tests/python/gpu/test_operator_gpu.py | 14 ++++++-
 3 files changed, 111 insertions(+), 49 deletions(-)

diff --git a/src/operator/spatial_transformer.cc 
b/src/operator/spatial_transformer.cc
index 0d8ee29..51b0ebf 100644
--- a/src/operator/spatial_transformer.cc
+++ b/src/operator/spatial_transformer.cc
@@ -27,6 +27,10 @@
 
 namespace mshadow {
 template<typename DType>
+bool between(DType value, int lowerBound, int upperBound) {
+  return (value >= lowerBound && value <= upperBound);
+}
+template<typename DType>
 inline void BilinearSamplingForward(const Tensor<cpu, 4, DType> &output,
                                     const Tensor<cpu, 4, DType> &input,
                                     const Tensor<cpu, 3, DType> grid_src) {
@@ -43,19 +47,28 @@ inline void BilinearSamplingForward(const Tensor<cpu, 4, 
DType> &output,
           index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
           DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 
2;
           DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
-          index_t top_left_y = std::min(i_h, std::max(0, 
static_cast<int>(floor(y_real))));
-          index_t top_left_x = std::min(i_w, std::max(0, 
static_cast<int>(floor(x_real))));
+          int top_left_y = static_cast<int>(floor(y_real));
+          int top_left_x = static_cast<int>(floor(x_real));
           DType top_left_y_w = 1.0 - (y_real - top_left_y);
           DType top_left_x_w = 1.0 - (x_real - top_left_x);
-          index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + 
top_left_y * i_w + top_left_x;
-          DType top_left_v = *(data + data_index);
-          DType top_right_v = *(data + data_index + 1);
-          DType bottom_left_v = *(data + data_index + i_w);
-          DType bottom_right_v = *(data + data_index + i_w + 1);
+          int data_index = n * i_c * i_h * i_w + c * i_h * i_w +
+                           top_left_y * i_w + top_left_x;
+          DType top_left_v = 0;
+          DType top_right_v = 0;
+          DType bottom_left_v = 0;
+          DType bottom_right_v = 0;
+          if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1))
+            top_left_v = *(data + data_index);
+          if (between(top_left_x + 1, 0, i_w-1) && between(top_left_y, 0, 
i_h-1))
+            top_right_v = *(data + data_index + 1);
+          if (between(top_left_x, 0, i_w-1) && between(top_left_y + 1, 0, 
i_h-1))
+            bottom_left_v = *(data + data_index + i_w);
+          if (between(top_left_x+1, 0, i_w-1) && between(top_left_y + 1, 0, 
i_h-1))
+            bottom_right_v = *(data + data_index + i_w + 1);
           *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w +
-                             top_right_v * top_left_y_w * (1.0 - top_left_x_w) 
+
-                             bottom_left_v * (1.0 - top_left_y_w) * 
top_left_x_w +
-                             bottom_right_v * (1.0 - top_left_y_w) * (1.0 - 
top_left_x_w);
+                              top_right_v * top_left_y_w * (1.0 - 
top_left_x_w) +
+                              bottom_left_v * (1.0 - top_left_y_w) * 
top_left_x_w +
+                              bottom_right_v * (1.0 - top_left_y_w) * (1.0 - 
top_left_x_w);
         }
       }
     }
@@ -82,8 +95,8 @@ inline void BilinearSamplingBackward(const Tensor<cpu, 4, 
DType> &input_grad,
           index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
           DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h 
- 1) / 2;
           DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
-          index_t top_left_y = std::min(i_h, std::max(0, 
static_cast<int>(floor(y_real))));
-          index_t top_left_x = std::min(i_w, std::max(0, 
static_cast<int>(floor(x_real))));
+          index_t top_left_y = static_cast<int>(floor(y_real));
+          index_t top_left_x = static_cast<int>(floor(x_real));
           DType top_left_y_w = 1.0 - (y_real - top_left_y);
           DType top_left_x_w = 1.0 - (x_real - top_left_x);
           for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
@@ -91,18 +104,29 @@ inline void BilinearSamplingBackward(const Tensor<cpu, 4, 
DType> &input_grad,
             index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + 
top_left_y * i_w
                                  + top_left_x;
             // calc 4 vertex value in input data
-            DType top_left_v = *(data + data_index);
-            DType top_right_v = *(data + data_index + 1);
-            DType bottom_left_v = *(data + data_index + i_w);
-            DType bottom_right_v = *(data + data_index + i_w + 1);
-            // calc input grad
-            *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * 
top_left_x_w;
-            *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
-                                           * (1.0 - top_left_x_w);
-            *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - 
top_left_y_w)
-                                            * top_left_x_w;
-            *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - 
top_left_y_w)
-                                                * (1.0 - top_left_x_w);
+            DType top_left_v = 0;
+            DType top_right_v = 0;
+            DType bottom_left_v = 0;
+            DType bottom_right_v = 0;
+            if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, 
i_h-1)) {
+              *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * 
top_left_x_w;
+              top_left_v = *(data + data_index);
+            }
+            if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, 
i_h-1)) {
+              *(g_input + data_index + 1) += *(grad + grad_index) * 
top_left_y_w
+                                             * (1.0 - top_left_x_w);
+              top_right_v = *(data + data_index + 1);
+            }
+            if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, 
i_h-1)) {
+              *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - 
top_left_y_w)
+                                              * top_left_x_w;
+              bottom_left_v = *(data + data_index + i_w);
+            }
+            if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, 
i_h-1)) {
+              *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 
- top_left_y_w)
+                                                  * (1.0 - top_left_x_w);
+              bottom_right_v = *(data + data_index + i_w + 1);
+            }
             // calc weight grad of top_left_w, then multiple -1 is the grad of 
grid_src
             top_left_y_gw -= *(grad + grad_index) * (top_right_v - 
bottom_right_v +
                              (top_left_v - top_right_v - bottom_left_v + 
bottom_right_v)
diff --git a/src/operator/spatial_transformer.cu 
b/src/operator/spatial_transformer.cu
index b3d635c..d5e4480 100644
--- a/src/operator/spatial_transformer.cu
+++ b/src/operator/spatial_transformer.cu
@@ -31,6 +31,10 @@
 
 namespace mshadow {
 template<typename DType>
+__device__ bool between(DType value, int lowerBound, int upperBound) {
+  return (value >= lowerBound && value <= upperBound);
+}
+template<typename DType>
 __global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h,
                                               const int i_w, const DType* data,
                                               const DType* grid, const int o_n,
@@ -48,19 +52,27 @@ __global__ void BilinearSamplingForwardKernel(const int 
i_c, const int i_h,
     index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
     DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
     DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
-    index_t top_left_y = min(i_h, max(0, static_cast<int>(floor(y_real))));
-    index_t top_left_x = min(i_w, max(0, static_cast<int>(floor(x_real))));
+    int top_left_y = static_cast<int>(floor(y_real));
+    int top_left_x = static_cast<int>(floor(x_real));
     DType top_left_y_w = 1.0 - (y_real - top_left_y);
     DType top_left_x_w = 1.0 - (x_real - top_left_x);
-    index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * 
i_w + top_left_x;
-    DType top_left_v = *(data + data_index);
-    DType top_right_v = *(data + data_index + 1);
-    DType bottom_left_v = *(data + data_index + i_w);
-    DType bottom_right_v = *(data + data_index + i_w + 1);
+    int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + 
top_left_x;
+    DType top_left_v = 0;
+    DType top_right_v = 0;
+    DType bottom_left_v = 0;
+    DType bottom_right_v = 0;
+    if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1))
+      top_left_v = *(data + data_index);
+    if (between(top_left_x + 1, 0, i_w-1) && between(top_left_y, 0, i_h-1))
+      top_right_v = *(data + data_index + 1);
+    if (between(top_left_x, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
+      bottom_left_v = *(data + data_index + i_w);
+    if (between(top_left_x+1, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
+      bottom_right_v = *(data + data_index + i_w + 1);
     *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w +
-                       top_right_v * top_left_y_w * (1.0 - top_left_x_w) +
-                       bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w +
-                       bottom_right_v * (1.0 - top_left_y_w) * (1.0 - 
top_left_x_w);
+                        top_right_v * top_left_y_w * (1.0 - top_left_x_w) +
+                        bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w +
+                        bottom_right_v * (1.0 - top_left_y_w) * (1.0 - 
top_left_x_w);
     }
 }
 
@@ -83,29 +95,43 @@ __global__ void BilinearSamplingBackwardKernel(const int 
i_c, const int i_h,
     index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
     DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) 
/ 2;
     DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
-    index_t top_left_y = min(i_h, max(0, static_cast<int>(floor(y_real))));
-    index_t top_left_x = min(i_w, max(0, static_cast<int>(floor(x_real))));
+    int top_left_y = static_cast<int>(floor(y_real));
+    int top_left_x = static_cast<int>(floor(x_real));
     DType top_left_y_w = 1.0 - (y_real - top_left_y);
     DType top_left_x_w = 1.0 - (x_real - top_left_x);
     for (index_t c = 0; c < o_c; ++c) {
       index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
       index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * 
i_w + top_left_x;
       // calc 4 vertex value in input data
-      DType top_left_v = *(data + data_index);
-      DType top_right_v = *(data + data_index + 1);
-      DType bottom_left_v = *(data + data_index + i_w);
-      DType bottom_right_v = *(data + data_index + i_w + 1);
+      DType top_left_v = 0;
+      DType top_right_v = 0;
+      DType bottom_left_v = 0;
+      DType bottom_right_v = 0;
       // calc input grad
-      *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * 
top_left_x_w;
-      *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w * 
(1.0 - top_left_x_w);
-      *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - 
top_left_y_w) * top_left_x_w;
-      *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - 
top_left_y_w) *
-                                          (1.0 - top_left_x_w);
+      if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
+        *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * 
top_left_x_w;
+        top_left_v = *(data + data_index);
+      }
+      if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
+        *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w * 
(1.0 - top_left_x_w);
+        top_right_v = *(data + data_index + 1);
+      }
+      if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
+        *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - 
top_left_y_w) * top_left_x_w;
+        bottom_left_v = *(data + data_index + i_w);
+      }
+      if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
+        *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - 
top_left_y_w) *
+                                            (1.0 - top_left_x_w);
+        bottom_right_v = *(data + data_index + i_w + 1);
+      }
       // calc weight grad of top_left_w, then multiple -1 is the grad of 
grid_src
       top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
-                       (top_left_v - top_right_v - bottom_left_v + 
bottom_right_v) * top_left_x_w);
-      top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v 
+ (top_left_v -
-                       top_right_v - bottom_left_v + bottom_right_v) * 
top_left_y_w);
+                       (top_left_v - top_right_v - bottom_left_v + 
bottom_right_v)
+                       * top_left_x_w);
+      top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
+                       (top_left_v - top_right_v - bottom_left_v + 
bottom_right_v)
+                       * top_left_y_w);
     }
     // calc grid_src grad
     *(grid_src + grid_src_index + o_h * o_w) = top_left_y_gw * (i_h - 1) / 2;
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index c80b9e3..cd8e85a 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -623,7 +623,6 @@ def test_bilinear_sampler_with_type():
     check_consistency(sym, ctx_list)
     check_consistency(sym, ctx_list, grad_req="add")
 
-
 def test_grid_generator_with_type():
     data = mx.sym.Variable('data')
     sym = mx.sym.GridGenerator(data=data, transform_type='affine', 
target_shape=(20, 20))
@@ -637,6 +636,19 @@ def test_grid_generator_with_type():
     check_consistency(sym, ctx_list)
     check_consistency(sym, ctx_list, grad_req="add")
 
+def test_spatial_transformer_with_type():
+    np.random.seed(1234)
+    data = mx.sym.Variable('data')
+    loc = mx.sym.Flatten(data)
+    loc = mx.sym.FullyConnected(data=loc, num_hidden=10)
+    loc = mx.sym.Activation(data=loc, act_type='relu')
+    loc = mx.sym.FullyConnected(data=loc, num_hidden=6)
+    sym = mx.sym.SpatialTransformer(data=data, loc=loc, target_shape=(10, 10),
+                                    transform_type="affine", 
sampler_type="bilinear")
+    ctx_list = [{'ctx': mx.gpu(0), 'data': (1, 5, 10, 10), 'type_dict': 
{'data': np.float32}},
+                {'ctx': mx.cpu(0), 'data': (1, 5, 10, 10), 'type_dict': 
{'data': np.float32}}]
+    check_consistency(sym, ctx_list)
+    check_consistency(sym, ctx_list, grad_req="add")
 
 # Checking max pooling consistency over the data sets of different float types 
is problematic
 # as one max value in a float32 data set may not be the max value in a float16 
data set.

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to