[SYSTEMML-1736] Add a new 2D top_k utility function This function computes the top k values (i.e. probabilities) and associated indices (i.e. classes) from the input matrix X. A typical use case here is that in which X is the output of a 2D softmax layer, so each channel contains a set of normalized class probalilities, and values and indices will contain the top k probabilities and indices along the channel axis. The scenario will be common in an image segmentation problem.
Closes #551. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5e7e5777 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5e7e5777 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5e7e5777 Branch: refs/heads/master Commit: 5e7e57774b1936736d8551bacf3dffd60bc45071 Parents: 2e78eb9 Author: Fei Hu <hufe...@gmail.com> Authored: Wed Jun 28 12:23:13 2017 -0700 Committer: Mike Dusenberry <mwdus...@us.ibm.com> Committed: Wed Jun 28 12:23:13 2017 -0700 ---------------------------------------------------------------------- scripts/nn/test/run_tests.dml | 1 + scripts/nn/test/test.dml | 85 ++++++++++++++++++++++++++++++++++---- scripts/nn/util.dml | 62 ++++++++++++++++++++------- 3 files changed, 127 insertions(+), 21 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/5e7e5777/scripts/nn/test/run_tests.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml index 4cc2875..f4c33d8 100644 --- a/scripts/nn/test/run_tests.dml +++ b/scripts/nn/test/run_tests.dml @@ -100,6 +100,7 @@ test::threshold() test::transpose_NCHW_to_CNHW() test::top_k_row() test::top_k() +test::top_k2d() print("---") print("Other tests complete -- look for any ERRORs or WARNINGs.") http://git-wip-us.apache.org/repos/asf/systemml/blob/5e7e5777/scripts/nn/test/test.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml index b1190fc..a1bb6cd 100644 --- a/scripts/nn/test/test.dml +++ b/scripts/nn/test/test.dml @@ -493,17 +493,17 @@ transpose_NCHW_to_CNHW = function() { * Test for `transpose_NCHW_to_CNHW` function. */ print("Testing transpose_NCHW_to_CNHW function.") - + # Generate data N = 2 C = 3 H = 4 W = 5 X = matrix(seq(1, N*C*H*W), rows=N, cols=C*H*W) - + out = util::transpose_NCHW_to_CNHW(X, C) - - target = + + target = matrix("1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 @@ -511,7 +511,7 @@ transpose_NCHW_to_CNHW = function() { 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120", rows=C, cols=N*H*W) - + # Equivalency check for (i in 1:nrow(out)) { for(j in 1:ncol(out)) { @@ -790,7 +790,7 @@ threshold = function() { top_k_row = function() { /* - * Test for the top_k function. + * Test for the top_k_row function. */ print("Testing the top_k_row function.") @@ -806,7 +806,7 @@ top_k_row = function() { 4 8", rows=1, cols=3) - # Test the top 3 for the second row + # Test the top 3 for the second row. [values, indices] = util::top_k_row(X, 2, 3) check_values = test_util::check_all_equal(values, expected_values) check_indices = test_util::check_all_equal(indices, expected_indices) @@ -852,3 +852,74 @@ top_k = function() { check_values_topAll = test_util::check_all_equal(values_topAll, expected_values_topAll) check_indices_topAll = test_util::check_all_equal(indices_topAll, expected_indices_topAll) } + +top_k2d = function() { + /* + * Test for the top_k2d function. + */ + print("Testing the top_k2d function.") + # Generate data, of shape (2, 3, 3, 4) + k = 2 + X = matrix("0.1 0.4 0.4 0.5 + 0.4 0.1 0.6 0.1 + 0.7 0.7 0.3 0.2 + + 0.2 0.5 0.4 0.5 + 0.4 0.1 0.6 0.1 + 0.7 0.8 0.3 0.2 + + 0.3 0.4 0.4 0.5 + 0.4 0.1 0.6 0.1 + 0.7 0.2 0.3 0.2 + + 0.1 0.4 0.4 0.5 + 0.4 0.1 0.6 0.1 + 0.7 0.7 0.3 0.2 + + 0.2 0.5 0.4 0.5 + 0.4 0.1 0.6 0.1 + 0.7 0.8 0.3 0.2 + + 0.3 0.4 0.4 0.5 + 0.4 0.1 0.6 0.1 + 0.7 0.2 0.3 0.2", rows=2, cols=3*3*4) + + expected_values = matrix("0.3 0.5 0.4 0.5 + 0.4 0.1 0.6 0.1 + 0.7 0.8 0.3 0.2 + + 0.2 0.4 0.4 0.5 + 0.4 0.1 0.6 0.1 + 0.7 0.7 0.3 0.2 + + 0.3 0.5 0.4 0.5 + 0.4 0.1 0.6 0.1 + 0.7 0.8 0.3 0.2 + + 0.2 0.4 0.4 0.5 + 0.4 0.1 0.6 0.1 + 0.7 0.7 0.3 0.2", rows=2, cols=2*3*4) + + expected_indices = matrix("3 2 1 1 + 1 1 1 1 + 1 2 1 1 + + 2 1 2 2 + 2 2 2 2 + 2 1 2 2 + + 3 2 1 1 + 1 1 1 1 + 1 2 1 1 + + 2 1 2 2 + 2 2 2 2 + 2 1 2 2", rows=2, cols=24) + + [values, indices] = util::top_k2d(X, k, 3, 3, 4) + + # Equivalency check + check_values = test_util::check_all_equal(values, expected_values) + check_indices = test_util::check_all_equal(indices, expected_indices) +} + http://git-wip-us.apache.org/repos/asf/systemml/blob/5e7e5777/scripts/nn/util.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml index fb54c43..102a507 100644 --- a/scripts/nn/util.dml +++ b/scripts/nn/util.dml @@ -204,7 +204,7 @@ threshold = function(matrix[double] X, double thresh) return (matrix[double] out) { /* * Computes an indicator matrix with values in {0, 1} depending on - * whether or not the values in X are above the input threshold + * whether or not the values in X are above the input threshold. * * Inputs: * - X: Inputs, of shape (any, any). @@ -231,7 +231,6 @@ transpose_NCHW_to_CNHW = function(matrix[double] X, int C) return (matrix[double * - out: Transposed output with C rows. */ - if(1==1){} N = nrow(X) D = ncol(X) / C @@ -275,17 +274,17 @@ top_k_row = function(matrix[double] X, integer r, integer k) return (matrix[double] values, matrix[double] indices) { /* * Computes the top k values (i.e. probabilities) and associated - * indices (i.e. classes) in the rth row of the input matrix X + * indices (i.e. classes) in the rth row of the input matrix X. * * Inputs: - * - X: Inputs, of shape (N D). - * - r: Input row number of X to look for - * - k: Input number of top elements to look for + * - X: Inputs, of shape (N, D). + * - r: Input row number of X to look for. + * - k: Input number of top elements to look for. * * Outputs: - * - values: The top k values at the rth row, of shape - * (1, k) - * - indices: The class indices, of shape (1, k) + * - values: The top k values at the rth row, of shape + * (1, k). + * - indices: The class indices, of shape (1, k). */ #TODO: do r & k need to be checked in the valid range @@ -308,13 +307,13 @@ top_k = function(matrix[double] X, integer k) * indices (i.e. classes) for the input matrix X. * * Inputs: - * - X: Inputs, of shape (N D). - * - k: Input number of top elements to look for + * - X: Inputs, of shape (N, D). + * - k: Input number of top elements to look for. * * Outputs: - * - values: The top k values along a certain dimension, of shape - * (N, k) - * - indices: The indices of classes, of shape (N, K) + * - values: The top k values along a certain dimension, of shape + * (N, k). + * - indices: The indices of classes, of shape (N, K). */ N = nrow(X) D = ncol(X) @@ -327,3 +326,38 @@ top_k = function(matrix[double] X, integer k) indices[r, ] = index } } + +top_k2d = function(matrix[double] X, int k, int C, int Hin, int Win) + return (matrix[double] values, matrix[double] indices) { + /* + * Computes the top k values (i.e. probabilities) and associated + * indices (i.e. classes) for the input matrix X. + * + * Inputs: + * - X: Inputs, of shape (N, C*Hin*Win). + * - k: Input number of top elements to look for. + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + * + * Outputs: + * - values: The top k values along a certain dimension, of shape + * (N, k*Hin*Win). + * - indices: The indices of classes, of shape (N, k*Hin*Win). + */ + N = nrow(X) + + # Reshape the input matrix (N, C*Hin*Win) to (N*Hin*Win, C) + X_C_NHW = transpose_NCHW_to_CNHW(X, C) + X_NHW_C = t(X_C_NHW) + + # Compute the top k for the reshape matrix. + [values_NHW_K, indices_NHW_K] = top_k(X_NHW_C, k) # shape: (N*Hin*Win, k) + + values_K_NHW = t(values_NHW_K) + indices_K_NHW = t(indices_NHW_K) + + values = transpose_NCHW_to_CNHW(values_K_NHW, N) + indices = transpose_NCHW_to_CNHW(indices_K_NHW, N) +} +