[SYSTEMML-2408] Fix correctness fused mmchain XtXvy w/ minus weights This patch fixes result correctness issues of the fused mmchain operator of type XtXvy (e.g., t(X) %*% (X %*% v - y)) which empty input handling of v was only valid for multiply weights but not minus.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9de00dbb Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9de00dbb Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9de00dbb Branch: refs/heads/master Commit: 9de00dbb2d7441ed694ec4d092a9269b6f7ccccc Parents: 7f05d04 Author: Matthias Boehm <[email protected]> Authored: Mon Jun 18 23:38:41 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Jun 18 23:38:41 2018 -0700 ---------------------------------------------------------------------- .../runtime/matrix/data/LibMatrixMult.java | 22 +++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/9de00dbb/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java index 9241d7e..c6189ab 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java @@ -242,7 +242,8 @@ public class LibMatrixMult */ public static void matrixMultChain(MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, ChainType ct) { //check inputs / outputs (after that mV and mW guaranteed to be dense) - if( mX.isEmptyBlock(false) || mV.isEmptyBlock(false) || (mW !=null && mW.isEmptyBlock(false)) ) { + if( mX.isEmptyBlock(false) || (mV.isEmptyBlock(false) && ct!=ChainType.XtXvy) + || (mW !=null && mW.isEmptyBlock(false)) ) { ret.examSparsity(); //turn empty dense into sparse return; } @@ -283,7 +284,8 @@ public class LibMatrixMult */ public static void matrixMultChain(MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, ChainType ct, int k) { //check inputs / outputs (after that mV and mW guaranteed to be dense) - if( mX.isEmptyBlock(false) || mV.isEmptyBlock(false) || (mW !=null && mW.isEmptyBlock(false)) ) { + if( mX.isEmptyBlock(false) || (mV.isEmptyBlock(false) && ct!=ChainType.XtXvy) + || (mW !=null && mW.isEmptyBlock(false)) ) { ret.examSparsity(); //turn empty dense into sparse return; } @@ -1614,7 +1616,8 @@ public class LibMatrixMult for( int i=rl; i < rl+bn; i++ ) { double[] avals = a.values(i); int aix = a.pos(i); - double val = dotProduct(avals, b, aix, 0, cd); + double val = (b == null) ? 0 : + dotProduct(avals, b, aix, 0, cd); val *= (weights) ? w[i] : 1; val -= (weights2) ? w[i] : 0; vectMultiplyAdd(val, avals, c, aix, 0, cd); @@ -1625,10 +1628,12 @@ public class LibMatrixMult { //compute 1st matrix-vector for row block Arrays.fill(tmp, 0); - for( int bj=0; bj<cd; bj+=blocksizeJ ) { - int bjmin = Math.min(cd-bj, blocksizeJ); - for( int i=0; i < blocksizeI; i++ ) - tmp[i] += dotProduct(a.values(bi+i), b, a.pos(bi+i,bj), bj, bjmin); + if( b != null ) { + for( int bj=0; bj<cd; bj+=blocksizeJ ) { + int bjmin = Math.min(cd-bj, blocksizeJ); + for( int i=0; i < blocksizeI; i++ ) + tmp[i] += dotProduct(a.values(bi+i), b, a.pos(bi+i,bj), bj, bjmin); + } } //multiply/subtract weights (in-place), if required @@ -1673,7 +1678,8 @@ public class LibMatrixMult double[] avals = a.values(i); //compute 1st matrix-vector dot product - double val = dotProduct(avals, b, aix, apos, 0, alen); + double val = (b == null) ? 0 : + dotProduct(avals, b, aix, apos, 0, alen); //multiply/subtract weights, if required val *= (weights) ? w[i] : 1;
