[SYSTEMML-1140] Performance conv2d_bias_add (cache-conscious transpose) Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/de1e119d Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/de1e119d Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/de1e119d
Branch: refs/heads/master Commit: de1e119de0b2fc2a6c6a2c57bf64c4172a26890d Parents: d0b23d6 Author: Matthias Boehm <mboe...@gmail.com> Authored: Fri Feb 10 06:58:34 2017 +0100 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Fri Feb 10 07:55:53 2017 +0100 ---------------------------------------------------------------------- .../sysml/runtime/matrix/data/LibMatrixDNN.java | 66 +++++++++----------- 1 file changed, 31 insertions(+), 35 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/de1e119d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index 29e59bd..82b0a61 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -233,39 +233,39 @@ public class LibMatrixDNN { } /** - * Performs the operation: ret += t(elem) + * Performs the operation for(e : elem) ret += t(e) in a cache-conscious manner + * by sequentially aggregating for(e : elem) tmp += e and finally transposing + * ret = t(tmp). + * * @param ret left and output matrix - * @param elem right untransposed matrix + * @param elem array of right untransposed matrices (expected in dense format) * @param params convolution parameters - * @throws DMLRuntimeException if DMLRuntimeException occurs + * @throws DMLRuntimeException in case of unsupported inputs or output */ - private static void elementWiseInPlaceTransposedAddition(MatrixBlock ret, MatrixBlock elem) throws DMLRuntimeException { - if(ret.getNumRows() != elem.getNumColumns() || ret.getNumColumns() != elem.getNumRows()) { - throw new DMLRuntimeException("Incorrect dimensions"); - } - int numRow = ret.getNumColumns(); - if(!ret.isInSparseFormat() && !elem.isInSparseFormat()) { - int iter = 0; - for(int i = 0; i < elem.getNumRows(); i++) { - for(int j = 0; j < elem.getNumColumns(); j++, iter++) { - int index = j*numRow+i; - ret.denseBlock[index] += elem.denseBlock[iter]; - } - } - } - else if(!ret.isInSparseFormat() && elem.isInSparseFormat()) { - if(!elem.isEmptyBlock()) { - Iterator<IJV> iter = elem.sparseBlock.getIterator(); - while(iter.hasNext()) { - IJV ijv = iter.next(); - int index = ijv.getJ()*numRow + ijv.getI(); - ret.denseBlock[index] += ijv.getV(); - } - } - } - else { - throw new DMLRuntimeException("Sparse return format not supported"); + private static void elementWiseInPlaceTransposedAddition(MatrixBlock ret, MatrixBlock[] elem) + throws DMLRuntimeException + { + //sanity checks non-empty and dense inputs / dense output + if( elem == null || elem.length==0 ) + throw new DMLRuntimeException("Empty input not supported."); + for( MatrixBlock e : elem ) + if( e.isInSparseFormat() ) + throw new DMLRuntimeException("Sparse input format not supported."); + if( ret.isInSparseFormat() ) + throw new DMLRuntimeException("Sparse output format not supported."); + + //Step 1: aggregate partial blocks without transpose + MatrixBlock tmpAgg = elem[0]; + double[] tmp = tmpAgg.denseBlock; + for( int k=1; k<elem.length; k++ ) { + double[] tmp2 = elem[k].denseBlock; + for( int i=0; i<tmp.length; i++ ) + tmp[i] += tmp2[i]; } + + //Step 2: cache-conscious transpose to output + tmpAgg.setNonZeros(-1); //avoid early abort + LibMatrixReorg.transpose(tmpAgg, ret); } @SuppressWarnings("unused") @@ -948,9 +948,7 @@ public class LibMatrixDNN { for( Future<Long> task : taskret ) params.output.nonZeros += task.get(); if(type == TaskType.LoopedIm2ColConv2dBwdFilter) { - for(MatrixBlock partialRetBlock : partialRetBlocks) { - elementWiseInPlaceTransposedAddition(params.output, partialRetBlock); - } + elementWiseInPlaceTransposedAddition(params.output, partialRetBlocks.toArray(new MatrixBlock[0])); } } catch (Exception e) { @@ -965,9 +963,7 @@ public class LibMatrixDNN { doutReshapedBlocks, partialRetBlocks).call()); if(type == TaskType.LoopedIm2ColConv2dBwdFilter) { - for(MatrixBlock partialRetBlock : partialRetBlocks) { - elementWiseInPlaceTransposedAddition(params.output, partialRetBlock); - } + elementWiseInPlaceTransposedAddition(params.output, partialRetBlocks.toArray(new MatrixBlock[0])); } } catch (Exception e) { throw new DMLRuntimeException("Error while executing single-threaded " + type.name(), e);