[SYSTEMML-1970] Performance conv2d-backward-data (for sparse filter)

This patch follows-up on the recent modification of conv2d backward
filter, by similarly applying a sparse rotate for conv2d backward data.
Furthermore, this also includes the removal of unnecessary allocations
per input row, and thread-local nnz maintenance.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/78a3808e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/78a3808e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/78a3808e

Branch: refs/heads/master
Commit: 78a3808e0aaefb0c6f6959611ef119695d4d1d3e
Parents: b261661
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sun Oct 22 17:57:29 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sun Oct 22 17:57:29 2017 -0700

----------------------------------------------------------------------
 .../sysml/runtime/matrix/data/LibMatrixDNN.java  |  4 ++--
 .../LibMatrixDNNConv2dBackwardDataHelper.java    | 19 ++++++++++---------
 .../LibMatrixDNNConv2dBackwardFilterHelper.java  | 17 ++++++++---------
 .../matrix/data/LibMatrixDNNConv2dHelper.java    |  2 +-
 .../runtime/matrix/data/LibMatrixDNNHelper.java  |  2 +-
 5 files changed, 22 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/78a3808e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index b967780..ac66e51 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -186,10 +186,10 @@ public class LibMatrixDNN {
                if(isEligibleForConv2dBackwardDataDense(params))
                        
Statistics.numNativeSparseConv2dBwdDataCalls.increment();
                
-               
execute(LibMatrixDNNHelper.getConv2dBackwardDataWorkers(params), params);
+               long nnz = 
execute(LibMatrixDNNHelper.getConv2dBackwardDataWorkers(params), params);
                
                //post-processing: maintain nnz
-               outputBlock.recomputeNonZeros(); 
+               outputBlock.setNonZeros(nnz);
                outputBlock.examSparsity();
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/78a3808e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardDataHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardDataHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardDataHelper.java
index 04c13e6..cd50000 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardDataHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardDataHelper.java
@@ -78,22 +78,22 @@ public class LibMatrixDNNConv2dBackwardDataHelper {
                        int PQ = _params.P*_params.Q; int K = _params.K; int 
CRS = _params.C*_params.R*_params.S;
                        MatrixBlock filter = _params.input1;
                        MatrixBlock dout = _params.input2;
-                       MatrixBlock dout_reshaped = new MatrixBlock(PQ, K, 
false);
-                       dout_reshaped.allocateDenseBlock();
+                       MatrixBlock outRotate = new MatrixBlock(PQ, K, 
dout.sparse);
+                       MatrixBlock outMM = new MatrixBlock(PQ, CRS, false);
+                       outRotate.allocateBlock();
                        LibMatrixDNNRotate180Helper.Rotate180Worker 
rotate180Worker = 
-                                       
LibMatrixDNNRotate180Helper.Rotate180Worker.getWorker( dout, dout_reshaped, 
_params, true, false);
+                               
LibMatrixDNNRotate180Helper.Rotate180Worker.getWorker( dout, outRotate, 
_params, true, false);
                        long time1 = 0; long time2 = 0;
                        for(int n = _rl; n < _ru; n++)  {
                                // rotate180(dout[n,]) => dout_reshaped
                                rotate180Worker.execute(n, 0);
-                               
                                // dout_reshaped %*% filter => temp
-                               MatrixBlock temp = new MatrixBlock(PQ, CRS, 
false);
                                long t1 = DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
-                               
LibMatrixDNNHelper.singleThreadedMatMult(dout_reshaped, filter, temp, true, 
false, _params);
+                               outMM.reset(PQ, CRS, false);
+                               
LibMatrixDNNHelper.singleThreadedMatMult(outRotate, filter, outMM, 
!outRotate.sparse, false, _params);
                                long t2 = DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
                                // col2im(temp) => output[n,] 
-                               LibMatrixDNNHelper.doCol2imOverSingleImage(n, 
temp, _params);
+                               LibMatrixDNNHelper.doCol2imOverSingleImage(n, 
outMM, _params);
                                long t3 = DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
                                
                                if(DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS) {
@@ -105,8 +105,9 @@ public class LibMatrixDNNConv2dBackwardDataHelper {
                                
LibMatrixDNN.loopedConvBwdDataMatMultTime.addAndGet(time1);
                                
LibMatrixDNN.loopedConvBwdDataCol2ImTime.addAndGet(time2);
                        }
-                       return 0L;
+                       
+                       //multi-threaded nnz maintenance of current working set
+                       return _params.output.recomputeNonZeros(_rl, _ru-1);
                }
-               
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/78a3808e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
index de45b81..f0fd002 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
@@ -86,13 +86,12 @@ public class LibMatrixDNNConv2dBackwardFilterHelper {
                        int PQ = _params.P*_params.Q, K = _params.K, CRS = 
_params.C*_params.R*_params.S;
                        MatrixBlock dout = _params.input2;
                        MatrixBlock im2ColOutBlock = new MatrixBlock(CRS, PQ, 
false);
-                       MatrixBlock dout_reshaped = new MatrixBlock(PQ, K, 
dout.sparse);
-                       MatrixBlock temp = new MatrixBlock(CRS, K, false);
-                       dout_reshaped.allocateBlock();
-                       temp.allocateBlock();
+                       MatrixBlock outRotate = new MatrixBlock(PQ, K, 
dout.sparse);
+                       MatrixBlock outMM = new MatrixBlock(CRS, K, false);
+                       outRotate.allocateBlock();
                        
                        Im2colWorker im2ColWorker = Im2colWorker.getWorker( 
_params.input1, im2ColOutBlock, _params, true, false);
-                       Rotate180Worker rotate180Worker = 
Rotate180Worker.getWorker( dout, dout_reshaped, _params, true, false);
+                       Rotate180Worker rotate180Worker = 
Rotate180Worker.getWorker( dout, outRotate, _params, true, false);
                        double [] partRet = new double[CRS*_params.K];
                        long time1 = 0; long time2 = 0;
                        for(int n = _rl; n < _ru; n++) {
@@ -104,12 +103,12 @@ public class LibMatrixDNNConv2dBackwardFilterHelper {
                                im2ColWorker.execute(n);
                                long t2 = DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
                                
-                               temp.reset(CRS, K, false);
-                               
LibMatrixDNNHelper.singleThreadedMatMult(im2ColOutBlock, dout_reshaped, temp, 
true, true, _params);
+                               outMM.reset(CRS, K, false);
+                               
LibMatrixDNNHelper.singleThreadedMatMult(im2ColOutBlock, outRotate, outMM, 
!im2ColOutBlock.sparse, !outRotate.sparse, _params);
                                long t3 = DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
                                
-                               if( !temp.isEmptyBlock() ) //accumulate row 
results
-                                       
LibMatrixMult.vectAdd(temp.getDenseBlock(), partRet, 0, 0, K*CRS);
+                               if( !outMM.isEmptyBlock() ) //accumulate row 
results
+                                       
LibMatrixMult.vectAdd(outMM.getDenseBlock(), partRet, 0, 0, K*CRS);
                                
                                if(DMLScript.STATISTICS && 
LibMatrixDNN.DISPLAY_STATISTICS) {
                                        time1 += t2 - t1;

http://git-wip-us.apache.org/repos/asf/systemml/blob/78a3808e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
index dd44de2..6a0205e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
@@ -219,7 +219,7 @@ public class LibMatrixDNNConv2dHelper {
                                
                                // t(_im2ColOutBlock) %*% t(filter) => 
t(matMultOutBlock)
                                outMM.reset(outMM.rlen, outMM.clen, false);
-                               
LibMatrixDNNHelper.singleThreadedMatMult(outIm2col, _params.input2, outMM, 
false, true, _params);
+                               
LibMatrixDNNHelper.singleThreadedMatMult(outIm2col, _params.input2, outMM, 
false, false, _params);
                                
                                // Copy the matrix matMultOutBlock of shape [K 
X PQ] to params.output.denseBlock + destPos
                                partialCopyTrans(outMM, _params.output, n*K*PQ, 
K, PQ);

http://git-wip-us.apache.org/repos/asf/systemml/blob/78a3808e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
index 6117b90..92eb79b 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
@@ -201,7 +201,7 @@ public class LibMatrixDNNHelper {
                int taskSize = (int)(Math.ceil((double)params.N / k));
                
                boolean isEmptyDenseInput = (!params.input1.isInSparseFormat() 
&& params.input1.denseBlock == null) || 
-                                                                               
                                                
(!params.input2.isInSparseFormat() && params.input2.denseBlock == null);
+                       (!params.input2.isInSparseFormat() && 
params.input2.denseBlock == null);
                
                for(int i = 0; i*taskSize < params.N; i++) {
                        
if(LibMatrixDNN.isEligibleForConv2dBackwardDataDense(params)) 

Reply via email to