Repository: systemml Updated Branches: refs/heads/master 192fb5582 -> 45eec2d25
[MINOR] Cleanup tensor index computation in convolution operations Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c9977b73 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c9977b73 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c9977b73 Branch: refs/heads/master Commit: c9977b736b38e8c2b9ec34645f062337e1273c94 Parents: 192fb55 Author: Matthias Boehm <[email protected]> Authored: Tue Jan 9 20:40:31 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Tue Jan 9 20:40:31 2018 -0800 ---------------------------------------------------------------------- .../runtime/matrix/data/LibMatrixDNNHelper.java | 131 ++++++++++--------- .../data/LibMatrixDNNPoolingBackwardHelper.java | 32 ++--- .../data/LibMatrixDNNRotate180Helper.java | 13 +- 3 files changed, 92 insertions(+), 84 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/c9977b73/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java index a5564cd..e99c2ef 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java @@ -38,6 +38,16 @@ import org.apache.sysml.utils.Statistics; public class LibMatrixDNNHelper { + protected static class CellIndex3 { + public int ix1; + public int ix2; + public int ix3; + @Override + public String toString() { + return "("+ix1+", "+ix2+", "+ix3+")"; + } + } + // *********************************** low-level runtime operator selection *********************************************** // *********************************** based on runtime properties (sparsity, native, etc) ******************************** // These methods help reduce branch miss predictions and instruction-cache misses. @@ -275,18 +285,26 @@ public class LibMatrixDNNHelper { // *********************************** utility methods ****************************************************** + protected static CellIndex3 computeTensorIndexes(int j, int H, int W) { + return computeTensorIndexes(j, H, W, new CellIndex3()); + } + /** - * Computes tensor indexes from column index such that column index is equal to ret[0]*HW + ret[1]*W + ret[2] + * Computes tensor indexes from a linearized column index such that + * the column index is equal to ix1*NM + ix2*M + ix3 * * @param j column index - * @param ret tensor indexes - * @param H second last dimension - * @param W last dimension + * @param N second last dimension + * @param M last dimension + * @param ret output object for reuse + * @return tensor indexes */ - static void computeTensorIndexes(int j, int [] ret, int H, int W) { - ret[0] = j / (H*W); - ret[1] = (j - ret[0]*(H*W))/W; - ret[2] = j % W; + protected static CellIndex3 computeTensorIndexes(int j, int N, int M, CellIndex3 ret) { + int tmp = j / M; + ret.ix1 = tmp / N; + ret.ix2 = tmp % N; + ret.ix3 = j % M; + return ret; } //Split a filter of size [K, CRS] into c filters of [K, RS] @@ -310,23 +328,21 @@ public class LibMatrixDNNHelper { } } else { + SparseBlock sblock = _params.input2.sparseBlock; + CellIndex3 ix = new CellIndex3(); for(int k = 0; k < _params.K; k++) { - if( !_params.input2.sparseBlock.isEmpty(k) ) { - int [] tensorIndexes = new int[3]; - // Find maxIndex - int apos = _params.input2.sparseBlock.pos(k); - int alen = _params.input2.sparseBlock.size(k); - int[] aix = _params.input2.sparseBlock.indexes(k); - double[] avals = _params.input2.sparseBlock.values(k); - for(int j=apos; j<apos+alen; j++) { - computeTensorIndexes(aix[j], tensorIndexes, _params.R, _params.S); - if(c != tensorIndexes[0]) - continue; - int r = tensorIndexes[1]; - int s = tensorIndexes[2]; - outputArr[k*RS + r*S + s] = avals[j]; - nnz += outputArr[k*RS + r*S + s] != 0 ? 1 : 0; - } + if( sblock.isEmpty(k) ) continue; + // Find maxIndex + int apos = sblock.pos(k); + int alen = sblock.size(k); + int[] aix = sblock.indexes(k); + double[] avals = sblock.values(k); + for(int j=apos; j<apos+alen; j++) { + ix = computeTensorIndexes(aix[j], _params.R, _params.S, ix); + if(c != ix.ix1) + continue; + outputArr[k*RS + ix.ix2*S + ix.ix3] = avals[j]; + nnz += outputArr[k*RS + ix.ix2*S + ix.ix3] != 0 ? 1 : 0; } } } @@ -415,8 +431,6 @@ public class LibMatrixDNNHelper { * @throws DMLRuntimeException if error occurs */ static int getMaxIndexSparse(int p, int q, int inputOffset, int n, int c, MatrixBlock input, ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException { - int [] tensorIndexes = new int[3]; - int start_h = params.start_indexes_h[p]; int end_h = params.end_indexes_h[p]; int start_w = params.start_indexes_w[q]; @@ -432,18 +446,19 @@ public class LibMatrixDNNHelper { // if start_index_h < 0 || start_index_w < 0 || end_index_h >= params.H || end_index_w >= params.W // input.isEmptyBlock() check is done by the caller - if( !input.sparseBlock.isEmpty(n) ) { + CellIndex3 ix = new CellIndex3(); + SparseBlock sblock = input.sparseBlock; + if( !sblock.isEmpty(n) ) { // Find maxIndex - int apos = input.sparseBlock.pos(n); - int alen = input.sparseBlock.size(n); - int[] aix = input.sparseBlock.indexes(n); - double[] avals = input.sparseBlock.values(n); + int apos = sblock.pos(n); + int alen = sblock.size(n); + int[] aix = sblock.indexes(n); + double[] avals = sblock.values(n); for(int j=apos; j<apos+alen; j++) { - computeTensorIndexes(aix[j], tensorIndexes, params.H, params.W); - if(c != tensorIndexes[0]) - continue; - int h = tensorIndexes[1]; - int w = tensorIndexes[2]; + ix = computeTensorIndexes(aix[j], params.H, params.W, ix); + if(c != ix.ix1) continue; + int h = ix.ix2; + int w = ix.ix3; if(h >= start_h && h < end_h && w >= start_w && w < end_w) { double val = performReluBackward && avals[j] < 0 ? 0 : avals[j]; if(maxVal < val) { @@ -514,32 +529,26 @@ public class LibMatrixDNNHelper { if(!input.isEmptyBlock()) { int outOffset = outputN*params.C*params.H*params.W; int HW = params.H*params.W; - int [] tensorIndexes = new int[3]; + CellIndex3 ix = new CellIndex3(); + SparseBlock sblock = input.sparseBlock; for(int i = 0; i < input.getNumRows(); i++) { - if( !input.sparseBlock.isEmpty(i) ) { - computeTensorIndexes(i, tensorIndexes, params.P, params.Q); - int p = tensorIndexes[1]; - int q = tensorIndexes[2]; - int tmpP = p*params.stride_h - params.pad_h; - int tmpQ = q*params.stride_w - params.pad_w; - if(tensorIndexes[0] != 0) - throw new DMLRuntimeException("Incorrect tensor indexes: " + tensorIndexes[0] + " != 0 <" + p + " " + q + " " + tensorIndexes[0] + params.P + " " + params.Q + ">"); - - int apos = input.sparseBlock.pos(i); - int alen = input.sparseBlock.size(i); - int[] aix = input.sparseBlock.indexes(i); - double[] avals = input.sparseBlock.values(i); - for(int j = apos; j < apos+alen; j++) { - computeTensorIndexes(aix[j], tensorIndexes, params.R, params.S); - int c = tensorIndexes[0]; - int r = tensorIndexes[1]; - int s = tensorIndexes[2]; - int h = tmpP + r; - int w = tmpQ + s; - if(h >= 0 && h < params.H && w >= 0 && w < params.W) { - int outIndex = outOffset + c*HW + h*params.W + w; - outputArray[outIndex] += avals[j]; - } + if( sblock.isEmpty(i) ) continue; + ix = computeTensorIndexes(i, params.P, params.Q, ix); + int tmpP = ix.ix2*params.stride_h - params.pad_h; + int tmpQ = ix.ix3*params.stride_w - params.pad_w; + if(ix.ix1 != 0) + throw new DMLRuntimeException("Incorrect tensor indexes: "+ ix + ", " + params.P + " " + params.Q); + int apos = sblock.pos(i); + int alen = sblock.size(i); + int[] aix = sblock.indexes(i); + double[] avals = sblock.values(i); + for(int j = apos; j < apos+alen; j++) { + ix = computeTensorIndexes(aix[j], params.R, params.S, ix); + int h = tmpP + ix.ix2; + int w = tmpQ + ix.ix3; + if(h >= 0 && h < params.H && w >= 0 && w < params.W) { + int outIndex = outOffset + ix.ix1*HW + h*params.W + w; + outputArray[outIndex] += avals[j]; } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/c9977b73/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java index 6e5e978..3dfb545 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java @@ -21,6 +21,8 @@ package org.apache.sysml.runtime.matrix.data; import java.util.Arrays; import java.util.concurrent.Callable; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNNHelper.CellIndex3; + /** * This class contains the set of operators used for performing pooling backward */ @@ -99,24 +101,22 @@ public class LibMatrixDNNPoolingBackwardHelper { @Override public Long call() throws Exception { + CellIndex3 ix = new CellIndex3(); double[] out = output.getDenseBlockValues(); + SparseBlock sblock = dout.sparseBlock; for(int n = _rl; n < _ru; n++) { - if( !dout.sparseBlock.isEmpty(n) ) { - int [] tensorIndexes = new int[3]; - int apos = dout.sparseBlock.pos(n); - int alen = dout.sparseBlock.size(n); - int[] aix = dout.sparseBlock.indexes(n); - double[] avals = dout.sparseBlock.values(n); - for(int j = apos; j < apos+alen; j++) { - LibMatrixDNNHelper.computeTensorIndexes(aix[j], tensorIndexes, P, Q); - int c = tensorIndexes[0]; - int p = tensorIndexes[1]; - int q = tensorIndexes[2]; - final int inputOffset = n*CHW + c*HW; - int maxIndex = LibMatrixDNNHelper.getMaxIndex(p, q, inputOffset, inputArray, _params, performReluBackward); - if(maxIndex != -1) - out[maxIndex] += avals[j]; - } + if( sblock.isEmpty(n) ) continue; + int apos = sblock.pos(n); + int alen = sblock.size(n); + int[] aix = sblock.indexes(n); + double[] avals = sblock.values(n); + for(int j = apos; j < apos+alen; j++) { + ix = LibMatrixDNNHelper.computeTensorIndexes(aix[j], P, Q, ix); + final int inputOffset = n*CHW + ix.ix1*HW; + int maxIndex = LibMatrixDNNHelper.getMaxIndex(ix.ix2, ix.ix3, + inputOffset, inputArray, _params, performReluBackward); + if(maxIndex != -1) + out[maxIndex] += avals[j]; } } //thread-local nnz maintenance http://git-wip-us.apache.org/repos/asf/systemml/blob/c9977b73/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java index 1ca6ad1..74e2baa 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java @@ -18,6 +18,8 @@ */ package org.apache.sysml.runtime.matrix.data; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNNHelper.CellIndex3; + /** * This class contains the different implementation of rotate180 operation */ @@ -90,21 +92,18 @@ public class LibMatrixDNNRotate180Helper { if( sblock==null || sblock.isEmpty(inputN) ) return; + CellIndex3 ix = new CellIndex3(); int outputOffset = outputN*params.P*params.Q; - int [] tensorIndexes = new int[3]; int apos = sblock.pos(inputN); int alen = sblock.size(inputN); int[] aix = sblock.indexes(inputN); double[] avals = sblock.values(inputN); for(int j = apos; j < apos+alen; j++) { - LibMatrixDNNHelper.computeTensorIndexes(aix[j], tensorIndexes, params.P, params.Q); - int k = tensorIndexes[0]; - int p = tensorIndexes[1]; - int q = tensorIndexes[2]; + ix = LibMatrixDNNHelper.computeTensorIndexes(aix[j], params.P, params.Q, ix); if( trans ) - out.appendValue(k, outputOffset + p*params.Q + q, avals[j]); + out.appendValue(ix.ix1, outputOffset + ix.ix2*params.Q + ix.ix3, avals[j]); else - out.appendValue(outputOffset + p*params.Q + q, k, avals[j]); + out.appendValue(outputOffset + ix.ix2*params.Q + ix.ix3, ix.ix1, avals[j]); } } }
