[SYSTEMML-2408] Fix correctness fused mmchain XtXvy w/ minus weights

This patch fixes result correctness issues of the fused mmchain operator
of type XtXvy (e.g., t(X) %*% (X %*% v - y)) which empty input handling
of v was only valid for multiply weights but not minus.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9de00dbb
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9de00dbb
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9de00dbb

Branch: refs/heads/master
Commit: 9de00dbb2d7441ed694ec4d092a9269b6f7ccccc
Parents: 7f05d04
Author: Matthias Boehm <[email protected]>
Authored: Mon Jun 18 23:38:41 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Jun 18 23:38:41 2018 -0700

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixMult.java      | 22 +++++++++++++-------
 1 file changed, 14 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/9de00dbb/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index 9241d7e..c6189ab 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -242,7 +242,8 @@ public class LibMatrixMult
         */
        public static void matrixMultChain(MatrixBlock mX, MatrixBlock mV, 
MatrixBlock mW, MatrixBlock ret, ChainType ct) {
                //check inputs / outputs (after that mV and mW guaranteed to be 
dense)
-               if( mX.isEmptyBlock(false) || mV.isEmptyBlock(false) || (mW 
!=null && mW.isEmptyBlock(false)) ) {
+               if( mX.isEmptyBlock(false) || (mV.isEmptyBlock(false) && 
ct!=ChainType.XtXvy)
+                       || (mW !=null && mW.isEmptyBlock(false)) ) {
                        ret.examSparsity(); //turn empty dense into sparse
                        return;
                }
@@ -283,7 +284,8 @@ public class LibMatrixMult
         */
        public static void matrixMultChain(MatrixBlock mX, MatrixBlock mV, 
MatrixBlock mW, MatrixBlock ret, ChainType ct, int k) {
                //check inputs / outputs (after that mV and mW guaranteed to be 
dense)
-               if( mX.isEmptyBlock(false) || mV.isEmptyBlock(false) || (mW 
!=null && mW.isEmptyBlock(false)) ) {
+               if( mX.isEmptyBlock(false) || (mV.isEmptyBlock(false) && 
ct!=ChainType.XtXvy)
+                       || (mW !=null && mW.isEmptyBlock(false)) ) {
                        ret.examSparsity(); //turn empty dense into sparse
                        return;
                }
@@ -1614,7 +1616,8 @@ public class LibMatrixMult
                for( int i=rl; i < rl+bn; i++ ) {
                        double[] avals = a.values(i);
                        int aix = a.pos(i);
-                       double val = dotProduct(avals, b, aix, 0, cd);
+                       double val = (b == null) ? 0 :
+                               dotProduct(avals, b, aix, 0, cd);
                        val *= (weights) ? w[i] : 1;
                        val -= (weights2) ? w[i] : 0;
                        vectMultiplyAdd(val, avals, c, aix, 0, cd);
@@ -1625,10 +1628,12 @@ public class LibMatrixMult
                {
                        //compute 1st matrix-vector for row block
                        Arrays.fill(tmp, 0);
-                       for( int bj=0; bj<cd; bj+=blocksizeJ ) {
-                               int bjmin = Math.min(cd-bj, blocksizeJ);
-                               for( int i=0; i < blocksizeI; i++ )
-                                       tmp[i] += dotProduct(a.values(bi+i), b, 
a.pos(bi+i,bj), bj, bjmin);
+                       if( b != null ) {
+                               for( int bj=0; bj<cd; bj+=blocksizeJ ) {
+                                       int bjmin = Math.min(cd-bj, blocksizeJ);
+                                       for( int i=0; i < blocksizeI; i++ )
+                                               tmp[i] += 
dotProduct(a.values(bi+i), b, a.pos(bi+i,bj), bj, bjmin);
+                               }
                        }
                        
                        //multiply/subtract weights (in-place), if required
@@ -1673,7 +1678,8 @@ public class LibMatrixMult
                        double[] avals = a.values(i);
                        
                        //compute 1st matrix-vector dot product
-                       double val = dotProduct(avals, b, aix, apos, 0, alen);
+                       double val = (b == null) ? 0 :
+                               dotProduct(avals, b, aix, apos, 0, alen);
                        
                        //multiply/subtract weights, if required
                        val *= (weights) ? w[i] : 1;

Reply via email to