Repository: systemml Updated Branches: refs/heads/master 428839952 -> 9e6715daf
[MINOR] Cleanup `nn` library. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5da30b51 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5da30b51 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5da30b51 Branch: refs/heads/master Commit: 5da30b51b4afa9b88bb5b91941e7f956a23becc8 Parents: 4288399 Author: Mike Dusenberry <[email protected]> Authored: Fri Jun 30 16:31:51 2017 -0700 Committer: Mike Dusenberry <[email protected]> Committed: Fri Jun 30 16:31:51 2017 -0700 ---------------------------------------------------------------------- scripts/nn/test/test.dml | 6 ++-- scripts/nn/util.dml | 69 +++++++++++++++++++------------------------ 2 files changed, 34 insertions(+), 41 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/5da30b51/scripts/nn/test/test.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml index a1bb6cd..1364583 100644 --- a/scripts/nn/test/test.dml +++ b/scripts/nn/test/test.dml @@ -835,19 +835,19 @@ top_k = function() { 2 3 4 1", rows=2, cols=4) # test top_1 - print("Case 1: test top_1") + print(" - Testing top_1.") [values_top1, indices_top1] = util::top_k(X, 1) check_values_top1 = test_util::check_all_equal(values_top1, expected_values_top1) check_indices_top1 = test_util::check_all_equal(indices_top1, expected_indices_top1) # test top_2 - print("Case 2: test top_2") + print(" - Testing top_2.") [values_top2, indices_top2] = util::top_k(X, 2) check_values_top2 = test_util::check_all_equal(values_top2, expected_values_top2) check_indices_top2 = test_util::check_all_equal(indices_top2, expected_indices_top2) # test top_All - print("Case 3: test top_All") + print(" - Testing top_All.") [values_topAll, indices_topAll] = util::top_k(X, 4) check_values_topAll = test_util::check_all_equal(values_topAll, expected_values_topAll) check_indices_topAll = test_util::check_all_equal(indices_topAll, expected_indices_topAll) http://git-wip-us.apache.org/repos/asf/systemml/blob/5da30b51/scripts/nn/util.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml index 102a507..ba345c5 100644 --- a/scripts/nn/util.dml +++ b/scripts/nn/util.dml @@ -216,57 +216,50 @@ threshold = function(matrix[double] X, double thresh) out = X > thresh } -/* - * Reshape util for tensors in NCHW format. - * Transposes the 1st and 2nd dimensions. - */ -transpose_NCHW_to_CNHW = function(matrix[double] X, int C) return (matrix[double] out){ +transpose_NCHW_to_CNHW = function(matrix[double] X, int C) + return (matrix[double] out) { /* + * Reshape util for tensors in NCHW format. + * Transposes the 1st and 2nd dimensions. + * * Inputs: - * - X: Input with N rows and channels flattened within each row in - * channel-major format (NCHW). + * - X: Inputs, of shape (N, C*H*W). * - C: Number of channels (dimensionality of depth). * * Outputs: - * - out: Transposed output with C rows. + * - out: Outputs with the N and C axes transposed, of + * shape (C, N*H*W). */ - N = nrow(X) D = ncol(X) / C - /* - * This is an easy reshape because the channels remain intact. By - * reshaping X to a matrix with N*C rows, we can reduce our task to - * re-ordering rows (followed by the obvious reshape to achieve the - * required output shape with C rows). - * - * The difficult part is to obtain the permutation matrix required - * for re-ordering the rows. In this case, since we want to bring the - * ith channels from all rows together, we will need a column vector - * of the following form: - * [1, 1+C, 1+2C, ..., 1+(N-1)C, - * 2, 2+C, ..., 2+(N-1)C, - * 3, 3+C, ..., 3+(N-1)C, - * . - * . - * . - * C, 2C, ..., NC]' - * This vector can be produced via an outer call. - */ + # This is an easy reshape because the channels remain intact. By + # reshaping X to a matrix with N*C rows, we can reduce our task to + # re-ordering rows (followed by the obvious reshape to achieve the + # required output shape with C rows). + # + # The difficult part is to obtain the permutation matrix required + # for re-ordering the rows. In this case, since we want to bring the + # ith channels from all rows together, we will need a column vector + # of the following form: + # [1, 1+C, 1+2C, ..., 1+(N-1)C, + # 2, 2+C, ..., 2+(N-1)C, + # 3, 3+C, ..., 3+(N-1)C, + # . + # . + # . + # C, 2C, ..., NC]' + # This vector can be produced via an outer call. col_idx = outer(seq(1,C), C*t(seq(0,N-1)), "+") - /* - * Generate the permutation matrix by: - * - reshaping the result of outer into a col - * - invoking table - */ + # Generate the permutation matrix by: + # - reshaping the result of outer into a col + # - invoking table permut = table(seq(1, N*C), matrix(col_idx, rows=N*C, cols=1), N*C, N*C) - /* - * Generate the output by: - * - pre-multiplying the (reshaped) X with the permutation matrix - * - reshape to get the output shape with C rows - */ + # Generate the output by: + # - pre-multiplying the (reshaped) X with the permutation matrix + # - reshape to get the output shape with C rows out = matrix(permut %*% matrix(X, rows=N*C, cols=D), rows=C, cols=N*D) }
