Repository: systemml Updated Branches: refs/heads/master f516e4bdc -> 345682404
[SYSTEMML-1728] Reshape util to convert tensors in NCHW to CNHW format Closes #552. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/34568240 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/34568240 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/34568240 Branch: refs/heads/master Commit: 345682404c3fb1348484c375e811ee3f5805a691 Parents: f516e4b Author: prithvirajsen <[email protected]> Authored: Thu Jun 22 11:05:24 2017 -0700 Committer: Glenn Weidner <[email protected]> Committed: Thu Jun 22 11:05:25 2017 -0700 ---------------------------------------------------------------------- scripts/nn/test/run_tests.dml | 1 + scripts/nn/test/test.dml | 33 ++++++++++++++++++++++++ scripts/nn/util.dml | 52 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/34568240/scripts/nn/test/run_tests.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml index 5f3ca6e..cca0d0d 100644 --- a/scripts/nn/test/run_tests.dml +++ b/scripts/nn/test/run_tests.dml @@ -97,6 +97,7 @@ test::max_pool2d() test::padding() test::tanh() test::threshold() +test::transpose_NCHW_to_CNHW() print("---") print("Other tests complete -- look for any ERRORs or WARNINGs.") http://git-wip-us.apache.org/repos/asf/systemml/blob/34568240/scripts/nn/test/test.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml index 37f9f73..b0899e3 100644 --- a/scripts/nn/test/test.dml +++ b/scripts/nn/test/test.dml @@ -488,6 +488,39 @@ padding = function() { } } +transpose_NCHW_to_CNHW = function() { + /* + * Test for `transpose_NCHW_to_CNHW` function. + */ + print("Testing transpose_NCHW_to_CNHW function.") + + # Generate data + N = 2 + C = 3 + H = 4 + W = 5 + X = matrix(seq(1, N*C*H*W), rows=N, cols=C*H*W) + + out = util::transpose_NCHW_to_CNHW(X, C) + + target = + matrix("1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 + 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 + 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 + 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 + 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 + 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120", + rows=C, cols=N*H*W) + + # Equivalency check + for (i in 1:nrow(out)) { + for(j in 1:ncol(out)) { + rel_error = test_util::check_rel_error(as.scalar(out[i,j]), + as.scalar(target[i,j]), 1e-10, 1e-12) + } + } +} + max_pool2d = function() { /* * Test for the 2D max pooling functions. http://git-wip-us.apache.org/repos/asf/systemml/blob/34568240/scripts/nn/util.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml index c4da16a..329f22f 100644 --- a/scripts/nn/util.dml +++ b/scripts/nn/util.dml @@ -216,3 +216,55 @@ threshold = function(matrix[double] X, double thresh) out = X > thresh } +/* + * Reshape util for tensors in NCHW format. + * Transposes the 1st and 2nd dimensions. + */ +transpose_NCHW_to_CNHW = function(matrix[double] X, int C) return (matrix[double] out){ + /* + * Inputs: + * - X: Input with N rows and channels flattened within each row in + * channel-major format (NCHW). + * - C: Number of channels (dimensionality of depth). + * + * Outputs: + * - out: Transposed output with C rows. + */ + N = nrow(X) + D = ncol(X) / C + + /* + * This is an easy reshape because the channels remain intact. By + * reshaping X to a matrix with N*C rows, we can reduce our task to + * re-ordering rows (followed by the obvious reshape to achieve the + * required output shape with C rows). + * + * The difficult part is to obtain the permutation matrix required + * for re-ordering the rows. In this case, since we want to bring the + * ith channels from all rows together, we will need a column vector + * of the following form: + * [1, 1+C, 1+2C, ..., 1+(N-1)C, + * 2, 2+C, ..., 2+(N-1)C, + * 3, 3+C, ..., 3+(N-1)C, + * . + * . + * . + * C, 2C, ..., NC]' + * This vector can be produced via an outer call. + */ + col_idx = outer(seq(1,C), C*t(seq(0,N-1)), "+") + + /* + * Generate the permutation matrix by: + * - reshaping the result of outer into a col + * - invoking table + */ + permut = table(seq(1, N*C), matrix(col_idx, rows=N*C, cols=1), N*C, N*C) + + /* + * Generate the output by: + * - pre-multiplying the (reshaped) X with the permutation matrix + * - reshape to get the output shape with C rows + */ + out = matrix(permut %*% matrix(X, rows=N*C, cols=D), rows=C, cols=N*D) +}
