[SYSTEMML-1676] Add a new 2D softmax layer to the `nn` library A 2D softmax layer would accept a tensor of shape (N,C,H,W), where the C axis contains scores for D classes, and output a tensor of the same shape, with the scores transformed to normalized probabilities. The typical use case would be a segmentation problem, in which every pixel has a multiclass prediction.
Closes #555. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9e6715da Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9e6715da Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9e6715da Branch: refs/heads/master Commit: 9e6715dafa7d47104ad8b6c6c1d39fc9423905aa Parents: 5da30b5 Author: Fei Hu <[email protected]> Authored: Fri Jun 30 16:32:08 2017 -0700 Committer: Mike Dusenberry <[email protected]> Committed: Fri Jun 30 16:32:08 2017 -0700 ---------------------------------------------------------------------- scripts/nn/layers/softmax2d.dml | 116 +++++++++++++++++++++++++++++++++++ scripts/nn/test/grad_check.dml | 56 +++++++++++++++++ scripts/nn/test/run_tests.dml | 2 + scripts/nn/test/test.dml | 59 ++++++++++++++++++ 4 files changed, 233 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/9e6715da/scripts/nn/layers/softmax2d.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/layers/softmax2d.dml b/scripts/nn/layers/softmax2d.dml new file mode 100644 index 0000000..aad587d --- /dev/null +++ b/scripts/nn/layers/softmax2d.dml @@ -0,0 +1,116 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +/* + * 2D Softmax classifier layer. + */ + + source("nn/util.dml") as util + source("nn/layers/softmax.dml") as softmax + + forward = function(matrix[double] scores, int C) + return (matrix[double] probs) { + /* + * Computes the forward pass for a softmax2d classifier. The input + * has four dimensions (N, C, Hin, Win), that means it has N + * 2d-examples with a shape (Hin, Win), each pixel in the 2d + * example has C values that are interpreted as unnormalized, + * log-probabilities for each of C classes. The softmax function + * transforms these values to normalized probabilities across the C + * classes, for every example. + * + * This can be interpreted as a generalization of the sigmoid + * function to multiple classes. + * + * `probs_ijk = e^scores_ijk / sum(e^scores_ijk)` + * + * In these equations, `probs_ijk` is the C-dimensional vector of the + * normalized probabilities for the pixel `j, k` in the example `i` + * + * Inputs: + * - scores: Inputs, of shape (N, C*Hin*Win). + * - C: Number of input channels (dimensionality of input depth). + * + * Outputs: + * - probs: Outputs, of shape (N, C*Hin*Win). + */ + + # For numerical stability, we subtract the max score of an example from all scores for that + # example. This is equivalent to the original formulation: + # e^scores_ijk / sum(e^scores_ijk) == C*e^scores_ijk / C*sum(e^scores_ijk) + # == e^(scores_ijk+log(C)) / sum(e^(scores_ijk+log(C)) + # set log(C) = -max(scores_ijk): + # == e^(scores_ijk-max(scores_ijk)) / sum(e^(scores_ijk-max(scores_ijk)) + + N = nrow(scores) + + #Transpose the matrix from (N, C*H*W) to (N*H*W, C) + scores_C_NHW = util::transpose_NCHW_to_CNHW(scores, C) + scores_NHW_C = t(scores_C_NHW) + + probs_NHW_C = softmax::forward(scores_NHW_C) + + #Transpose the matrix from (N*H*W, C) to (N, C*H*W) + probs_C_NHW = t(probs_NHW_C) + probs = util::transpose_NCHW_to_CNHW(probs_C_NHW, N) +} + +backward = function(matrix[double] dprobs, matrix[double] scores, int C) + return (matrix[double] dscores) { + /* + * Computes the backward pass for a softmax2d classifier. + * + * Note that dscores_ij has multiple source branches: + * + * ``` + * dprobs_ij/dscores_ij = probs_ij * (1 - probs_ij) + * dprobs_ik/dscores_ij = -probs_ik * probs_ij, for all k != j + * + * dloss/dscores_ij = + * (dloss/dprobs_ij * dprobs_ij/dscores_ij) + * + sum_{k!=j}(dloss/dprobs_ik * dprobs_ik/dscores_ij) + * ``` + * + * Inputs: + * - dprobs: Gradient wrt `probs` from upstream, of shape (N, C*Hin*Win). + * - scores: Inputs, of shape (N, C*Hin*Win). + * - C: Number of input channels (dimensionality of input depth). + * + * Outputs: + * - dscores: Gradient wrt `scores`, of shape (N, C*Win*Hin). + */ + + N = nrow(scores) + + #Transpose the matrix from (N, C*H*W) to (N*H*W, C) + dprobs_C_NHW = util::transpose_NCHW_to_CNHW(dprobs, C) + dprobs_NHW_C = t(dprobs_C_NHW) + + #Transpose the matrix from (N, C*H*W) to (N*H*W, C) + scores_C_NHW = util::transpose_NCHW_to_CNHW(scores, C) + scores_NHW_C = t(scores_C_NHW) + + dscores_NHW_C = softmax::backward(dprobs_NHW_C, scores_NHW_C) + + #Transpose the matrix from (N*H*W, C) to (N, C*H*W) + dscores_C_NHW = t(dscores_NHW_C) + dscores = util::transpose_NCHW_to_CNHW(dscores_C_NHW, N) +} http://git-wip-us.apache.org/repos/asf/systemml/blob/9e6715da/scripts/nn/test/grad_check.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/grad_check.dml b/scripts/nn/test/grad_check.dml index fcb45cd..6844f40 100644 --- a/scripts/nn/test/grad_check.dml +++ b/scripts/nn/test/grad_check.dml @@ -46,10 +46,12 @@ source("nn/layers/scale_shift1d.dml") as scale_shift1d source("nn/layers/scale_shift2d.dml") as scale_shift2d source("nn/layers/sigmoid.dml") as sigmoid source("nn/layers/softmax.dml") as softmax +source("nn/layers/softmax2d.dml") as softmax2d source("nn/layers/tanh.dml") as tanh source("nn/test/conv2d_simple.dml") as conv2d_simple source("nn/test/max_pool2d_simple.dml") as max_pool2d_simple source("nn/test/util.dml") as test_util +source("nn/util.dml") as util affine = function() { /* @@ -1844,6 +1846,60 @@ softmax = function() { } } +softmax2d = function() { + /* + * Gradient check for the 2D softmax layer. + */ + print("Grad checking the 2D softmax layer with L2 loss.") + + # Generate data + N = 3 # num examples + C = 10 # num classes + Hin = 5 # example height + Win = 5 # example width + + X = rand(rows=N, cols=C*Hin*Win) + y = rand(rows=N, cols=C*Hin*Win, min=0, max=1, pdf="uniform") + y_C_NHW = util::transpose_NCHW_to_CNHW(y, C) + y_NHW_C = t(y_C_NHW) + y_NHW_C = y_NHW_C / rowSums(y_NHW_C) + + # Compute analytical gradients of loss wrt parameters + out = softmax2d::forward(X, C) + out_C_NHW = util::transpose_NCHW_to_CNHW(out, C) + out_NHW_C = t(out_C_NHW) + + dout_NHW_C = l2_loss::backward(out_NHW_C, y_NHW_C) + dout_C_NHW = t(dout_NHW_C) + dout = util::transpose_NCHW_to_CNHW(dout_C_NHW, N) + dX = softmax2d::backward(dout, X, C) + + # Grad check + h = 1e-5 + for (i in 1:nrow(X)) { + for (j in 1:ncol(X)) { + # Compute numerical derivative + old = as.scalar(X[i,j]) + X[i,j] = old - h + outmh = softmax2d::forward(X, C) + outmh_C_NHW = util::transpose_NCHW_to_CNHW(outmh, C) + outmh_NHW_C = t(outmh_C_NHW) + lossmh = l2_loss::forward(outmh_NHW_C, y_NHW_C) + + X[i,j] = old + h + outph = softmax2d::forward(X, C) + outph_C_NHW = util::transpose_NCHW_to_CNHW(outph, C) + outph_NHW_C = t(outph_C_NHW) + lossph = l2_loss::forward(outph_NHW_C, y_NHW_C) + X[i,j] = old # reset + dX_num = (lossph-lossmh) / (2*h) # numerical derivative + + # Check error + rel_error = test_util::check_rel_grad_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh) + } + } +} + tanh = function() { /* * Gradient check for the hyperbolic tangent (tanh) nonlinearity http://git-wip-us.apache.org/repos/asf/systemml/blob/9e6715da/scripts/nn/test/run_tests.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml index f4c33d8..0662ffa 100644 --- a/scripts/nn/test/run_tests.dml +++ b/scripts/nn/test/run_tests.dml @@ -59,6 +59,7 @@ grad_check::scale_shift1d() grad_check::scale_shift2d() grad_check::sigmoid() grad_check::softmax() +grad_check::softmax2d() grad_check::tanh() print("") @@ -101,6 +102,7 @@ test::transpose_NCHW_to_CNHW() test::top_k_row() test::top_k() test::top_k2d() +test::softmax2d() print("---") print("Other tests complete -- look for any ERRORs or WARNINGs.") http://git-wip-us.apache.org/repos/asf/systemml/blob/9e6715da/scripts/nn/test/test.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml index 1364583..adaef5c 100644 --- a/scripts/nn/test/test.dml +++ b/scripts/nn/test/test.dml @@ -33,6 +33,7 @@ source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss source("nn/layers/max_pool2d.dml") as max_pool2d source("nn/layers/max_pool2d_builtin.dml") as max_pool2d_builtin source("nn/layers/tanh.dml") as tanh +source("nn/layers/softmax2d.dml") as softmax2d source("nn/test/conv2d_simple.dml") as conv2d_simple source("nn/test/max_pool2d_simple.dml") as max_pool2d_simple source("nn/test/util.dml") as test_util @@ -923,3 +924,61 @@ top_k2d = function() { check_indices = test_util::check_all_equal(indices, expected_indices) } +softmax2d = function() { + /* + * Test for 2D softmax function. + */ + print("Testing the 2D softmax function.") + + N = 2 # num example + C = 3 # num class + Hin = 3 # example height + Win = 3 # example width + X = matrix("10.0 10.0 0.0 + 10.0 0.0 0.0 + 10.0 0.0 0.0 + 0.0 10.0 0.0 + 0.0 10.0 0.0 + 0.0 10.0 0.0 + 0.0 0.0 10.0 + 0.0 0.0 10.0 + 0.0 0.0 10.0 + 10.0 10.0 0.0 + 10.0 0.0 0.0 + 10.0 0.0 0.0 + 0.0 10.0 0.0 + 0.0 10.0 0.0 + 0.0 10.0 0.0 + 0.0 0.0 10.0 + 0.0 0.0 10.0 + 0.0 0.0 10.0", rows=N, cols=C*Hin*Win) + + probs_expected = matrix("9.99909163e-01 4.99988675e-01 4.53958055e-05 + 9.99909163e-01 4.53958055e-05 4.53958055e-05 + 9.99909163e-01 4.53958055e-05 4.53958055e-05 + 4.53958055e-05 4.99988675e-01 4.53958055e-05 + 4.53958055e-05 9.99909163e-01 4.53958055e-05 + 4.53958055e-05 9.99909163e-01 4.53958055e-05 + 4.53958055e-05 2.26994507e-05 9.99909163e-01 + 4.53958055e-05 4.53958055e-05 9.99909163e-01 + 4.53958055e-05 4.53958055e-05 9.99909163e-01 + 9.99909163e-01 4.99988675e-01 4.53958055e-05 + 9.99909163e-01 4.53958055e-05 4.53958055e-05 + 9.99909163e-01 4.53958055e-05 4.53958055e-05 + 4.53958055e-05 4.99988675e-01 4.53958055e-05 + 4.53958055e-05 9.99909163e-01 4.53958055e-05 + 4.53958055e-05 9.99909163e-01 4.53958055e-05 + 4.53958055e-05 2.26994507e-05 9.99909163e-01 + 4.53958055e-05 4.53958055e-05 9.99909163e-01 + 4.53958055e-05 4.53958055e-05 9.99909163e-01", rows=N, cols=C*Hin*Win) + probs = softmax2d::forward(X, C) + + # Equivalency check + for (i in 1:nrow(probs)) { + for (j in 1:ncol(probs)) { + rel_error = test_util::check_rel_error(as.scalar(probs[i,j]), as.scalar(probs_expected[i,j]), + 1e-5, 1e-6) + } + } +} +
