Repository: systemml
Updated Branches:
  refs/heads/master 9f808c43e -> 5e7e57774


[SYSTEMML-1678] Add a new 1D top_k utility function

This function computes the top k values (i.e. probabilities) and
associated indices (i.e. classes) from the input matrix X.  A typical
use case is that in which X is the output of a softmax layer, and
values and indices will contain rows with the top k probabilities and
class indices.

Closes #551.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/2e78eb9a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/2e78eb9a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/2e78eb9a

Branch: refs/heads/master
Commit: 2e78eb9a56148f9e27ea1cb646e2d6dd528d251e
Parents: 9f808c4
Author: Fei Hu <[email protected]>
Authored: Wed Jun 28 12:23:08 2017 -0700
Committer: Mike Dusenberry <[email protected]>
Committed: Wed Jun 28 12:23:08 2017 -0700

----------------------------------------------------------------------
 scripts/nn/test/run_tests.dml |  2 ++
 scripts/nn/test/test.dml      | 66 +++++++++++++++++++++++++++++++++++++-
 scripts/nn/util.dml           | 59 ++++++++++++++++++++++++++++++++++
 3 files changed, 126 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/2e78eb9a/scripts/nn/test/run_tests.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml
index cca0d0d..4cc2875 100644
--- a/scripts/nn/test/run_tests.dml
+++ b/scripts/nn/test/run_tests.dml
@@ -98,6 +98,8 @@ test::padding()
 test::tanh()
 test::threshold()
 test::transpose_NCHW_to_CNHW()
+test::top_k_row()
+test::top_k()
 
 print("---")
 print("Other tests complete -- look for any ERRORs or WARNINGs.")

http://git-wip-us.apache.org/repos/asf/systemml/blob/2e78eb9a/scripts/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml
index b0899e3..b1190fc 100644
--- a/scripts/nn/test/test.dml
+++ b/scripts/nn/test/test.dml
@@ -768,7 +768,7 @@ tanh = function() {
 
 threshold = function() {
   /*
-   * Test for threshold function.
+   * Test for the threshold function.
    */
   print("Testing the threshold function.")
 
@@ -788,3 +788,67 @@ threshold = function() {
   out = test_util::check_all_equal(indicator_matrix, target_matrix)
 }
 
+top_k_row = function() {
+  /*
+   * Test for the top_k function.
+   */
+  print("Testing the top_k_row function.")
+
+  #Generate data
+  X = matrix("2 3 2 1 19 20 31 12 3 4 5 60 9 2
+              3 6 15 18 6 12 1 17 3 0 4 19 1 6", rows=2, cols=14)
+  r = 2
+  k = 3
+  expected_values = matrix("19
+                            18
+                            17", rows=1, cols=3)
+  expected_indices = matrix("12
+                             4
+                             8", rows=1, cols=3)
+
+  # Test the top 3 for the second row
+  [values, indices] = util::top_k_row(X, 2, 3)
+  check_values = test_util::check_all_equal(values, expected_values)
+  check_indices = test_util::check_all_equal(indices, expected_indices)
+}
+
+top_k = function() {
+   /*
+    * Test for the top_k function.
+   */
+   print("Testing the top_k function.")
+
+   # Generate data
+   X = matrix("0.1 0.3 0.2 0.4
+               0.1 0.3 0.3 0.2", rows=2, cols=4)
+   expected_values_top1 = matrix("0.4
+                                  0.3", rows=2, cols=1)
+   expected_indices_top1 = matrix("4
+                                   2", rows=2, cols=1)
+   expected_values_top2 = matrix("0.4 0.3
+                                  0.3 0.3", rows=2, cols=2)
+   expected_indices_top2 = matrix("4 2
+                                   2 3", rows=2, cols=2)
+   expected_values_topAll = matrix("0.4 0.3 0.2 0.1
+                                    0.3 0.3 0.2 0.1", rows=2, cols=4)
+   expected_indices_topAll = matrix("4 2 3 1
+                                     2 3 4 1", rows=2, cols=4)
+
+   # test top_1
+   print("Case 1: test top_1")
+   [values_top1, indices_top1] = util::top_k(X, 1)
+   check_values_top1 = test_util::check_all_equal(values_top1, 
expected_values_top1)
+   check_indices_top1 = test_util::check_all_equal(indices_top1, 
expected_indices_top1)
+
+   # test top_2
+   print("Case 2: test top_2")
+   [values_top2, indices_top2] = util::top_k(X, 2)
+   check_values_top2 = test_util::check_all_equal(values_top2, 
expected_values_top2)
+   check_indices_top2 = test_util::check_all_equal(indices_top2, 
expected_indices_top2)
+
+   # test top_All
+   print("Case 3: test top_All")
+   [values_topAll, indices_topAll] = util::top_k(X, 4)
+   check_values_topAll = test_util::check_all_equal(values_topAll, 
expected_values_topAll)
+   check_indices_topAll = test_util::check_all_equal(indices_topAll, 
expected_indices_topAll)
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/2e78eb9a/scripts/nn/util.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml
index 329f22f..fb54c43 100644
--- a/scripts/nn/util.dml
+++ b/scripts/nn/util.dml
@@ -230,6 +230,8 @@ transpose_NCHW_to_CNHW = function(matrix[double] X, int C) 
return (matrix[double
    * Outputs:
    *  - out: Transposed output with C rows.
    */
+
+  if(1==1){}
   N = nrow(X)
   D = ncol(X) / C
 
@@ -268,3 +270,60 @@ transpose_NCHW_to_CNHW = function(matrix[double] X, int C) 
return (matrix[double
    */
   out = matrix(permut %*% matrix(X, rows=N*C, cols=D), rows=C, cols=N*D)
 }
+
+top_k_row = function(matrix[double] X, integer r, integer k)
+    return (matrix[double] values, matrix[double] indices) {
+  /*
+   * Computes the top k values (i.e. probabilities) and associated
+   * indices (i.e. classes) in the rth row of the input matrix X
+   *
+   * Inputs:
+   * - X: Inputs, of shape (N D).
+   * - r: Input row number of X to look for
+   * - k: Input number of top elements to look for
+   *
+   * Outputs:
+   * - values: The top k values at the rth row, of shape
+   *    (1, k)
+   * - indices: The class indices, of shape (1, k)
+   */
+
+  #TODO: do r & k need to be checked in the valid range
+  row = X[r, ]
+  row_t = t(row)
+  indices = order(target=row_t, by=1, decreasing=TRUE, index.return=TRUE)
+  indices = t(indices)
+  indices = indices[1, 1:k]
+
+  values = matrix(0, rows=1, cols=k)
+  for (i in 1:k) {
+    values[1, i] = row[1, as.scalar(indices[1, i])]
+  }
+}
+
+top_k = function(matrix[double] X, integer k)
+     return (matrix[double] values, matrix[double] indices) {
+  /*
+   * Computes the top k values (i.e. probabilities) and associated
+   * indices (i.e. classes) for the input matrix X.
+   *
+   * Inputs:
+   * - X: Inputs, of shape (N D).
+   * - k: Input number of top elements to look for
+   *
+   * Outputs:
+   * - values: The top k values along a certain dimension, of shape
+   *    (N, k)
+   * - indices: The indices of classes, of shape (N, K)
+   */
+  N = nrow(X)
+  D = ncol(X)
+  values = matrix(0, rows=N, cols=k)
+  indices = matrix(0, rows=N, cols=k)
+
+  parfor (r in 1:N) {
+    [value, index] = top_k_row(X, r, k)
+    values[r, ] = value
+    indices[r, ] = index
+  }
+}

Reply via email to