Repository: systemml Updated Branches: refs/heads/master 9f808c43e -> 5e7e57774
[SYSTEMML-1678] Add a new 1D top_k utility function This function computes the top k values (i.e. probabilities) and associated indices (i.e. classes) from the input matrix X. A typical use case is that in which X is the output of a softmax layer, and values and indices will contain rows with the top k probabilities and class indices. Closes #551. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/2e78eb9a Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/2e78eb9a Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/2e78eb9a Branch: refs/heads/master Commit: 2e78eb9a56148f9e27ea1cb646e2d6dd528d251e Parents: 9f808c4 Author: Fei Hu <[email protected]> Authored: Wed Jun 28 12:23:08 2017 -0700 Committer: Mike Dusenberry <[email protected]> Committed: Wed Jun 28 12:23:08 2017 -0700 ---------------------------------------------------------------------- scripts/nn/test/run_tests.dml | 2 ++ scripts/nn/test/test.dml | 66 +++++++++++++++++++++++++++++++++++++- scripts/nn/util.dml | 59 ++++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/2e78eb9a/scripts/nn/test/run_tests.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml index cca0d0d..4cc2875 100644 --- a/scripts/nn/test/run_tests.dml +++ b/scripts/nn/test/run_tests.dml @@ -98,6 +98,8 @@ test::padding() test::tanh() test::threshold() test::transpose_NCHW_to_CNHW() +test::top_k_row() +test::top_k() print("---") print("Other tests complete -- look for any ERRORs or WARNINGs.") http://git-wip-us.apache.org/repos/asf/systemml/blob/2e78eb9a/scripts/nn/test/test.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml index b0899e3..b1190fc 100644 --- a/scripts/nn/test/test.dml +++ b/scripts/nn/test/test.dml @@ -768,7 +768,7 @@ tanh = function() { threshold = function() { /* - * Test for threshold function. + * Test for the threshold function. */ print("Testing the threshold function.") @@ -788,3 +788,67 @@ threshold = function() { out = test_util::check_all_equal(indicator_matrix, target_matrix) } +top_k_row = function() { + /* + * Test for the top_k function. + */ + print("Testing the top_k_row function.") + + #Generate data + X = matrix("2 3 2 1 19 20 31 12 3 4 5 60 9 2 + 3 6 15 18 6 12 1 17 3 0 4 19 1 6", rows=2, cols=14) + r = 2 + k = 3 + expected_values = matrix("19 + 18 + 17", rows=1, cols=3) + expected_indices = matrix("12 + 4 + 8", rows=1, cols=3) + + # Test the top 3 for the second row + [values, indices] = util::top_k_row(X, 2, 3) + check_values = test_util::check_all_equal(values, expected_values) + check_indices = test_util::check_all_equal(indices, expected_indices) +} + +top_k = function() { + /* + * Test for the top_k function. + */ + print("Testing the top_k function.") + + # Generate data + X = matrix("0.1 0.3 0.2 0.4 + 0.1 0.3 0.3 0.2", rows=2, cols=4) + expected_values_top1 = matrix("0.4 + 0.3", rows=2, cols=1) + expected_indices_top1 = matrix("4 + 2", rows=2, cols=1) + expected_values_top2 = matrix("0.4 0.3 + 0.3 0.3", rows=2, cols=2) + expected_indices_top2 = matrix("4 2 + 2 3", rows=2, cols=2) + expected_values_topAll = matrix("0.4 0.3 0.2 0.1 + 0.3 0.3 0.2 0.1", rows=2, cols=4) + expected_indices_topAll = matrix("4 2 3 1 + 2 3 4 1", rows=2, cols=4) + + # test top_1 + print("Case 1: test top_1") + [values_top1, indices_top1] = util::top_k(X, 1) + check_values_top1 = test_util::check_all_equal(values_top1, expected_values_top1) + check_indices_top1 = test_util::check_all_equal(indices_top1, expected_indices_top1) + + # test top_2 + print("Case 2: test top_2") + [values_top2, indices_top2] = util::top_k(X, 2) + check_values_top2 = test_util::check_all_equal(values_top2, expected_values_top2) + check_indices_top2 = test_util::check_all_equal(indices_top2, expected_indices_top2) + + # test top_All + print("Case 3: test top_All") + [values_topAll, indices_topAll] = util::top_k(X, 4) + check_values_topAll = test_util::check_all_equal(values_topAll, expected_values_topAll) + check_indices_topAll = test_util::check_all_equal(indices_topAll, expected_indices_topAll) +} http://git-wip-us.apache.org/repos/asf/systemml/blob/2e78eb9a/scripts/nn/util.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml index 329f22f..fb54c43 100644 --- a/scripts/nn/util.dml +++ b/scripts/nn/util.dml @@ -230,6 +230,8 @@ transpose_NCHW_to_CNHW = function(matrix[double] X, int C) return (matrix[double * Outputs: * - out: Transposed output with C rows. */ + + if(1==1){} N = nrow(X) D = ncol(X) / C @@ -268,3 +270,60 @@ transpose_NCHW_to_CNHW = function(matrix[double] X, int C) return (matrix[double */ out = matrix(permut %*% matrix(X, rows=N*C, cols=D), rows=C, cols=N*D) } + +top_k_row = function(matrix[double] X, integer r, integer k) + return (matrix[double] values, matrix[double] indices) { + /* + * Computes the top k values (i.e. probabilities) and associated + * indices (i.e. classes) in the rth row of the input matrix X + * + * Inputs: + * - X: Inputs, of shape (N D). + * - r: Input row number of X to look for + * - k: Input number of top elements to look for + * + * Outputs: + * - values: The top k values at the rth row, of shape + * (1, k) + * - indices: The class indices, of shape (1, k) + */ + + #TODO: do r & k need to be checked in the valid range + row = X[r, ] + row_t = t(row) + indices = order(target=row_t, by=1, decreasing=TRUE, index.return=TRUE) + indices = t(indices) + indices = indices[1, 1:k] + + values = matrix(0, rows=1, cols=k) + for (i in 1:k) { + values[1, i] = row[1, as.scalar(indices[1, i])] + } +} + +top_k = function(matrix[double] X, integer k) + return (matrix[double] values, matrix[double] indices) { + /* + * Computes the top k values (i.e. probabilities) and associated + * indices (i.e. classes) for the input matrix X. + * + * Inputs: + * - X: Inputs, of shape (N D). + * - k: Input number of top elements to look for + * + * Outputs: + * - values: The top k values along a certain dimension, of shape + * (N, k) + * - indices: The indices of classes, of shape (N, K) + */ + N = nrow(X) + D = ncol(X) + values = matrix(0, rows=N, cols=k) + indices = matrix(0, rows=N, cols=k) + + parfor (r in 1:N) { + [value, index] = top_k_row(X, r, k) + values[r, ] = value + indices[r, ] = index + } +}
