[SYSTEMML-2142] Performance sparse im2col (constraint handling) This patch improves the performance of sparse im2col operations (general case for arbitrary pad and stride) because it dominates the runtime of ultra-sparse conv2d operations by more than an order-of-magnitude. On a scenario of conv2d operations over ultra sparse inputs with sparsity 1e-5, this patch improves performance by more than 2.5x.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/401b7965 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/401b7965 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/401b7965 Branch: refs/heads/master Commit: 401b796578fb9dbdc2b835ce84263ea61910c9e3 Parents: 08c7f00 Author: Matthias Boehm <mboe...@gmail.com> Authored: Fri Feb 9 22:23:42 2018 -0800 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Fri Feb 9 22:23:42 2018 -0800 ---------------------------------------------------------------------- .../runtime/matrix/data/LibMatrixDNNIm2Col.java | 46 +++++++++----------- 1 file changed, 21 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/401b7965/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java index 4af4933..d1f4b30 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java @@ -187,7 +187,7 @@ public class LibMatrixDNNIm2Col { avals[j], R, S, RS, P, trans); else appendInputValueToIm2colOutput(output, ix.ix1, ix.ix2, ix.ix3, avals[j], - R, S, P, Q, stride_h, stride_w, pad_h, pad_w, trans); + R, S, RS, P, Q, stride_h, stride_w, pad_h, pad_w, trans); } output.sortSparseRows(); @@ -198,51 +198,47 @@ public class LibMatrixDNNIm2Col { * Appends the value corresponding to the given [, cInput, hInput, wInput] to the appropriate im2col location of the output * * @param output output matrix block - * @param cInput input channel index (zero-based) - * @param hInput input height index (zero-based) - * @param wInput input width index (zero-based) + * @param c input channel index (zero-based) + * @param h input height index (zero-based) + * @param w input width index (zero-based) * @param value input value * @param R filter height * @param S filter width + * @param RS R*S * @param P output height * @param Q output width * @param stride_h stride height * @param stride_w stride width * @param pad_h pad height * @param pad_w pad width + * @param trans transposed output */ - private static void appendInputValueToIm2colOutput(MatrixBlock output, int cInput, int hInput, int wInput, double value, - int R, int S, int P, int Q, int stride_h, int stride_w, int pad_h, int pad_w, boolean trans) { - int RS = R*S; + private static void appendInputValueToIm2colOutput(MatrixBlock output, int c, int h, int w, double value, + int R, int S, int RS, int P, int Q, int stride_h, int stride_w, int pad_h, int pad_w, boolean trans) + { // For the given h,w index, insert avals[j] into respective r,s,p,q locations - // Constraints: for(int r = 0; r < R; r++) { if(0 <= p && p < P && (hInput - r + pad_h) % stride_h == 0) { ... } } // Constraint 1: p >= 0 and p = (hInput - r + pad_h) / stride_h // Therefore, r <= hInput + pad_h // Constraint 2: p < P and p = (hInput - r + pad_h) / stride_h // Therefore, hInput + pad_h - P*stride_h < r // Math.max(0, hInput + pad_h - P*stride_h + 1) <= r <= Math.min(R-1, hInput + pad_h) - int rMin = Math.max(0, hInput + pad_h - P*stride_h + 1); - int rMax = Math.min(R-1, hInput + pad_h); - int sMin = Math.max(0, wInput + pad_w - Q*stride_w + 1); - int sMax = Math.min(S-1, wInput + pad_w); + int rMin = Math.max(0, h + pad_h - P*stride_h + 1); + int rMax = Math.min(R-1, h + pad_h); + int sMin = Math.max(0, w + pad_w - Q*stride_w + 1); + int sMax = Math.min(S-1, w + pad_w); // Constraint 3: (hInput - r + pad_h) % stride_h == 0 - while((hInput - rMin + pad_h) % stride_h != 0 && rMin <= rMax) rMin++; - while((wInput - sMin + pad_w) % stride_w != 0 && sMin <= sMax) sMin++; + rMin += Math.min((h-rMin+pad_h) % stride_h, rMax-rMin+1); + sMin += Math.min((w-sMin+pad_w) % stride_w, sMax-sMin+1); - for(int r = rMin; r <= rMax; r += stride_h) { + for( int r=rMin, ix=c*RS+rMin*S; r<=rMax; r+=stride_h, ix+=stride_h*S ) { // Only append value if h == hInput, where h = (r - pad_h) + p*stride_h and 0 <= p < P // Therefore, p = (hInput - r + pad_h) / stride_h. Use the same logic for q. - final int p = (hInput - r + pad_h) / stride_h; - final int pQ = p*Q; - final int outRowIndex = cInput*RS + r*S; - for(int s = sMin; s <= sMax; s += stride_w) { - int q = (wInput - s + pad_w) / stride_w; - // chw -> [crs, pq] - if( trans ) - output.appendValue(pQ + q, outRowIndex + s, value); - else - output.appendValue(outRowIndex + s, pQ + q, value); + final int pQ = (h - r + pad_h) / stride_h * Q; + for(int s=sMin, ws=w-sMin+pad_w; s<=sMax; s+=stride_w, ws-=stride_w) { + int q = ws / stride_w; // chw -> [crs, pq] + output.appendValue(trans ? pQ+q : ix+s, + trans ? ix+s : pQ+q, value); } } }