Repository: systemml
Updated Branches:
  refs/heads/master 428839952 -> 9e6715daf


[MINOR] Cleanup `nn` library.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5da30b51
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5da30b51
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5da30b51

Branch: refs/heads/master
Commit: 5da30b51b4afa9b88bb5b91941e7f956a23becc8
Parents: 4288399
Author: Mike Dusenberry <[email protected]>
Authored: Fri Jun 30 16:31:51 2017 -0700
Committer: Mike Dusenberry <[email protected]>
Committed: Fri Jun 30 16:31:51 2017 -0700

----------------------------------------------------------------------
 scripts/nn/test/test.dml |  6 ++--
 scripts/nn/util.dml      | 69 +++++++++++++++++++------------------------
 2 files changed, 34 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/5da30b51/scripts/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml
index a1bb6cd..1364583 100644
--- a/scripts/nn/test/test.dml
+++ b/scripts/nn/test/test.dml
@@ -835,19 +835,19 @@ top_k = function() {
                                      2 3 4 1", rows=2, cols=4)
 
    # test top_1
-   print("Case 1: test top_1")
+   print(" - Testing top_1.")
    [values_top1, indices_top1] = util::top_k(X, 1)
    check_values_top1 = test_util::check_all_equal(values_top1, 
expected_values_top1)
    check_indices_top1 = test_util::check_all_equal(indices_top1, 
expected_indices_top1)
 
    # test top_2
-   print("Case 2: test top_2")
+   print(" - Testing top_2.")
    [values_top2, indices_top2] = util::top_k(X, 2)
    check_values_top2 = test_util::check_all_equal(values_top2, 
expected_values_top2)
    check_indices_top2 = test_util::check_all_equal(indices_top2, 
expected_indices_top2)
 
    # test top_All
-   print("Case 3: test top_All")
+   print(" - Testing top_All.")
    [values_topAll, indices_topAll] = util::top_k(X, 4)
    check_values_topAll = test_util::check_all_equal(values_topAll, 
expected_values_topAll)
    check_indices_topAll = test_util::check_all_equal(indices_topAll, 
expected_indices_topAll)

http://git-wip-us.apache.org/repos/asf/systemml/blob/5da30b51/scripts/nn/util.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml
index 102a507..ba345c5 100644
--- a/scripts/nn/util.dml
+++ b/scripts/nn/util.dml
@@ -216,57 +216,50 @@ threshold = function(matrix[double] X, double thresh)
   out = X > thresh
 }
 
-/*
- * Reshape util for tensors in NCHW format.
- * Transposes the 1st and 2nd dimensions.
- */
-transpose_NCHW_to_CNHW = function(matrix[double] X, int C) return 
(matrix[double] out){
+transpose_NCHW_to_CNHW = function(matrix[double] X, int C)
+    return (matrix[double] out) {
   /*
+   * Reshape util for tensors in NCHW format.
+   * Transposes the 1st and 2nd dimensions.
+   *
    * Inputs:
-   *  - X: Input with N rows and channels flattened within each row in
-   *      channel-major format (NCHW).
+   *  - X: Inputs, of shape (N, C*H*W).
    *  - C: Number of channels (dimensionality of depth).
    *
    * Outputs:
-   *  - out: Transposed output with C rows.
+   *  - out: Outputs with the N and C axes transposed, of
+   *      shape (C, N*H*W).
    */
-
   N = nrow(X)
   D = ncol(X) / C
 
-  /*
-   * This is an easy reshape because the channels remain intact. By
-   * reshaping X to a matrix with N*C rows, we can reduce our task to
-   * re-ordering rows (followed by the obvious reshape to achieve the
-   * required output shape with C rows).
-   *
-   * The difficult part is to obtain the permutation matrix required
-   * for re-ordering the rows. In this case, since we want to bring the
-   * ith channels from all rows together, we will need a column vector
-   * of the following form:
-   * [1, 1+C, 1+2C, ..., 1+(N-1)C,
-   *  2, 2+C, ..., 2+(N-1)C,
-   *  3, 3+C, ..., 3+(N-1)C,
-   *  .
-   *  .
-   *  .
-   *  C, 2C, ..., NC]'
-   * This vector can be produced via an outer call.
-   */
+  # This is an easy reshape because the channels remain intact. By
+  # reshaping X to a matrix with N*C rows, we can reduce our task to
+  # re-ordering rows (followed by the obvious reshape to achieve the
+  # required output shape with C rows).
+  #
+  # The difficult part is to obtain the permutation matrix required
+  # for re-ordering the rows. In this case, since we want to bring the
+  # ith channels from all rows together, we will need a column vector
+  # of the following form:
+  # [1, 1+C, 1+2C, ..., 1+(N-1)C,
+  #  2, 2+C, ..., 2+(N-1)C,
+  #  3, 3+C, ..., 3+(N-1)C,
+  #  .
+  #  .
+  #  .
+  #  C, 2C, ..., NC]'
+  # This vector can be produced via an outer call.
   col_idx = outer(seq(1,C), C*t(seq(0,N-1)), "+")
 
-  /*
-   * Generate the permutation matrix by:
-   * - reshaping the result of outer into a col
-   * - invoking table
-   */
+  # Generate the permutation matrix by:
+  # - reshaping the result of outer into a col
+  # - invoking table
   permut = table(seq(1, N*C), matrix(col_idx, rows=N*C, cols=1), N*C, N*C)
 
-  /*
-   * Generate the output by:
-   * - pre-multiplying the (reshaped) X with the permutation matrix
-   * - reshape to get the output shape with C rows
-   */
+  # Generate the output by:
+  # - pre-multiplying the (reshaped) X with the permutation matrix
+  # - reshape to get the output shape with C rows
   out = matrix(permut %*% matrix(X, rows=N*C, cols=D), rows=C, cols=N*D)
 }
 

Reply via email to