Repository: incubator-systemml
Updated Branches:
  refs/heads/master 16e990928 -> 169a2da5f


[SYSTEMML-1408] Add padding parameters to max-pooling layers

This adds padding parameters to the max-pooling layers, along with the
associated tests.  Also, there are some general code formatting updates.

Closes #434.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/15ccb7c0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/15ccb7c0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/15ccb7c0

Branch: refs/heads/master
Commit: 15ccb7c03016b2c4eafd2d060852a265f28a070a
Parents: 16e9909
Author: Mike Dusenberry <[email protected]>
Authored: Wed Mar 22 16:44:52 2017 -0700
Committer: Mike Dusenberry <[email protected]>
Committed: Wed Mar 22 16:44:52 2017 -0700

----------------------------------------------------------------------
 .../SystemML-NN/examples/get_mnist_data.sh      |   8 +-
 .../examples/mnist_lenet-predict.dml            |   4 +-
 .../SystemML-NN/examples/mnist_lenet-train.dml  |   2 +-
 .../SystemML-NN/examples/mnist_lenet.dml        |  41 ++--
 .../examples/mnist_softmax-predict.dml          |   4 +-
 .../examples/mnist_softmax-train.dml            |   2 +-
 .../staging/SystemML-NN/nn/layers/affine.dml    |   4 +-
 scripts/staging/SystemML-NN/nn/layers/conv.dml  |  12 +-
 .../SystemML-NN/nn/layers/conv_builtin.dml      |   9 +-
 .../nn/layers/cross_entropy_loss.dml            |   4 +-
 .../staging/SystemML-NN/nn/layers/dropout.dml   |   4 +-
 .../staging/SystemML-NN/nn/layers/l1_loss.dml   |   4 +-
 .../staging/SystemML-NN/nn/layers/l1_reg.dml    |   2 +-
 .../staging/SystemML-NN/nn/layers/l2_loss.dml   |   4 +-
 .../staging/SystemML-NN/nn/layers/l2_reg.dml    |   2 +-
 .../staging/SystemML-NN/nn/layers/log_loss.dml  |   6 +-
 scripts/staging/SystemML-NN/nn/layers/lstm.dml  |  10 +-
 .../staging/SystemML-NN/nn/layers/max_pool.dml  |  54 +++--
 .../SystemML-NN/nn/layers/max_pool_builtin.dml  |  23 +-
 scripts/staging/SystemML-NN/nn/layers/rnn.dml   |   6 +-
 .../staging/SystemML-NN/nn/layers/softmax.dml   |   2 +-
 .../staging/SystemML-NN/nn/test/conv_simple.dml |  35 +--
 .../staging/SystemML-NN/nn/test/grad_check.dml  | 218 +++++++++--------
 .../SystemML-NN/nn/test/max_pool_simple.dml     |  83 +++++--
 scripts/staging/SystemML-NN/nn/test/test.dml    | 231 ++++++++++++++++---
 scripts/staging/SystemML-NN/nn/util.dml         |  10 +-
 26 files changed, 537 insertions(+), 247 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/examples/get_mnist_data.sh
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/examples/get_mnist_data.sh 
b/scripts/staging/SystemML-NN/examples/get_mnist_data.sh
index 6fed70b..deb0c40 100755
--- a/scripts/staging/SystemML-NN/examples/get_mnist_data.sh
+++ b/scripts/staging/SystemML-NN/examples/get_mnist_data.sh
@@ -8,9 +8,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -23,6 +23,6 @@
 DIR="$(cd "$(dirname "$0")" && pwd)"
 mkdir -p $DIR/data/mnist/
 cd $DIR/data/mnist/
-curl -O http://pjreddie.com/media/files/mnist_train.csv
-curl -O http://pjreddie.com/media/files/mnist_test.csv
+curl -O https://pjreddie.com/media/files/mnist_train.csv
+curl -O https://pjreddie.com/media/files/mnist_test.csv
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/examples/mnist_lenet-predict.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/examples/mnist_lenet-predict.dml 
b/scripts/staging/SystemML-NN/examples/mnist_lenet-predict.dml
index fc8e904..775926c 100644
--- a/scripts/staging/SystemML-NN/examples/mnist_lenet-predict.dml
+++ b/scripts/staging/SystemML-NN/examples/mnist_lenet-predict.dml
@@ -41,7 +41,7 @@
 # Outputs:
 #  - probs: File containing class probability predictions for each
 #     image.
-# 
+#
 # Data:
 # The X file should contain images of handwritten digits,
 # where each example is a 28x28 pixel image of grayscale values in
@@ -79,7 +79,7 @@ b3 = read($model_dir+"/b3")
 W4 = read($model_dir+"/W4")
 b4 = read($model_dir+"/b4")
 
-# Predict classes 
+# Predict classes
 probs = mnist_lenet::predict(X, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4)
 
 # Output results

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/examples/mnist_lenet-train.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/examples/mnist_lenet-train.dml 
b/scripts/staging/SystemML-NN/examples/mnist_lenet-train.dml
index d555a41..c23029f 100644
--- a/scripts/staging/SystemML-NN/examples/mnist_lenet-train.dml
+++ b/scripts/staging/SystemML-NN/examples/mnist_lenet-train.dml
@@ -41,7 +41,7 @@
 #  - W1, W2, W3, W4: Files containing the trained weights of the model.
 #  - b1, b2, b3, b4: Files containing the trained biases of the model.
 #  - accuracy: File containing the final accuracy on the test data.
-# 
+#
 # Data:
 # The MNIST dataset contains labeled images of handwritten digits,
 # where each example is a 28x28 pixel image of grayscale values in

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/examples/mnist_lenet.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/examples/mnist_lenet.dml 
b/scripts/staging/SystemML-NN/examples/mnist_lenet.dml
index 22a793c..f991487 100644
--- a/scripts/staging/SystemML-NN/examples/mnist_lenet.dml
+++ b/scripts/staging/SystemML-NN/examples/mnist_lenet.dml
@@ -114,13 +114,17 @@ train = function(matrix[double] X, matrix[double] y,
 
       # Compute forward pass
       ## layer 1: conv1 -> relu1 -> pool1
-      [outc1, Houtc1, Woutc1] = conv::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, stride, stride, pad, pad)
+      [outc1, Houtc1, Woutc1] = conv::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, stride, stride,
+                                              pad, pad)
       outr1 = relu::forward(outc1)
-      [outp1, Houtp1, Woutp1] = max_pool::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2, strideh=2, stridew=2) 
+      [outp1, Houtp1, Woutp1] = max_pool::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, 
pad=0)
       ## layer 2: conv2 -> relu2 -> pool2
-      [outc2, Houtc2, Woutc2] = conv::forward(outp1, W2, b2, F1, Houtp1, 
Woutp1, Hf, Wf, stride, stride, pad, pad)
+      [outc2, Houtc2, Woutc2] = conv::forward(outp1, W2, b2, F1, Houtp1, 
Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
       outr2 = relu::forward(outc2)
-      [outp2, Houtp2, Woutp2] = max_pool::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2, strideh=2, stridew=2) 
+      [outp2, Houtp2, Woutp2] = max_pool::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, 
pad=0)
       ## layer 3:  affine3 -> relu3 -> dropout
       outa3 = affine::forward(outp2, W3, b3)
       outr3 = relu::forward(outa3)
@@ -146,7 +150,8 @@ train = function(matrix[double] X, matrix[double] y,
         accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(y_val))
 
         # Output results
-        print("Epoch: " + e + ", Iter: " + i + ", Train Loss: " + loss + ", 
Train Accuracy: " + accuracy + ", Val Loss: " + loss_val + ", Val Accuracy: " + 
accuracy_val)
+        print("Epoch: " + e + ", Iter: " + i + ", Train Loss: " + loss + ", 
Train Accuracy: "
+              + accuracy + ", Val Loss: " + loss_val + ", Val Accuracy: " + 
accuracy_val)
       }
 
       # Compute data backward pass
@@ -160,13 +165,17 @@ train = function(matrix[double] X, matrix[double] y,
       douta3 = relu::backward(doutr3, outa3)
       [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
       ## layer 2: conv2 -> relu2 -> pool2
-      doutr2 = max_pool::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, 
Woutc2, Hf=2, Wf=2, strideh=2, stridew=2)
+      doutr2 = max_pool::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, 
Woutc2, Hf=2, Wf=2,
+                                  strideh=2, stridew=2, pad=0, pad=0)
       doutc2 = relu::backward(doutr2, outc2)
-      [doutp1, dW2, db2] = conv::backward(doutc2, Houtc2, Woutc2, outp1, W2, 
b2, F1, Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad)
+      [doutp1, dW2, db2] = conv::backward(doutc2, Houtc2, Woutc2, outp1, W2, 
b2, F1,
+                                          Houtp1, Woutp1, Hf, Wf, stride, 
stride, pad, pad)
       ## layer 1: conv1 -> relu1 -> pool1
-      doutr1 = max_pool::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, 
Woutc1, Hf=2, Wf=2, strideh=2, stridew=2)
+      doutr1 = max_pool::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, 
Woutc1, Hf=2, Wf=2,
+                                  strideh=2, stridew=2, pad=0, pad=0)
       doutc1 = relu::backward(doutr1, outc1)
-      [dX_batch, dW1, db1] = conv::backward(doutc1, Houtc1, Woutc1, X_batch, 
W1, b1, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      [dX_batch, dW1, db1] = conv::backward(doutc1, Houtc1, Woutc1, X_batch, 
W1, b1, C, Hin, Win,
+                                            Hf, Wf, stride, stride, pad, pad)
 
       # Compute regularization backward pass
       dW1_reg = l2_reg::backward(W1, lambda)
@@ -251,13 +260,17 @@ predict = function(matrix[double] X, int C, int Hin, int 
Win,
 
     # Compute forward pass
     ## layer 1: conv1 -> relu1 -> pool1
-    [outc1, Houtc1, Woutc1] = conv::forward(X_batch, W1, b1, C, Hin, Win, Hf, 
Wf, stride, stride, pad, pad)
+    [outc1, Houtc1, Woutc1] = conv::forward(X_batch, W1, b1, C, Hin, Win, Hf, 
Wf, stride, stride,
+                                            pad, pad)
     outr1 = relu::forward(outc1)
-    [outp1, Houtp1, Woutp1] = max_pool::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2, strideh=2, stridew=2) 
+    [outp1, Houtp1, Woutp1] = max_pool::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, 
pad=0)
     ## layer 2: conv2 -> relu2 -> pool2
-    [outc2, Houtc2, Woutc2] = conv::forward(outp1, W2, b2, F1, Houtp1, Woutp1, 
Hf, Wf, stride, stride, pad, pad)
+    [outc2, Houtc2, Woutc2] = conv::forward(outp1, W2, b2, F1, Houtp1, Woutp1, 
Hf, Wf,
+                                            stride, stride, pad, pad)
     outr2 = relu::forward(outc2)
-    [outp2, Houtp2, Woutp2] = max_pool::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2, strideh=2, stridew=2) 
+    [outp2, Houtp2, Woutp2] = max_pool::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, 
pad=0)
     ## layer 3:  affine3 -> relu3
     outa3 = affine::forward(outp2, W3, b3)
     outr3 = relu::forward(outa3)
@@ -281,7 +294,7 @@ eval = function(matrix[double] probs, matrix[double] y)
    *
    * Inputs:
    *  - probs: Class probabilities, of shape (N, K).
-   *  - y: Target matrix, of shape (N, 
+   *  - y: Target matrix, of shape (N, K).
    *
    * Outputs:
    *  - loss: Scalar loss, of shape (1).

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/examples/mnist_softmax-predict.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/examples/mnist_softmax-predict.dml 
b/scripts/staging/SystemML-NN/examples/mnist_softmax-predict.dml
index bc7d158..52f31fd 100644
--- a/scripts/staging/SystemML-NN/examples/mnist_softmax-predict.dml
+++ b/scripts/staging/SystemML-NN/examples/mnist_softmax-predict.dml
@@ -37,7 +37,7 @@
 # Outputs:
 #  - probs: File containing class probability predictions for each
 #     image.
-# 
+#
 # Data:
 # The X file should contain images of handwritten digits,
 # where each example is a 28x28 pixel image of grayscale values in
@@ -66,7 +66,7 @@ X = X / 255.0
 W = read($model_dir+"/W")
 b = read($model_dir+"/b")
 
-# Predict classes 
+# Predict classes
 probs = mnist_softmax::predict(X, W, b)
 
 # Output results

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/examples/mnist_softmax-train.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/examples/mnist_softmax-train.dml 
b/scripts/staging/SystemML-NN/examples/mnist_softmax-train.dml
index 39bb9d8..dff192e 100644
--- a/scripts/staging/SystemML-NN/examples/mnist_softmax-train.dml
+++ b/scripts/staging/SystemML-NN/examples/mnist_softmax-train.dml
@@ -38,7 +38,7 @@
 #  - W: File containing the trained weights of the model.
 #  - b: File containing the trained biases of the model.
 #  - accuracy: File containing the final accuracy on the test data.
-# 
+#
 # Data:
 # The MNIST dataset contains labeled images of handwritten digits,
 # where each example is a 28x28 pixel image of grayscale values in

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/affine.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/affine.dml 
b/scripts/staging/SystemML-NN/nn/layers/affine.dml
index e7e4fd8..6a4c210 100644
--- a/scripts/staging/SystemML-NN/nn/layers/affine.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/affine.dml
@@ -69,7 +69,7 @@ init = function(int D, int M)
    *
    * Note: This is just a convenience function, and parameters
    * may be initialized manually if needed.
-   * 
+   *
    * We use the heuristic by He et al. [http://arxiv.org/abs/1502.01852],
    * which limits the magnification of inputs/gradients during
    * forward/backward passes by scaling unit-Gaussian weights by a
@@ -84,6 +84,6 @@ init = function(int D, int M)
    *  - b: Biases vector, of shape (1, M).
    */
   W = rand(rows=D, cols=M, pdf="normal") * sqrt(2.0/D)
-  b = matrix(0, rows=1, cols=M) 
+  b = matrix(0, rows=1, cols=M)
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/conv.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/conv.dml 
b/scripts/staging/SystemML-NN/nn/layers/conv.dml
index 100bc12..4036bbc 100644
--- a/scripts/staging/SystemML-NN/nn/layers/conv.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/conv.dml
@@ -69,7 +69,7 @@ forward = function(matrix[double] X, matrix[double] W, 
matrix[double] b,
   F = nrow(W)
   Hout = as.integer((Hin + 2 * padh - Hf) / strideh + 1)
   Wout = as.integer((Win + 2 * padw - Wf) / stridew + 1)
-  
+
   # Create output volume
   out = matrix(0, rows=N, cols=F*Hout*Wout)
 
@@ -124,7 +124,7 @@ backward = function(matrix[double] dout, int Hout, int Wout,
    */
   N = nrow(X)
   F = nrow(W)
-  
+
   # Create gradient volumes
   # Note: Create convenience gradient volumes for dW and db that will
   # allow for one gradient to be stored per example, allowing for
@@ -151,8 +151,8 @@ backward = function(matrix[double] dout, int Hout, int Wout,
 
     # Compute dX
     dXn_padded_cols = t(W) %*% doutn  # shape (C*Hf*Wf, Hout*Wout)
-    dXn_padded =
-      util::col2im(dXn_padded_cols, C, Hin+2*padh, Win+2*padw, Hf, Wf, 
strideh, stridew, "add")
+    dXn_padded = util::col2im(dXn_padded_cols, C, Hin+2*padh, Win+2*padw, Hf, 
Wf,
+                              strideh, stridew, "add")
     dXn = util::unpad_image(dXn_padded, Hin, Win, padh, padw)
     dX[n,] = matrix(dXn, rows=1, cols=C*Hin*Win)  # reshape
   }
@@ -170,7 +170,7 @@ init = function(int F, int C, int Hf, int Wf)
    *
    * Note: This is just a convenience function, and parameters
    * may be initialized manually if needed.
-   * 
+   *
    * We use the heuristic by He et al. [http://arxiv.org/abs/1502.01852],
    * which limits the magnification of inputs/gradients during
    * forward/backward passes by scaling unit-Gaussian weights by a
@@ -187,6 +187,6 @@ init = function(int F, int C, int Hf, int Wf)
    *  - b: Biases vector, of shape (F, 1).
    */
   W = rand(rows=F, cols=C*Hf*Wf, pdf="normal") * sqrt(2.0/(C*Hf*Wf))
-  b = matrix(0, rows=F, cols=1) 
+  b = matrix(0, rows=F, cols=1)
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml 
b/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml
index 3113ccf..44df74a 100644
--- a/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml
@@ -60,10 +60,9 @@ forward = function(matrix[double] X, matrix[double] W, 
matrix[double] b,
    */
   N = nrow(X)
   F = nrow(W)
-  # TODO: We should eliminate this in a seperate PR
   Hout = as.integer((Hin + 2 * padh - Hf) / strideh + 1)
   Wout = as.integer((Win + 2 * padw - Wf) / stridew + 1)
-  
+
   # Convolution - built-in implementation
   out = conv2d(X, W, input_shape=[N,C,Hin,Win], filter_shape=[F,C,Hf,Wf],
                stride=[strideh,stridew], padding=[padh,padw])
@@ -105,7 +104,7 @@ backward = function(matrix[double] dout, int Hout, int Wout,
    */
   N = nrow(X)
   F = nrow(W)
-  
+
   # Partial derivatives for convolution - built-in implementation
   dW = conv2d_backward_filter(X, dout, stride=[strideh,stridew], 
padding=[padh,padw],
                               input_shape=[N,C,Hin,Win], 
filter_shape=[F,C,Hf,Wf])
@@ -123,7 +122,7 @@ init = function(int F, int C, int Hf, int Wf)
    *
    * Note: This is just a convenience function, and parameters
    * may be initialized manually if needed.
-   * 
+   *
    * We use the heuristic by He et al. [http://arxiv.org/abs/1502.01852],
    * which limits the magnification of inputs/gradients during
    * forward/backward passes by scaling unit-Gaussian weights by a
@@ -140,6 +139,6 @@ init = function(int F, int C, int Hf, int Wf)
    *  - b: Biases vector, of shape (F, 1).
    */
   W = rand(rows=F, cols=C*Hf*Wf, pdf="normal") * sqrt(2.0/(C*Hf*Wf))
-  b = matrix(0, rows=F, cols=1) 
+  b = matrix(0, rows=F, cols=1)
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml 
b/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml
index 9e3e7cd..f9cd507 100644
--- a/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml
@@ -26,7 +26,7 @@
  *  vectors of class probs.
  * L = (1/N) sum(L_i) for i=1 to N, where N is the number of examples.
  */
-forward = function(matrix[double] pred, matrix[double] y) 
+forward = function(matrix[double] pred, matrix[double] y)
     return (double loss) {
   /*
    * Computes the forward pass for a cross-entropy loss function.  The
@@ -50,7 +50,7 @@ forward = function(matrix[double] pred, matrix[double] y)
   loss = sum(losses) / N
 }
 
-backward = function(matrix[double] pred, matrix[double] y) 
+backward = function(matrix[double] pred, matrix[double] y)
     return (matrix[double] dpred) {
   /*
    * Computes the backward pass of a cross-entropy loss function.  The

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/dropout.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/dropout.dml 
b/scripts/staging/SystemML-NN/nn/layers/dropout.dml
index 6b46305..2b1bd1d 100644
--- a/scripts/staging/SystemML-NN/nn/layers/dropout.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/dropout.dml
@@ -47,10 +47,10 @@ forward = function(matrix[double] X, double p, int seed)
   # to create a dropout mask.  Fortunately, SystemML has a `sparsity` 
parameter on
   # the `rand` function that allows use to create a mask directly.
   if (seed == -1) {
-       mask = rand(rows=nrow(X), cols=ncol(X), min=1, max=1, sparsity=p)
+    mask = rand(rows=nrow(X), cols=ncol(X), min=1, max=1, sparsity=p)
   }
   else {
-       mask = rand(rows=nrow(X), cols=ncol(X), min=1, max=1, sparsity=p, 
seed=seed)
+    mask = rand(rows=nrow(X), cols=ncol(X), min=1, max=1, sparsity=p, 
seed=seed)
   }
   out = X * mask / p
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml 
b/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml
index 6c625e8..7d6c821 100644
--- a/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml
@@ -25,7 +25,7 @@
  * L_i = sum_j(abs((pred_i)_j - (y_i)_j)) for all j.
  * L = (1/N) sum(L_i) for i=1 to N, where N is the number of examples.
  */
-forward = function(matrix[double] pred, matrix[double] y) 
+forward = function(matrix[double] pred, matrix[double] y)
     return (double loss) {
   /*
    * Computes the forward pass for an L1 loss function.  The inputs
@@ -46,7 +46,7 @@ forward = function(matrix[double] pred, matrix[double] y)
   loss = sum(losses) / N
 }
 
-backward = function(matrix[double] pred, matrix[double] y) 
+backward = function(matrix[double] pred, matrix[double] y)
     return (matrix[double] dpred) {
   /*
    * Computes the backward pass for an L1 loss function.  The inputs

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/l1_reg.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/l1_reg.dml 
b/scripts/staging/SystemML-NN/nn/layers/l1_reg.dml
index 28de74c..b2175ab 100644
--- a/scripts/staging/SystemML-NN/nn/layers/l1_reg.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/l1_reg.dml
@@ -46,7 +46,7 @@ backward = function(matrix[double] X, double lambda) return 
(matrix[double] dX)
    *  - lambda: Regularization strength.
    *
    * Outputs:
-   *  - dX: Gradient wrt X, of same shape as X. 
+   *  - dX: Gradient wrt X, of same shape as X.
    */
   dX = lambda * sign(X)
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml 
b/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml
index c4a8618..9f27cc2 100644
--- a/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml
@@ -25,7 +25,7 @@
  * L_i = (1/2) 2norm(pred_i - y_i)^2
  * L = (1/N) sum(L_i) for i=1 to N, where N is the number of examples.
  */
-forward = function(matrix[double] pred, matrix[double] y) 
+forward = function(matrix[double] pred, matrix[double] y)
     return (double loss) {
   /*
    * Computes the forward pass for an L2 loss function.  The inputs
@@ -46,7 +46,7 @@ forward = function(matrix[double] pred, matrix[double] y)
   loss = sum(losses) / N
 }
 
-backward = function(matrix[double] pred, matrix[double] y) 
+backward = function(matrix[double] pred, matrix[double] y)
     return (matrix[double] dpred) {
   /*
    * Computes the backward pass for an L2 loss function.  The inputs

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/l2_reg.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/l2_reg.dml 
b/scripts/staging/SystemML-NN/nn/layers/l2_reg.dml
index 22df974..44f2a54 100644
--- a/scripts/staging/SystemML-NN/nn/layers/l2_reg.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/l2_reg.dml
@@ -46,7 +46,7 @@ backward = function(matrix[double] X, double lambda) return 
(matrix[double] dX)
    *  - lambda: Regularization strength.
    *
    * Outputs:
-   *  - dX: Gradient wrt X, of same shape as X. 
+   *  - dX: Gradient wrt X, of same shape as X.
    */
   dX = lambda * X
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/log_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/log_loss.dml 
b/scripts/staging/SystemML-NN/nn/layers/log_loss.dml
index 0bcb02e..ad5e561 100644
--- a/scripts/staging/SystemML-NN/nn/layers/log_loss.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/log_loss.dml
@@ -23,10 +23,10 @@
  * Log loss function.
  *
  * L_i = -y_i*log(pred_i) - (1-y_i)*log(1-pred_i), where y_i is a
- *  binary target, and pred_i is a probability of y=1. 
+ *  binary target, and pred_i is a probability of y=1.
  * L = (1/N) sum(L_i) for i=1 to N, where N is the number of examples.
  */
-forward = function(matrix[double] pred, matrix[double] y) 
+forward = function(matrix[double] pred, matrix[double] y)
     return (double loss) {
   /*
    * Computes the forward pass for a log loss function.
@@ -48,7 +48,7 @@ forward = function(matrix[double] pred, matrix[double] y)
   loss = sum(losses) / N
 }
 
-backward = function(matrix[double] pred, matrix[double] y) 
+backward = function(matrix[double] pred, matrix[double] y)
     return (matrix[double] dpred) {
   /*
    * Computes the backward pass for a log loss function.

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/lstm.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/lstm.dml 
b/scripts/staging/SystemML-NN/nn/layers/lstm.dml
index b0fdd52..0dd9f4c 100644
--- a/scripts/staging/SystemML-NN/nn/layers/lstm.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/lstm.dml
@@ -81,7 +81,7 @@ forward = function(matrix[double] X, matrix[double] W, 
matrix[double] b, int T,
   }
   # caches to be used during the backward pass for performance
   cache_out = matrix(0, rows=T, cols=N*M)
-  cache_c = matrix(0, rows=T, cols=N*M) 
+  cache_c = matrix(0, rows=T, cols=N*M)
   cache_ifog = matrix(0, rows=T, cols=N*4*M)
 
   for (t in 1:T) {  # each timestep
@@ -113,7 +113,7 @@ forward = function(matrix[double] X, matrix[double] W, 
matrix[double] b, int T,
   }
 }
 
-backward = function(matrix[double] dout, matrix[double] dc, 
+backward = function(matrix[double] dout, matrix[double] dc,
                     matrix[double] X, matrix[double] W, matrix[double] b, int 
T, int D,
                     boolean given_sequences, matrix[double] out0, 
matrix[double] c0,
                     matrix[double] cache_out, matrix[double] cache_c, 
matrix[double] cache_ifog)
@@ -197,7 +197,7 @@ backward = function(matrix[double] dout, matrix[double] dc,
     dc_prev = f * dct  # shape (N, M)
     di = g * dct  # input gate, shape (N, M)
     dg = i * dct  # g gate, shape (N, M)
-    
+
     di_raw = i * (1-i) * di
     df_raw = f * (1-f) * df
     do_raw = o * (1-o) * do
@@ -228,7 +228,7 @@ init = function(int N, int D, int M)
    *
    * Note: This is just a convenience function, and parameters
    * may be initialized manually if needed.
-   * 
+   *
    * We use the Glorot uniform heuristic which limits the magnification
    * of inputs/gradients during forward/backward passes by scaling
    * uniform weights by a factor of sqrt(6/(fan_in + fan_out)).
@@ -248,7 +248,7 @@ init = function(int N, int D, int M)
   fan_out = 4*M
   scale = sqrt(6/(fan_in+fan_out))
   W = rand(rows=D+M, cols=4*M, min=-scale, max=scale, pdf="uniform")
-  b = matrix(0, rows=1, cols=4*M) 
+  b = matrix(0, rows=1, cols=4*M)
   out0 = matrix(0, rows=N, cols=M)
   c0 = matrix(0, rows=N, cols=M)
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/max_pool.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/max_pool.dml 
b/scripts/staging/SystemML-NN/nn/layers/max_pool.dml
index 94f93e2..ec7d431 100644
--- a/scripts/staging/SystemML-NN/nn/layers/max_pool.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/max_pool.dml
@@ -25,7 +25,7 @@
 source("nn/util.dml") as util
 
 forward = function(matrix[double] X, int C, int Hin, int Win, int Hf, int Wf,
-                   int strideh, int stridew)
+                   int strideh, int stridew, int padh, int padw)
     return (matrix[double] out, int Hout, int Wout) {
   /*
    * Computes the forward pass for a 2D spatial max pooling layer.
@@ -46,6 +46,10 @@ forward = function(matrix[double] X, int C, int Hin, int 
Win, int Hf, int Wf,
    *  - Wf: Filter width.
    *  - strideh: Stride over height.
    *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      A typical value is 0.
+   *  - padw: Padding for left and right sides.
+   *      A typical value is 0.
    *
    * Outputs:
    *  - out: Outputs, of shape (N, C*Hout*Wout).
@@ -53,8 +57,8 @@ forward = function(matrix[double] X, int C, int Hin, int Win, 
int Hf, int Wf,
    *  - Wout: Output width.
    */
   N = nrow(X)
-  Hout = as.integer((Hin - Hf) / strideh + 1)
-  Wout = as.integer((Win - Wf) / stridew + 1)
+  Hout = as.integer((Hin + 2 * padh - Hf) / strideh + 1)
+  Wout = as.integer((Win + 2 * padw - Wf) / stridew + 1)
 
   # Create output volume
   out = matrix(0, rows=N, cols=C*Hout*Wout)
@@ -62,11 +66,16 @@ forward = function(matrix[double] X, int C, int Hin, int 
Win, int Hf, int Wf,
   # Max pooling - im2col implementation
   parfor (n in 1:N) {  # all examples
     img = matrix(X[n,], rows=C, cols=Hin*Win)  # reshape
-    img_maxes = matrix(0, rows=C, cols=Hout*Wout)  # zeros
 
+    if (padh > 0 | padw > 0) {
+      # Pad image
+      img = util::pad_image(img, Hin, Win, padh, padw)  # shape (C, 
(Hin+2*padh)*(Win+2*padw))
+    }
+
+    img_maxes = matrix(0, rows=C, cols=Hout*Wout)  # zeros
     parfor (c in 1:C) {  # all channels
       # Extract local image slice patches into columns with im2col, of shape 
(Hf*Wf, Hout*Wout)
-      img_slice_cols = util::im2col(img[c,], Hin, Win, Hf, Wf, strideh, 
stridew)
+      img_slice_cols = util::im2col(img[c,], Hin+2*padh, Win+2*padw, Hf, Wf, 
strideh, stridew)
 
       # Max pooling on patches
       img_maxes[c,] = colMaxs(img_slice_cols)
@@ -76,8 +85,9 @@ forward = function(matrix[double] X, int C, int Hin, int Win, 
int Hf, int Wf,
   }
 }
 
-backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X, 
int C,
-                    int Hin, int Win, int Hf, int Wf, int strideh, int stridew)
+backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X,
+                    int C, int Hin, int Win, int Hf, int Wf,
+                    int strideh, int stridew, int padh, int padw)
     return (matrix[double] dX) {
   /*
    * Computes the backward pass for a 2D spatial max pooling layer.
@@ -96,22 +106,31 @@ backward = function(matrix[double] dout, int Hout, int 
Wout, matrix[double] X, i
    *  - Wf: Filter width.
    *  - strideh: Stride over height.
    *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      A typical value is 0.
+   *  - padw: Padding for left and right sides.
+   *      A typical value is 0.
    *
    * Outputs:
    *  - dX: Gradient wrt X, of shape (N, C*Hin*Win).
    */
   N = nrow(X)
-  
+
   # Create gradient volume
   dX = matrix(0, rows=N, cols=C*Hin*Win)
-  
+
   # Gradient of max pooling
   parfor (n in 1:N, check=0) {  # all examples
     img = matrix(X[n,], rows=C, cols=Hin*Win)
-    dimg = matrix(0, rows=C, cols=Hin*Win)
+    if (padh > 0 | padw > 0) {
+      # Pad image
+      img = util::pad_image(img, Hin, Win, padh, padw)  # shape (C, 
(Hin+2*padh)*(Win+2*padw))
+    }
+
+    dimg = matrix(0, rows=C, cols=(Hin+2*padh)*(Win+2*padw))
     parfor (c in 1:C, check=0) {  # all channels
-      img_slice = matrix(img[c,], rows=Hin, cols=Win)
-      dimg_slice = matrix(0, rows=Hin, cols=Win)
+      img_slice = matrix(img[c,], rows=Hin+2*padh, cols=Win+2*padw)
+      dimg_slice = matrix(0, rows=Hin+2*padh, cols=Win+2*padw)
       for (hout in 1:Hout, check=0) {  # all output rows
         hin = (hout-1) * strideh + 1
         for (wout in 1:Wout) {  # all output columns
@@ -120,11 +139,16 @@ backward = function(matrix[double] dout, int Hout, int 
Wout, matrix[double] X, i
           max_val_ind = img_slice_patch == max(img_slice_patch)  # max value 
indicator matrix
           # gradient passes through only for the max value(s) in this patch
           dimg_slice_patch = max_val_ind * dout[n, (c-1)*Hout*Wout + 
(hout-1)*Wout + wout]
-          dimg_slice[hin:hin+Hf-1, win:win+Wf-1] =
-            dimg_slice[hin:hin+Hf-1, win:win+Wf-1] + dimg_slice_patch
+          dimg_slice[hin:hin+Hf-1, win:win+Wf-1] = dimg_slice[hin:hin+Hf-1, 
win:win+Wf-1]
+                                                   + dimg_slice_patch
         }
       }
-      dimg[c,] = matrix(dimg_slice, rows=1, cols=Hin*Win)
+      dimg[c,] = matrix(dimg_slice, rows=1, cols=(Hin+2*padh)*(Win+2*padw))
+    }
+
+    if (padh > 0 | padw > 0) {
+      # Unpad image gradient
+      dimg = util::unpad_image(dimg, Hin, Win, padh, padw)  # shape (C, 
(Hin+2*padh)*(Win+2*padw))
     }
     dX[n,] = matrix(dimg, rows=1, cols=C*Hin*Win)
   }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/max_pool_builtin.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/max_pool_builtin.dml 
b/scripts/staging/SystemML-NN/nn/layers/max_pool_builtin.dml
index 97e991a..ae2b4a1 100644
--- a/scripts/staging/SystemML-NN/nn/layers/max_pool_builtin.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/max_pool_builtin.dml
@@ -23,7 +23,7 @@
  * Max pooling layer.
  */
 forward = function(matrix[double] X, int C, int Hin, int Win, int Hf, int Wf,
-                   int strideh, int stridew)
+                   int strideh, int stridew, int padh, int padw)
     return (matrix[double] out, int Hout, int Wout) {
   /*
    * Computes the forward pass for a 2D spatial max pooling layer.
@@ -44,6 +44,10 @@ forward = function(matrix[double] X, int C, int Hin, int 
Win, int Hf, int Wf,
    *  - Wf: Filter width.
    *  - strideh: Stride over height.
    *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      A typical value is 0.
+   *  - padw: Padding for left and right sides.
+   *      A typical value is 0.
    *
    * Outputs:
    *  - out: Outputs, of shape (N, C*Hout*Wout).
@@ -55,12 +59,13 @@ forward = function(matrix[double] X, int C, int Hin, int 
Win, int Hf, int Wf,
   Wout = as.integer((Win - Wf) / stridew + 1)
 
   # Max pooling - built-in implementation
-  out = max_pool(X, input_shape=[N,C,Hin,Win], pool_size=[Hf,Wf], 
stride=[strideh,stridew],
-                 padding=[0,0])
+  out = max_pool(X, input_shape=[N,C,Hin,Win], pool_size=[Hf,Wf],
+                 stride=[strideh,stridew], padding=[padh,padw])
 }
 
-backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X, 
int C,
-                    int Hin, int Win, int Hf, int Wf, int strideh, int stridew)
+backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X,
+                    int C, int Hin, int Win, int Hf, int Wf,
+                    int strideh, int stridew, int padh, int padw)
     return (matrix[double] dX) {
   /*
    * Computes the backward pass for a 2D spatial max pooling layer.
@@ -79,14 +84,18 @@ backward = function(matrix[double] dout, int Hout, int 
Wout, matrix[double] X, i
    *  - Wf: Filter width.
    *  - strideh: Stride over height.
    *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      A typical value is 0.
+   *  - padw: Padding for left and right sides.
+   *      A typical value is 0.
    *
    * Outputs:
    *  - dX: Gradient wrt X, of shape (N, C*Hin*Win).
    */
   N = nrow(X)
-  
+
   # Gradient of max pooling
   dX = max_pool_backward(X, dout, input_shape=[N,C,Hin,Win], pool_size=[Hf,Wf],
-                         stride=[strideh,stridew], padding=[0,0])
+                         stride=[strideh,stridew], padding=[padh,padw])
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/rnn.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/rnn.dml 
b/scripts/staging/SystemML-NN/nn/layers/rnn.dml
index 6c432bd..cd3eefe 100644
--- a/scripts/staging/SystemML-NN/nn/layers/rnn.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/rnn.dml
@@ -82,7 +82,7 @@ forward = function(matrix[double] X, matrix[double] W, 
matrix[double] b, int T,
 
 backward = function(matrix[double] dout, matrix[double] X, matrix[double] W, 
matrix[double] b,
                     int T, int D, boolean given_sequences, matrix[double] out0,
-                    matrix[double] cache_out) 
+                    matrix[double] cache_out)
     return (matrix[double] dX, matrix[double] dW, matrix[double] db, 
matrix[double] dout0) {
   /*
    * Computes the backward pass for a simple RNN layer with M neurons.
@@ -157,7 +157,7 @@ init = function(int N, int D, int M)
    *
    * Note: This is just a convenience function, and parameters
    * may be initialized manually if needed.
-   * 
+   *
    * We use the Glorot uniform heuristic which limits the magnification
    * of inputs/gradients during forward/backward passes by scaling
    * uniform weights by a factor of sqrt(6/(fan_in + fan_out)).
@@ -176,7 +176,7 @@ init = function(int N, int D, int M)
   fan_out = M
   scale = sqrt(6/(fan_in+fan_out))
   W = rand(rows=D+M, cols=M, min=-scale, max=scale, pdf="uniform")
-  b = matrix(0, rows=1, cols=M) 
+  b = matrix(0, rows=1, cols=M)
   out0 = matrix(0, rows=N, cols=M)
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/layers/softmax.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/softmax.dml 
b/scripts/staging/SystemML-NN/nn/layers/softmax.dml
index 64f257b..854e8a8 100644
--- a/scripts/staging/SystemML-NN/nn/layers/softmax.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/softmax.dml
@@ -61,7 +61,7 @@ backward = function(matrix[double] dprobs, matrix[double] 
scores)
    * dprobs_ij/dscores_ij = probs_ij * (1 - probs_ij)
    * dprobs_ik/dscores_ij = -probs_ik * probs_ij, for all k != j
    *
-   * dloss/dscores_ij = dloss/dprobs_ij * dprobs_ij/dscores_ij + 
+   * dloss/dscores_ij = dloss/dprobs_ij * dprobs_ij/dscores_ij +
    *                    sum_{k!=j}(dloss/dprobs_ik * dprobs_ik/dscores_ij)
    *
    * Inputs:

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/test/conv_simple.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/conv_simple.dml 
b/scripts/staging/SystemML-NN/nn/test/conv_simple.dml
index f065668..fb9d02c 100644
--- a/scripts/staging/SystemML-NN/nn/test/conv_simple.dml
+++ b/scripts/staging/SystemML-NN/nn/test/conv_simple.dml
@@ -58,7 +58,7 @@ forward = function(matrix[double] X, matrix[double] W, 
matrix[double] b,
   F = nrow(W)
   Hout = as.integer((Hin + 2 * padh - Hf) / strideh + 1)
   Wout = as.integer((Win + 2 * padw - Wf) / stridew + 1)
-  
+
   # Create output volume
   out = matrix(0, rows=N, cols=F*Hout*Wout)
 
@@ -83,11 +83,11 @@ forward = function(matrix[double] X, matrix[double] W, 
matrix[double] b,
           Xn_padded_patch = matrix(0, rows=C, cols=Hf*Wf)  # zeros
           parfor (c in 1:C, check=0) {
             Xn_padded_slice = matrix(Xn_padded[c,], rows=Hin+2*padh, 
cols=Win+2*padw)  # reshape
-            Xn_padded_patch[c,] = 
-              matrix(Xn_padded_slice[h0:h0-1+Hf, w0:w0-1+Wf], rows=1, 
cols=Hf*Wf)  # reshape
+            Xn_padded_patch[c,] = matrix(Xn_padded_slice[h0:h0-1+Hf, 
w0:w0-1+Wf], rows=1,
+                                         cols=Hf*Wf)  # reshape
           }
-          out[n, (f-1)*Hout*Wout + (hout-1)*Wout + wout] = 
-            W[f,] %*% matrix(Xn_padded_patch, rows=C*Hf*Wf, cols=1) + b[f,]
+          out[n, (f-1)*Hout*Wout + (hout-1)*Wout + wout] =
+              W[f,] %*% matrix(Xn_padded_patch, rows=C*Hf*Wf, cols=1) + b[f,]
         }
       }
     }
@@ -131,7 +131,7 @@ backward = function(matrix[double] dout, int Hout, int Wout,
   F = nrow(W)
   Hout = as.integer((Hin + 2 * padh - Hf) / strideh + 1)
   Wout = as.integer((Win + 2 * padw - Wf) / stridew + 1)
-  
+
   # Create gradient volumes
   dX = matrix(0, rows=N, cols=C*Hin*Win)
   dW = matrix(0, rows=F, cols=C*Hf*Wf)
@@ -160,16 +160,17 @@ backward = function(matrix[double] dout, int Hout, int 
Wout,
                                     rows=C, cols=Hf*Wf)  # reshape
           for (c in 1:C) {
             Xn_padded_slice = matrix(Xn_padded[c,], rows=Hin+2*padh, 
cols=Win+2*padw)  # reshape
-            Xn_padded_patch[c,] = 
-              matrix(Xn_padded_slice[h0:h0-1+Hf, w0:w0-1+Wf], rows=1, 
cols=Hf*Wf)  # reshape
+            Xn_padded_patch[c,] = matrix(Xn_padded_slice[h0:h0-1+Hf, 
w0:w0-1+Wf],
+                                         rows=1, cols=Hf*Wf)  # reshape
             dXn_padded_slice = matrix(0, rows=Hin+2*padh, cols=Win+2*padw)
-            dXn_padded_slice[h0:h0-1+Hf, w0:w0-1+Wf] =
-              matrix(dXn_padded_patch[c,], rows=Hf, cols=Wf)  # reshape
-            dXn_padded[c,] = dXn_padded[c,] +
-              matrix(dXn_padded_slice, rows=1, cols=(Hin+2*padh)*(Win+2*padw))
+            dXn_padded_slice[h0:h0-1+Hf, w0:w0-1+Wf] = 
matrix(dXn_padded_patch[c,],
+                                                              rows=Hf, 
cols=Wf)  # reshape
+            dXn_padded[c,] = dXn_padded[c,] + matrix(dXn_padded_slice,
+                                                     rows=1, 
cols=(Hin+2*padh)*(Win+2*padw))
           }
-          dW[f,] = dW[f,] + matrix(Xn_padded_patch, rows=1, cols=C*Hf*Wf) *
-            dout[n, (f-1)*Hout*Wout + (hout-1)*Wout + wout]
+          dW[f,] = dW[f,]
+                   + matrix(Xn_padded_patch, rows=1, cols=C*Hf*Wf)
+                   * dout[n, (f-1)*Hout*Wout + (hout-1)*Wout + wout]
           db[f,] = db[f,] + dout[n, (f-1)*Hout*Wout + (hout-1)*Wout + wout]
         }
       }
@@ -179,7 +180,7 @@ backward = function(matrix[double] dout, int Hout, int Wout,
     parfor (c in 1:C, check=0) {
       dXn_padded_slice = matrix(dXn_padded[c,], rows=(Hin+2*padh), 
cols=(Win+2*padw))
       dXn_slice = dXn_padded_slice[padh+1:padh+Hin, padw+1:padw+Win]
-      dXn[c, ] = matrix(dXn_slice, rows=1, cols=Hin*Win)
+      dXn[c,] = matrix(dXn_slice, rows=1, cols=Hin*Win)
     }
     dX[n,] = matrix(dXn, rows=1, cols=C*Hin*Win)
   }
@@ -189,7 +190,7 @@ init = function(int F, int C, int Hf, int Wf)
     return (matrix[double] W, matrix[double] b) {
   /*
    * Initialize the parameters of this layer.
-   * 
+   *
    * We use the heuristic by He et al. [http://arxiv.org/abs/1502.01852],
    * which limits the magnification of inputs/gradients during
    * forward/backward passes by scaling unit-Gaussian weights by a
@@ -206,6 +207,6 @@ init = function(int F, int C, int Hf, int Wf)
    *  - b: Biases vector, of shape (F, 1).
    */
   W = rand(rows=F, cols=C*Hf*Wf, pdf="normal") * sqrt(2.0/(C*Hf*Wf))
-  b = matrix(0, rows=F, cols=1) 
+  b = matrix(0, rows=F, cols=1)
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/test/grad_check.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/grad_check.dml 
b/scripts/staging/SystemML-NN/nn/test/grad_check.dml
index 5500c4f..db923df 100644
--- a/scripts/staging/SystemML-NN/nn/test/grad_check.dml
+++ b/scripts/staging/SystemML-NN/nn/test/grad_check.dml
@@ -50,7 +50,7 @@ check_rel_error = function(double dw_a, double dw_n, double 
lossph, double lossm
    * Check and report any issues with the relative error measure between
    * the analytical and numerical partial derivatives.
    *
-   *  - Issues an "ERROR" statement for relative errors > 1e-2, 
+   *  - Issues an "ERROR" statement for relative errors > 1e-2,
    *  indicating that the gradient is likely incorrect.
    *  - Issues a "WARNING" statement for relative errors < 1e-2
    *  but > 1e-4, indicating that the may be incorrect.
@@ -66,7 +66,7 @@ check_rel_error = function(double dw_a, double dw_n, double 
lossph, double lossm
    */
   # Compute relative error
   rel_error = util::compute_rel_error(dw_a, dw_n)
-  
+
   # Evaluate relative error
   thresh_error = 1e-2
   thresh_warn = 1e-4
@@ -186,8 +186,8 @@ conv = function() {
   # Compute analytical gradients of loss wrt parameters
   [out, Hout, Wout] = conv::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, 
stride, pad, pad)
   dout = l2_loss::backward(out, y)
-  [dX, dW, db] =
-    conv::backward(dout, Hout, Wout, X, W, b, C, Hin, Win, Hf, Wf, stride, 
stride, pad, pad)
+  [dX, dW, db] = conv::backward(dout, Hout, Wout, X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride,
+                                pad, pad)
 
   # Grad check
   h = 1e-5
@@ -274,8 +274,8 @@ conv_builtin = function() {
   # Compute analytical gradients of loss wrt parameters
   [out, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
   dout = l2_loss::backward(out, y)
-  [dX, dW, db] =
-    conv_builtin::backward(dout, Hout, Wout, X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+  [dX, dW, db] = conv_builtin::backward(dout, Hout, Wout, X, W, b, C, Hin, 
Win, Hf, Wf,
+                                        stride, stride, pad, pad)
 
   # Grad check
   h = 1e-5
@@ -285,10 +285,12 @@ conv_builtin = function() {
       # Compute numerical derivative
       old = as.scalar(X[i,j])
       X[i,j] = old - h
-      [outmh, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride, pad, pad)
+      [outmh, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride,
+                                                  pad, pad)
       lossmh = l2_loss::forward(outmh, y)
       X[i,j] = old + h
-      [outph, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride, pad, pad)
+      [outph, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride,
+                                                  pad, pad)
       lossph = l2_loss::forward(outph, y)
       X[i,j] = old  # reset
       dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
@@ -304,10 +306,12 @@ conv_builtin = function() {
       # Compute numerical derivative
       old = as.scalar(W[i,j])
       W[i,j] = old - h
-      [outmh, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride, pad, pad)
+      [outmh, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride,
+                                                  pad, pad)
       lossmh = l2_loss::forward(outmh, y)
       W[i,j] = old + h
-      [outph, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride, pad, pad)
+      [outph, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride,
+                                                  pad, pad)
       lossph = l2_loss::forward(outph, y)
       W[i,j] = old  # reset
       dW_num = (lossph - lossmh) / (2 * h) # numerical derivative
@@ -323,10 +327,12 @@ conv_builtin = function() {
       # Compute numerical derivative
       old = as.scalar(b[i,j])
       b[i,j] = old - h
-      [outmh, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride, pad, pad)
+      [outmh, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride,
+                                                  pad, pad)
       lossmh = l2_loss::forward(outmh, y)
       b[i,j] = old + h
-      [outph, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride, pad, pad)
+      [outph, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, 
Wf, stride, stride,
+                                                  pad, pad)
       lossph = l2_loss::forward(outph, y)
       b[i,j] = old  # reset
       db_num = (lossph - lossmh) / (2 * h) # numerical derivative
@@ -362,8 +368,8 @@ conv_simple = function() {
   # Compute analytical gradients of loss wrt parameters
   [out, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
   dout = l2_loss::backward(out, y)
-  [dX, dW, db] =
-    conv_simple::backward(dout, Hout, Wout, X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+  [dX, dW, db] = conv_simple::backward(dout, Hout, Wout, X, W, b, C, Hin, Win, 
Hf, Wf,
+                                       stride, stride, pad, pad)
 
   # Grad check
   h = 1e-5
@@ -373,10 +379,12 @@ conv_simple = function() {
       # Compute numerical derivative
       old = as.scalar(X[i,j])
       X[i,j] = old - h
-      [outmh, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+      [outmh, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride,
+                                                 pad, pad)
       lossmh = l2_loss::forward(outmh, y)
       X[i,j] = old + h
-      [outph, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+      [outph, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride,
+                                                 pad, pad)
       lossph = l2_loss::forward(outph, y)
       X[i,j] = old  # reset
       dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
@@ -392,10 +400,12 @@ conv_simple = function() {
       # Compute numerical derivative
       old = as.scalar(W[i,j])
       W[i,j] = old - h
-      [outmh, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+      [outmh, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride,
+                                                 pad, pad)
       lossmh = l2_loss::forward(outmh, y)
       W[i,j] = old + h
-      [outph, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+      [outph, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride,
+                                                 pad, pad)
       lossph = l2_loss::forward(outph, y)
       W[i,j] = old  # reset
       dW_num = (lossph - lossmh) / (2 * h) # numerical derivative
@@ -411,10 +421,12 @@ conv_simple = function() {
       # Compute numerical derivative
       old = as.scalar(b[i,j])
       b[i,j] = old - h
-      [outmh, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+      [outmh, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride,
+                                                 pad, pad)
       lossmh = l2_loss::forward(outmh, y)
       b[i,j] = old + h
-      [outph, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+      [outph, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, 
stride, stride,
+                                                 pad, pad)
       lossph = l2_loss::forward(outph, y)
       b[i,j] = old  # reset
       db_num = (lossph - lossmh) / (2 * h) # numerical derivative
@@ -830,30 +842,36 @@ max_pool = function() {
   Wf = 2  # pool filter width
   stride = 2
   X = rand(rows=N, cols=C*Hin*Win)
-  y = rand(rows=N, cols=C*2*2)
-
-  # Compute analytical gradients of loss wrt parameters
-  [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
-  dout = l2_loss::backward(out, y)
-  dX = max_pool::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, stride, 
stride)
-
-  # Grad check
-  h = 1e-5
-  for (i in 1:nrow(X)) {
-    for (j in 1:ncol(X)) {
-      # Compute numerical derivative
-      old = as.scalar(X[i,j])
-      X[i,j] = old - h
-      [outmh, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, 
stride)
-      lossmh = l2_loss::forward(outmh, y)
-      X[i,j] = old + h
-      [outph, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, 
stride)
-      lossph = l2_loss::forward(outph, y)
-      X[i,j] = old  # reset
-      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
 
-      # Check error
-      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+  for (pad in 0:1) {
+    print(" - Grad checking w/ pad="+pad+".")
+    Hout = as.integer((Hin + 2 * pad - Hf) / stride + 1)
+    Wout = as.integer((Win + 2 * pad - Wf) / stride + 1)
+    y = rand(rows=N, cols=C*Hout*Wout)
+
+    # Compute analytical gradients of loss wrt parameters
+    [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, 
stride, pad, pad)
+    dout = l2_loss::backward(out, y)
+    dX = max_pool::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, stride, 
stride, pad, pad)
+
+    # Grad check
+    h = 1e-5
+    for (i in 1:nrow(X)) {
+      for (j in 1:ncol(X)) {
+        # Compute numerical derivative
+        old = as.scalar(X[i,j])
+        X[i,j] = old - h
+        [outmh, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+        lossmh = l2_loss::forward(outmh, y)
+        X[i,j] = old + h
+        [outph, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+        lossph = l2_loss::forward(outph, y)
+        X[i,j] = old  # reset
+        dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+        # Check error
+        rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+      }
     }
   }
 }
@@ -873,30 +891,39 @@ max_pool_builtin = function() {
   Wf = 2  # pool filter width
   stride = 2
   X = rand(rows=N, cols=C*Hin*Win)
-  y = rand(rows=N, cols=C*2*2)
-
-  # Compute analytical gradients of loss wrt parameters
-  [out, Hout, Wout] = max_pool_builtin::forward(X, C, Hin, Win, Hf, Wf, 
stride, stride)
-  dout = l2_loss::backward(out, y)
-  dX = max_pool_builtin::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, 
stride, stride)
-
-  # Grad check
-  h = 1e-5
-  for (i in 1:nrow(X)) {
-    for (j in 1:ncol(X)) {
-      # Compute numerical derivative
-      old = as.scalar(X[i,j])
-      X[i,j] = old - h
-      [outmh, Hout, Wout] = max_pool_builtin::forward(X, C, Hin, Win, Hf, Wf, 
stride, stride)
-      lossmh = l2_loss::forward(outmh, y)
-      X[i,j] = old + h
-      [outph, Hout, Wout] = max_pool_builtin::forward(X, C, Hin, Win, Hf, Wf, 
stride, stride)
-      lossph = l2_loss::forward(outph, y)
-      X[i,j] = old  # reset
-      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
 
-      # Check error
-      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+  for (pad in 0:1) {
+    print(" - Grad checking w/ pad="+pad+".")
+    Hout = as.integer((Hin + 2 * pad - Hf) / stride + 1)
+    Wout = as.integer((Win + 2 * pad - Wf) / stride + 1)
+    y = rand(rows=N, cols=C*Hout*Wout)
+
+    # Compute analytical gradients of loss wrt parameters
+    [out, Hout, Wout] = max_pool_builtin::forward(X, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+    dout = l2_loss::backward(out, y)
+    dX = max_pool_builtin::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, 
stride, stride,
+                                    pad, pad)
+
+    # Grad check
+    h = 1e-5
+    for (i in 1:nrow(X)) {
+      for (j in 1:ncol(X)) {
+        # Compute numerical derivative
+        old = as.scalar(X[i,j])
+        X[i,j] = old - h
+        [outmh, Hout, Wout] = max_pool_builtin::forward(X, C, Hin, Win, Hf, 
Wf, stride, stride,
+                                                        pad, pad)
+        lossmh = l2_loss::forward(outmh, y)
+        X[i,j] = old + h
+        [outph, Hout, Wout] = max_pool_builtin::forward(X, C, Hin, Win, Hf, 
Wf, stride, stride,
+                                                        pad, pad)
+        lossph = l2_loss::forward(outph, y)
+        X[i,j] = old  # reset
+        dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+        # Check error
+        rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+      }
     }
   }
 }
@@ -916,30 +943,39 @@ max_pool_simple = function() {
   Wf = 2  # pool filter width
   stride = 2
   X = rand(rows=N, cols=C*Hin*Win)
-  y = rand(rows=N, cols=C*2*2)
-
-  # Compute analytical gradients of loss wrt parameters
-  [out, Hout, Wout] = max_pool_simple::forward(X, C, Hin, Win, Hf, Wf, stride, 
stride)
-  dout = l2_loss::backward(out, y)
-  dX = max_pool_simple::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, 
stride, stride)
-
-  # Grad check
-  h = 1e-5
-  for (i in 1:nrow(X)) {
-    for (j in 1:ncol(X)) {
-      # Compute numerical derivative
-      old = as.scalar(X[i,j])
-      X[i,j] = old - h
-      [outmh, Hout, Wout] = max_pool_simple::forward(X, C, Hin, Win, Hf, Wf, 
stride, stride)
-      lossmh = l2_loss::forward(outmh, y)
-      X[i,j] = old + h
-      [outph, Hout, Wout] = max_pool_simple::forward(X, C, Hin, Win, Hf, Wf, 
stride, stride)
-      lossph = l2_loss::forward(outph, y)
-      X[i,j] = old  # reset
-      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
 
-      # Check error
-      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+  for (pad in 0:1) {
+    print(" - Grad checking w/ pad="+pad+".")
+    Hout = as.integer((Hin + 2 * pad - Hf) / stride + 1)
+    Wout = as.integer((Win + 2 * pad - Wf) / stride + 1)
+    y = rand(rows=N, cols=C*Hout*Wout)
+
+    # Compute analytical gradients of loss wrt parameters
+    [out, Hout, Wout] = max_pool_simple::forward(X, C, Hin, Win, Hf, Wf, 
stride, stride, pad, pad)
+    dout = l2_loss::backward(out, y)
+    dX = max_pool_simple::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, 
stride, stride,
+                                   pad, pad)
+
+    # Grad check
+    h = 1e-5
+    for (i in 1:nrow(X)) {
+      for (j in 1:ncol(X)) {
+        # Compute numerical derivative
+        old = as.scalar(X[i,j])
+        X[i,j] = old - h
+        [outmh, Hout, Wout] = max_pool_simple::forward(X, C, Hin, Win, Hf, Wf, 
stride, stride,
+                                                       pad, pad)
+        lossmh = l2_loss::forward(outmh, y)
+        X[i,j] = old + h
+        [outph, Hout, Wout] = max_pool_simple::forward(X, C, Hin, Win, Hf, Wf, 
stride, stride,
+                                                       pad, pad)
+        lossph = l2_loss::forward(outph, y)
+        X[i,j] = old  # reset
+        dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+        # Check error
+        rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+      }
     }
   }
 }
@@ -1252,7 +1288,7 @@ two_layer_affine_l2_net = function() {
 
   # Compute analytical gradients
   [pred, loss, dX, dW1, db1, dW2, db2] = two_layer_affine_l2_net_run(X, y, W1, 
b1, W2, b2)
-  
+
   # Grad check
   h = 1e-5
   print(" - Grad checking X.")
@@ -1356,7 +1392,7 @@ two_layer_affine_l2_net_run = function(matrix[double] X, 
matrix[double] y,
 
   # Compute backward pass
   [dX, dpred, daout, dhout, dW1, db1, dW2, db2] =
-    two_layer_affine_l2_net_backward(X, y, pred, aout, hout, W1, b1, W2, b2)
+      two_layer_affine_l2_net_backward(X, y, pred, aout, hout, W1, b1, W2, b2)
 }
 
 two_layer_affine_l2_net_forward = function(matrix[double] X, matrix[double] y,

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/test/max_pool_simple.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/max_pool_simple.dml 
b/scripts/staging/SystemML-NN/nn/test/max_pool_simple.dml
index 08938e5..12db116 100644
--- a/scripts/staging/SystemML-NN/nn/test/max_pool_simple.dml
+++ b/scripts/staging/SystemML-NN/nn/test/max_pool_simple.dml
@@ -25,7 +25,7 @@
  * This implementation is intended to be a simple, reference version.
  */
 forward = function(matrix[double] X, int C, int Hin, int Win, int Hf, int Wf,
-                   int strideh, int stridew)
+                   int strideh, int stridew, int padh, int padw)
     return (matrix[double] out, int Hout, int Wout) {
   /*
    * Computes the forward pass for a 2D spatial max pooling layer.
@@ -43,6 +43,10 @@ forward = function(matrix[double] X, int C, int Hin, int 
Win, int Hf, int Wf,
    *  - Wf: Filter width.
    *  - strideh: Stride over height.
    *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      A typical value is 0.
+   *  - padw: Padding for left and right sides.
+   *      A typical value is 0.
    *
    * Outputs:
    *  - out: Outputs, of shape (N, C*Hout*Wout).
@@ -50,31 +54,43 @@ forward = function(matrix[double] X, int C, int Hin, int 
Win, int Hf, int Wf,
    *  - Wout: Output width.
    */
   N = nrow(X)
-  Hout = as.integer((Hin - Hf) / strideh + 1)
-  Wout = as.integer((Win - Wf) / stridew + 1)
+  Hout = as.integer((Hin + 2 * padh - Hf) / strideh + 1)
+  Wout = as.integer((Win + 2 * padw - Wf) / stridew + 1)
 
   # Create output volume
-  out = matrix(0, rows=N, cols=C*Hout*Wout)  
+  out = matrix(0, rows=N, cols=C*Hout*Wout)
 
   # Max pooling
   parfor (n in 1:N, check=0) {  # all examples
-    img = matrix(X[n,], rows=C, cols=Hin*Win)
+    Xn = matrix(X[n,], rows=C, cols=Hin*Win)
+
+    # Pad image
+    Xn_padded = matrix(0, rows=C, cols=(Hin+2*padh)*(Win+2*padw))  # zeros
+    parfor (c in 1:C) {
+      Xn_slice = matrix(Xn[c,], rows=Hin, cols=Win)  # depth slice C reshaped
+      Xn_padded_slice = matrix(Xn_padded[c,], rows=Hin+2*padh, cols=Win+2*padw)
+      Xn_padded_slice[padh+1:padh+Hin, padw+1:padw+Win] = Xn_slice
+      Xn_padded[c, ] = matrix(Xn_padded_slice, rows=1, 
cols=(Hin+2*padh)*(Win+2*padw))  # reshape
+    }
+    img = Xn_padded  # shape (C, (Hin+2*padh)*(Win+2*padw))
+
     parfor (c in 1:C, check=0) {  # all channels
-      img_slice = matrix(img[c,], rows=Hin, cols=Win)
+      img_slice = matrix(img[c,], rows=Hin+2*padh, cols=Win+2*padw)
       parfor (hout in 1:Hout, check=0) {  # all output rows
         hin = (hout-1) * strideh + 1
         parfor (wout in 1:Wout, check=0) {  # all output columns
           win = (wout-1) * stridew + 1
-          out[n, (c-1)*Hout*Wout + (hout-1)*Wout + wout] =
-            max(img_slice[hin:hin+Hf-1, win:win+Wf-1])
+          out[n, (c-1)*Hout*Wout + (hout-1)*Wout + wout] = 
max(img_slice[hin:hin+Hf-1,
+                                                               win:win+Wf-1])
         }
       }
     }
   }
 }
 
-backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X, 
int C, 
-                    int Hin, int Win, int Hf, int Wf, int strideh, int stridew)
+backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X,
+                    int C, int Hin, int Win, int Hf, int Wf,
+                    int strideh, int stridew, int padh, int padw)
     return (matrix[double] dX) {
   /*
    * Computes the backward pass for a 2D spatial max pooling layer.
@@ -93,22 +109,37 @@ backward = function(matrix[double] dout, int Hout, int 
Wout, matrix[double] X, i
    *  - Wf: Filter width.
    *  - strideh: Stride over height.
    *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      A typical value is 0.
+   *  - padw: Padding for left and right sides.
+   *      A typical value is 0.
    *
    * Outputs:
    *  - dX: Gradient wrt X, of shape (N, C*Hin*Win).
    */
   N = nrow(X)
-  
+
   # Create gradient volume
   dX = matrix(0, rows=N, cols=C*Hin*Win)
-  
+
   # Gradient of max pooling
-  parfor (n in 1:N, check=0) {  # all examples
-    img = matrix(X[n,], rows=C, cols=Hin*Win)
-    dimg = matrix(0, rows=C, cols=Hin*Win)
-    parfor (c in 1:C, check=0) {  # all channels
-      img_slice = matrix(img[c,], rows=Hin, cols=Win)
-      dimg_slice = matrix(0, rows=Hin, cols=Win)
+  for (n in 1:N) {  # all examples
+    Xn = matrix(X[n,], rows=C, cols=Hin*Win)
+
+    # Pad image
+    Xn_padded = matrix(0, rows=C, cols=(Hin+2*padh)*(Win+2*padw))  # zeros
+    parfor (c in 1:C) {
+      Xn_slice = matrix(Xn[c,], rows=Hin, cols=Win)  # depth slice C reshaped
+      Xn_padded_slice = matrix(Xn_padded[c,], rows=Hin+2*padh, cols=Win+2*padw)
+      Xn_padded_slice[padh+1:padh+Hin, padw+1:padw+Win] = Xn_slice
+      Xn_padded[c, ] = matrix(Xn_padded_slice, rows=1, 
cols=(Hin+2*padh)*(Win+2*padw))  # reshape
+    }
+    img = Xn_padded
+
+    dimg = matrix(0, rows=C, cols=(Hin+2*padh)*(Win+2*padw))
+    for (c in 1:C) {  # all channels
+      img_slice = matrix(img[c,], rows=Hin+2*padh, cols=Win+2*padw)
+      dimg_slice = matrix(0, rows=Hin+2*padh, cols=Win+2*padw)
       for (hout in 1:Hout, check=0) {  # all output rows
         hin = (hout-1) * strideh + 1
         for (wout in 1:Wout) {  # all output columns
@@ -117,13 +148,21 @@ backward = function(matrix[double] dout, int Hout, int 
Wout, matrix[double] X, i
           max_val_ind = img_slice_patch == max(img_slice_patch)  # max value 
indicator matrix
           # gradient passes through only for the max value(s) in this patch
           dimg_slice_patch = max_val_ind * dout[n, (c-1)*Hout*Wout + 
(hout-1)*Wout + wout]
-          dimg_slice[hin:hin+Hf-1, win:win+Wf-1] =
-            dimg_slice[hin:hin+Hf-1, win:win+Wf-1] + dimg_slice_patch
+          dimg_slice[hin:hin+Hf-1, win:win+Wf-1] = dimg_slice[hin:hin+Hf-1, 
win:win+Wf-1]
+                                                   + dimg_slice_patch
         }
       }
-      dimg[c,] = matrix(dimg_slice, rows=1, cols=Hin*Win)
+      dimg[c,] = matrix(dimg_slice, rows=1, cols=(Hin+2*padh)*(Win+2*padw))
+    }
+
+    # Unpad derivs on input
+    dXn = matrix(0, rows=C, cols=Hin*Win)
+    parfor (c in 1:C, check=0) {
+      dXn_padded_slice = matrix(dimg[c,], rows=(Hin+2*padh), cols=(Win+2*padw))
+      dXn_slice = dXn_padded_slice[padh+1:padh+Hin, padw+1:padw+Win]
+      dXn[c, ] = matrix(dXn_slice, rows=1, cols=Hin*Win)
     }
-    dX[n,] = matrix(dimg, rows=1, cols=C*Hin*Win)
+    dX[n,] = matrix(dXn, rows=1, cols=C*Hin*Win)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/test.dml 
b/scripts/staging/SystemML-NN/nn/test/test.dml
index d0e83f5..5052fa6 100644
--- a/scripts/staging/SystemML-NN/nn/test/test.dml
+++ b/scripts/staging/SystemML-NN/nn/test/test.dml
@@ -55,18 +55,20 @@ conv = function() {
 
   # Forward
   [out, Hout, Wout] = conv::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, 
stride, pad, pad)
-  [out_simple, Hout_simple, Wout_simple] =
-    conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, 
pad)
-  [out_builtin, Hout_builtin, Wout_builtin] =
-    conv_builtin::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, 
pad)
+  [out_simple, Hout_simple, Wout_simple] = conv_simple::forward(X, W, b, C, 
Hin, Win, Hf, Wf,
+                                                                stride, 
stride, pad, pad)
+  [out_builtin, Hout_builtin, Wout_builtin] = conv_builtin::forward(X, W, b, 
C, Hin, Win, Hf, Wf,
+                                                                    stride, 
stride, pad, pad)
 
   # Equivalency check
   out = matrix(out, rows=1, cols=N*F*Hout*Wout)
   out_simple = matrix(out_simple, rows=1, cols=N*F*Hout*Wout)
   out_builtin = matrix(out_builtin, rows=1, cols=N*F*Hout*Wout)
   for (i in 1:length(out)) {
-    rel_error = util::check_rel_error(as.scalar(out[1,i]), 
as.scalar(out_simple[1,i]), 1e-10, 1e-12)
-    rel_error = util::check_rel_error(as.scalar(out[1,i]), 
as.scalar(out_builtin[1,i]), 1e-10, 1e-12)
+    rel_error = util::check_rel_error(as.scalar(out[1,i]),
+                                      as.scalar(out_simple[1,i]), 1e-10, 1e-12)
+    rel_error = util::check_rel_error(as.scalar(out[1,i]),
+                                      as.scalar(out_builtin[1,i]), 1e-10, 
1e-12)
   }
 }
 
@@ -86,12 +88,12 @@ cross_entropy_loss = function() {
   pred = matrix(0, rows=N, cols=K)
   y = rand(rows=N, cols=K, min=0, max=1, pdf="uniform")
   y = y / rowSums(y)  # normalized probs
-  
+
   loss = cross_entropy_loss::forward(pred, y)
-  
+
   inf = 1/0
   if (loss == inf) {
-      print("ERROR: The cross-entropy loss function ouptuts infinity for 
all-zero predictions.")
+    print("ERROR: The cross-entropy loss function ouptuts infinity for 
all-zero predictions.")
   }
 }
 
@@ -144,7 +146,7 @@ padding = function() {
 
   # Pad image
   x_pad = util::pad_image(x, Hin, Win, pad, pad)
-  
+
   # Check for padded rows & columns
   for (c in 1:C) {
     x_pad_slice = matrix(x_pad[c,], rows=Hin+2*pad, cols=Win+2*pad)
@@ -184,39 +186,206 @@ max_pool = function() {
   stride = 2
   X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
 
-  # Forward
-  [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
-  [out_simple, Hout_simple, Wout_simple] =
-    max_pool_simple::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
-  [out_builtin, Hout_builtin, Wout_builtin] =
-    max_pool_builtin::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+  for (padh in 0:3) {
+    for (padw in 0:3) {
+      print(" - Testing w/ padh="+padh+" & padw="+padw+".")
+      if (1==1) {}  # force correct printing
+      print("   - Testing forward")
+      [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, 
stride, padh, padw)
+      [out_simple, Hout_simple, Wout_simple] = max_pool_simple::forward(X, C, 
Hin, Win, Hf, Wf,
+                                                                        
stride, stride, padh, padw)
+      [out_builtin, Hout_builtin, Wout_builtin] = max_pool_builtin::forward(X, 
C, Hin, Win, Hf, Wf,
+                                                                            
stride, stride,
+                                                                            
padh, padw)
 
-  # Equivalency check
-  out = matrix(out, rows=1, cols=N*C*Hout*Wout)
-  out_simple = matrix(out_simple, rows=1, cols=N*C*Hout*Wout)
-  out_builtin = matrix(out_builtin, rows=1, cols=N*C*Hout*Wout)
-  for (i in 1:length(out)) {
-    rel_error = util::check_rel_error(as.scalar(out[1,i]), 
as.scalar(out_simple[1,i]), 1e-10, 1e-12)
-    rel_error = util::check_rel_error(as.scalar(out[1,i]), 
as.scalar(out_builtin[1,i]), 1e-10, 1e-12)
+      # Equivalency check
+      out = matrix(out, rows=1, cols=N*C*Hout*Wout)
+      out_simple = matrix(out_simple, rows=1, cols=N*C*Hout*Wout)
+      out_builtin = matrix(out_builtin, rows=1, cols=N*C*Hout*Wout)
+      for (i in 1:length(out)) {
+        rel_error = util::check_rel_error(as.scalar(out[1,i]),
+                                          as.scalar(out_simple[1,i]), 1e-10, 
1e-12)
+        rel_error = util::check_rel_error(as.scalar(out[1,i]),
+                                          as.scalar(out_builtin[1,i]), 1e-10, 
1e-12)
+      }
+
+      print("   - Testing backward")
+      dout = rand(rows=N, cols=C*Hout*Wout, pdf="normal")
+      dX = max_pool::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, 
stride, stride, padh, padw)
+      dX_simple = max_pool_simple::backward(dout, Hout_simple, Wout_simple, X, 
C, Hin, Win, Hf, Wf,
+                                            stride, stride, padh, padw)
+      dX_builtin = max_pool_builtin::backward(dout, Hout_builtin, 
Wout_builtin, X, C, Hin, Win,
+                                              Hf, Wf, stride, stride, padh, 
padw)
+
+      # Equivalency check
+      dX = matrix(dX, rows=1, cols=N*C*Hin*Win)
+      dX_simple = matrix(dX_simple, rows=1, cols=N*C*Hin*Win)
+      dX_builtin = matrix(dX_builtin, rows=1, cols=N*C*Hin*Win)
+      for (i in 1:length(dX)) {
+        rel_error = util::check_rel_error(as.scalar(dX[1,i]),
+                                          as.scalar(dX_simple[1,i]), 1e-10, 
1e-12)
+        rel_error = util::check_rel_error(as.scalar(dX[1,i]),
+                                          as.scalar(dX_builtin[1,i]), 1e-10, 
1e-12)
+      }
+    }
   }
 
   # ---
-  # Check for correct behavior
-  # Generate data
+  print(" - Testing for correct behavior against known answer w/ pad=0.")
+  # generate data
+  # -- channel 1
+  #  1  2  3  4
+  #  5  6  7  8
+  #  9 10 11 12
+  # 13 14 15 16
+  # -- channel 2
+  #  1  5  9 13
+  #  2  6 10 14
+  #  3  7 11 15
+  #  4  8 12 16
   C = 2  # num channels
   Hin = 4  # input height
   Win = 4  # input width
   X = matrix(seq(1,16,1), rows=Hin, cols=Win)
-  X = matrix(rbind(X, t(X)), rows=1, cols=C*Hin*Win)
-  X = rbind(X, X)  # N=2
+  X = matrix(rbind(X, t(X)), rows=1, cols=C*Hin*Win)  # C=2
+  X = rbind(X, X)  # n=2
+  pad = 0
 
-  # Forward
-  [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+  # forward
+  [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, 
stride, pad, pad)
+  [out_simple, Hout_simple, Wout_simple] = max_pool_simple::forward(X, C, Hin, 
Win, Hf, Wf,
+                                                                    stride, 
stride, pad, pad)
+  [out_builtin, Hout_builtin, Wout_builtin] = max_pool_builtin::forward(X, C, 
Hin, Win, Hf, Wf,
+                                                                        
stride, stride, pad, pad)
 
-  # Equivalency check
+  # equivalency check
+  # -- channel 1
+  #   6  8
+  #  14 16
+  # -- channel 2
+  #  6  14
+  #  8  16
   target = matrix("6 8 14 16 6 14 8 16", rows=1, cols=C*Hout*Wout)
-  target = rbind(target, target)  # N=2
+  target = rbind(target, target)  # n=2
+  tmp = util::check_all_equal(out, target)
+  tmp = util::check_all_equal(out_simple, target)
+  tmp = util::check_all_equal(out_builtin, target)
+
+  print(" - Testing for correct behavior against known answer w/ pad=1.")
+  # generate data
+  # -- channel 1
+  #  0  0  0  0  0  0
+  #  0  1  2  3  4  0
+  #  0  5  6  7  8  0
+  #  0  9 10 11 12  0
+  #  0 13 14 15 16  0
+  #  0  0  0  0  0  0
+  # -- channel 2
+  #  0  0  0  0  0  0
+  #  0  1  5  9 13  0
+  #  0  2  6 10 14  0
+  #  0  3  7 11 15  0
+  #  0  4  8 12 16  0
+  #  0  0  0  0  0  0
+  pad = 1
+
+  # forward
+  [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, 
stride, pad, pad)
+  [out_simple, Hout_simple, Wout_simple] = max_pool_simple::forward(X, C, Hin, 
Win, Hf, Wf,
+                                                                    stride, 
stride, pad, pad)
+  [out_builtin, Hout_builtin, Wout_builtin] = max_pool_builtin::forward(X, C, 
Hin, Win, Hf, Wf,
+                                                                        
stride, stride, pad, pad)
+
+  # equivalency check
+  # -- channel 1
+  #  1  3  4
+  #  9 11 12
+  # 13 15 16
+  # -- channel 2
+  #  1  9 13
+  #  3 11 15
+  #  4 12 16
+  target = matrix("1 3 4 9 11 12 13 15 16 1 9 13 3 11 15 4 12 16", rows=1, 
cols=C*Hout*Wout)
+  target = rbind(target, target)  # n=2
+  tmp = util::check_all_equal(out, target)
+  tmp = util::check_all_equal(out_simple, target)
+  tmp = util::check_all_equal(out_builtin, target)
+
+  print(" - Testing for correct behavior against known answer w/ all negative 
matrix w/ pad=0.")
+  # generate data
+  # -- channel 1
+  #  -1  -2  -3  -4
+  #  -5  -6  -7  -8
+  #  -9 -10 -11 -12
+  # -13 -14 -15 -16
+  # -- channel 2
+  #  -1  -5  -9 -13
+  #  -2  -6 -10 -14
+  #  -3  -7 -11 -15
+  #  -4  -8 -12 -16
+  X = X * -1
+  pad = 0
+
+  # forward
+  [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, 
stride, pad, pad)
+  [out_simple, Hout_simple, Wout_simple] = max_pool_simple::forward(X, C, Hin, 
Win, Hf, Wf,
+                                                                    stride, 
stride, pad, pad)
+  [out_builtin, Hout_builtin, Wout_builtin] = max_pool_builtin::forward(X, C, 
Hin, Win, Hf, Wf,
+                                                                        
stride, stride, pad, pad)
+
+  # equivalency check
+  # -- channel 1
+  #  -1  -3
+  #  -9 -11
+  # -- channel 2
+  #  -1  -9
+  #  -3 -11
+  target = matrix("-1 -3 -9 -11 -1 -9 -3 -11", rows=1, cols=C*Hout*Wout)
+  target = rbind(target, target)  # n=2
+  tmp = util::check_all_equal(out, target)
+  tmp = util::check_all_equal(out_simple, target)
+  tmp = util::check_all_equal(out_builtin, target)
+
+
+  print(" - Testing for correct behavior against known answer w/ all negative 
matrix w/ pad=1.")
+  # generate data
+  # -- channel 1
+  #  0   0   0   0   0  0
+  #  0  -1  -2  -3  -4  0
+  #  0  -5  -6  -7  -8  0
+  #  0  -9 -10 -11 -12  0
+  #  0 -13 -14 -15 -16  0
+  #  0   0   0   0   0  0
+  # -- channel 2
+  #  0   0   0   0   0  0
+  #  0  -1  -5  -9 -13  0
+  #  0  -2  -6 -10 -14  0
+  #  0  -3  -7 -11 -15  0
+  #  0  -4  -8 -12 -16  0
+  #  0   0   0   0   0  0
+  pad = 1
+
+  # forward
+  [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, 
stride, pad, pad)
+  [out_simple, Hout_simple, Wout_simple] = max_pool_simple::forward(X, C, Hin, 
Win, Hf, Wf,
+                                                                    stride, 
stride, pad, pad)
+  [out_builtin, Hout_builtin, Wout_builtin] = max_pool_builtin::forward(X, C, 
Hin, Win, Hf, Wf,
+                                                                        
stride, stride, pad, pad)
+
+  # equivalency check
+  # -- channel 1
+  #  0  0  0
+  #  0 -6  0
+  #  0  0  0
+  # -- channel 2
+  #  0  0  0
+  #  0 -6  0
+  #  0  0  0
+  target = matrix("0 0 0 0 -6 0 0 0 0 0 0 0 0 -6 0 0 0 0", rows=1, 
cols=C*Hout*Wout)
+  target = rbind(target, target)  # n=2
   tmp = util::check_all_equal(out, target)
+  tmp = util::check_all_equal(out_simple, target)
+  tmp = util::check_all_equal(out_builtin, target)
 }
 
 tanh = function() {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/15ccb7c0/scripts/staging/SystemML-NN/nn/util.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/util.dml 
b/scripts/staging/SystemML-NN/nn/util.dml
index 38fddc3..6870a5f 100644
--- a/scripts/staging/SystemML-NN/nn/util.dml
+++ b/scripts/staging/SystemML-NN/nn/util.dml
@@ -57,7 +57,7 @@ check_all_equal = function(matrix[double] X1, matrix[double] 
X2)
 
   # Evaluate relative error
   if (!equivalent) {
-      print("ERROR: The two matrices are not equivalent.")
+    print("ERROR: The two matrices are not equivalent.")
   }
 }
 
@@ -102,12 +102,12 @@ check_rel_error = function(double x1, double x2, double 
thresh_error, double thr
 
   # Evaluate relative error
   if (rel_error > thresh_error) {
-      print("ERROR: Relative error " + rel_error + " > " + thresh_error + " 
with " + x1 +
-            " vs " + x2 + ".")
+    print("ERROR: Relative error " + rel_error + " > " + thresh_error + " with 
" + x1 +
+          " vs " + x2 + ".")
   }
   else if (rel_error > thresh_warn & rel_error <= thresh_error) {
-      print("WARNING: Relative error " + rel_error + " > " + thresh_warn + " & 
<= " + thresh_error +
-            " with " + x1 + " vs " + x2 + ".")
+    print("WARNING: Relative error " + rel_error + " > " + thresh_warn + " & 
<= " + thresh_error +
+          " with " + x1 + " vs " + x2 + ".")
   }
 }
 

Reply via email to