Repository: incubator-systemml
Updated Branches:
  refs/heads/master 0f4571810 -> c944ad1dc


[SYSTEMML-1409][SYSTEMML-1410] New Batch Normalization Layers

This commit adds a new batch normalization layer, `batch_norm`, and a
new spatial batch normalization layer, `spatial_batch_norm`.

The batch normalization layer uses the per-feature sample mean and
per-feature uncorrected sample variance during training to
normalize each feature of the input data.

The spatial batch normalization layer uses the per-channel sample mean
and per-channel uncorrected sample variance during training to
normalize each channel of the input data.

Additionally, these layers introduce learnable parameters (gamma, beta)
to control the amount of normalization.

   y = ((x-mean) / sqrt(var+eps)) * gamma + beta

Finally, these implementations maintain exponential moving averages of
the mean and variance during training for use during testing.

Closes #444.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/c944ad1d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/c944ad1d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/c944ad1d

Branch: refs/heads/master
Commit: c944ad1dca3dd0110c874e227f7511c9a182f26a
Parents: 0f45718
Author: Mike Dusenberry <mwdus...@us.ibm.com>
Authored: Thu Mar 30 14:37:11 2017 -0700
Committer: Mike Dusenberry <mwdus...@us.ibm.com>
Committed: Thu Mar 30 14:37:11 2017 -0700

----------------------------------------------------------------------
 .../SystemML-NN/nn/layers/batch_norm.dml        | 208 ++++++++++++++++
 .../nn/layers/spatial_batch_norm.dml            | 235 +++++++++++++++++++
 .../staging/SystemML-NN/nn/test/grad_check.dml  | 216 +++++++++++++++++
 scripts/staging/SystemML-NN/nn/test/test.dml    | 132 ++++++++++-
 scripts/staging/SystemML-NN/nn/test/tests.dml   |   9 +
 scripts/staging/SystemML-NN/nn/util.dml         |  19 ++
 6 files changed, 816 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c944ad1d/scripts/staging/SystemML-NN/nn/layers/batch_norm.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/batch_norm.dml 
b/scripts/staging/SystemML-NN/nn/layers/batch_norm.dml
new file mode 100644
index 0000000..d332e8c
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/batch_norm.dml
@@ -0,0 +1,208 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Batch normalization layer.
+ */
+forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
+                   string mode, matrix[double] ema_mean, matrix[double] 
ema_var,
+                   double mu, double epsilon)
+    return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] 
ema_var_upd,
+            matrix[double] cache_mean, matrix[double] cache_var, 
matrix[double] cache_norm) {
+  /*
+   * Computes the forward pass for a batch normalization layer.
+   *
+   * A batch normalization layer uses the per-feature sample mean and
+   * per-feature uncorrected sample variance during training to
+   * normalize each feature of the input data.  Additionally, it
+   * introduces learnable parameters (gamma, beta) to control the
+   * amount of normalization.
+   *
+   *    y = ((x-mean) / sqrt(var+eps)) * gamma + beta
+   *
+   * This implementation maintains exponential moving averages of the
+   * mean and variance during training for use during testing.
+   *
+   * Reference:
+   *  - Batch Normalization: Accelerating Deep Network Training by
+   *    Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015
+   *    - https://arxiv.org/abs/1502.03167
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - gamma: Scale parameters, of shape (1, D).
+   *  - beta: Shift parameters, of shape (1, D).
+   *  - mode: 'train' or 'test' to indicate if the model is currently
+   *      being trained or tested.  During training, the current batch
+   *      mean and variance will be used to normalize the inputs, while
+   *      during testing, the exponential average of the mean and
+   *      variance over all previous batches will be used.
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (1, D).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (1, D).
+   *  - mu: Momentum value for moving averages.
+   *      Typical values are in the range of [0.9, 0.999].
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-5, 1e-3].
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, D).
+   *  - ema_mean_upd: Updated exponential moving average of the mean,
+   *      of shape (1, D).
+   *  - ema_var_upd: Updated exponential moving average of the variance,
+   *      of shape (1, D).
+   *  - cache_mean: Cache of the batch mean, of shape (1, D).
+   *      Note: This is used for performance during training.
+   *  - cache_var: Cache of the batch variance, of shape (1, D).
+   *      Note: This is used for performance during training.
+   *  - cache_norm: Cache of the normalized inputs, of shape (N, D).
+   *      Note: This is used for performance during training.
+   */
+  N = nrow(X)
+
+  if(mode == 'train') {
+    # Compute feature-wise mean and variance
+    mean = colMeans(X)  # shape (1, D)
+    # var = (1/N) * colSums((X-mean)^2)
+    var = colVars(X) * ((N-1)/N)  # compute uncorrected variance, of shape (1, 
D)
+    # Update moving averages
+    ema_mean_upd = mu*ema_mean + (1-mu)*mean
+    ema_var_upd = mu*ema_var + (1-mu)*var
+  }
+  else {
+    # Use moving averages of mean and variance during testing
+    mean = ema_mean
+    var = ema_var
+    ema_mean_upd = ema_mean
+    ema_var_upd = ema_var
+  }
+
+  # Normalize, shift, and scale
+  # norm = (X-mean)*(var+epsilon)^(-1/2)
+  norm = (X-mean) / sqrt(var+epsilon)  # shape (N, D)
+  out = norm*gamma + beta  # shape (N, D)
+
+  # Save variable for backward pass
+  cache_mean = mean
+  cache_var = var
+  cache_norm = norm
+}
+
+backward = function(matrix[double] dout, matrix[double] out,
+                    matrix[double] ema_mean_upd, matrix[double] ema_var_upd,
+                    matrix[double] cache_mean, matrix[double] cache_var, 
matrix[double] cache_norm,
+                    matrix[double] X, matrix[double] gamma, matrix[double] 
beta,
+                    string mode, matrix[double] ema_mean, matrix[double] 
ema_var,
+                    double mu, double epsilon)
+      return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
+  /*
+   * Computes the backward pass for a batch normalization layer.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of shape (N, D).
+   *  - out: Outputs from the forward pass, of shape (N, D).
+   *  - ema_mean_upd: Updated exponential moving average of the mean
+   *      from the forward pass, of shape (1, D).
+   *  - ema_var_upd: Updated exponential moving average of the variance
+   *      from the forward pass, of shape (1, D).
+   *  - cache_mean: Cache of the batch mean from the forward pass, of
+   *      shape (1, D).  Note: This is used for performance during
+   *      training.
+   *  - cache_var: Cache of the batch variance from the forward pass,
+   *      of shape (1, D).  Note: This is used for performance during
+   *      training.
+   *  - cache_norm: Cache of the normalized inputs from the forward
+   *      pass, of shape (N, D).  Note: This is used for performance
+   *      during training.
+   *  - X: Input data matrix to the forward pass, of shape (N, D).
+   *  - gamma: Scale parameters, of shape (1, D).
+   *  - beta: Shift parameters, of shape (1, D).
+   *  - mode: 'train' or 'test' to indicate if the model is currently
+   *      being trained or tested.  During training, the current batch
+   *      mean and variance will be used to normalize the inputs, while
+   *      during testing, the exponential average of the mean and
+   *      variance over all previous batches will be used.
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (1, D).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (1, D).
+   *  - mu: Momentum value for moving averages.
+   *      Typical values are in the range of [0.9, 0.999].
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-5, 1e-3].
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of shape (N, D).
+   *  - dgamma: Gradient wrt W, of shape (1, D).
+   *  - dbeta: Gradient wrt b, of shape (1, D).
+   *
+   */
+  N = nrow(X)
+  mean = cache_mean
+  var = cache_var
+  norm = cache_norm
+  centered = X-mean
+
+  if (mode == 'train') {
+    # Compute gradients during training
+    dgamma = colSums(norm*dout)  # shape (1, D)
+    dbeta = colSums(dout)  # shape (1, D)
+    dnorm = dout * gamma  # shape (N, D)
+    dvar = (-1/2) * colSums(centered * (var+epsilon)^(-3/2) * dnorm)  # shape 
(1, D)
+    dmean = colSums((-dnorm/sqrt(var+epsilon)) + ((-2/N)*centered*dvar))  # 
shape (1, D)
+    dX = (dnorm/sqrt(var+epsilon)) + ((2/N)*centered*dvar) + ((1/N)*dmean)  # 
shape (N, D)
+  }
+  else {
+    # Compute gradients during testing
+    dgamma = colSums(norm*dout)  # shape (1, D)
+    dbeta = colSums(dout)  # shape (1, D)
+    dnorm = dout * gamma  # shape (N, D)
+    dX = dnorm / sqrt(var+epsilon)  # shape (N, D)
+  }
+}
+
+init = function(int D)
+    return (matrix[double] gamma, matrix[double] beta,
+            matrix[double] ema_mean, matrix[double] ema_var) {
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * Inputs:
+   *  - D: Dimensionality of the input features.
+   *
+   * Outputs:
+   *  - gamma: Scale parameters, of shape (1, D).
+   *  - beta: Shift parameters, of shape (1, D).
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (1, D).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (1, D).
+   */
+   gamma = matrix(1, rows=1, cols=D)
+   beta = matrix(0, rows=1, cols=D)
+   ema_mean = matrix(0, rows=1, cols=D)
+   ema_var = matrix(1, rows=1, cols=D)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c944ad1d/scripts/staging/SystemML-NN/nn/layers/spatial_batch_norm.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/spatial_batch_norm.dml 
b/scripts/staging/SystemML-NN/nn/layers/spatial_batch_norm.dml
new file mode 100644
index 0000000..53ca989
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/spatial_batch_norm.dml
@@ -0,0 +1,235 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Spatial batch normalization layer.
+ */
+source("nn/util.dml") as util
+
+forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
+                   int C, int Hin, int Win, string mode,
+                   matrix[double] ema_mean, matrix[double] ema_var,
+                   double mu, double epsilon)
+    return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] 
ema_var_upd,
+            matrix[double] cache_mean, matrix[double] cache_var, 
matrix[double] cache_norm) {
+  /*
+   * Computes the forward pass for a spatial batch normalization layer.
+   *
+   * A spatial batch normalization layer uses the per-channel sample
+   * mean and per-channel uncorrected sample variance during training
+   * to normalize each channel of the input data.  Additionally, it
+   * introduces learnable parameters (gamma, beta) to control the
+   * amount of normalization.
+   *
+   *    y = ((x-mean) / sqrt(var+eps)) * gamma + beta
+   *
+   * This implementation maintains exponential moving averages of the
+   * mean and variance during training for use during testing.
+   *
+   * Reference:
+   *  - Batch Normalization: Accelerating Deep Network Training by
+   *    Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015
+   *    - https://arxiv.org/abs/1502.03167
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - beta: Shift parameters, of shape (C, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - mode: 'train' or 'test' to indicate if the model is currently
+   *      being trained or tested.  During training, the current batch
+   *      mean and variance will be used to normalize the inputs, while
+   *      during testing, the exponential average of the mean and
+   *      variance over all previous batches will be used.
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (C, 1).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (C, 1).
+   *  - mu: Momentum value for moving averages.
+   *      Typical values are in the range of [0.9, 0.999].
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-5, 1e-3].
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, C*Hin*Win).
+   *  - ema_mean_upd: Updated exponential moving average of the mean,
+   *      of shape (C, 1).
+   *  - ema_var_upd: Updated exponential moving average of the variance,
+   *      of shape (C, 1).
+   *  - cache_mean: Cache of the batch mean, of shape (C, 1).
+   *      Note: This is used for performance during training.
+   *  - cache_var: Cache of the batch variance, of shape (C, 1).
+   *      Note: This is used for performance during training.
+   *  - cache_norm: Cache of the normalized inputs, of
+   *      shape (C, N*Hin*Win). Note: This is used for performance
+   *      during training.
+   */
+  N = nrow(X)
+
+  if(mode == 'train') {
+    # Compute channel-wise mean and variance
+    # Since we don't have tensors, we will compute the means and variances in 
a piece-wise fashion.
+    #  - mean of total group is mean of subgroup means
+    #  - variance is the mean of the subgroup variances + the variance of the 
subgroup means
+    subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
+    subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)  # 
uncorrected variances
+    mean = rowMeans(subgrp_means)  # shape (C, 1)
+    var = rowMeans(subgrp_vars) + 
rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))  # shape (C, 1)
+    # Update moving averages
+    ema_mean_upd = mu*ema_mean + (1-mu)*mean
+    ema_var_upd = mu*ema_var + (1-mu)*var
+  }
+  else {
+    # Use moving averages of mean and variance during testing
+    mean = ema_mean
+    var = ema_var
+    ema_mean_upd = ema_mean
+    ema_var_upd = ema_var
+  }
+
+  # Normalize, shift, and scale
+  # norm = (X-mean)*(var+epsilon)^(-1/2)
+  #      = (X-mean) / sqrt(var+epsilon)
+  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
+  norm = bias_multiply(centered, 1/sqrt(var+epsilon))  # shape (N, C*Hin*Win)
+  # out = norm*gamma + beta
+  scaled = bias_multiply(norm, gamma)  # shape (N, C*Hin*Win)
+  out = bias_add(scaled, beta)  # shape (N, C*Hin*Win)
+
+  # Save variable for backward pass
+  cache_mean = mean
+  cache_var = var
+  cache_norm = norm
+}
+
+backward = function(matrix[double] dout, matrix[double] out,
+                    matrix[double] ema_mean_upd, matrix[double] ema_var_upd,
+                    matrix[double] cache_mean, matrix[double] cache_var, 
matrix[double] cache_norm,
+                    matrix[double] X, matrix[double] gamma, matrix[double] 
beta,
+                    int C, int Hin, int Win, string mode,
+                    matrix[double] ema_mean, matrix[double] ema_var,
+                    double mu, double epsilon)
+      return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
+  /*
+   * Computes the backward pass for a spatial batch normalization layer.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of shape (N, C*Hin*Win).
+   *  - out: Outputs from the forward pass, of shape (N, C*Hin*Win).
+   *  - ema_mean_upd: Updated exponential moving average of the mean
+   *      from the forward pass, of shape (C, 1).
+   *  - ema_var_upd: Updated exponential moving average of the variance
+   *      from the forward pass, of shape (C, 1).
+   *  - cache_mean: Cache of the batch mean from the forward pass, of
+   *      shape (C, 1).  Note: This is used for performance during
+   *      training.
+   *  - cache_var: Cache of the batch variance from the forward pass,
+   *      of shape (C, 1).  Note: This is used for performance during
+   *      training.
+   *  - cache_norm: Cache of the normalized inputs from the forward
+   *      pass, of shape (C, N*Hin*Win).  Note: This is used for
+   *      performance during training.
+   *  - X: Input data matrix to the forward pass, of
+   *      shape (N, C*Hin*Win).
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - beta: Shift parameters, of shape (C, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - mode: 'train' or 'test' to indicate if the model is currently
+   *      being trained or tested.  During training, the current batch
+   *      mean and variance will be used to normalize the inputs, while
+   *      during testing, the exponential average of the mean and
+   *      variance over all previous batches will be used.
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (C, 1).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (C, 1).
+   *  - mu: Momentum value for moving averages.
+   *      Typical values are in the range of [0.9, 0.999].
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-5, 1e-3].
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of shape (N, C*Hin*Win).
+   *  - dgamma: Gradient wrt W, of shape (C, 1).
+   *  - dbeta: Gradient wrt b, of shape (C, 1).
+   *
+   */
+  N = nrow(X)
+  mean = cache_mean
+  var = cache_var
+  norm = cache_norm
+  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
+
+  if (mode == 'train') {
+    # Compute gradients during training
+    dgamma = util::channel_sums(norm*dout, C, Hin, Win)  # shape (C, 1)
+    dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
+    dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
+    dvar = util::channel_sums((-1/2) * bias_multiply(centered, 
(var+epsilon)^(-3/2)) * dnorm,
+                              C, Hin, Win)  # shape (C, 1)
+    dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, 
-1/sqrt(var+epsilon)), C, Hin, Win)
+    dmean_var_branch =  util::channel_sums((-2/(N*Hin*Win)) * centered, C, 
Hin, Win)
+    dmean_var_branch = dmean_var_branch * dvar  # we can't use a function 
within an expression yet
+    dmean = dmean_norm_branch + dmean_var_branch  # shape (C, 1)
+    dX_norm_branch = bias_multiply(dnorm, 1/sqrt(var+epsilon))
+    dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, 
cols=C*Hin*Win), dmean)
+    dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar)
+    dX = dX_norm_branch + dX_mean_branch + dX_var_branch  # shape (N, 
C*Hin*Win)
+  }
+  else {
+    # Compute gradients during testing
+    dgamma = util::channel_sums(norm*dout, C, Hin, Win)  # shape (C, 1)
+    dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
+    dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
+    dX = bias_multiply(dnorm, 1/sqrt(var+epsilon))  # shape (N, C*Hin*Win)
+  }
+}
+
+init = function(int C)
+    return (matrix[double] gamma, matrix[double] beta,
+            matrix[double] ema_mean, matrix[double] ema_var) {
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * Inputs:
+   *  - C: Number of input channels (dimensionality of input depth).
+   *
+   * Outputs:
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - beta: Shift parameters, of shape (C, 1).
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (C, 1).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (C, 1).
+   */
+   gamma = matrix(1, rows=C, cols=1)
+   beta = matrix(0, rows=C, cols=1)
+   ema_mean = matrix(0, rows=C, cols=1)
+   ema_var = matrix(1, rows=C, cols=1)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c944ad1d/scripts/staging/SystemML-NN/nn/test/grad_check.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/grad_check.dml 
b/scripts/staging/SystemML-NN/nn/test/grad_check.dml
index db923df..6b90d56 100644
--- a/scripts/staging/SystemML-NN/nn/test/grad_check.dml
+++ b/scripts/staging/SystemML-NN/nn/test/grad_check.dml
@@ -23,6 +23,7 @@
  * Gradient checks for various architectures.
  */
 source("nn/layers/affine.dml") as affine
+source("nn/layers/batch_norm.dml") as batch_norm
 source("nn/layers/conv.dml") as conv
 source("nn/layers/conv_builtin.dml") as conv_builtin
 source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
@@ -39,6 +40,7 @@ source("nn/layers/relu.dml") as relu
 source("nn/layers/rnn.dml") as rnn
 source("nn/layers/sigmoid.dml") as sigmoid
 source("nn/layers/softmax.dml") as softmax
+source("nn/layers/spatial_batch_norm.dml") as spatial_batch_norm
 source("nn/layers/tanh.dml") as tanh
 source("nn/test/conv_simple.dml") as conv_simple
 source("nn/test/max_pool_simple.dml") as max_pool_simple
@@ -161,6 +163,108 @@ affine = function() {
   }
 }
 
+batch_norm = function() {
+  /*
+   * Gradient check for the batch normalization layer.
+   */
+  print("Grad checking the batch normalization layer with L2 loss.")
+
+  # Generate data
+  N = 3 # num examples
+  D = 100 # num features
+  mu = 0.9  # momentum
+  eps = 1e-5  # epsilon
+  X = rand(rows=N, cols=D)
+  y = rand(rows=N, cols=D)
+  gamma = rand(rows=1, cols=D)
+  beta = rand(rows=1, cols=D)
+  ema_mean = rand(rows=1, cols=D)
+  ema_var = rand(rows=1, cols=D)
+  #[dummy, dummy, ema_mean, ema_var] = batch_norm::init(D)
+
+  # Check training & testing modes
+  for (i in 1:2) {
+    if (i == 1)
+      mode = 'train'
+    else
+      mode = 'test'
+    print(" - Grad checking the '"+mode+"' mode.")
+
+    # Compute analytical gradients of loss wrt parameters
+    [out, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+        batch_norm::forward(X, gamma, beta, mode, ema_mean, ema_var, mu, eps)
+    dout = l2_loss::backward(out, y)
+    [dX, dgamma, dbeta] = batch_norm::backward(dout, out, ema_mean_upd, 
ema_var_upd,
+                                               cache_mean, cache_var, 
cache_norm,
+                                               X, gamma, beta, mode, ema_mean, 
ema_var, mu, eps)
+
+    # Grad check
+    h = 1e-5
+    print("   - Grad checking X.")
+    for (i in 1:nrow(X)) {
+      for (j in 1:ncol(X)) {
+        # Compute numerical derivative
+        old = as.scalar(X[i,j])
+        X[i,j] = old - h
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm::forward(X, gamma, beta, mode, ema_mean, ema_var, mu, 
eps)
+        lossmh = l2_loss::forward(outmh, y)
+        X[i,j] = old + h
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm::forward(X, gamma, beta, mode, ema_mean, ema_var, mu, 
eps)
+        lossph = l2_loss::forward(outph, y)
+        X[i,j] = old  # reset
+        dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+        # Check error
+        rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+      }
+    }
+
+    print("   - Grad checking gamma.")
+    for (i in 1:nrow(gamma)) {
+      for (j in 1:ncol(gamma)) {
+        # Compute numerical derivative
+        old = as.scalar(gamma[i,j])
+        gamma[i,j] = old - h
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm::forward(X, gamma, beta, mode, ema_mean, ema_var, mu, 
eps)
+        lossmh = l2_loss::forward(outmh, y)
+        gamma[i,j] = old + h
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm::forward(X, gamma, beta, mode, ema_mean, ema_var, mu, 
eps)
+        lossph = l2_loss::forward(outph, y)
+        gamma[i,j] = old  # reset
+        dgamma_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+        # Check error
+        rel_error = check_rel_error(as.scalar(dgamma[i,j]), dgamma_num, 
lossph, lossmh)
+      }
+    }
+
+    print("   - Grad checking beta.")
+    for (i in 1:nrow(beta)) {
+      for (j in 1:ncol(beta)) {
+        # Compute numerical derivative
+        old = as.scalar(beta[i,j])
+        beta[i,j] = old - h
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm::forward(X, gamma, beta, mode, ema_mean, ema_var, mu, 
eps)
+        lossmh = l2_loss::forward(outmh, y)
+        beta[i,j] = old + h
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm::forward(X, gamma, beta, mode, ema_mean, ema_var, mu, 
eps)
+        lossph = l2_loss::forward(outph, y)
+        beta[i,j] = old  # reset
+        dbeta_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+        # Check error
+        rel_error = check_rel_error(as.scalar(dbeta[i,j]), dbeta_num, lossph, 
lossmh)
+      }
+    }
+  }
+}
+
 conv = function() {
   /*
    * Gradient check for the convolutional layer using `im2col`.
@@ -1203,6 +1307,118 @@ softmax = function() {
   }
 }
 
+spatial_batch_norm = function() {
+  /*
+   * Gradient check for the spatial batch normalization layer.
+   */
+  print("Grad checking the spatial batch normalization layer with L2 loss.")
+
+  # Generate data
+  N = 3 # num examples
+  N = 2  # num examples
+  C = 2  # num channels
+  Hin = 5  # input height
+  Win = 5  # input width
+  mu = 0.9  # momentum
+  eps = 1e-5  # epsilon
+  X = rand(rows=N, cols=C*Hin*Win)
+  y = rand(rows=N, cols=C*Hin*Win)
+  gamma = rand(rows=C, cols=1)
+  beta = rand(rows=C, cols=1)
+  ema_mean = rand(rows=C, cols=1)
+  ema_var = rand(rows=C, cols=1)
+  #[dummy, dummy, ema_mean, ema_var] = spatial_batch_norm::init(C)
+
+  # Check training & testing modes
+  for (i in 1:2) {
+    if (i == 1)
+      mode = 'train'
+    else
+      mode = 'test'
+    print(" - Grad checking the '"+mode+"' mode.")
+
+    # Compute analytical gradients of loss wrt parameters
+    [out, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+        spatial_batch_norm::forward(X, gamma, beta, C, Hin, Win, mode, 
ema_mean, ema_var, mu, eps)
+    dout = l2_loss::backward(out, y)
+    [dX, dgamma, dbeta] = spatial_batch_norm::backward(dout, out, 
ema_mean_upd, ema_var_upd,
+                                                       cache_mean, cache_var, 
cache_norm,
+                                                       X, gamma, beta, C, Hin, 
Win, mode,
+                                                       ema_mean, ema_var, mu, 
eps)
+
+    # Grad check
+    h = 1e-5
+    print("   - Grad checking X.")
+    for (i in 1:nrow(X)) {
+      for (j in 1:ncol(X)) {
+        # Compute numerical derivative
+        old = as.scalar(X[i,j])
+        X[i,j] = old - h
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            spatial_batch_norm::forward(X, gamma, beta, C, Hin, Win, mode,
+                                        ema_mean, ema_var, mu, eps)
+        lossmh = l2_loss::forward(outmh, y)
+        X[i,j] = old + h
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            spatial_batch_norm::forward(X, gamma, beta, C, Hin, Win, mode,
+                                        ema_mean, ema_var, mu, eps)
+        lossph = l2_loss::forward(outph, y)
+        X[i,j] = old  # reset
+        dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+        # Check error
+        rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+      }
+    }
+
+    print("   - Grad checking gamma.")
+    for (i in 1:nrow(gamma)) {
+      for (j in 1:ncol(gamma)) {
+        # Compute numerical derivative
+        old = as.scalar(gamma[i,j])
+        gamma[i,j] = old - h
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            spatial_batch_norm::forward(X, gamma, beta, C, Hin, Win, mode,
+                                        ema_mean, ema_var, mu, eps)
+        lossmh = l2_loss::forward(outmh, y)
+        gamma[i,j] = old + h
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            spatial_batch_norm::forward(X, gamma, beta, C, Hin, Win, mode,
+                                        ema_mean, ema_var, mu, eps)
+        lossph = l2_loss::forward(outph, y)
+        gamma[i,j] = old  # reset
+        dgamma_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+        # Check error
+        rel_error = check_rel_error(as.scalar(dgamma[i,j]), dgamma_num, 
lossph, lossmh)
+      }
+    }
+
+    print("   - Grad checking beta.")
+    for (i in 1:nrow(beta)) {
+      for (j in 1:ncol(beta)) {
+        # Compute numerical derivative
+        old = as.scalar(beta[i,j])
+        beta[i,j] = old - h
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            spatial_batch_norm::forward(X, gamma, beta, C, Hin, Win, mode,
+                                        ema_mean, ema_var, mu, eps)
+        lossmh = l2_loss::forward(outmh, y)
+        beta[i,j] = old + h
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            spatial_batch_norm::forward(X, gamma, beta, C, Hin, Win, mode,
+                                        ema_mean, ema_var, mu, eps)
+        lossph = l2_loss::forward(outph, y)
+        beta[i,j] = old  # reset
+        dbeta_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+        # Check error
+        rel_error = check_rel_error(as.scalar(dbeta[i,j]), dbeta_num, lossph, 
lossmh)
+      }
+    }
+  }
+}
+
 tanh = function() {
   /*
    * Gradient check for the hyperbolic tangent (tanh) nonlinearity layer.

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c944ad1d/scripts/staging/SystemML-NN/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/test.dml 
b/scripts/staging/SystemML-NN/nn/test/test.dml
index dd24a55..b25fae2 100644
--- a/scripts/staging/SystemML-NN/nn/test/test.dml
+++ b/scripts/staging/SystemML-NN/nn/test/test.dml
@@ -22,16 +22,51 @@
 /*
  * Various tests, not including gradient checks.
  */
+source("nn/layers/batch_norm.dml") as batch_norm
 source("nn/layers/conv.dml") as conv
 source("nn/layers/conv_builtin.dml") as conv_builtin
 source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
 source("nn/layers/max_pool.dml") as max_pool
 source("nn/layers/max_pool_builtin.dml") as max_pool_builtin
+source("nn/layers/spatial_batch_norm.dml") as spatial_batch_norm
 source("nn/layers/tanh.dml") as tanh
 source("nn/test/conv_simple.dml") as conv_simple
 source("nn/test/max_pool_simple.dml") as max_pool_simple
 source("nn/util.dml") as util
 
+batch_norm = function() {
+  /*
+   * Test for the `batch_norm` function.
+   */
+  print("Testing the batch_norm function.")
+
+  # Generate data
+  N = 4  # Number of examples
+  D = 4  # Number of features
+  mode = 'train'  # execution mode
+  mu = 0.9  # momentum of moving averages
+  eps = 1e-5  # smoothing term
+  X = matrix(seq(1,16), rows=N, cols=D)
+
+  # Create layer
+  [gamma, beta, ema_mean, ema_var] = batch_norm::init(D)
+
+  # Forward
+  [out, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+      batch_norm::forward(X, gamma, beta, mode, ema_mean, ema_var, mu, eps)
+
+  # Equivalency check
+  target = matrix("-1.34160721 -1.34160721 -1.34160733 -1.34160709
+                   -0.44720244 -0.44720244 -0.44720244 -0.44720232
+                    0.44720244  0.44720232  0.44720244  0.44720244
+                    1.34160733  1.34160721  1.34160733  1.34160733", rows=1, 
cols=N*D)
+  out = matrix(out, rows=1, cols=N*D)
+  for (i in 1:length(out)) {
+    rel_error = util::check_rel_error(as.scalar(out[1,i]),
+                                      as.scalar(target[1,i]), 1e-3, 1e-4)
+  }
+}
+
 conv = function() {
   /*
    * Test for the `conv` functions.
@@ -189,8 +224,8 @@ max_pool = function() {
   for (padh in 0:3) {
     for (padw in 0:3) {
       print(" - Testing w/ padh="+padh+" & padw="+padw+".")
-      if (1==1) {}  # force correct printing
-      print("   - Testing forward")
+      #if (1==1) {}  # force correct printing
+      #print("   - Testing forward")
       [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, 
stride, padh, padw)
       [out_simple, Hout_simple, Wout_simple] = max_pool_simple::forward(X, C, 
Hin, Win, Hf, Wf,
                                                                         
stride, stride, padh, padw)
@@ -209,7 +244,7 @@ max_pool = function() {
                                           as.scalar(out_builtin[1,i]), 1e-10, 
1e-12)
       }
 
-      print("   - Testing backward")
+      #print("   - Testing backward")
       dout = rand(rows=N, cols=C*Hout*Wout, pdf="normal")
       dX = max_pool::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, 
stride, stride, padh, padw)
       dX_simple = max_pool_simple::backward(dout, Hout_simple, Wout_simple, X, 
C, Hin, Win, Hf, Wf,
@@ -388,6 +423,97 @@ max_pool = function() {
   tmp = util::check_all_equal(out_builtin, target)
 }
 
+spatial_batch_norm = function() {
+  /*
+   * Test for the `spatial_batch_norm` function.
+   */
+  print("Testing the spatial_batch_norm function.")
+
+  # Generate data
+  N = 2  # Number of examples
+  C = 3  # num channels
+  Hin = 4  # input height
+  Win = 5  # input width
+  mode = 'train'  # execution mode
+  mu = 0.9  # momentum of moving averages
+  eps = 1e-5  # smoothing term
+  X = matrix("70  29 23 55 72
+              42  98 68 48 39
+              34  73 44  6 40
+              74  18 18 53 53
+
+              63  85 72 61 72
+              32  36 23 29 63
+               9  43 43 49 43
+              31  43 89 94 50
+
+              62  12 32 41 87
+              25  48 99 52 61
+              12  83 60 55 34
+              30  42 68 88 51
+
+
+              67  59 62 67 84
+               8  76 24 19 57
+              10  89 63 72  2
+              59  56 16 15 70
+
+              32  69 55 39 93
+              84  36  4 30 40
+              70 100 36 76 59
+              69  15 40 24 34
+
+              51  67 11 13 32
+              66  85 55 85 38
+              32  35 17 83 34
+              55  58 52  0 99", rows=N, cols=C*Hin*Win)
+
+  # Create layer
+  [gamma, beta, ema_mean, ema_var] = spatial_batch_norm::init(C)
+
+  # Forward
+  [out, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+      spatial_batch_norm::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, 
ema_var, mu, eps)
+
+  # Equivalency check
+  target = matrix("0.86215019 -0.76679718 -1.00517964  0.26619387  0.94161105
+                  -0.25030172  1.97460198  0.78268933 -0.01191914 -0.36949289
+                  -0.56814504  0.98134136 -0.17084086 -1.68059683 -0.32976246
+                   1.02107191 -1.20383179 -1.20383179  0.18673301  0.18673301
+
+                   0.50426388  1.41921711  0.87856293  0.42108631  0.87856293
+                  -0.78498828 -0.61863315 -1.15928721 -0.90975463  0.50426388
+                  -1.74153018 -0.32751167 -0.32751167 -0.07797909 -0.32751167
+                  -0.82657707 -0.32751167  1.58557224  1.79351616 -0.0363903
+
+                   0.4607178  -1.49978399 -0.71558321 -0.36269283  1.44096887
+                  -0.99005347 -0.08822262  1.91148913  0.06861746  0.42150795
+                  -1.49978399  1.28412855  0.38229787  0.18624771 -0.63716316
+                  -0.79400325 -0.32348287  0.69597805  1.48017895  0.0294075
+
+
+                   0.74295878  0.42511559  0.54430676  0.74295878  1.41837597
+                  -1.60113597  1.10053277 -0.96544927 -1.16410136  0.34565473
+                  -1.52167511  1.61702824  0.5840373   0.94161105 -1.83951855
+                   0.42511559  0.30592418 -1.28329265 -1.32302308  0.86215019
+
+                  -0.78498828  0.75379658  0.17155361 -0.4938668   1.75192738
+                   1.37762833 -0.61863315 -1.9494741  -0.86816585 -0.45227802
+                   0.79538536  2.04304862 -0.61863315  1.04491806  0.33790874
+                   0.75379658 -1.49199748 -0.45227802 -1.11769855 -0.70181072
+
+                   0.0294075   0.65676796 -1.53899395 -1.46057391 -0.71558321
+                   0.61755812  1.36254871  0.18624771  1.36254871 -0.48032296
+                  -0.71558321 -0.59795308 -1.30373383  1.28412855 -0.63716316
+                   0.18624771  0.30387771  0.06861746 -1.97030437  
1.91148913", rows=1,
+                                                                               
 cols=N*C*Hin*Win)
+  out = matrix(out, rows=1, cols=N*C*Hin*Win)
+  for (i in 1:length(out)) {
+    rel_error = util::check_rel_error(as.scalar(out[1,i]),
+                                      as.scalar(target[1,i]), 1e-3, 1e-4)
+  }
+}
+
 tanh = function() {
   /*
    * Test for the `tanh` forward function.

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c944ad1d/scripts/staging/SystemML-NN/nn/test/tests.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/tests.dml 
b/scripts/staging/SystemML-NN/nn/test/tests.dml
index fd897ef..86bb77b 100644
--- a/scripts/staging/SystemML-NN/nn/test/tests.dml
+++ b/scripts/staging/SystemML-NN/nn/test/tests.dml
@@ -29,11 +29,15 @@ print("")
 print("Starting grad checks.")
 print("---")
 
+# Loss functions
 tmp = grad_check::cross_entropy_loss()
 tmp = grad_check::l1_loss()
 tmp = grad_check::l2_loss()
 tmp = grad_check::log_loss()
+
+# Other layers
 tmp = grad_check::affine()
+tmp = grad_check::batch_norm()
 tmp = grad_check::conv_simple()
 tmp = grad_check::conv()
 tmp = grad_check::conv_builtin()
@@ -48,7 +52,10 @@ tmp = grad_check::relu()
 tmp = grad_check::rnn()
 tmp = grad_check::sigmoid()
 tmp = grad_check::softmax()
+tmp = grad_check::spatial_batch_norm()
 tmp = grad_check::tanh()
+
+# Example model
 tmp = grad_check::two_layer_affine_l2_net()
 
 print("---")
@@ -62,11 +69,13 @@ print("")
 print("Starting other tests.")
 print("---")
 
+tmp = test::batch_norm()
 tmp = test::im2col()
 tmp = test::padding()
 tmp = test::conv()
 tmp = test::cross_entropy_loss()
 tmp = test::max_pool()
+tmp = test::spatial_batch_norm()
 tmp = test::tanh()
 
 print("---")

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c944ad1d/scripts/staging/SystemML-NN/nn/util.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/util.dml 
b/scripts/staging/SystemML-NN/nn/util.dml
index fdf0f76..dd0ac19 100644
--- a/scripts/staging/SystemML-NN/nn/util.dml
+++ b/scripts/staging/SystemML-NN/nn/util.dml
@@ -111,6 +111,25 @@ check_rel_error = function(double x1, double x2, double 
thresh_error, double thr
   }
 }
 
+channel_sums = function(matrix[double] X, int C, int Hin, int Win)
+    return (matrix[double] out) {
+  /*
+   * Computes a channel-wise summation over a 4D input.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (C, 1).
+   */
+  # Here we sum each column, reshape to (C, Hin*Win), and sum each row to 
result in the summation
+  # for each channel.
+  out = rowSums(matrix(colSums(X), rows=C, cols=Hin*Win))  # shape (C, 1)
+}
+
 im2col = function(matrix[double] img, int Hin, int Win, int Hf, int Wf, int 
strideh, int stridew)
     return (matrix[double] img_cols) {
   /*


Reply via email to