[SYSTEMML-2142] Performance sparse im2col (constraint handling)

This patch improves the performance of sparse im2col operations (general
case for arbitrary pad and stride) because it dominates the runtime of
ultra-sparse conv2d operations by more than an order-of-magnitude. On a
scenario of conv2d operations over ultra sparse inputs with sparsity
1e-5, this patch improves performance by more than 2.5x.
 

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/401b7965
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/401b7965
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/401b7965

Branch: refs/heads/master
Commit: 401b796578fb9dbdc2b835ce84263ea61910c9e3
Parents: 08c7f00
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri Feb 9 22:23:42 2018 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Feb 9 22:23:42 2018 -0800

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixDNNIm2Col.java | 46 +++++++++-----------
 1 file changed, 21 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/401b7965/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
index 4af4933..d1f4b30 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
@@ -187,7 +187,7 @@ public class LibMatrixDNNIm2Col {
                                                avals[j], R, S, RS, P, trans);
                                else
                                        appendInputValueToIm2colOutput(output, 
ix.ix1, ix.ix2, ix.ix3, avals[j], 
-                                               R, S, P, Q, stride_h, stride_w, 
pad_h, pad_w, trans);
+                                               R, S, RS, P, Q, stride_h, 
stride_w, pad_h, pad_w, trans);
                        }
                        
                        output.sortSparseRows();
@@ -198,51 +198,47 @@ public class LibMatrixDNNIm2Col {
         * Appends the value corresponding to the given [, cInput, hInput, 
wInput] to the appropriate im2col location of the output
         * 
         * @param output output matrix block
-        * @param cInput input channel index (zero-based)
-        * @param hInput input height index (zero-based)
-        * @param wInput input width index (zero-based)
+        * @param c input channel index (zero-based)
+        * @param h input height index (zero-based)
+        * @param w input width index (zero-based)
         * @param value input value
         * @param R filter height
         * @param S filter width
+        * @param RS R*S
         * @param P output height
         * @param Q output width
         * @param stride_h stride height
         * @param stride_w stride width
         * @param pad_h pad height
         * @param pad_w pad width
+        * @param trans transposed output
         */
-       private static void appendInputValueToIm2colOutput(MatrixBlock output, 
int cInput, int hInput, int wInput, double value, 
-                       int R, int S, int P, int Q, int stride_h, int stride_w, 
int pad_h, int pad_w, boolean trans) {
-               int RS = R*S;
+       private static void appendInputValueToIm2colOutput(MatrixBlock output, 
int c, int h, int w, double value, 
+               int R, int S, int RS, int P, int Q, int stride_h, int stride_w, 
int pad_h, int pad_w, boolean trans)
+       {
                // For the given h,w index, insert avals[j] into respective 
r,s,p,q locations
-               
                // Constraints: for(int r = 0; r < R; r++) { if(0 <= p && p < P 
&& (hInput - r + pad_h) % stride_h == 0) { ... } }
                // Constraint 1: p >= 0 and p = (hInput - r + pad_h)  / stride_h
                // Therefore,  r <= hInput + pad_h 
                // Constraint 2: p < P and p = (hInput - r + pad_h)  / stride_h
                // Therefore,  hInput + pad_h - P*stride_h < r
                // Math.max(0, hInput + pad_h - P*stride_h + 1) <= r <= 
Math.min(R-1, hInput + pad_h)
-               int rMin = Math.max(0, hInput + pad_h - P*stride_h + 1);
-               int rMax = Math.min(R-1, hInput + pad_h);
-               int sMin = Math.max(0, wInput + pad_w - Q*stride_w + 1);
-               int sMax = Math.min(S-1, wInput + pad_w);
+               int rMin = Math.max(0, h + pad_h - P*stride_h + 1);
+               int rMax = Math.min(R-1, h + pad_h);
+               int sMin = Math.max(0, w + pad_w - Q*stride_w + 1);
+               int sMax = Math.min(S-1, w + pad_w);
                // Constraint 3: (hInput - r + pad_h) % stride_h == 0
-               while((hInput - rMin + pad_h) % stride_h != 0 && rMin <= rMax) 
rMin++;
-               while((wInput - sMin + pad_w) % stride_w != 0 && sMin <= sMax) 
sMin++;
+               rMin += Math.min((h-rMin+pad_h) % stride_h, rMax-rMin+1);
+               sMin += Math.min((w-sMin+pad_w) % stride_w, sMax-sMin+1);
                
-               for(int r = rMin; r <= rMax; r += stride_h) {
+               for( int r=rMin, ix=c*RS+rMin*S; r<=rMax; r+=stride_h, 
ix+=stride_h*S ) {
                        // Only append value if h == hInput, where h = (r - 
pad_h) + p*stride_h and 0 <= p < P
                        // Therefore, p = (hInput - r + pad_h)  / stride_h. Use 
the same logic for q.
-                       final int p = (hInput - r + pad_h)  / stride_h;
-                       final int pQ = p*Q;
-                       final int outRowIndex = cInput*RS + r*S;
-                       for(int s = sMin; s <= sMax; s += stride_w) {
-                               int q = (wInput - s + pad_w)  / stride_w;
-                               // chw -> [crs, pq]
-                               if( trans )
-                                       output.appendValue(pQ + q, outRowIndex 
+ s, value);
-                               else
-                                       output.appendValue(outRowIndex + s, pQ 
+ q, value);
+                       final int pQ = (h - r + pad_h) / stride_h * Q;
+                       for(int s=sMin, ws=w-sMin+pad_w; s<=sMax; s+=stride_w, 
ws-=stride_w) {
+                               int q = ws / stride_w; // chw -> [crs, pq]
+                               output.appendValue(trans ? pQ+q : ix+s,
+                                       trans ? ix+s : pQ+q, value);
                        }
                }
        }

Reply via email to