[SYSTEMML-1970] Performance conv2d backward filter (dense/sparse-sparse)

This patch makes a number of smaller performance improvements to the
existing conv2d backward filter function. 

1) Dense/sparse-sparse: Whenever the rhs is sparse and the lhs has a
higher sparsity in terms of nnz/cells, we now flip the computation to
t(t(rhs)%*%t(lhs)), where the t(rhs) and t(lhs) are pckibacked into the
im2col and rotate calls, and the final transpose collapses with the
existing transpose add to a simple add.

2) Avoid unnecessary allocations in sparse-dense and sparse-sparse
matrix multiplications (limit temporary arrays to min(blocksize,ru-rl)).

3) Avoid class locking: So far the transAdd used class-wide
synchronization which causes unnecessary contention in scenarios where
multiple builtin functions are ran concurrently (e.g., parfor, jmlc). We
now simply synchronize over the monitor of the allocated output.

On an end-to-end cnn application w/ native blas enabled, this patch
improved performance from 349s to 335s.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/dd513ffe
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/dd513ffe
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/dd513ffe

Branch: refs/heads/master
Commit: dd513ffee87a4efaf1d5f771a8d0ee4ccccbae67
Parents: 591a0f7
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Thu Oct 26 20:26:14 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Thu Oct 26 23:38:19 2017 -0700

----------------------------------------------------------------------
 .../LibMatrixDNNConv2dBackwardFilterHelper.java | 90 +++++++++++++++++---
 .../runtime/matrix/data/LibMatrixDNNHelper.java |  2 +
 .../matrix/data/LibMatrixDNNIm2ColHelper.java   | 27 +++---
 .../runtime/matrix/data/LibMatrixMult.java      |  4 +-
 4 files changed, 95 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/dd513ffe/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
index f0fd002..9698725 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
@@ -124,21 +124,83 @@ public class LibMatrixDNNConv2dBackwardFilterHelper {
                }
        }
        
-       private static synchronized void inplaceTransAdd(double[] a, 
ConvolutionParameters params) {
-               // Perform transposed addition: output of size [K, CRS] += 
input of size [CRS,K]
-               double [] c = params.output.denseBlock;
-               final int CRS = params.C*params.R*params.S, K = params.K;
-               final int blocksizeIJ = 128; //L2 cache
+       public static class Conv2dBackwardFilterTrans implements Callable<Long> 
{
+               private final int _rl, _ru; 
+               private final ConvolutionParameters _params; 
+               
+               public Conv2dBackwardFilterTrans(int rl, int ru, 
ConvolutionParameters params) {
+                       _rl = rl; _ru = ru;
+                       _params = params;
+               }
                
-               //cache-conscious blocked execution
-               for( int bi=0; bi<CRS; bi+=blocksizeIJ )
-                       for( int bj=0; bj<K; bj+=blocksizeIJ ) {
-                               int bimin = Math.min(bi+blocksizeIJ, CRS);
-                               int bjmin = Math.min(bj+blocksizeIJ, K);
-                               //core transpose add operation
-                               for(int i=bi, aix=bi*K; i<bimin; i++, aix+=K)
-                                       for(int j=bj, cix=i+bj*CRS; j<bjmin; 
j++, cix+=CRS)
-                                               c[cix] += a[aix+j];
+               @Override
+               public Long call() throws Exception {
+                       int PQ = _params.P*_params.Q, K = _params.K, CRS = 
_params.C*_params.R*_params.S;
+                       MatrixBlock dout = _params.input2;
+                       MatrixBlock im2ColOutBlock = new MatrixBlock(PQ, CRS, 
false).allocateBlock();
+                       MatrixBlock outRotate = new MatrixBlock(K, PQ, 
dout.sparse).allocateBlock();
+                       MatrixBlock outMM = new MatrixBlock(K, CRS, 
false).allocateBlock();
+                       
+                       Im2colWorker im2ColWorker = Im2colWorker.getWorker( 
_params.input1, im2ColOutBlock, _params, true, true);
+                       Rotate180Worker rotate180Worker = 
Rotate180Worker.getWorker( dout, outRotate, _params, true, true);
+                       double [] partRet = new double[CRS*_params.K];
+                       long time1 = 0; long time2 = 0;
+                       for(int n = _rl; n < _ru; n++) {
+                               // rotate180(dout[n,]) => dout_reshaped
+                               rotate180Worker.execute(n, 0);
+                               
+                               // im2col(input) => _im2ColOutBlock
+                               long t1 = DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+                               im2ColWorker.execute(n);
+                               long t2 = DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+                               
+                               outMM.reset(K, CRS, false);
+                               //Timing time = new Timing(true);
+                               
LibMatrixDNNHelper.singleThreadedMatMult(outRotate, im2ColOutBlock, 
+                                       outMM, !outRotate.sparse, 
!im2ColOutBlock.sparse, _params);
+                               long t3 = DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+                               
+                               if( !outMM.isEmptyBlock() ) //accumulate row 
results
+                                       
LibMatrixMult.vectAdd(outMM.getDenseBlock(), partRet, 0, 0, K*CRS);
+                               
+                               if(DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS) {
+                                       time1 += t2 - t1;
+                                       time2 += t3 - t2;
+                               }
+                       }
+                       //no need to transpose because t(t(out)) cancel out
+                       inplaceAdd(partRet, _params);
+                       if(DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS) {
+                               
LibMatrixDNN.loopedConvBwdFilterIm2ColTime.addAndGet(time1);
+                               
LibMatrixDNN.loopedConvBwdFilterMatMultTime.addAndGet(time2);
                        }
+                       return 0L;
+               }
+       }
+       
+       private static void inplaceAdd(double[] a, ConvolutionParameters 
params) {
+               synchronized (params.output.denseBlock) {
+                       LibMatrixMult.vectAdd(a, params.output.denseBlock, 0, 
0, a.length);
+               }
+       }
+       
+       private static void inplaceTransAdd(double[] a, ConvolutionParameters 
params) {
+               synchronized (params.output.denseBlock) {
+                       // Perform transposed addition: output of size [K, CRS] 
+= input of size [CRS,K]
+                       double [] c = params.output.denseBlock;
+                       final int CRS = params.C*params.R*params.S, K = 
params.K;
+                       final int blocksizeIJ = 128; //L2 cache
+                       
+                       //cache-conscious blocked execution
+                       for( int bi=0; bi<CRS; bi+=blocksizeIJ )
+                               for( int bj=0; bj<K; bj+=blocksizeIJ ) {
+                                       int bimin = Math.min(bi+blocksizeIJ, 
CRS);
+                                       int bjmin = Math.min(bj+blocksizeIJ, K);
+                                       //core transpose add operation
+                                       for(int i=bi, aix=bi*K; i<bimin; i++, 
aix+=K)
+                                               for(int j=bj, cix=i+bj*CRS; 
j<bjmin; j++, cix+=CRS)
+                                                       c[cix] += a[aix+j];
+                               }
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/dd513ffe/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
index 55f6e4c..dfd0778 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
@@ -190,6 +190,8 @@ public class LibMatrixDNNHelper {
                        //implementation simply rotates the sparse filters into 
dense rows
                        if( applyNative ) 
                                ret.add(new 
SparseNativeConv2dBackwardFilterDense(i*taskSize, Math.min((i+1)*taskSize, 
params.N), params));
+                       else if( params.input2.sparse && 
params.input1.getSparsity() > params.input2.getSparsity() )
+                               ret.add(new 
Conv2dBackwardFilterTrans(i*taskSize, Math.min((i+1)*taskSize, params.N), 
params));
                        else if(!isEmptyDenseInput)
                                ret.add(new Conv2dBackwardFilter(i*taskSize, 
Math.min((i+1)*taskSize, params.N), params));
                        else

http://git-wip-us.apache.org/repos/asf/systemml/blob/dd513ffe/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
index a4a6d3d..a4b1877 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
@@ -41,10 +41,10 @@ public class LibMatrixDNNIm2ColHelper {
                                if( LOG.isTraceEnabled() ) 
                                        LOG.trace("Using 
DenseIm2colWorkerAllChannels operator to perform "
                                                + "im2col 
(stride1pad0="+stride1Pad0+", allChannels="+allChannels+").");
-                               if(allChannels && stride1Pad0 )
+                               if(allChannels && stride1Pad0 && !trans )
                                        return new 
DenseIm2colWorkerStride1Pad0AllChannels(input.getDenseBlock(), 
out.getDenseBlock(), params);
                                else if( allChannels )
-                                       return new 
DenseIm2colWorkerAllChannels(input.getDenseBlock(), out.getDenseBlock(), 
params);
+                                       return new 
DenseIm2colWorkerAllChannels(input.getDenseBlock(), out.getDenseBlock(), 
params, trans);
                                else if( stride1Pad0 )
                                        return new 
DenseIm2colWorkerStride1Pad0(input.getDenseBlock(), out.getDenseBlock(), 
params);
                                else
@@ -200,7 +200,8 @@ public class LibMatrixDNNIm2ColHelper {
                private final double[] inputArray, outputArray; 
                private final int CRS, S, R, P, Q, CHW, H, W; 
                private final int stride_h, stride_w, pad_h, pad_w;
-               public DenseIm2colWorkerAllChannels(double [] inputArray, 
double [] outputArray, ConvolutionParameters params) {
+               private final boolean trans;
+               public DenseIm2colWorkerAllChannels(double [] inputArray, 
double [] outputArray, ConvolutionParameters params, boolean trans) {
                        this.inputArray = inputArray;
                        this.outputArray = outputArray;
                        this.CRS = params.C * params.R * params.S;
@@ -208,6 +209,7 @@ public class LibMatrixDNNIm2ColHelper {
                        this.CHW = params.C*params.H*params.W;
                        this.stride_h = params.stride_h; this.stride_w = 
params.stride_w;
                        this.pad_h = params.pad_h; this.pad_w = params.pad_w;
+                       this.trans = trans;
                }
                
                @Override
@@ -217,23 +219,24 @@ public class LibMatrixDNNIm2ColHelper {
 
                @Override
                public void execute(int n) {
+                       //reset for selective copy
+                       Arrays.fill(outputArray, 0);
+                       
                        int nOffset = n * CHW;
                        for (int c = 0; c < CRS; ++c) {
                                int wOffset = c % S;
                                int hOffset = (c / S) % R;
                                int cInput = c / R / S;
                                for (int h = 0; h < P; ++h) {
-                                       int outOffset = (c * P + h) * Q;
+                                       int outOffset = trans ? c+(h*Q*CRS) : 
(c*P+h)*Q;
                                        int hPadded = h * stride_h - pad_h + 
hOffset;
                                        int inputOffset = nOffset + (cInput * H 
+ hPadded) * W;
-                                       if (hPadded < 0 || hPadded >= H) {
-                                               Arrays.fill(outputArray, 
outOffset, outOffset+Q, 0);
-                                       } else {
-                                               for (int w = 0; w < Q; ++w) {
-                                                       int wPadded = w * 
stride_w - pad_w + wOffset;
-                                                       boolean assign = 
(wPadded >= 0 && wPadded < W);
-                                                       outputArray[outOffset + 
w] = assign ? inputArray[inputOffset + wPadded] : 0;
-                                               }
+                                       if (hPadded < 0 || hPadded >= H ) 
continue;
+                                       for (int w = 0; w < Q; ++w) {
+                                               int wPadded = w * stride_w - 
pad_w + wOffset;
+                                               if( wPadded >= 0 && wPadded < W 
)
+                                                       outputArray[outOffset + 
(trans?w*CRS:w)] 
+                                                               = 
inputArray[inputOffset + wPadded];
                                        }
                                }
                        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/dd513ffe/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index fa4d667..684f327 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -1351,7 +1351,7 @@ public class LibMatrixMult
                final int blocksizeJ = 1024; 
                
                //temporary array of current sparse positions
-               int[] curk = new int[blocksizeI];
+               int[] curk = new int[Math.min(blocksizeI, ru-rl)];
                
                //blocked execution over IKJ 
                for( int bi = rl; bi < ru; bi+=blocksizeI ) {
@@ -1429,7 +1429,7 @@ public class LibMatrixMult
                                                
(int)Math.pow((double)m*cd/m1.nonZeros,2)));
                                
                                //temporary array of current sparse positions
-                               int[] curk = new int[blocksizeI];
+                               int[] curk = new int[Math.min(blocksizeI, 
ru-rl)];
                                
                                //blocked execution over IK 
                                for( int bi = rl; bi < ru; bi+=blocksizeI ) {

Reply via email to