Repository: systemml Updated Branches: refs/heads/master 4dabd2b0f -> 192fb5582
[SYSTEMML-2046] Large dense blocks in all quaternary mm operators This patch modifies the quaternary operators wsloss, wdivmm, wsigmoid, wcemm, and wumm to support dense large blocks. Furthermore, this also includes a fix of the dense-dense mv multiply w/ tall rhs matrices. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/192fb558 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/192fb558 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/192fb558 Branch: refs/heads/master Commit: 192fb5582a23a298722e66262229f4e0ef11f97d Parents: 4dabd2b Author: Matthias Boehm <[email protected]> Authored: Tue Jan 9 19:11:04 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Tue Jan 9 19:11:20 2018 -0800 ---------------------------------------------------------------------- .../runtime/matrix/data/LibMatrixMult.java | 750 ++++++++++--------- 1 file changed, 398 insertions(+), 352 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/192fb558/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java index c781a2e..7159a7d 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java @@ -745,7 +745,7 @@ public class LibMatrixMult } try - { + { ExecutorService pool = Executors.newFixedThreadPool(k); ArrayList<MatrixMultWDivTask> tasks = new ArrayList<>(); //create tasks (for wdivmm-left, parallelization over columns; @@ -1016,7 +1016,7 @@ public class LibMatrixMult for( int bk=0; bk<cd; bk+=blocksizeK ) { int bkmin = Math.min(bk+blocksizeK, cd); for( int i=bi; i<bimin; i++) - cvals[i] += dotProduct(a.values(i), bvals, a.pos(i), bk, bkmin-bk); + cvals[i] += dotProduct(a.values(i), bvals, a.pos(i,bk), bk, bkmin-bk); } } } @@ -2110,10 +2110,10 @@ public class LibMatrixMult private static void matrixMultWSLossDense(MatrixBlock mX, MatrixBlock mU, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, WeightsType wt, int rl, int ru) { - double[] x = mX.getDenseBlockValues(); - double[] u = mU.getDenseBlockValues(); - double[] v = mV.getDenseBlockValues(); - double[] w = (mW!=null)? mW.getDenseBlockValues() : null; + DenseBlock x = mX.getDenseBlock(); + DenseBlock u = mU.getDenseBlock(); + DenseBlock v = mV.getDenseBlock(); + DenseBlock w = (mW!=null)? mW.getDenseBlock() : null; final int n = mX.clen; final int cd = mU.clen; double wsloss = 0; @@ -2126,58 +2126,66 @@ public class LibMatrixMult //blocked execution - for( int bi = rl; bi < ru; bi+=blocksizeIJ ) - for( int bj = 0, bimin = Math.min(ru, bi+blocksizeIJ); bj < n; bj+=blocksizeIJ ) - { + for( int bi = rl; bi < ru; bi+=blocksizeIJ ) { + int bimin = Math.min(ru, bi+blocksizeIJ); + for( int bj = 0; bj < n; bj+=blocksizeIJ ){ int bjmin = Math.min(n, bj+blocksizeIJ); - + // Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) - if( wt==WeightsType.POST ) - { - for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) - for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { - double wij = w[ix+j]; + if( wt==WeightsType.POST ) { + for( int i=bi; i<bimin; i++ ) { + double[] wvals = w.values(i), xvals = x.values(i), uvals = u.values(i); + int xix = x.pos(i), uix = u.pos(i); + for( int j=bj; j<bjmin; j++ ) { + double wij = wvals[xix+j]; if( wij != 0 ) { - double uvij = dotProduct(u, v, uix, vix, cd); - wsloss += wij*(x[ix+j]-uvij)*(x[ix+j]-uvij); //^2 + double uvij = dotProduct(uvals, v.values(j), uix, v.pos(j), cd); + wsloss += wij*(xvals[xix+j]-uvij)*(xvals[xix+j]-uvij); //^2 } - } + } + } } // Pattern 1b) sum ((X!=0) * (X - U %*% t(V)) ^ 2) (post_nz weighting) - else if( wt==WeightsType.POST_NZ ) - { - for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) - for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { - double xij = x[ix+j]; + else if( wt==WeightsType.POST_NZ ) { + for( int i=bi; i<bimin; i++ ) { + double[] xvals = x.values(i), uvals = u.values(i); + int xix = x.pos(i), uix = u.pos(i); + for( int j=bj; j<bjmin; j++ ) { + double xij = xvals[xix+j]; if( xij != 0 ) { - double uvij = dotProduct(u, v, uix, vix, cd); + double uvij = dotProduct(uvals, v.values(j), uix, v.pos(j), cd); wsloss += (xij-uvij)*(xij-uvij); //^2 } - } + } + } } // Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) - else if( wt==WeightsType.PRE ) - { - for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) - for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { - double wij = w[ix+j]; + else if( wt==WeightsType.PRE ) { + for( int i=bi; i<bimin; i++ ) { + double[] wvals = w.values(i), xvals = x.values(i), uvals = u.values(i); + int xix = x.pos(i), uix = u.pos(i); + for( int j=bj; j<bjmin; j++ ) { + double wij = wvals[xix+j]; double uvij = 0; if( wij != 0 ) - uvij = dotProduct(u, v, uix, vix, cd); - wsloss += (x[ix+j]-wij*uvij)*(x[ix+j]-wij*uvij); //^2 + uvij = dotProduct(uvals, v.values(j), uix, v.pos(j), cd); + wsloss += (xvals[xix+j]-wij*uvij)*(xvals[xix+j]-wij*uvij); //^2 } + } } // Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) - else if( wt==WeightsType.NONE ) - { - for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) - for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { - double uvij = dotProduct(u, v, uix, vix, cd); - wsloss += (x[ix+j]-uvij)*(x[ix+j]-uvij); //^2 + else if( wt==WeightsType.NONE ) { + for( int i=bi; i<bimin; i++ ) { + double[] xvals = x.values(i), uvals = u.values(i); + int xix = x.pos(i), uix = u.pos(i); + for( int j=bj; j<bjmin; j++) { + double uvij = dotProduct(uvals, v.values(j), uix, v.pos(j), cd); + wsloss += (xvals[xix+j]-uvij)*(xvals[xix+j]-uvij); //^2 } + } } + } } - ret.quickSetValue(0, 0, wsloss); } @@ -2185,48 +2193,48 @@ public class LibMatrixMult { SparseBlock x = mX.sparseBlock; SparseBlock w = (mW!=null)? mW.sparseBlock : null; - double[] u = mU.getDenseBlockValues(); - double[] v = mV.getDenseBlockValues(); + DenseBlock u = mU.getDenseBlock(); + DenseBlock v = mV.getDenseBlock(); final int n = mX.clen; final int cd = mU.clen; double wsloss = 0; // Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) - if( wt==WeightsType.POST ) - { + if( wt==WeightsType.POST ) { // approach: iterate over W, point-wise in order to exploit sparsity - for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) - if( !w.isEmpty(i) ) { - int wpos = w.pos(i); - int wlen = w.size(i); - int[] wix = w.indexes(i); - double[] wval = w.values(i); - if( w.isAligned(i, x) ) { - //O(n) where n is nnz in w/x - double[] xval = x.values(i); - for( int k=wpos; k<wpos+wlen; k++ ) { - double uvij = dotProduct(u, v, uix, wix[k]*cd, cd); - wsloss += wval[k]*(xval[k]-uvij)*(xval[k]-uvij); - } + for( int i=rl; i<ru; i++ ) { + if( w.isEmpty(i) ) continue; + int wpos = w.pos(i); + int wlen = w.size(i); + int[] wix = w.indexes(i); + double[] wval = w.values(i); + double[] uvals = u.values(i); + int uix = u.pos(i); + if( w.isAligned(i, x) ) { + //O(n) where n is nnz in w/x + double[] xval = x.values(i); + for( int k=wpos; k<wpos+wlen; k++ ) { + double uvij = dotProduct(uvals, v.values(wix[k]), uix, v.pos(wix[k]), cd); + wsloss += wval[k]*(xval[k]-uvij)*(xval[k]-uvij); } - else { - //O(n log m) where n/m is nnz in w/x - for( int k=wpos; k<wpos+wlen; k++ ) { - double xi = mX.quickGetValue(i, wix[k]); - double uvij = dotProduct(u, v, uix, wix[k]*cd, cd); - wsloss += wval[k]*(xi-uvij)*(xi-uvij); - } + } + else { + //O(n log m) where n/m is nnz in w/x + for( int k=wpos; k<wpos+wlen; k++ ) { + double xi = mX.quickGetValue(i, wix[k]); + double uvij = dotProduct(uvals, v.values(wix[k]), uix, v.pos(wix[k]), cd); + wsloss += wval[k]*(xi-uvij)*(xi-uvij); } - } + } + } } // Pattern 1b) sum ((X!=0) * (X - U %*% t(V)) ^ 2) (post weighting) - else if( wt==WeightsType.POST_NZ ) - { + else if( wt==WeightsType.POST_NZ ) { // approach: iterate over W, point-wise in order to exploit sparsity // blocked over ij, while maintaining front of column indexes, where the // blocksize is chosen such that we reuse each vector on average 8 times. final int blocksizeIJ = (int) (8L*mX.rlen*mX.clen/mX.nonZeros); - int[] curk = new int[blocksizeIJ]; + int[] curk = new int[blocksizeIJ]; for( int bi=rl; bi<ru; bi+=blocksizeIJ ) { int bimin = Math.min(ru, bi+blocksizeIJ); @@ -2235,49 +2243,49 @@ public class LibMatrixMult //blocked execution over column blocks for( int bj=0; bj<n; bj+=blocksizeIJ ) { int bjmin = Math.min(n, bj+blocksizeIJ); - for( int i=bi, uix=bi*cd; i<bimin; i++, uix+=cd ) { - if( !x.isEmpty(i) ) { - int xpos = x.pos(i); - int xlen = x.size(i); - int[] xix = x.indexes(i); - double[] xval = x.values(i); - int k = xpos + curk[i-bi]; - for( ; k<xpos+xlen && xix[k]<bjmin; k++ ) { - double uvij = dotProduct(u, v, uix, xix[k]*cd, cd); - wsloss += (xval[k]-uvij)*(xval[k]-uvij); - } - curk[i-bi] = k - xpos; + for( int i=bi; i<bimin; i++ ) { + if( x.isEmpty(i) ) continue; + int xpos = x.pos(i); + int xlen = x.size(i); + int[] xix = x.indexes(i); + double[] xval = x.values(i), uvals = u.values(i); + int uix = u.pos(i); + int k = xpos + curk[i-bi]; + for( ; k<xpos+xlen && xix[k]<bjmin; k++ ) { + double uvij = dotProduct(uvals, v.values(xix[k]), uix, v.pos(xix[k]), cd); + wsloss += (xval[k]-uvij)*(xval[k]-uvij); } + curk[i-bi] = k - xpos; } } } } // Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) - else if( wt==WeightsType.PRE ) - { + else if( wt==WeightsType.PRE ) { // approach: iterate over all cells of X maybe sparse and dense // (note: tuning similar to pattern 3 possible but more complex) - for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) - for( int j=0, vix=0; j<n; j++, vix+=cd) - { + for( int i=rl; i<ru; i++ ) { + double[] uvals = u.values(i); + int uix = u.pos(i); + for( int j=0; j<n; j++ ) { double xij = mX.quickGetValue(i, j); double wij = mW.quickGetValue(i, j); double uvij = 0; if( wij != 0 ) - uvij = dotProduct(u, v, uix, vix, cd); + uvij = dotProduct(uvals, v.values(j), uix, v.pos(j), cd); wsloss += (xij-wij*uvij)*(xij-wij*uvij); } + } } // Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) - else if( wt==WeightsType.NONE ) - { + else if( wt==WeightsType.NONE ) { //approach: use sparsity-exploiting pattern rewrite sum((X-(U%*%t(V)))^2) //-> sum(X^2)-sum(2*X*(U%*%t(V))))+sum((t(U)%*%U)*(t(V)%*%V)), where each //parallel task computes sum(X^2)-sum(2*X*(U%*%t(V)))) and the last term //sum((t(U)%*%U)*(t(V)%*%V)) is computed once via two tsmm operations. final int blocksizeIJ = (int) (8L*mX.rlen*mX.clen/mX.nonZeros); - int[] curk = new int[blocksizeIJ]; + int[] curk = new int[blocksizeIJ]; for( int bi=rl; bi<ru; bi+=blocksizeIJ ) { int bimin = Math.min(ru, bi+blocksizeIJ); @@ -2286,16 +2294,18 @@ public class LibMatrixMult //blocked execution over column blocks for( int bj=0; bj<n; bj+=blocksizeIJ ) { int bjmin = Math.min(n, bj+blocksizeIJ); - for( int i=bi, uix=bi*cd; i<bimin; i++, uix+=cd ) { + for( int i=bi; i<bimin; i++ ) { if( x.isEmpty(i) ) continue; int xpos = x.pos(i); int xlen = x.size(i); int[] xix = x.indexes(i); double[] xval = x.values(i); + double[] uvals = u.values(i); + int uix = u.pos(i); int k = xpos + curk[i-bi]; for( ; k<xpos+xlen && xix[k]<bjmin; k++ ) { double xij = xval[k]; - double uvij = dotProduct(u, v, uix, xix[k]*cd, cd); + double uvij = dotProduct(uvals, v.values(xix[k]), uix, v.pos(xix[k]), cd); wsloss += xij * xij - 2 * xij * uvij; } curk[i-bi] = k - xpos; @@ -2311,7 +2321,7 @@ public class LibMatrixMult { final int n = mX.clen; final int cd = mU.clen; - double wsloss = 0; + double wsloss = 0; // Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) if( wt==WeightsType.POST ) @@ -2320,31 +2330,32 @@ public class LibMatrixMult if( mW.sparse ) //SPARSE { SparseBlock w = mW.sparseBlock; - - for( int i=rl; i<ru; i++ ) - if( !w.isEmpty(i) ) { - int wpos = w.pos(i); - int wlen = w.size(i); - int[] wix = w.indexes(i); - double[] wval = w.values(i); - for( int k=wpos; k<wpos+wlen; k++ ) { - double uvij = dotProductGeneric(mU, mV, i, wix[k], cd); - double xi = mX.quickGetValue(i, wix[k]); - wsloss += wval[k]*(xi-uvij)*(xi-uvij); - } - } + for( int i=rl; i<ru; i++ ) { + if( w.isEmpty(i) ) continue; + int wpos = w.pos(i); + int wlen = w.size(i); + int[] wix = w.indexes(i); + double[] wval = w.values(i); + for( int k=wpos; k<wpos+wlen; k++ ) { + double uvij = dotProductGeneric(mU, mV, i, wix[k], cd); + double xi = mX.quickGetValue(i, wix[k]); + wsloss += wval[k]*(xi-uvij)*(xi-uvij); + } + } } else //DENSE { - double[] w = mW.getDenseBlockValues(); - - for( int i=rl, wix=rl*n; i<ru; i++, wix+=n ) + DenseBlock w = mW.getDenseBlock(); + for( int i=rl; i<ru; i++ ) { + double[] wvals = w.values(i); + int wix = w.pos(i); for( int j=0; j<n; j++) - if( w[wix+j] != 0 ) { + if( wvals[wix+j] != 0 ) { double uvij = dotProductGeneric(mU, mV, i, j, cd); double xij = mX.quickGetValue(i, j); - wsloss += w[wix+j]*(xij-uvij)*(xij-uvij); + wsloss += wvals[wix+j]*(xij-uvij)*(xij-uvij); } + } } } // Pattern 1b) sum ((X!=0) * (X - U %*% t(V)) ^ 2) (post weighting) @@ -2354,29 +2365,32 @@ public class LibMatrixMult if( mX.sparse ) //SPARSE { SparseBlock x = mX.sparseBlock; - - for( int i=rl; i<ru; i++ ) - if( !x.isEmpty(i) ) { - int xpos = x.pos(i); - int xlen = x.size(i); - int[] xix = x.indexes(i); - double[] xval = x.values(i); - for( int k=xpos; k<xpos+xlen; k++ ) { - double uvij = dotProductGeneric(mU, mV, i, xix[k], cd); - wsloss += (xval[k]-uvij)*(xval[k]-uvij); - } - } + for( int i=rl; i<ru; i++ ) { + if( x.isEmpty(i) ) continue; + int xpos = x.pos(i); + int xlen = x.size(i); + int[] xix = x.indexes(i); + double[] xval = x.values(i); + for( int k=xpos; k<xpos+xlen; k++ ) { + double uvij = dotProductGeneric(mU, mV, i, xix[k], cd); + wsloss += (xval[k]-uvij)*(xval[k]-uvij); + } + } } else //DENSE { - double[] x = mX.getDenseBlockValues(); - - for( int i=rl, xix=rl*n; i<ru; i++, xix+=n ) - for( int j=0; j<n; j++) - if( x[xix+j] != 0 ) { + DenseBlock x = mX.getDenseBlock(); + for( int i=rl; i<ru; i++ ) { + double[] xvals = x.values(i); + int xix = x.pos(i); + for( int j=0; j<n; j++) { + double xij = xvals[xix+j]; + if( xij != 0 ) { double uvij = dotProductGeneric(mU, mV, i, j, cd); - wsloss += (x[xix+j]-uvij)*(x[xix+j]-uvij); + wsloss += (xij-uvij)*(xij-uvij); } + } + } } } // Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) @@ -2384,8 +2398,7 @@ public class LibMatrixMult { // approach: iterate over all cells of X maybe sparse and dense for( int i=rl; i<ru; i++ ) - for( int j=0; j<n; j++) - { + for( int j=0; j<n; j++) { double xij = mX.quickGetValue(i, j); double wij = mW.quickGetValue(i, j); double uvij = 0; @@ -2415,17 +2428,20 @@ public class LibMatrixMult double uvij = dotProductGeneric(mU, mV, i, xix[k], cd); wsloss += xij * xij - 2 * xij * uvij; } - } + } } else { //DENSE - double[] x = mX.getDenseBlockValues(); - for( int i=rl, xix=rl*n; i<ru; i++, xix+=n ) + DenseBlock x = mX.getDenseBlock(); + for( int i=rl; i<ru; i++ ) { + double[] xvals = x.values(i); + int xix = x.pos(i); for( int j=0; j<n; j++) - if( x[xix+j] != 0 ) { - double xij = x[xix+j]; + if( xvals[xix+j] != 0 ) { + double xij = xvals[xix+j]; double uvij = dotProductGeneric(mU, mV, i, j, cd); wsloss += xij * xij - 2 * xij * uvij; } + } } } @@ -2446,11 +2462,11 @@ public class LibMatrixMult private static void matrixMultWSigmoidDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) throws DMLRuntimeException - { - double[] w = mW.getDenseBlockValues(); - double[] c = ret.getDenseBlockValues(); - double[] u = mU.getDenseBlockValues(); - double[] v = mV.getDenseBlockValues(); + { + DenseBlock w = mW.getDenseBlock(); + DenseBlock c = ret.getDenseBlock(); + DenseBlock u = mU.getDenseBlock(); + DenseBlock v = mV.getDenseBlock(); final int n = mW.clen; final int cd = mU.clen; @@ -2467,19 +2483,23 @@ public class LibMatrixMult final int blocksizeIJ = 16; //u/v block (max at typical L2 size) //blocked execution - for( int bi = rl; bi < ru; bi+=blocksizeIJ ) - for( int bj = 0, bimin = Math.min(ru, bi+blocksizeIJ); bj < n; bj+=blocksizeIJ ) - { + for( int bi = rl; bi < ru; bi+=blocksizeIJ ) { + int bimin = Math.min(ru, bi+blocksizeIJ); + for( int bj = 0; bj < n; bj+=blocksizeIJ ) { int bjmin = Math.min(n, bj+blocksizeIJ); - //core wsigmoid computation - for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) - for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { - double wij = w[ix+j]; + for( int i=bi; i<bimin; i++ ) { + double[] wvals = w.values(i), uvals = u.values(i), cvals = c.values(i); + int wix = w.pos(i), uix = u.pos(i); + for( int j=bj; j<bjmin; j++) { + double wij = wvals[wix+j]; if( wij != 0 ) - c[ix+j] = wsigmoid(wij, u, v, uix, vix, flagminus, flaglog, cd); + cvals[wix+j] = wsigmoid(wij, uvals, v.values(j), + uix, v.pos(j), flagminus, flaglog, cd); } + } } + } } private static void matrixMultWSigmoidSparseDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) @@ -2487,27 +2507,29 @@ public class LibMatrixMult { SparseBlock w = mW.sparseBlock; SparseBlock c = ret.sparseBlock; - double[] u = mU.getDenseBlockValues(); - double[] v = mV.getDenseBlockValues(); + DenseBlock u = mU.getDenseBlock(); + DenseBlock v = mV.getDenseBlock(); final int cd = mU.clen; - boolean flagminus = (wt==WSigmoidType.MINUS || wt==WSigmoidType.LOG_MINUS); + boolean flagminus = (wt==WSigmoidType.MINUS || wt==WSigmoidType.LOG_MINUS); boolean flaglog = (wt==WSigmoidType.LOG || wt==WSigmoidType.LOG_MINUS); //approach: iterate over non-zeros of w, selective mm computation - for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) - if( !w.isEmpty(i) ) { - int wpos = w.pos(i); - int wlen = w.size(i); - int[] wix = w.indexes(i); - double[] wval = w.values(i); - - c.allocate(i, wlen); - for( int k=wpos; k<wpos+wlen; k++ ) { - double cval = wsigmoid(wval[k], u, v, uix, wix[k]*cd, flagminus, flaglog, cd); - c.append(i, wix[k], cval); - } + for( int i=rl; i<ru; i++ ) { + if( w.isEmpty(i) ) continue; + int wpos = w.pos(i); + int wlen = w.size(i); + int[] wix = w.indexes(i); + double[] wval = w.values(i); + double[] uvals = u.values(i); + int uix = u.pos(i); + c.allocate(i, wlen); + for( int k=wpos; k<wpos+wlen; k++ ) { + double cval = wsigmoid(wval[k], uvals, v.values(wix[k]), + uix, v.pos(wix[k]), flagminus, flaglog, cd); + c.append(i, wix[k], cval); } + } } private static void matrixMultWSigmoidGeneric (MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) @@ -2525,39 +2547,38 @@ public class LibMatrixMult //w and c always in same representation SparseBlock w = mW.sparseBlock; SparseBlock c = ret.sparseBlock; - - for( int i=rl; i<ru; i++ ) - if( !w.isEmpty(i) ) { - int wpos = w.pos(i); - int wlen = w.size(i); - int[] wix = w.indexes(i); - double[] wval = w.values(i); - - c.allocate(i, wlen); - for( int k=wpos; k<wpos+wlen; k++ ) { - double cval = wsigmoid(wval[k], mU, mV, i, wix[k], flagminus, flaglog, cd); - c.append(i, wix[k], cval); - } - } + for( int i=rl; i<ru; i++ ) { + if( w.isEmpty(i) ) continue; + int wpos = w.pos(i); + int wlen = w.size(i); + int[] wix = w.indexes(i); + double[] wval = w.values(i); + c.allocate(i, wlen); + for( int k=wpos; k<wpos+wlen; k++ ) { + double cval = wsigmoid(wval[k], mU, mV, i, wix[k], flagminus, flaglog, cd); + c.append(i, wix[k], cval); + } + } } else //DENSE { //w and c always in same representation - double[] w = mW.getDenseBlockValues(); - double[] c = ret.getDenseBlockValues(); - - for( int i=rl, ix=rl*n; i<ru; i++ ) - for( int j=0; j<n; j++, ix++) { - double wij = w[ix]; - if( wij != 0 ) { - c[ix] = wsigmoid(wij, mU, mV, i, j, flagminus, flaglog, cd); - } + DenseBlock w = mW.getDenseBlock(); + DenseBlock c = ret.getDenseBlock(); + for( int i=rl; i<ru; i++ ) { + double[] wvals = w.values(i), cvals = c.values(i); + int ix = w.pos(i); + for( int j=0; j<n; j++ ) { + double wij = wvals[ix+j]; + if( wij != 0 ) + cvals[ix+j] = wsigmoid(wij, mU, mV, i, j, flagminus, flaglog, cd); } + } } } private static void matrixMultWDivMMDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock mX, MatrixBlock ret, WDivMMType wt, int rl, int ru, int cl, int cu) - throws DMLRuntimeException + throws DMLRuntimeException { final boolean basic = wt.isBasic(); final boolean left = wt.isLeft(); @@ -2566,14 +2587,13 @@ public class LibMatrixMult final boolean four = wt.hasFourInputs(); final boolean scalar = wt.hasScalar(); final double eps = scalar ? mX.quickGetValue(0, 0) : 0; - final int n = mW.clen; final int cd = mU.clen; - double[] w = mW.getDenseBlockValues(); - double[] u = mU.getDenseBlockValues(); - double[] v = mV.getDenseBlockValues(); - double[] x = (mX==null) ? null : mX.getDenseBlockValues(); - double[] c = ret.getDenseBlockValues(); + DenseBlock w = mW.getDenseBlock(); + DenseBlock u = mU.getDenseBlock(); + DenseBlock v = mV.getDenseBlock(); + DenseBlock x = (mX==null) ? null : mX.getDenseBlock(); + DenseBlock c = ret.getDenseBlock(); //approach: iterate over non-zeros of w, selective mm computation //cache-conscious blocking: due to blocksize constraint (default 1000), @@ -2582,26 +2602,32 @@ public class LibMatrixMult final int blocksizeIJ = 16; //u/v block (max at typical L2 size) //blocked execution - for( int bi = rl; bi < ru; bi+=blocksizeIJ ) - for( int bj = cl, bimin = Math.min(ru, bi+blocksizeIJ); bj < cu; bj+=blocksizeIJ ) - { + for( int bi = rl; bi < ru; bi+=blocksizeIJ ) { + int bimin = Math.min(ru, bi+blocksizeIJ); + for( int bj = cl; bj < cu; bj+=blocksizeIJ ) { int bjmin = Math.min(cu, bj+blocksizeIJ); - //core wsigmoid computation - for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) - for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) - if( w[ix+j] != 0 ) { + for( int i=bi; i<bimin; i++ ) { + double[] wvals = w.values(i), uvals = u.values(i); + double[] xvals = four ? x.values(i) : null; + int wix = w.pos(i), uix = u.pos(i); + for( int j=bj; j<bjmin; j++ ) + if( wvals[wix+j] != 0 ) { + double[] cvals = c.values((basic||!left) ? i : j); if( basic ) - c[ix+j] = w[ix+j] * dotProduct(u, v, uix, vix, cd); - else if( four ) //left/right + cvals[wix+j] = wvals[wix+j] * dotProduct(uvals, v.values(j), uix, v.pos(j), cd); + else if( four ) { //left/right if (scalar) - wdivmm(w[ix+j], eps, u, v, c, uix, vix, left, scalar, cd); + wdivmm(wvals[wix+j], eps, uvals, v.values(j), cvals, uix, v.pos(j), left, scalar, cd); else - wdivmm(w[ix+j], x[ix+j], u, v, c, uix, vix, left, scalar, cd); + wdivmm(wvals[wix+j], xvals[wix+j], uvals, v.values(j), cvals, uix, v.pos(j), left, scalar, cd); + } else //left/right minus/default - wdivmm(w[ix+j], u, v, c, uix, vix, left, mult, minus, cd); + wdivmm(wvals[wix+j], uvals, v.values(j), cvals, uix, v.pos(j), left, mult, minus, cd); } + } } + } } private static void matrixMultWDivMMSparseDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock mX, MatrixBlock ret, WDivMMType wt, int rl, int ru, int cl, int cu) @@ -2617,9 +2643,9 @@ public class LibMatrixMult final int cd = mU.clen; SparseBlock w = mW.sparseBlock; - double[] u = mU.getDenseBlockValues(); - double[] v = mV.getDenseBlockValues(); - double[] c = ret.getDenseBlockValues(); + DenseBlock u = mU.getDenseBlock(); + DenseBlock v = mV.getDenseBlock(); + DenseBlock c = ret.getDenseBlock(); SparseBlock x = (mX==null) ? null : mX.sparseBlock; //approach: iterate over non-zeros of w, selective mm computation @@ -2651,18 +2677,21 @@ public class LibMatrixMult { int bjmin = Math.min(cu, bj+blocksizeJ); //core wdivmm block matrix mult - for( int i=bi, uix=bi*cd; i<bimin; i++, uix+=cd ) { + for( int i=bi; i<bimin; i++ ) { if( w.isEmpty(i) ) continue; int wpos = w.pos(i); int wlen = w.size(i); int[] wix = w.indexes(i); double[] wval = w.values(i); + double[] uvals = u.values(i); + int uix = u.pos(i); int k = wpos + curk[i-bi]; if( basic ) { for( ; k<wpos+wlen && wix[k]<bjmin; k++ ) - ret.appendValue( i, wix[k], wval[k] * dotProduct(u, v, uix, wix[k]*cd, cd)); + ret.appendValue( i, wix[k], wval[k] *dotProduct( + uvals, v.values(wix[k]), uix, v.pos(wix[k]), cd)); } else if( four ) { //left/right //checking alignment per row is ok because early abort if false, @@ -2670,21 +2699,31 @@ public class LibMatrixMult if( !scalar && w.isAligned(i, x) ) { //O(n) where n is nnz in w/x double[] xvals = x.values(i); - for( ; k<wpos+wlen && wix[k]<bjmin; k++ ) - wdivmm(wval[k], xvals[k], u, v, c, uix, wix[k]*cd, left, scalar, cd); + for( ; k<wpos+wlen && wix[k]<bjmin; k++ ) { + double[] cvals = c.values(left ? wix[k] : i); + wdivmm(wval[k], xvals[k], uvals, v.values(wix[k]), + cvals, uix, v.pos(wix[k]), left, scalar, cd); + } } else { //scalar or O(n log m) where n/m are nnz in w/x - for( ; k<wpos+wlen && wix[k]<bjmin; k++ ) + for( ; k<wpos+wlen && wix[k]<bjmin; k++ ) { + double[] cvals = c.values(left ? wix[k] : i); if (scalar) - wdivmm(wval[k], eps, u, v, c, uix, wix[k]*cd, left, scalar, cd); + wdivmm(wval[k], eps, uvals, v.values(wix[k]), + cvals, uix, v.pos(wix[k]), left, scalar, cd); else - wdivmm(wval[k], x.get(i, wix[k]), u, v, c, uix, wix[k]*cd, left, scalar, cd); + wdivmm(wval[k], x.get(i, wix[k]), uvals, + v.values(wix[k]), cvals, uix, v.pos(wix[k]), left, scalar, cd); + } } } else { //left/right minus default - for( ; k<wpos+wlen && wix[k]<bjmin; k++ ) - wdivmm(wval[k], u, v, c, uix, wix[k]*cd, left, mult, minus, cd); + for( ; k<wpos+wlen && wix[k]<bjmin; k++ ) { + double[] cvals = c.values(left ? wix[k] : i); + wdivmm(wval[k], uvals, v.values(wix[k]), cvals, + uix, v.pos(wix[k]), left, mult, minus, cd); + } } curk[i-bi] = k - wpos; } @@ -2702,11 +2741,10 @@ public class LibMatrixMult final boolean four = wt.hasFourInputs(); final boolean scalar = wt.hasScalar(); final double eps = scalar ? mX.quickGetValue(0, 0) : 0; - final int n = mW.clen; final int cd = mU.clen; //output always in dense representation - double[] c = ret.getDenseBlockValues(); + DenseBlock c = ret.getDenseBlock(); //approach: iterate over non-zeros of w, selective mm computation if( mW.sparse ) //SPARSE @@ -2714,55 +2752,58 @@ public class LibMatrixMult SparseBlock w = mW.sparseBlock; for( int i=rl; i<ru; i++ ) { - if( !w.isEmpty(i) ) { - int wpos = w.pos(i); - int wlen = w.size(i); - int[] wix = w.indexes(i); - double[] wval = w.values(i); - int k = (cl==0) ? 0 : w.posFIndexGTE(i,cl); - k = (k>=0) ? wpos+k : wpos+wlen; - for( ; k<wpos+wlen && wix[k]<cu; k++ ) { - if( basic ) { - double uvij = dotProductGeneric(mU,mV, i, wix[k], cd); - ret.appendValue(i, wix[k], uvij); - } - else if( four ) { //left/right - double xij = scalar ? eps : mX.quickGetValue(i, wix[k]); - wdivmm(wval[k], xij, mU, mV, c, i, wix[k], left, scalar, cd); - } - else { //left/right minus/default - wdivmm(wval[k], mU, mV, c, i, wix[k], left, mult, minus, cd); - } + if( w.isEmpty(i) ) continue; + int wpos = w.pos(i); + int wlen = w.size(i); + int[] wix = w.indexes(i); + double[] wval = w.values(i); + int k = (cl==0) ? 0 : w.posFIndexGTE(i,cl); + k = (k>=0) ? wpos+k : wpos+wlen; + for( ; k<wpos+wlen && wix[k]<cu; k++ ) { + double[] cvals = c.values((basic||!left) ? i : wix[k]); + if( basic ) { + double uvij = dotProductGeneric(mU,mV, i, wix[k], cd); + ret.appendValue(i, wix[k], uvij); + } + else if( four ) { //left/right + double xij = scalar ? eps : mX.quickGetValue(i, wix[k]); + wdivmm(wval[k], xij, mU, mV, cvals, i, wix[k], left, scalar, cd); } - } + else { //left/right minus/default + wdivmm(wval[k], mU, mV, cvals, i, wix[k], left, mult, minus, cd); + } + } } } else //DENSE { - double[] w = mW.getDenseBlockValues(); - - for( int i=rl, ix=rl*n; i<ru; i++, ix+=n ) + DenseBlock w = mW.getDenseBlock(); + for( int i=rl; i<ru; i++ ) { + double[] wvals = w.values(i); + int ix = w.pos(i); for( int j=cl; j<cu; j++) - if( w[ix+j] != 0 ) { + if( wvals[ix+j] != 0 ) { + double[] cvals = c.values((basic||!left) ? i : j); if( basic ) { - c[ix+j] = dotProductGeneric(mU,mV, i, j, cd); + cvals[ix+j] = dotProductGeneric(mU,mV, i, j, cd); } else if( four ) { //left/right double xij = scalar ? eps : mX.quickGetValue(i, j); - wdivmm(w[ix+j], xij, mU, mV, c, i, j, left, scalar, cd); + wdivmm(wvals[ix+j], xij, mU, mV, cvals, i, j, left, scalar, cd); } else { //left/right minus/default - wdivmm(w[ix+j], mU, mV, c, i, j, left, mult, minus, cd); + wdivmm(wvals[ix+j], mU, mV, cvals, i, j, left, mult, minus, cd); } } + } } } private static void matrixMultWCeMMDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, double eps, MatrixBlock ret, WCeMMType wt, int rl, int ru) { - double[] w = mW.getDenseBlockValues(); - double[] u = mU.getDenseBlockValues(); - double[] v = mV.getDenseBlockValues(); + DenseBlock w = mW.getDenseBlock(); + DenseBlock u = mU.getDenseBlock(); + DenseBlock v = mV.getDenseBlock(); final int n = mW.clen; final int cd = mU.clen; double wceval = 0; @@ -2773,32 +2814,34 @@ public class LibMatrixMult final int blocksizeIJ = 16; //u/v block (max at typical L2 size) //blocked execution - for( int bi = rl; bi < ru; bi+=blocksizeIJ ) - for( int bj = 0, bimin = Math.min(ru, bi+blocksizeIJ); bj < n; bj+=blocksizeIJ ) - { + for( int bi = rl; bi < ru; bi+=blocksizeIJ ) { + int bimin = Math.min(ru, bi+blocksizeIJ); + for( int bj = 0; bj < n; bj+=blocksizeIJ ) { int bjmin = Math.min(n, bj+blocksizeIJ); - - for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) - for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { - double wij = w[ix+j]; + for( int i=bi; i<bimin; i++ ) { + double[] wvals = w.values(i), uvals = u.values(i); + int wix = w.pos(i), uix = u.pos(i); + for( int j=bj; j<bjmin; j++ ) { + double wij = wvals[wix+j]; if( wij != 0 ) { - double uvij = dotProduct(u, v, uix, vix, cd); + double uvij = dotProduct(uvals, v.values(j), uix, v.pos(j), cd); wceval += wij * Math.log(uvij + eps); } } + } + } } - ret.quickSetValue(0, 0, wceval); } private static void matrixMultWCeMMSparseDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, double eps, MatrixBlock ret, WCeMMType wt, int rl, int ru) { SparseBlock w = mW.sparseBlock; - double[] u = mU.getDenseBlockValues(); - double[] v = mV.getDenseBlockValues(); + DenseBlock u = mU.getDenseBlock(); + DenseBlock v = mV.getDenseBlock(); final int n = mW.clen; final int cd = mU.clen; - double wceval = 0; + double wceval = 0; // approach: iterate over W, point-wise in order to exploit sparsity // blocked over ij, while maintaining front of column indexes, where the @@ -2813,22 +2856,23 @@ public class LibMatrixMult //blocked execution over column blocks for( int bj=0; bj<n; bj+=blocksizeIJ ) { int bjmin = Math.min(n, bj+blocksizeIJ); - for( int i=bi, uix=bi*cd; i<bimin; i++, uix+=cd ) { + for( int i=bi; i<bimin; i++ ) { if( w.isEmpty(i) ) continue; int wpos = w.pos(i); int wlen = w.size(i); int[] wix = w.indexes(i); - double[] wval = w.values(i); + double[] wvals = w.values(i); + double[] uvals = u.values(i); + int uix = u.pos(i); int k = wpos + curk[i-bi]; for( ; k<wpos+wlen && wix[k]<bjmin; k++ ) { - double uvij = dotProduct(u, v, uix, wix[k]*cd, cd); - wceval += wval[k] * Math.log(uvij + eps); + double uvij = dotProduct(uvals, v.values(wix[k]), uix, v.pos(wix[k]), cd); + wceval += wvals[k] * Math.log(uvij + eps); } curk[i-bi] = k - wpos; } } } - ret.quickSetValue(0, 0, wceval); } @@ -2842,71 +2886,76 @@ public class LibMatrixMult if( mW.sparse ) //SPARSE { SparseBlock w = mW.sparseBlock; - - for( int i=rl; i<ru; i++ ) - if( !w.isEmpty(i) ) { - int wpos = w.pos(i); - int wlen = w.size(i); - int[] wix = w.indexes(i); - double[] wval = w.values(i); - for( int k=wpos; k<wpos+wlen; k++ ) { - double uvij = dotProductGeneric(mU, mV, i, wix[k], cd); - wceval += wval[k] * Math.log(uvij + eps); - } - } + for( int i=rl; i<ru; i++ ) { + if( w.isEmpty(i) ) continue; + int wpos = w.pos(i); + int wlen = w.size(i); + int[] wix = w.indexes(i); + double[] wval = w.values(i); + for( int k=wpos; k<wpos+wlen; k++ ) { + double uvij = dotProductGeneric(mU, mV, i, wix[k], cd); + wceval += wval[k] * Math.log(uvij + eps); + } + } } else //DENSE { - double[] w = mW.getDenseBlockValues(); - - for( int i=rl, ix=rl*n; i<ru; i++ ) - for( int j=0; j<n; j++, ix++) { - double wij = w[ix]; + DenseBlock w = mW.getDenseBlock(); + for( int i=rl; i<ru; i++ ) { + double[] wvals = w.values(i); + int wix = w.pos(i); + for( int j=0; j<n; j++ ) { + double wij = wvals[wix+j]; if( wij != 0 ) { double uvij = dotProductGeneric(mU, mV, i, j, cd); wceval += wij * Math.log(uvij + eps); } } + } } ret.quickSetValue(0, 0, wceval); } private static void matrixMultWuMMDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn, int rl, int ru) - throws DMLRuntimeException - { - double[] w = mW.getDenseBlockValues(); - double[] c = ret.getDenseBlockValues(); - double[] u = mU.getDenseBlockValues(); - double[] v = mV.getDenseBlockValues(); + throws DMLRuntimeException + { + DenseBlock w = mW.getDenseBlock(); + DenseBlock c = ret.getDenseBlock(); + DenseBlock u = mU.getDenseBlock(); + DenseBlock v = mV.getDenseBlock(); final int n = mW.clen; final int cd = mU.clen; //note: cannot compute U %*% t(V) in-place of result w/ regular mm because //t(V) comes in transformed form and hence would require additional memory - boolean flagmult = (wt==WUMMType.MULT); + boolean flagmult = (wt==WUMMType.MULT); //approach: iterate over non-zeros of w, selective mm computation //cache-conscious blocking: due to blocksize constraint (default 1000), - //a blocksize of 16 allows to fit blocks of UV into L2 cache (256KB) + //a blocksize of 16 allows to fit blocks of UV into L2 cache (256KB) - final int blocksizeIJ = 16; //u/v block (max at typical L2 size) + final int blocksizeIJ = 16; //u/v block (max at typical L2 size) //blocked execution - for( int bi = rl; bi < ru; bi+=blocksizeIJ ) - for( int bj = 0, bimin = Math.min(ru, bi+blocksizeIJ); bj < n; bj+=blocksizeIJ ) - { + for( int bi = rl; bi < ru; bi+=blocksizeIJ ) { + int bimin = Math.min(ru, bi+blocksizeIJ); + for( int bj = 0; bj < n; bj+=blocksizeIJ ) { int bjmin = Math.min(n, bj+blocksizeIJ); - //core wsigmoid computation - for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) - for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { - double wij = w[ix+j]; + for( int i=bi; i<bimin; i++ ) { + double[] wvals = w.values(i), uvals = u.values(i), cvals = c.values(i); + int wix = w.pos(i), uix = u.pos(i); + for( int j=bj; j<bjmin; j++ ) { + double wij = wvals[wix+j]; if( wij != 0 ) - c[ix+j] = wumm(wij, u, v, uix, vix, flagmult, fn, cd); + cvals[wix+j] = wumm(wij, uvals, v.values(j), + uix, v.pos(j), flagmult, fn, cd); } + } } + } } private static void matrixMultWuMMSparseDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn, int rl, int ru) @@ -2914,70 +2963,67 @@ public class LibMatrixMult { SparseBlock w = mW.sparseBlock; SparseBlock c = ret.sparseBlock; - double[] u = mU.getDenseBlockValues(); - double[] v = mV.getDenseBlockValues(); + DenseBlock u = mU.getDenseBlock(); + DenseBlock v = mV.getDenseBlock(); final int cd = mU.clen; + boolean flagmult = (wt==WUMMType.MULT); - boolean flagmult = (wt==WUMMType.MULT); - //approach: iterate over non-zeros of w, selective mm computation - for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) - if( !w.isEmpty(i) ) { - int wpos = w.pos(i); - int wlen = w.size(i); - int[] wix = w.indexes(i); - double[] wval = w.values(i); - - c.allocate(i, wlen); - for( int k=wpos; k<wpos+wlen; k++ ) { - double cval = wumm(wval[k], u, v, uix, wix[k]*cd, flagmult, fn, cd); - c.append(i, wix[k], cval); - } + for( int i=rl; i<ru; i++ ) { + if( w.isEmpty(i) ) continue; + int wpos = w.pos(i); + int wlen = w.size(i); + int[] wix = w.indexes(i); + double[] wvals = w.values(i); + double[] uvals = u.values(i); + int uix = u.pos(i); + c.allocate(i, wlen); + for( int k=wpos; k<wpos+wlen; k++ ) { + double cval = wumm(wvals[k], uvals, v.values(wix[k]), + uix, v.pos(wix[k]), flagmult, fn, cd); + c.append(i, wix[k], cval); } + } } private static void matrixMultWuMMGeneric (MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn, int rl, int ru) throws DMLRuntimeException { - final int n = mW.clen; + final int n = mW.clen; final int cd = mU.clen; - - boolean flagmult = (wt==WUMMType.MULT); + boolean flagmult = (wt==WUMMType.MULT); //approach: iterate over non-zeros of w, selective mm computation - if( mW.sparse ) //SPARSE - { + if( mW.sparse ) { //SPARSE //w and c always in same representation SparseBlock w = mW.sparseBlock; SparseBlock c = ret.sparseBlock; - - for( int i=rl; i<ru; i++ ) - if( !w.isEmpty(i) ) { - int wpos = w.pos(i); - int wlen = w.size(i); - int[] wix = w.indexes(i); - double[] wval = w.values(i); - - c.allocate(i, wlen); - for( int k=wpos; k<wpos+wlen; k++ ) { - double cval = wumm(wval[k], mU, mV, i, wix[k], flagmult, fn, cd); - c.append(i, wix[k], cval); - } - } + for( int i=rl; i<ru; i++ ) { + if( w.isEmpty(i) ) continue; + int wpos = w.pos(i); + int wlen = w.size(i); + int[] wix = w.indexes(i); + double[] wval = w.values(i); + c.allocate(i, wlen); + for( int k=wpos; k<wpos+wlen; k++ ) { + double cval = wumm(wval[k], mU, mV, i, wix[k], flagmult, fn, cd); + c.append(i, wix[k], cval); + } + } } - else //DENSE - { + else { //DENSE //w and c always in same representation - double[] w = mW.getDenseBlockValues(); - double[] c = ret.getDenseBlockValues(); - - for( int i=rl, ix=rl*n; i<ru; i++ ) - for( int j=0; j<n; j++, ix++) { - double wij = w[ix]; - if( wij != 0 ) { - c[ix] = wumm(wij, mU, mV, i, j, flagmult, fn, cd); - } + DenseBlock w = mW.getDenseBlock(); + DenseBlock c = ret.getDenseBlock(); + for( int i=rl; i<ru; i++ ) { + double[] wvals = w.values(i), cvals = c.values(i); + int ix = w.pos(i); + for( int j=0; j<n; j++) { + double wij = wvals[ix+j]; + if( wij != 0 ) + cvals[ix+j] = wumm(wij, mU, mV, i, j, flagmult, fn, cd); } + } } } @@ -3471,17 +3517,17 @@ public class LibMatrixMult private static void wdivmm( final double wij, double[] u, double[] v, double[] c, final int uix, final int vix, final boolean left, final boolean mult, final boolean minus, final int len ) { - //compute dot product over ui vj + //compute dot product over ui vj double uvij = dotProduct(u, v, uix, vix, len); //compute core wdivmm double tmpval = minus ? uvij - wij : - mult ? wij * uvij : wij / uvij; + mult ? wij * uvij : wij / uvij; //prepare inputs for final mm int bix = left ? uix : vix; int cix = left ? vix : uix; - double[] b = left ? u : v; + double[] b = left ? u : v; //compute final mm output vectMultiplyAdd(tmpval, b, c, bix, cix, len); @@ -3498,7 +3544,7 @@ public class LibMatrixMult //prepare inputs for final mm int bix = left ? uix : vix; int cix = left ? vix : uix; - double[] b = left ? u : v; + double[] b = left ? u : v; //compute final mm output vectMultiplyAdd(tmpval, b, c, bix, cix, len); @@ -3511,12 +3557,12 @@ public class LibMatrixMult //compute core wdivmm double wtmp = minus ? uvij - wij : - mult ? wij * uvij : wij / uvij; + mult ? wij * uvij : wij / uvij; //prepare inputs for final mm int bix = left ? uix : vix; int cix = left ? vix*len : uix*len; - MatrixBlock b = left ? u : v; + MatrixBlock b = left ? u : v; //compute final mm for( int k2=0; k2<len; k2++ ) @@ -3534,7 +3580,7 @@ public class LibMatrixMult //prepare inputs for final mm int bix = left ? uix : vix; int cix = left ? vix*len : uix*len; - MatrixBlock b = left ? u : v; + MatrixBlock b = left ? u : v; //compute final mm for( int k2=0; k2<len; k2++ )
