[SYSTEMML-1675] Add a new 2D depthwise transpose convolution layer

This adds a new 2D depthwise transpose convolution layer.  A depthwise
transpose convolution (1) applies a different filter to each unique
group of M input channels separately, thus condensing each group of M
input channels to 1 output channel, and (2) concatenates the results
into a single volume with C/M output channels.  This is in contrast to
a regular 2D transpose convolution, in which all of the filters would be
applied to all of the input channels at once.

In addition to the new layer, this also adds the associated unit and
gradient tests.

Closes #542.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c83e99af
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c83e99af
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c83e99af

Branch: refs/heads/master
Commit: c83e99af755100e37f3e4fbfed20b5c455a635d5
Parents: f2d975f
Author: Mike Dusenberry <[email protected]>
Authored: Mon Jun 19 13:52:59 2017 -0700
Committer: Mike Dusenberry <[email protected]>
Committed: Mon Jun 19 13:52:59 2017 -0700

----------------------------------------------------------------------
 .../nn/layers/conv2d_transpose_depthwise.dml    | 198 +++++++++++++++++++
 scripts/nn/test/grad_check.dml                  | 104 ++++++++++
 scripts/nn/test/run_tests.dml                   |   2 +
 scripts/nn/test/test.dml                        |  59 ++++++
 4 files changed, 363 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c83e99af/scripts/nn/layers/conv2d_transpose_depthwise.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/conv2d_transpose_depthwise.dml 
b/scripts/nn/layers/conv2d_transpose_depthwise.dml
new file mode 100644
index 0000000..fdd7c10
--- /dev/null
+++ b/scripts/nn/layers/conv2d_transpose_depthwise.dml
@@ -0,0 +1,198 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * 2D Depthwise Transpose Convolutional layer.
+ *
+ * Utilizes built-in convolution operators for higher performance.
+ */
+source("nn/util.dml") as util
+
+forward = function(matrix[double] X, matrix[double] W, matrix[double] b,
+                   int C, int Hin, int Win, int M, int Hf, int Wf,
+                   int strideh, int stridew, int padh, int padw,
+                   int out_padh, int out_padw)
+    return (matrix[double] out, int Hout, int Wout){
+  /*
+   * Computes the forward pass for a 2D depthwise spatial transpose
+   * convolutional layer with C/M filters of depth M.  The input data
+   * has N examples, each represented as a 3D volume with C channels
+   * unrolled into a single vector.  For each group of M input channels,
+   * a 2D transpose convolution is applied with 1 unique filter,
+   * yielding 1 output channel per input group of M input channels.
+   * The resulting C/M separate output channels are then concatenated
+   * together channel-wise into a single volume of C/M output channels.
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - W: Weights, of shape (C/M, M*Hf*Wf).
+   *  - b: Biases, of shape (C/M, 1).
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - M: Depth of each filter (C must be divisible by M).
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *  - out_padh: extra padding for top side. This should
+   *      lie in [0, strideh-1].
+   *  - out_padw: extra padding for right side. This should
+   *      lie in [0, stridew-1].
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, C/M*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   */
+  N = nrow(X)
+  F = nrow(W)
+  Hout = strideh*(Hin-1) - 2*padh + Hf + out_padh
+  Wout = stridew*(Win-1) - 2*padw + Wf + out_padw
+
+  # create output volume
+  out = matrix(0, rows=N, cols=C/M*Hout*Wout)
+
+  # depthwise transpose convolution
+  # TODO: Explore usage of parfor loops more to determine if they can provide 
a performance
+  # benefit.  Initial tests show that they are slower than the regular for 
loop, likely because
+  # they cause a reduction from a multithreaded conv2d op to a singlethreaded 
version.  For a
+  # number of filters C/M >> the number of examples, it's possible that the 
parfor loop could be
+  # faster.
+  #parfor (f in 1:F, check=0) {  # each channel
+  for (f in 1:F) {
+    # compute gradient wrt data of conv2d using 1 filter and M input channels
+    w = matrix(W[f,], rows=M, cols=Hf*Wf)  # 1 filter, of shape (M, 1*Hf*Wf)
+    Xm = X[,((f-1)*M*Hin*Win + 1):f*M*Hin*Win]  # M input channels, of shape 
(N, M*Hin*Win)
+    outm = conv2d_backward_data(w, Xm, stride=[strideh,stridew], 
padding=[padh,padw],
+                                input_shape=[N,1,Hout,Wout], 
filter_shape=[M,1,Hf,Wf])
+
+    # store
+    out[,((f-1)*Hout*Wout + 1):f*Hout*Wout] = outm  # outm has shape (N, 
1*Hout*Wout)
+  }
+
+  # add bias term to each output filter
+  out = bias_add(out, b)
+}
+
+backward = function(matrix[double] dout, int Hout, int Wout,
+                    matrix[double] X, matrix[double] W, matrix[double] b,
+                    int C, int Hin, int Win, int M, int Hf, int Wf,
+                    int strideh, int stridew, int padh, int padw)
+    return (matrix[double] dX, matrix[double] dW, matrix[double] db){
+  /*
+   * Computes the backward pass for a 2D spatial transpose
+   * convolutional layer with F filters.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of
+   *      shape (N, C/M*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - W: Weights, of shape (C/M, M*Hf*Wf).
+   *  - b: Biases, of shape (C/M, 1).
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - M: Depth of each filter (C must be divisible by M).
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+   *  - dW: Gradient wrt `W`, of shape (C/M, M*Hf*Wf).
+   *  - db: Gradient wrt `b`, of shape (C/M, 1).
+   */
+  N = nrow(X)
+  F = nrow(W)
+
+  # create gradient volumes
+  dX = matrix(0, rows=N, cols=C*Hin*Win)
+  dW = matrix(0, rows=C/M, cols=M*Hf*Wf)
+  db = matrix(0, rows=C/M, cols=1)
+
+  # depthwise transpose convolution
+  for (f in 1:F) {
+    # extract 1 gradient channel, 1 depth-1 filter, and M input channels, 
since the forward pass
+    # maps M input channels to 1 output channel for each filter
+    doutf = dout[,((f-1)*Hout*Wout + 1):f*Hout*Wout]  # shape (N, 1*Hout*Wout)
+    w = matrix(W[f,], rows=M, cols=Hf*Wf)  # 1 filter, of shape (M, 1*Hf*Wf)
+    Xm = X[,((f-1)*M*Hin*Win + 1):f*M*Hin*Win]  # M input channels, of shape 
(N, M*Hin*Win)
+
+    # compute gradients:
+    # conv2d_backward_filter takes the input and gradient wrt the output
+    # as first and second args, respectively. Given that we need to
+    # compute the grad wrt to filter for transpose convolution, where
+    # the roles of the input and output are reversed, we reverse the
+    # order of the args (along with setting input_shape to the dout
+    # shape).
+    dw = conv2d_backward_filter(doutf, Xm, stride=[strideh,stridew], 
padding=[padh,padw],
+                                input_shape=[N,1,Hout,Wout], 
filter_shape=[M,1,Hf,Wf])
+    # Since the forward for transpose convolution makes a call to
+    # conv2d_backward_data, to compute its derivative wrt to data
+    # we can run conv2d by applying the filter on the grad wrt the
+    # output (this makes sense because convolution transpose is the
+    # 'reverse' of convolution). It's easy to see that this will produce
+    # an output of the required size.
+    dXm = conv2d(doutf, w, input_shape=[N,1,Hout,Wout], 
filter_shape=[M,1,Hf,Wf],
+                 stride=[strideh,stridew], padding=[padh,padw])
+
+    # store
+    dX[,((f-1)*M*Hin*Win + 1):f*M*Hin*Win] = dXm
+    dW[f,] = matrix(dw, rows=1, cols=M*Hf*Wf)
+  }
+
+  # partial derivatives for bias vector
+  db = util::channel_sums(dout, C/M, Hout, Wout)
+}
+
+init = function(int C, int M, int Hf, int Wf)
+    return (matrix[double] W, matrix[double] b){
+  /*
+   * Utility function to initialize the parameters of this layer.
+   *
+   * We use the heuristic by He et al., which limits the magnification
+   * of inputs/gradients during forward/backward passes by scaling
+   * unit-Gaussian weights by a factor of sqrt(2/n), under the
+   * assumption of relu neurons.
+   *  - http://arxiv.org/abs/1502.01852
+   *
+   * Inputs:
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - M: Depth of each filter (C must be divisible by M).
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *
+   * Outputs:
+   *  - W: Weights, of shape (C/M, M*Hf*Wf).
+   *  - b: Biases, of shape (C/M, 1).
+   */
+  W = rand(rows=C/M, cols=M*Hf*Wf, pdf="normal") * sqrt(2/(M*Hf*Wf))
+  b = matrix(0, rows=C/M, cols=1)
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/c83e99af/scripts/nn/test/grad_check.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/grad_check.dml b/scripts/nn/test/grad_check.dml
index 67aeac1..fcb45cd 100644
--- a/scripts/nn/test/grad_check.dml
+++ b/scripts/nn/test/grad_check.dml
@@ -29,6 +29,7 @@ source("nn/layers/conv2d.dml") as conv2d
 source("nn/layers/conv2d_builtin.dml") as conv2d_builtin
 source("nn/layers/conv2d_depthwise.dml") as conv2d_depthwise
 source("nn/layers/conv2d_transpose.dml") as conv2d_transpose
+source("nn/layers/conv2d_transpose_depthwise.dml") as 
conv2d_transpose_depthwise
 source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
 source("nn/layers/dropout.dml") as dropout
 source("nn/layers/l1_loss.dml") as l1_loss
@@ -809,6 +810,109 @@ conv2d_transpose = function() {
   }
 }
 
+conv2d_transpose_depthwise = function() {
+  /*
+   * Gradient check for the 2D depthwise transpose convolutional layer.
+   */
+  print("Grad checking the 2D depthwise transpose convolutional layer with L2 
loss.")
+
+  # Generate data
+  N = 2  # num examples
+  C = 8  # num channels
+  Hin = 3  # input height
+  Win = 3  # input width
+  M = 4  # depth of filters
+  Hf = 3  # filter height
+  Wf = 3  # filter width
+  stride = 2
+  pad = 1
+  out_pad = 1
+  X = rand(rows=N, cols=C*Hin*Win)
+
+  # Create layers
+  [W, b] = conv2d_transpose_depthwise::init(C, M, Hf, Wf)
+
+  # Compute analytical gradients of loss wrt parameters
+  [out, Hout, Wout] = conv2d_transpose_depthwise::forward(X, W, b, C, Hin, 
Win, M, Hf, Wf,
+                                                          stride, stride, pad, 
pad,
+                                                          out_pad, out_pad)
+  y = rand(rows=N, cols=C/M*Hout*Wout)
+  dout = l2_loss::backward(out,y)
+  [dX, dW, db] = conv2d_transpose_depthwise::backward(dout, Hout, Wout, X, W, 
b, C, Hin, Win, M,
+                                                      Hf, Wf, stride, stride, 
pad, pad)
+
+  # Grad check
+  h = 1e-5
+  print(" - Grad checking X.")
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      [outmh, Hout, Wout] = conv2d_transpose_depthwise::forward(X, W, b, C, 
Hin, Win, M, Hf, Wf,
+                                                                stride, 
stride, pad, pad,
+                                                                out_pad, 
out_pad)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      [outph, Hout, Wout] = conv2d_transpose_depthwise::forward(X, W, b, C, 
Hin, Win, M, Hf, Wf,
+                                                                stride, 
stride, pad, pad,
+                                                                out_pad, 
out_pad)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+      # Check error
+      rel_error = test_util::check_rel_grad_error(as.scalar(dX[i,j]), dX_num, 
lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking W.")
+  for (i in 1:nrow(W)) {
+    for (j in 1:ncol(W)) {
+      # Compute numerical derivative
+      old = as.scalar(W[i,j])
+      W[i,j] = old - h
+      [outmh, Hout, Wout] = conv2d_transpose_depthwise::forward(X, W, b, C, 
Hin, Win, M, Hf, Wf,
+                                                                stride, 
stride, pad, pad,
+                                                                out_pad, 
out_pad)
+      lossmh = l2_loss::forward(outmh, y)
+      W[i,j] = old + h
+      [outph, Hout, Wout] = conv2d_transpose_depthwise::forward(X, W, b, C, 
Hin, Win, M, Hf, Wf,
+                                                                stride, 
stride, pad, pad,
+                                                                out_pad, 
out_pad)
+      lossph = l2_loss::forward(outph, y)
+      W[i,j] = old  # reset
+      dW_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+      # Check error
+      rel_error = test_util::check_rel_grad_error(as.scalar(dW[i,j]), dW_num, 
lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking b.")
+  for (i in 1:nrow(b)) {
+    for (j in 1:ncol(b)) {
+      # Compute numerical derivative
+      old = as.scalar(b[i,j])
+      b[i,j] = old - h
+      [outmh, Hout, Wout] = conv2d_transpose_depthwise::forward(X, W, b, C, 
Hin, Win, M, Hf, Wf,
+                                                                stride, 
stride, pad, pad,
+                                                                out_pad, 
out_pad)
+      lossmh = l2_loss::forward(outmh, y)
+      b[i,j] = old + h
+      [outph, Hout, Wout] = conv2d_transpose_depthwise::forward(X, W, b, C, 
Hin, Win, M, Hf, Wf,
+                                                                stride, 
stride, pad, pad,
+                                                                out_pad, 
out_pad)
+      lossph = l2_loss::forward(outph, y)
+      b[i,j] = old  # reset
+      db_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+      # Check error
+      rel_error = test_util::check_rel_grad_error(as.scalar(db[i,j]), db_num, 
lossph, lossmh)
+    }
+  }
+}
+
 cross_entropy_loss = function() {
   /*
    * Gradient check for the cross-entropy loss function.

http://git-wip-us.apache.org/repos/asf/systemml/blob/c83e99af/scripts/nn/test/run_tests.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml
index ec6fcff..5f3ca6e 100644
--- a/scripts/nn/test/run_tests.dml
+++ b/scripts/nn/test/run_tests.dml
@@ -47,6 +47,7 @@ grad_check::conv2d_builtin()
 grad_check::conv2d_simple()
 grad_check::conv2d_depthwise()
 grad_check::conv2d_transpose()
+grad_check::conv2d_transpose_depthwise()
 grad_check::dropout()
 grad_check::lstm()
 grad_check::max_pool2d()
@@ -89,6 +90,7 @@ test::batch_norm2d()
 test::conv2d()
 test::conv2d_depthwise()
 test::conv2d_transpose()
+test::conv2d_transpose_depthwise()
 test::cross_entropy_loss()
 test::im2col()
 test::max_pool2d()

http://git-wip-us.apache.org/repos/asf/systemml/blob/c83e99af/scripts/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml
index c60aab9..37f9f73 100644
--- a/scripts/nn/test/test.dml
+++ b/scripts/nn/test/test.dml
@@ -28,6 +28,7 @@ source("nn/layers/conv2d.dml") as conv2d
 source("nn/layers/conv2d_builtin.dml") as conv2d_builtin
 source("nn/layers/conv2d_depthwise.dml") as conv2d_depthwise
 source("nn/layers/conv2d_transpose.dml") as conv2d_transpose
+source("nn/layers/conv2d_transpose_depthwise.dml") as 
conv2d_transpose_depthwise
 source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
 source("nn/layers/max_pool2d.dml") as max_pool2d
 source("nn/layers/max_pool2d_builtin.dml") as max_pool2d_builtin
@@ -326,6 +327,64 @@ conv2d_transpose = function() {
   }
 }
 
+conv2d_transpose_depthwise = function() {
+  /*
+   * Test for the 2D depthwise transpose convolution function.
+   */
+  print("Testing the 2D depthwise transpose convolution function.")
+
+  # Generate data
+  N = 2  # num examples
+  C = 4  # num channels
+  Hin = 2  # input height
+  Win = 2  # input width
+  M = 2  # depth of each filter
+  Hf = 3  # filter height
+  Wf = 3  # filter width
+  stride = 1
+  pad = 0
+  out_pad = 0  # padding added to output
+  X = matrix(seq(1,N*C*Hin*Win), rows=N, cols=C*Hin*Win) / (N*C*Hin*Win) * 2 - 
1  # normalized
+
+  # Create layer
+  W = matrix(seq(1,C/M*M*Hf*Wf), rows=C/M, cols=M*Hf*Wf) / (C/M*M*Hf*Wf) * 2 - 
1  # normalized
+  b = matrix(seq(1,C/M), rows=C/M, cols=1) / (C/M)^2  # non-zero & non-one
+
+  # Forward
+  [out, Hout, Wout] = conv2d_transpose_depthwise::forward(X, W, b, C, Hin, 
Win, M, Hf, Wf,
+                                                          stride, stride, pad, 
pad,
+                                                          out_pad, out_pad)
+
+  # Equivalency check
+  target = matrix("1.44097221  2.45486116  2.28125     1.1875
+                   2.1875      3.80555558  3.48611116  1.72916663
+                   1.6875      2.84722233  2.52777767  1.27083325
+                   0.80902779  1.24652779  1.10069442  0.625
+
+                   0.37152776  0.24652773  0.18402778  0.35416669
+                   0.21527778 -0.02777781 -0.12500003  0.22916666
+                   0.04861115 -0.31944442 -0.41666669  0.10416666
+                   0.32291669  0.20486113  0.1701389   0.375
+
+
+                   0.05208334 -0.21180555 -0.16319445  0.02083334
+                  -0.25694442 -0.8611111  -0.7361111  -0.27083331
+                  -0.09027778 -0.4861111  -0.3611111  -0.0625
+                   0.08680556 -0.08680557 -0.01041669  0.125
+
+                   0.98263896  1.57986116  1.73958337  1.1875
+                   1.77083337  3.30555558  3.65277791  2.22916675
+                   2.27083325  4.34722233  4.69444466  2.77083349
+                   1.60069442  2.87152767  3.05902767  1.875     ", rows=N, 
cols=C/M*Hout*Wout)
+
+  for (i in 1:nrow(out)) {
+    for(j in 1:ncol(out)) {
+      rel_error = test_util::check_rel_error(as.scalar(out[i,j]),
+                                             as.scalar(target[i,j]), 1e-3, 
1e-4)
+    }
+  }
+}
+
 cross_entropy_loss = function() {
   /*
    * Test for the cross-entropy loss function.

Reply via email to