Repository: systemml Updated Branches: refs/heads/master 009561384 -> 17838a3d3
[SYSTEMML-1686] Fix 2D transpose convolution filter shape & input size Currently, the transpose conv2d layer (`nn/layers/conv2d_tranpose.dml`) has a bug in which the filters tensor `W` has an incorrect shape, and the `conv2d_backward_data` op has an incorrect input shape argument. This results in an exception when the number of input channels `C` is not equal to the number of filters `F` (i.e. number of output channels). Since the transpose conv2d op is the gradient of the conv2d op, the filter tensor needs to have the shape `(C, F, Hf, Wf)` for `F` filters, rather than `(F, C, Hf, Wf)`, in order to map from an input with `C` channels to an output with `F` channels during the input data gradient function (`conv2d_backward_data`) that is used in the forward pass. Additionally, the `input_shape` argument for `conv2d_backward_data` needs to be `(N, F, Hout, Wout)`, rather than `(N, C, Hout, Wout)` in order to map from an input with `C` channels to an output with `F` channels. Our previous test cases did not catch this issue because the tests used `C = F = 1`. The motivation here is to think about the shape of `W` during a regular 2D conv function, i.e. `(F, C, Hf, Wf)`. In the forward *and* backward pass of regular 2D conv, `W` maintains the same `(F, C, Hf, Wf)`shape. Specifically, in the backward pass, the data gradient function `conv2d_backward_data` accepts the same `W` of shape `(F, C, Hf, Wf)` to map from gradients wrt the output of shape `(N, F, Hout, Wout)` to gradients wrt to the input of shape `(N, C, Hin, Win)`. In a 2D transpose conv function, we use the 2D conv data gradient function as the forward function. If we use the same size terminology as from regular 2D conv, we would be mapping from an input with `F` channels to an output with `C` channels (which seems backward, but is correct since the forward pass of 2D transpose conv is defined as the backward pass of 2D conv), and we would be using the same filter `W` of shape `(F, C, Hf, Wf)`. Now, in order to make the terminology of the 2D transpose conv layer consistent with the rest of the library, we would need to swap the variable names `F` and `C` so that we are mapping from an input with `C` channels to an output with `F` channels. Thus, we would use a filter `W` of shape `(C, F, Hf, Wf`). This commit fixes these issues, and improves the robustness of the associated tests. Closes #541. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/324bea5d Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/324bea5d Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/324bea5d Branch: refs/heads/master Commit: 324bea5d95c0ce4c7fe98e348b2a98ca0e888d84 Parents: 0095613 Author: Mike Dusenberry <[email protected]> Authored: Thu Jun 15 15:20:37 2017 -0700 Committer: Mike Dusenberry <[email protected]> Committed: Thu Jun 15 15:20:37 2017 -0700 ---------------------------------------------------------------------- scripts/nn/layers/conv2d_builtin.dml | 2 +- scripts/nn/layers/conv2d_transpose.dml | 32 ++++++------- scripts/nn/test/grad_check.dml | 70 ++++++++++++----------------- scripts/nn/test/test.dml | 63 +++++++++++++++++--------- 4 files changed, 88 insertions(+), 79 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/324bea5d/scripts/nn/layers/conv2d_builtin.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/layers/conv2d_builtin.dml b/scripts/nn/layers/conv2d_builtin.dml index bda7a9c..6b066eb 100644 --- a/scripts/nn/layers/conv2d_builtin.dml +++ b/scripts/nn/layers/conv2d_builtin.dml @@ -123,7 +123,7 @@ backward = function(matrix[double] dout, int Hout, int Wout, # Partial derivatives for convolution - built-in implementation dW = conv2d_backward_filter(X, dout, stride=[strideh,stridew], padding=[padh,padw], input_shape=[N,C,Hin,Win], filter_shape=[F,C,Hf,Wf]) - dX = conv2d_backward_data(W, dout, stride=[strideh, stridew], padding=[padh,padw], + dX = conv2d_backward_data(W, dout, stride=[strideh,stridew], padding=[padh,padw], input_shape=[N,C,Hin,Win], filter_shape=[F,C,Hf,Wf]) # Partial derivatives for bias vector http://git-wip-us.apache.org/repos/asf/systemml/blob/324bea5d/scripts/nn/layers/conv2d_transpose.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/layers/conv2d_transpose.dml b/scripts/nn/layers/conv2d_transpose.dml index ecd7a1c..0838563 100644 --- a/scripts/nn/layers/conv2d_transpose.dml +++ b/scripts/nn/layers/conv2d_transpose.dml @@ -24,7 +24,7 @@ * * Utilizes built-in convolution operators for higher performance. */ - + forward = function(matrix[double] X, matrix[double] W, matrix[double] b, int C, int Hin, int Win, int Hf, int Wf, int strideh, int stridew, int padh, int padw, @@ -37,7 +37,7 @@ forward = function(matrix[double] X, matrix[double] W, matrix[double] b, * * Inputs: * - X: Inputs, of shape (N, C*Hin*Win). - * - W: Weights, of shape (F, C*Hf*Wf). + * - W: Weights, of shape (C, F*Hf*Wf). * - b: Biases, of shape (F, 1). * - C: Number of input channels (dimensionality of depth). * - Hin: Input height. @@ -48,7 +48,7 @@ forward = function(matrix[double] X, matrix[double] W, matrix[double] b, * - stridew: Stride over width. * - padh: Padding for top and bottom sides. * - padw: Padding for left and right sides. - * - out_padh: extra padding for top side. This should + * - out_padh: extra padding for top side. This should * lie in [0, strideh-1]. * - out_padw: extra padding for right side. This should * lie in [0, stridew-1]. @@ -59,7 +59,7 @@ forward = function(matrix[double] X, matrix[double] W, matrix[double] b, * - Wout: Output width. */ N = nrow(X) - F = nrow(W) + F = nrow(b) Hout = strideh * (Hin-1) - 2*padh + Hf + out_padh Wout = stridew * (Win-1) - 2*padw + Wf + out_padw @@ -80,7 +80,7 @@ forward = function(matrix[double] X, matrix[double] W, matrix[double] b, * convolution. */ out = conv2d_backward_data(W, X, stride=[strideh,stridew], padding=[padh,padw], - input_shape=[N,C,Hout,Wout], filter_shape=[F,C,Hf,Wf]) + input_shape=[N,F,Hout,Wout], filter_shape=[C,F,Hf,Wf]) out = bias_add(out, b) } @@ -100,7 +100,7 @@ backward = function(matrix[double] dout, int Hout, int Wout, * - Hout: Output height. * - Wout: Output width. * - X: Inputs, of shape (N, C*Hin*Win). - * - W: Weights, of shape (F, C*Hf*Wf). + * - W: Weights, of shape (C, F*Hf*Wf). * - b: Biases, of shape (F, 1). * - C: Number of input channels (dimensionality of depth). * - Hin: Input height. @@ -114,15 +114,15 @@ backward = function(matrix[double] dout, int Hout, int Wout, * * Outputs: * - dX: Gradient wrt `X`, of shape (N, C*Hin*Win). - * - dW: Gradient wrt `W`, of shape (F, C*Hf*Wf). + * - dW: Gradient wrt `W`, of shape (C, F*Hf*Wf). * - db: Gradient wrt `b`, of shape (F, 1). */ N = nrow(X) - F = nrow(W) + F = nrow(b) /* * conv2d_backward_filter takes the input and delta map as first and - * second args, respectively. Given that, we need to compute the + * second args, respectively. Given that we need to compute the * grad (wrt to filter) for transpose convolution where the roles of * the input and output are reversed, we reverse the order of the * args (along with setting input_shape to the delta map shape). @@ -134,7 +134,7 @@ backward = function(matrix[double] dout, int Hout, int Wout, * applying it. */ dW = conv2d_backward_filter(dout, X, stride=[strideh,stridew], padding=[padh,padw], - input_shape=[N,C,Hout,Wout], filter_shape=[F,C,Hf,Wf]) + input_shape=[N,F,Hout,Wout], filter_shape=[C,F,Hf,Wf]) /* * Since the forward for transpose convolution makes a call to @@ -147,7 +147,7 @@ backward = function(matrix[double] dout, int Hout, int Wout, * and the filter, keep in mind that the forward function rotates the * filter by 180 degrees before applying it. */ - dX = conv2d(dout, W, input_shape=[N,C,Hout,Wout], filter_shape=[F,C,Hf,Wf], + dX = conv2d(dout, W, input_shape=[N,F,Hout,Wout], filter_shape=[C,F,Hf,Wf], stride=[strideh,stridew], padding=[padh,padw]) db = rowSums(matrix(colSums(dout), rows=F, cols=Hout*Wout)) @@ -171,10 +171,10 @@ init = function(int F, int C, int Hf, int Wf) * - Wf: Filter width. * * Outputs: - * - W: Weights, of shape (F, C*Hf*Wf). + * - W: Weights, of shape (C, F*Hf*Wf). * - b: Biases, of shape (F, 1). */ - W = rand(rows=F, cols=C*Hf*Wf, pdf="normal") * sqrt(2/(C*Hf*Wf)) + W = rand(rows=C, cols=F*Hf*Wf, pdf="normal") * sqrt(2/(C*Hf*Wf)) b = matrix(0, rows=F, cols=1) } @@ -187,7 +187,7 @@ init_bilinear = function(int C, int K) * channel-wise independent kernels of size K = 2f - f%2, * stride = f and pad = ceil((f-1)/2). The weights are set * via bilinear interpolation, bias is set to 0. - * + * * Inputs: * - C: Number of input channels (dimensionality of depth). * - K: Kernel size (upsampling requires a square filter @@ -202,7 +202,7 @@ init_bilinear = function(int C, int K) vect = 1 - abs(seq(0, K-1) / factor_up - center) weights = matrix(vect %*% t(vect), rows=1, cols=K*K) - /* + /* * To create a multi-channel channel-independent upsampling filter, * we need to intersperse the filter weights with 0s. For instance, * consider the case of 2X upsampling. In this case, K=4 and we have @@ -232,7 +232,7 @@ init_bilinear = function(int C, int K) * resulting row C times */ repl_weights = matrix(1, rows=C, cols=1) %*% cbind(weights, matrix(0, rows=1, cols=C*K*K)) - + /* * The above operation added extra C*K*K trailing 0s in the last row * that we do not need. Thus, we need to: http://git-wip-us.apache.org/repos/asf/systemml/blob/324bea5d/scripts/nn/test/grad_check.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/grad_check.dml b/scripts/nn/test/grad_check.dml index 5e57081..48e470c 100644 --- a/scripts/nn/test/grad_check.dml +++ b/scripts/nn/test/grad_check.dml @@ -619,57 +619,53 @@ conv2d_simple = function() { conv2d_transpose = function() { /* - * Gradient check for the 2D convolution transpose layer. + * Gradient check for the 2D transpose convolutional layer. */ - print("Grad checking the 2D convolution transpose layer with L2 loss.") - - N = 2 - C = 2 - Hin = 3 - Win = 3 - F = 2 - Hf = 3 - Wf = 3 + print("Grad checking the 2D transpose convolutional layer with L2 loss.") + + # Generate data + N = 2 # num examples + C = 2 # num channels + Hin = 3 # input height + Win = 3 # input width + F = 2 # num filters + Hf = 3 # filter height + Wf = 3 # filter width stride = 2 pad = 1 out_pad = 1 - X = rand(rows=N, cols=C*Hin*Win) - [W,b] = conv2d_transpose::init(F, C, Hf, Wf) - + # Create layers + [W, b] = conv2d_transpose::init(F, C, Hf, Wf) + + # Compute analytical gradients of loss wrt parameters [out, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad, out_pad, out_pad) - y = rand(rows=N, cols=F*Hout*Wout) - dout = l2_loss::backward(out,y) - [dX, dW, db] = conv2d_transpose::backward(dout, Hout, Wout, X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad) - + + # Grad check h = 1e-5 print(" - Grad checking X.") for (i in 1:nrow(X)) { for (j in 1:ncol(X)) { + # Compute numerical derivative old = as.scalar(X[i,j]) X[i,j] = old - h - [outmh, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad, out_pad, out_pad) - lossmh = l2_loss::forward(outmh, y) - X[i,j] = old + h [outph, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad, out_pad, out_pad) - lossph = l2_loss::forward(outph, y) + X[i,j] = old # reset + dX_num = (lossph-lossmh) / (2*h) # numerical derivative - X[i,j] = old - - dX_num = (lossph-lossmh) / (2*h) - + # Check error rel_error = test_util::check_rel_grad_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh) } } @@ -677,24 +673,20 @@ conv2d_transpose = function() { print(" - Grad checking W.") for (i in 1:nrow(W)) { for (j in 1:ncol(W)) { + # Compute numerical derivative old = as.scalar(W[i,j]) W[i,j] = old - h - [outmh, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad, out_pad, out_pad) - lossmh = l2_loss::forward(outmh, y) - W[i,j] = old + h [outph, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad, out_pad, out_pad) - lossph = l2_loss::forward(outph, y) + W[i,j] = old # reset + dW_num = (lossph-lossmh) / (2*h) # numerical derivative - W[i,j] = old - - dW_num = (lossph-lossmh) / (2*h) - + # Check error rel_error = test_util::check_rel_grad_error(as.scalar(dW[i,j]), dW_num, lossph, lossmh) } } @@ -702,24 +694,20 @@ conv2d_transpose = function() { print(" - Grad checking b.") for (i in 1:nrow(b)) { for (j in 1:ncol(b)) { + # Compute numerical derivative old = as.scalar(b[i,j]) b[i,j] = old - h - [outmh, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad, out_pad, out_pad) - lossmh = l2_loss::forward(outmh, y) - b[i,j] = old + h [outph, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad, out_pad, out_pad) - lossph = l2_loss::forward(outph, y) + b[i,j] = old # reset + db_num = (lossph-lossmh) / (2*h) # numerical derivative - b[i,j] = old - - db_num = (lossph-lossmh) / (2*h) - + # Check error rel_error = test_util::check_rel_grad_error(as.scalar(db[i,j]), db_num, lossph, lossmh) } } http://git-wip-us.apache.org/repos/asf/systemml/blob/324bea5d/scripts/nn/test/test.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml index 0df7f0f..52fb063 100644 --- a/scripts/nn/test/test.dml +++ b/scripts/nn/test/test.dml @@ -111,36 +111,57 @@ conv2d = function() { conv2d_transpose = function() { /* - * Test for the 2D convolution transpose function. + * Test for the 2D transpose convolution function. */ - print("Testing the 2D convolution transpose function.") - - N = 1 - C = 1 - Hin = 2 - Win = 2 - F = 1 - Hf = 3 - Wf = 3 + print("Testing the 2D transpose convolution function.") + + # Generate data + N = 2 # num examples + C = 3 # num channels + Hin = 2 # input height + Win = 2 # input width + F = 2 # num filters + Hf = 3 # filter height + Wf = 3 # filter width stride = 1 pad = 0 - out_pad = 0 - - X = matrix(seq(1,N*C*Hin*Win), rows=N, cols=C*Hin*Win) - W = matrix(seq(1,F*C*Hf*Wf), rows=F, cols=C*Hf*Wf) - b = matrix(0, rows=F, cols=1) + out_pad = 0 # padding added to output + X = matrix(seq(1,N*C*Hin*Win), rows=N, cols=C*Hin*Win) / (N*C*Hin*Win) * 2 - 1 # normalized + + # Create layer + W = matrix(seq(1,C*F*Hf*Wf), rows=C, cols=F*Hf*Wf) / (C*F*Hf*Wf) * 2 - 1 # normalized + b = matrix(seq(1,F), rows=F, cols=1) / F^2 # non-zero & non-one # Forward - [out, Hout, Wout] = - conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad, out_pad, out_pad) + [out, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, + pad, pad, out_pad, out_pad) # Equivalency check - target = matrix("1 4 7 6 7 23 33 24 19 53 63 42 21 52 59 36", rows=N, cols=C*Hout*Wout) - + target = matrix("1.21296299 2.03703713 1.91666663 1.02777779 + 1.83333337 3.18518519 2.98148131 1.52777767 + 1.5 2.57407403 2.37037039 1.24999988 + 0.78703707 1.25925922 1.17592585 0.69444442 + + 0.87962961 1.20370364 1.08333337 0.77777773 + 1.08333337 1.60185182 1.39814818 0.94444442 + 0.75 0.99074072 0.78703701 0.66666657 + 0.62037039 0.75925928 0.67592591 0.6111111 + + + 0.32407406 0.37037039 0.47222221 0.36111113 + 0.38888881 0.51851851 0.75925928 0.52777779 + 0.72222215 1.24074078 1.48148155 0.91666669 + 0.56481475 0.92592585 1.06481469 0.69444442 + + 0.99074078 1.53703713 1.63888896 1.11111116 + 1.63888884 2.93518519 3.17592597 1.94444442 + 1.97222221 3.65740728 3.89814806 2.33333325 + 1.39814818 2.42592597 2.56481481 1.61111116", rows=N, cols=F*Hout*Wout) + for (i in 1:nrow(out)) { for(j in 1:ncol(out)) { - rel_error = test_util::check_rel_error(as.scalar(out[1,i]), - as.scalar(target[1,i]), 1e-3, 1e-4) + rel_error = test_util::check_rel_error(as.scalar(out[i,j]), + as.scalar(target[i,j]), 1e-3, 1e-4) } } }
