Repository: systemml Updated Branches: refs/heads/master 240297bd5 -> a3ab19768
[SYSTEMML-2080] Fix correctness sparse relu-backward CP operations This patch fixes an correctness issues of relu backward over sparse inputs that has been introduced by a recent refactoring and cleanup. In detail, the binary inplace multiply, was mistakenly replaced by a binary inplace plus. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a3ab1976 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a3ab1976 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a3ab1976 Branch: refs/heads/master Commit: a3ab197686b96873b0a7686710e083906903aa69 Parents: 240297b Author: Matthias Boehm <[email protected]> Authored: Wed Jan 24 22:01:28 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Wed Jan 24 22:14:34 2018 -0800 ---------------------------------------------------------------------- .../runtime/matrix/data/LibMatrixDNNRelu.java | 39 +++++++++++++------- 1 file changed, 25 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/a3ab1976/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java index f1c3ecb..c8f2f41 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java @@ -79,7 +79,7 @@ public class LibMatrixDNNRelu } else { scalarOperations(_params.input1, out, n, _rl, _ru, GT0); // (X > 0) - binaryOperationInPlacePlus(_params.input2, out, n, _rl, _ru); // (X > 0) * dout + binaryOperationInPlaceMult(_params.input2, out, n, _rl, _ru); // (X > 0) * dout } return 0L; } @@ -113,24 +113,35 @@ public class LibMatrixDNNRelu } } - private static void binaryOperationInPlacePlus(MatrixBlock src, + private static void binaryOperationInPlaceMult(MatrixBlock src, DenseBlock c, int destNumCols, int src_rl, int src_ru) throws DMLRuntimeException { - if( src.isEmptyBlock(false) ) - return; //do nothing (add 0); + if( src.isEmptyBlock(false) ) { + c.set(src_rl, src_rl, 0, destNumCols, 0); + return; + } if(src.isInSparseFormat()) { for(int i = src_rl; i < src_ru; i++) { - if( src.getSparseBlock().isEmpty(i) ) continue; - int apos = src.getSparseBlock().pos(i); - int alen = src.getSparseBlock().size(i); - int[] aix = src.getSparseBlock().indexes(i); - double[] avals = src.getSparseBlock().values(i); - double[] cvals = c.values(i); - int cix = c.pos(i); - for(int j = apos; j < apos+alen; j++) - cvals[ cix+aix[j] ] += avals[j]; + if( !src.getSparseBlock().isEmpty(i) ) { + int apos = src.getSparseBlock().pos(i); + int alen = src.getSparseBlock().size(i); + int[] aix = src.getSparseBlock().indexes(i); + double[] avals = src.getSparseBlock().values(i); + double[] cvals = c.values(i); + int cix = c.pos(i); + int prevDestIndex = 0; + for(int j = apos; j < apos+alen; j++) { + c.set(i, i+1, prevDestIndex, aix[j], 0); + prevDestIndex = aix[j]+1; + cvals[ cix+aix[j] ] *= avals[j]; + } + c.set(i, i+1, prevDestIndex, destNumCols, 0); + } + else { + c.set(i, i+1, 0, destNumCols, 0); + } } } else { //DENSE @@ -139,7 +150,7 @@ public class LibMatrixDNNRelu double[] avals = a.values(i), cvals = c.values(i); int aix = a.pos(i), cix = c.pos(i); for(int j=0; j<destNumCols; j++) - cvals[cix+j] += avals[aix+j]; + cvals[cix+j] *= avals[aix+j]; } } }
