[SYSTEMML-1140] Performance conv2d_bias_add (cache-conscious transpose)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/de1e119d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/de1e119d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/de1e119d

Branch: refs/heads/master
Commit: de1e119de0b2fc2a6c6a2c57bf64c4172a26890d
Parents: d0b23d6
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri Feb 10 06:58:34 2017 +0100
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Feb 10 07:55:53 2017 +0100

----------------------------------------------------------------------
 .../sysml/runtime/matrix/data/LibMatrixDNN.java | 66 +++++++++-----------
 1 file changed, 31 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/de1e119d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 29e59bd..82b0a61 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -233,39 +233,39 @@ public class LibMatrixDNN {
        }
        
        /**
-        * Performs the operation: ret += t(elem)
+        * Performs the operation for(e : elem) ret += t(e) in a 
cache-conscious manner
+        * by sequentially aggregating for(e : elem) tmp += e and finally 
transposing
+        * ret = t(tmp).
+        * 
         * @param ret left and output matrix
-        * @param elem right untransposed matrix
+        * @param elem array of right untransposed matrices (expected in dense 
format)
         * @param params convolution parameters
-        * @throws DMLRuntimeException if DMLRuntimeException occurs
+        * @throws DMLRuntimeException in case of unsupported inputs or output
         */
-       private static void elementWiseInPlaceTransposedAddition(MatrixBlock 
ret, MatrixBlock elem) throws DMLRuntimeException {
-               if(ret.getNumRows() != elem.getNumColumns() || 
ret.getNumColumns() != elem.getNumRows()) {
-                       throw new DMLRuntimeException("Incorrect dimensions");
-               }
-               int numRow = ret.getNumColumns();
-               if(!ret.isInSparseFormat() && !elem.isInSparseFormat()) {
-                       int iter = 0;
-                       for(int i = 0; i < elem.getNumRows(); i++) {
-                               for(int j = 0; j < elem.getNumColumns(); j++, 
iter++) {
-                                       int index = j*numRow+i;
-                                       ret.denseBlock[index] += 
elem.denseBlock[iter];
-                               }
-                       }
-               }
-               else if(!ret.isInSparseFormat() && elem.isInSparseFormat()) {
-                       if(!elem.isEmptyBlock()) {
-                               Iterator<IJV> iter = 
elem.sparseBlock.getIterator();
-                               while(iter.hasNext()) {
-                                       IJV ijv = iter.next();
-                                       int index = ijv.getJ()*numRow + 
ijv.getI();
-                                       ret.denseBlock[index] += ijv.getV(); 
-                               }
-                       }
-               }
-               else {
-                       throw new DMLRuntimeException("Sparse return format not 
supported");
+       private static void elementWiseInPlaceTransposedAddition(MatrixBlock 
ret, MatrixBlock[] elem) 
+               throws DMLRuntimeException 
+       {
+               //sanity checks non-empty and dense inputs / dense output
+               if( elem == null || elem.length==0 )
+                       throw new DMLRuntimeException("Empty input not 
supported.");
+               for( MatrixBlock e : elem )
+                       if( e.isInSparseFormat() )
+                               throw new DMLRuntimeException("Sparse input 
format not supported.");
+               if( ret.isInSparseFormat() )
+                       throw new DMLRuntimeException("Sparse output format not 
supported.");
+                               
+               //Step 1: aggregate partial blocks without transpose
+               MatrixBlock tmpAgg = elem[0]; 
+               double[] tmp = tmpAgg.denseBlock;
+               for( int k=1; k<elem.length; k++ ) {
+                       double[] tmp2 = elem[k].denseBlock;
+                       for( int i=0; i<tmp.length; i++ )
+                               tmp[i] += tmp2[i];
                }
+               
+               //Step 2: cache-conscious transpose to output
+               tmpAgg.setNonZeros(-1); //avoid early abort
+               LibMatrixReorg.transpose(tmpAgg, ret);
        }
        
        @SuppressWarnings("unused")
@@ -948,9 +948,7 @@ public class LibMatrixDNN {
                                for( Future<Long> task : taskret )
                                        params.output.nonZeros += task.get();
                                if(type == 
TaskType.LoopedIm2ColConv2dBwdFilter) {
-                                       for(MatrixBlock partialRetBlock : 
partialRetBlocks) {
-                                               
elementWiseInPlaceTransposedAddition(params.output, partialRetBlock);
-                                       }
+                                       
elementWiseInPlaceTransposedAddition(params.output, 
partialRetBlocks.toArray(new MatrixBlock[0]));
                                }
                        } 
                        catch (Exception e) {
@@ -965,9 +963,7 @@ public class LibMatrixDNN {
                                                doutReshapedBlocks, 
partialRetBlocks).call());
                                
                                if(type == 
TaskType.LoopedIm2ColConv2dBwdFilter) {
-                                       for(MatrixBlock partialRetBlock : 
partialRetBlocks) {
-                                               
elementWiseInPlaceTransposedAddition(params.output, partialRetBlock);
-                                       }
+                                       
elementWiseInPlaceTransposedAddition(params.output, 
partialRetBlocks.toArray(new MatrixBlock[0]));
                                }
                        } catch (Exception e) {
                                throw new DMLRuntimeException("Error while 
executing single-threaded " + type.name(), e);

Reply via email to