This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 2b8de1629b [MINOR] JIT optimize LibMatrixBinCell
2b8de1629b is described below

commit 2b8de1629b935d0b75caf38e4295c706980f0ce7
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Thu Oct 26 18:24:54 2023 +0200

    [MINOR] JIT optimize LibMatrixBinCell
    
    This commit move some of the code inside LibMatrixBincell around to
    encourage jit compilation of some methods. In specific folloing methods
    have been introduced.
    
    - safeBinaryMvSparseRowVector
    - fillZeroValuesEmpty
    - fillZeroValuesDense
    - fillZeroValuesSparse
    - safeBinaryMMDenseDenseDensePM_Vec (Plus Multiply kernel vectorized)
    - safeBinaryMMDenseDenseDensePM     (Plus Multiply kernel small input)
    - safeBinaryMMDenseDenseDenseContiguous (This one makes a big difference)
    - safeBinaryMMDenseDenseDenseGeneric
    
    In specific the safeBinaryMMDenseDenseDenseContiguous,
    safeBinaryMMDenseDenseDensePMm and safeBinaryMMDenseDenseDensePM_Vec
    improve the performance by much.
    
    In LM_cg the performance:
    Stats output:
    
     +*  3.123   3000 (Before)
     +*  1.991   3000 (After)
    
     +   1.125   2021 (Before)
     +   0.703   2015 (After)
    
    This is training on Criteo 100k rows.
---
 .../runtime/matrix/data/LibMatrixBincell.java      | 430 +++++++++++++--------
 .../sysds/runtime/matrix/data/LibMatrixMult.java   |   2 +-
 2 files changed, 269 insertions(+), 163 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
index e53f09a7f4..e5ec7a0020 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
@@ -851,85 +851,93 @@ public class LibMatrixBincell {
        private static void safeBinaryMVSparse(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op) {
                boolean isMultiply = (op.fn instanceof Multiply);
                boolean skipEmpty = (isMultiply || isSparseSafeDivide(op, m2));
-               
-               int rlen = m1.rlen;
-               int clen = m1.clen;
-               SparseBlock a = m1.sparseBlock;
                BinaryAccessType atype = getBinaryAccessType(m1, m2);
-               
-               //early abort on skip and empty
-               if( skipEmpty && (m1.isEmptyBlock(false) || 
m2.isEmptyBlock(false) ) )
+
+               // early abort on skip and empty
+               if(skipEmpty && (m1.isEmptyBlock(false) || 
m2.isEmptyBlock(false)))
                        return; // skip entire empty block
-               
-               //allocate once in order to prevent repeated reallocation
-               if( ret.sparse )
+
+               // allocate once in order to prevent repeated reallocation
+               if(ret.sparse)
                        ret.allocateSparseRowsBlock();
-               
-               if( atype == BinaryAccessType.MATRIX_COL_VECTOR )
-               {
-                       for( int i=0; i<rlen; i++ ) {
-                               double v2 = m2.quickGetValue(i, 0);
-                               
-                               if( (skipEmpty && (a==null || a.isEmpty(i) || 
v2 == 0 ))
-                                       || ((a==null || a.isEmpty(i)) && v2 == 
0) )
-                               {
-                                       continue; //skip empty rows
-                               }
-                                       
-                               if( isMultiply && v2==1 ) { //ROW COPY
-                                       if( a != null && !a.isEmpty(i)  )
-                                               ret.appendRow(i, a.get(i));
-                               }
-                               else { //GENERAL CASE
-                                       int lastIx = -1;
-                                       if( a != null && !a.isEmpty(i) ) {
-                                               int apos = a.pos(i);
-                                               int alen = a.size(i);
-                                               int[] aix = a.indexes(i);
-                                               double[] avals = a.values(i);
-                                               for( int j=apos; j<apos+alen; 
j++ ) {
-                                                       //empty left
-                                                       fillZeroValues(op, v2, 
ret, skipEmpty, i, lastIx+1, aix[j]);
-                                                       //actual value
-                                                       double v = 
op.fn.execute( avals[j], v2 );
-                                                       ret.appendValue(i, 
aix[j], v);  
-                                                       lastIx = aix[j];
-                                               }
-                                       }
-                                       //empty left
-                                       fillZeroValues(op, v2, ret, skipEmpty, 
i, lastIx+1, clen);
-                               }
+
+               if(atype == BinaryAccessType.MATRIX_COL_VECTOR)
+                       safeBinaryMVSparseColVector(m1, m2, ret, op);
+               else if(atype == BinaryAccessType.MATRIX_ROW_VECTOR)
+                       safeBinaryMVSparseRowVector(m1, m2, ret, op);
+       }
+
+       private static void safeBinaryMVSparseColVector(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, BinaryOperator op) {
+               boolean isMultiply = (op.fn instanceof Multiply);
+               boolean skipEmpty = (isMultiply || isSparseSafeDivide(op, m2));
+
+               int rlen = m1.rlen;
+               int clen = m1.clen;
+               SparseBlock a = m1.sparseBlock;
+               for(int i = 0; i < rlen; i++) {
+                       double v2 = m2.quickGetValue(i, 0);
+
+                       if((skipEmpty && (a == null || a.isEmpty(i) || v2 == 
0)) || ((a == null || a.isEmpty(i)) && v2 == 0)) {
+                               continue; // skip empty rows
                        }
-               }
-               else if( atype == BinaryAccessType.MATRIX_ROW_VECTOR )
-               {
-                       for( int i=0; i<rlen; i++ ) {
-                               if( skipEmpty && (a==null || a.isEmpty(i)) )
-                                       continue; //skip empty rows
-                               if( skipEmpty && ret.sparse )
-                                       ret.sparseBlock.allocate(i, a.size(i));
+
+                       if(isMultiply && v2 == 1) { // ROW COPY
+                               if(a != null && !a.isEmpty(i))
+                                       ret.appendRow(i, a.get(i));
+                       }
+                       else { // GENERAL CASE
                                int lastIx = -1;
-                               if( a!=null && !a.isEmpty(i) ) {
+                               if(a != null && !a.isEmpty(i)) {
                                        int apos = a.pos(i);
                                        int alen = a.size(i);
                                        int[] aix = a.indexes(i);
                                        double[] avals = a.values(i);
-                                       for( int j=apos; j<apos+alen; j++ ) {
-                                               //empty left
-                                               fillZeroValues(op, m2, ret, 
skipEmpty, i, lastIx+1, aix[j]);
-                                               //actual value
-                                               double v2 = m2.quickGetValue(0, 
aix[j]);
-                                               double v = op.fn.execute( 
avals[j], v2 );
+                                       for(int j = apos; j < apos + alen; j++) 
{
+                                               // empty left
+                                               fillZeroValues(op, v2, ret, 
skipEmpty, i, lastIx + 1, aix[j]);
+                                               // actual value
+                                               double v = 
op.fn.execute(avals[j], v2);
                                                ret.appendValue(i, aix[j], v);
                                                lastIx = aix[j];
                                        }
                                }
-                               //empty left
-                               fillZeroValues(op, m2, ret, skipEmpty, i, 
lastIx+1, clen);
+                               // empty left
+                               fillZeroValues(op, v2, ret, skipEmpty, i, 
lastIx + 1, clen);
                        }
                }
-               
-               //no need to recomputeNonZeros since maintained in append value
+       }
+
+       private static void safeBinaryMVSparseRowVector(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, BinaryOperator op) {
+               boolean isMultiply = (op.fn instanceof Multiply);
+               boolean skipEmpty = (isMultiply || isSparseSafeDivide(op, m2));
+
+               int rlen = m1.rlen;
+               int clen = m1.clen;
+               SparseBlock a = m1.sparseBlock;
+               for(int i = 0; i < rlen; i++) {
+                       if(skipEmpty && (a == null || a.isEmpty(i)))
+                               continue; // skip empty rows
+                       if(skipEmpty && ret.sparse)
+                               ret.sparseBlock.allocate(i, a.size(i));
+                       int lastIx = -1;
+                       if(a != null && !a.isEmpty(i)) {
+                               int apos = a.pos(i);
+                               int alen = a.size(i);
+                               int[] aix = a.indexes(i);
+                               double[] avals = a.values(i);
+                               for(int j = apos; j < apos + alen; j++) {
+                                       // empty left
+                                       fillZeroValues(op, m2, ret, skipEmpty, 
i, lastIx + 1, aix[j]);
+                                       // actual value
+                                       double v2 = m2.quickGetValue(0, aix[j]);
+                                       double v = op.fn.execute(avals[j], v2);
+                                       ret.appendValue(i, aix[j], v);
+                                       lastIx = aix[j];
+                               }
+                       }
+                       // empty left
+                       fillZeroValues(op, m2, ret, skipEmpty, i, lastIx + 1, 
clen);
+               }
        }
        
        private static final void fillZeroValues(BinaryOperator op, double v2, 
MatrixBlock ret, boolean skipEmpty, int rpos, int cpos, int len) {
@@ -948,58 +956,84 @@ public class LibMatrixBincell {
                int cpos, int len) {
                if(skipEmpty)
                        return;
+               else if(m2.isEmpty()) 
+                       fillZeroValuesEmpty(op, m2, ret, skipEmpty, rpos, cpos, 
len);
+               else if(m2.isInSparseFormat()) 
+                       fillZeroValuesSparse(op, m2, ret, skipEmpty, rpos, 
cpos, len);
+               else 
+                       fillZeroValuesDense(op, m2, ret, skipEmpty, rpos, cpos, 
len);
+       }
 
-               final double zero =  op.fn.execute(0.0, 0.0);
+       private static void fillZeroValuesEmpty(BinaryOperator op, MatrixBlock 
m2, MatrixBlock ret, boolean skipEmpty,
+               int rpos, int cpos, int len) {
+               final double zero = op.fn.execute(0.0, 0.0);
                final boolean zeroIsZero = zero == 0.0;
-               if(m2.isEmpty()){
-                               if(!zeroIsZero){
-                                       while(cpos < len)
-                                               // TODO change this to a fill 
operation.
-                                               ret.appendValue(rpos, cpos++, 
zero);
-                               }
+               if(!zeroIsZero) {
+                       while(cpos < len)
+                               // TODO change this to a fill operation.
+                               ret.appendValue(rpos, cpos++, zero);
                }
-               else if(m2.isInSparseFormat()) {
-                       final SparseBlock sb = m2.getSparseBlock();
-                       if(sb.isEmpty(0)){
-                               if(!zeroIsZero){
-                                       while(cpos < len)
-                                               ret.appendValue(rpos, cpos++, 
zero);
-                               }
+       }
+
+       private static void fillZeroValuesDense(BinaryOperator op, MatrixBlock 
m2, MatrixBlock ret, boolean skipEmpty,
+               int rpos, int cpos, int len) {
+               final DenseBlock db = m2.getDenseBlock();
+               final double[] vals = db.values(0);
+               final SparseBlock r = ret.getSparseBlock();
+               if(ret.isInSparseFormat() && r instanceof SparseBlockMCSR) {
+                       SparseBlockMCSR mCSR = (SparseBlockMCSR) r;
+                       mCSR.allocate(rpos, cpos, len);
+                       SparseRow sr = mCSR.get(rpos);
+                       for(int k = cpos; k < len; k++) {
+                               sr.append(k, op.fn.execute(0, vals[k]));
                        }
-                       else{
-                               int apos = sb.pos(0);
-                               final int alen = sb.size(0) + apos;
-                               final int[] aix = sb.indexes(0);
-                               final double[] vals = sb.values(0);
-                               // skip aix pos until inside range of cpos and 
len
-                               while( apos < alen && aix[apos] < len && cpos > 
aix[apos]){
-                                       apos++;
-                               }
-                               // for each point in the sparse range
-                               for(; apos < alen && aix[apos] < len; apos++){
-                                       if(!zeroIsZero){
-                                               while(cpos < len  && cpos < 
aix[apos]){
-                                                       ret.appendValue(rpos, 
cpos++, zero);
-                                               }
-                                       }
-                                       cpos = aix[apos];
-                                       final double v = op.fn.execute(0, 
vals[apos]);
-                                       ret.appendValue(rpos, aix[apos], v);
-                                       // cpos++;
-                               }
-                               // process tail.
+               }
+               else {
+                       // def
+                       for(int k = cpos; k < len; k++) {
+                               ret.appendValue(rpos, k, op.fn.execute(0, 
vals[k]));
+                       }
+               }
+       }
+
+       private static void fillZeroValuesSparse(BinaryOperator op, MatrixBlock 
m2, MatrixBlock ret, boolean skipEmpty,
+               int rpos, int cpos, int len) {
+
+               final double zero = op.fn.execute(0.0, 0.0);
+               final boolean zeroIsZero = zero == 0.0;
+               final SparseBlock sb = m2.getSparseBlock();
+               if(sb.isEmpty(0)) {
+                       if(!zeroIsZero) {
+                               while(cpos < len)
+                                       ret.appendValue(rpos, cpos++, zero);
+                       }
+               }
+               else {
+                       int apos = sb.pos(0);
+                       final int alen = sb.size(0) + apos;
+                       final int[] aix = sb.indexes(0);
+                       final double[] vals = sb.values(0);
+                       // skip aix pos until inside range of cpos and len
+                       while(apos < alen && aix[apos] < len && cpos > 
aix[apos]) {
+                               apos++;
+                       }
+                       // for each point in the sparse range
+                       for(; apos < alen && aix[apos] < len; apos++) {
                                if(!zeroIsZero) {
-                                       while(cpos < len) {
+                                       while(cpos < len && cpos < aix[apos]) {
                                                ret.appendValue(rpos, cpos++, 
zero);
                                        }
                                }
-                       }
-               }
-               else {
-                       final DenseBlock db = m2.getDenseBlock();
-                       final double[] vals = db.values(0);
-                       for(int k = cpos; k < len; k++){
-                               ret.appendValue(rpos, k, op.fn.execute(0, 
vals[k]));
+                               cpos = aix[apos];
+                               final double v = op.fn.execute(0, vals[apos]);
+                               ret.appendValue(rpos, aix[apos], v);
+                               // cpos++;
+                       }
+                       // process tail.
+                       if(!zeroIsZero) {
+                               while(cpos < len) {
+                                       ret.appendValue(rpos, cpos++, zero);
+                               }
                        }
                }
        }
@@ -1313,40 +1347,86 @@ public class LibMatrixBincell {
        }
        
        private static long safeBinaryMMDenseDenseDense(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret,
-               BinaryOperator op, int rl, int ru)
-       {
-               boolean isPM = m1.clen >= 512 & (op.fn instanceof PlusMultiply 
| op.fn instanceof MinusMultiply);
-               double cntPM = !isPM ? Double.NaN : (op.fn instanceof 
PlusMultiply ?
-                       ((PlusMultiply)op.fn).getConstant() : -1d * 
((MinusMultiply)op.fn).getConstant());
+               BinaryOperator op, int rl, int ru){
+               final int clen = m1.clen;
+               final boolean isPM = (op.fn instanceof PlusMultiply || op.fn 
instanceof MinusMultiply);
                
                //guard for postponed allocation in single-threaded exec
-               if( !ret.isAllocated() )
+               if(!ret.isAllocated())
                        ret.allocateDenseBlock();
                
-               DenseBlock da = m1.getDenseBlock();
-               DenseBlock db = m2.getDenseBlock();
-               DenseBlock dc = ret.getDenseBlock();
-               ValueFunction fn = op.fn;
-               int clen = m1.clen;
+               final DenseBlock da = m1.getDenseBlock();
+               final DenseBlock db = m2.getDenseBlock();
+               final DenseBlock dc = ret.getDenseBlock();
                
-               //compute dense-dense binary, maintain nnz on-the-fly
+               if(isPM && clen >= 64)
+                       return safeBinaryMMDenseDenseDensePM_Vec(da, db, dc, 
op, rl, ru, clen);
+               else if(da.isContiguous() && db.isContiguous() && 
dc.isContiguous()) {
+                       if(op.fn instanceof PlusMultiply)
+                               return safeBinaryMMDenseDenseDensePM(da, db, 
dc, op, rl, ru, clen);
+                       else
+                               return 
safeBinaryMMDenseDenseDenseContiguous(da, db, dc, op, rl, ru, clen);
+               }
+               else
+                       return safeBinaryMMDenseDenseDenseGeneric(da, db, dc, 
op, rl, ru, clen);
+       }
+
+       private static final long safeBinaryMMDenseDenseDensePM_Vec(DenseBlock 
da, DenseBlock db, DenseBlock dc, BinaryOperator op,
+               int rl, int ru, int clen) {
+               final double cntPM = (op.fn instanceof PlusMultiply ? 
((PlusMultiply) op.fn).getConstant() : -1d *
+                       ((MinusMultiply) op.fn).getConstant());
                long lnnz = 0;
-               for(int i=rl; i<ru; i++) {
-                       double[] a = da.values(i);
-                       double[] b = db.values(i);
-                       double[] c = dc.values(i);
+               for(int i = rl; i < ru; i++) {
+                       final double[] a = da.values(i);
+                       final double[] b = db.values(i);
+                       final double[] c = dc.values(i);
                        int pos = da.pos(i);
-                       
-                       if( isPM ) {
-                               System.arraycopy(a, pos, c, pos, clen);
-                               LibMatrixMult.vectMultiplyAdd(cntPM, b, c, pos, 
pos, clen);
-                               lnnz += UtilFunctions.computeNnz(c, pos, clen);
-                       }
-                       else {
-                               for(int j=pos; j<pos+clen; j++) {
-                                       c[j] = fn.execute(a[j], b[j]);
-                                       lnnz += (c[j]!=0)? 1 : 0;
-                               }
+                       System.arraycopy(a, pos, c, pos, clen);
+                       LibMatrixMult.vectMultiplyAdd(cntPM, b, c, pos, pos, 
clen);
+                       lnnz += UtilFunctions.computeNnz(c, pos, clen);
+               }
+               return lnnz;
+       }
+
+       private static final long safeBinaryMMDenseDenseDensePM(DenseBlock da, 
DenseBlock db, DenseBlock dc, BinaryOperator op,
+               int rl, int ru, int clen) {
+               long lnnz = 0;
+               final double[] a = da.values(0);
+               final double[] b = db.values(0);
+               final double[] c = dc.values(0);
+               final double d = ((PlusMultiply)op.fn).getConstant();
+               for(int i = da.pos(rl); i < da.pos(ru); i++) {
+                       c[i] = a[i] + d * b[i];
+                       lnnz += (c[i] != 0) ? 1 : 0;
+               }
+               return lnnz;
+       }
+
+               private static final long 
safeBinaryMMDenseDenseDenseContiguous(DenseBlock da, DenseBlock db, DenseBlock 
dc, BinaryOperator op,
+               int rl, int ru, int clen) {
+               long lnnz = 0;
+               final double[] a = da.values(0);
+               final double[] b = db.values(0);
+               final double[] c = dc.values(0);
+               for(int i = da.pos(rl); i < da.pos(ru); i++) {
+                       c[i] += op.fn.execute(a[i], b[i]);
+                       lnnz += (c[i] != 0) ? 1 : 0;
+               }
+               return lnnz;
+       }
+
+       private static final long safeBinaryMMDenseDenseDenseGeneric(DenseBlock 
da, DenseBlock db, DenseBlock dc,
+               BinaryOperator op, int rl, int ru, int clen) {
+               final ValueFunction fn = op.fn;
+               long lnnz = 0;
+               for(int i = rl; i < ru; i++) {
+                       final double[] a = da.values(i);
+                       final double[] b = db.values(i);
+                       final double[] c = dc.values(i);
+                       int pos = da.pos(i);
+                       for(int j = pos; j < pos + clen; j++) {
+                               c[j] = fn.execute(a[j], b[j]);
+                               lnnz += (c[j] != 0) ? 1 : 0;
                        }
                }
                return lnnz;
@@ -1749,11 +1829,9 @@ public class LibMatrixBincell {
        }
 
        private static void safeBinaryInPlaceMatrixMatrix(MatrixBlock m1ret, 
MatrixBlock m2, BinaryOperator op) {
-               if(op.fn instanceof Plus && m1ret.isEmpty()) {
+               if(op.fn instanceof Plus && m1ret.isEmpty() && 
!m1ret.isAllocated())
                        m1ret.copy(m2);
-                       return;
-               }
-               if(m1ret.sparse && m2.sparse)
+               else if(m1ret.sparse && m2.sparse)
                        safeBinaryInPlaceSparse(m1ret, m2, op);
                else if(!m1ret.sparse && !m2.sparse)
                        safeBinaryInPlaceDense(m1ret, m2, op);
@@ -1875,43 +1953,72 @@ public class LibMatrixBincell {
 
        private static void safeBinaryInPlaceDense(MatrixBlock m1ret, 
MatrixBlock m2, BinaryOperator op) {
                // prepare outputs
-               m1ret.allocateDenseBlock();
+               if(!m1ret.isAllocated()) // allocate
+                       m1ret.allocateDenseBlock();
+
+               if(m2.isEmptyBlock(false))
+                       safeBinaryInPlaceDenseEmpty(m1ret, op);
+               else if(op.fn instanceof Plus)
+                       safeBinaryInPlaceDensePlus(m1ret, m2, op);
+               else
+                       safeBinaryInPlaceDenseGeneric(m1ret, m2, op);
+       }
+
+       private static void safeBinaryInPlaceDenseEmpty(MatrixBlock m1ret, 
BinaryOperator op) {
                DenseBlock a = m1ret.getDenseBlock();
-               DenseBlock b = m2.getDenseBlock();
                final int rlen = m1ret.rlen;
                final int clen = m1ret.clen;
-
                long lnnz = 0;
-               if(m2.isEmptyBlock(false)) {
-                       for(int r = 0; r < rlen; r++) {
-                               double[] avals = a.values(r);
-                               for(int c = 0, ix = a.pos(r); c < clen; c++, 
ix++) {
-                                       double tmp = op.fn.execute(avals[ix], 
0);
-                                       lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
-                               }
+               for(int r = 0; r < rlen; r++) {
+                       double[] avals = a.values(r);
+                       for(int c = 0, ix = a.pos(r); c < clen; c++, ix++) {
+                               double tmp = op.fn.execute(avals[ix], 0);
+                               lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
                        }
                }
-               else if(op.fn instanceof Plus) {
+               m1ret.setNonZeros(lnnz);
+       }
+
+       private static void safeBinaryInPlaceDensePlus(MatrixBlock m1ret, 
MatrixBlock m2, BinaryOperator op) {
+               DenseBlock a = m1ret.getDenseBlock();
+               DenseBlock b = m2.getDenseBlock();
+               final int rlen = m1ret.rlen;
+               final int clen = m1ret.clen;
+               long lnnz = 0;
+               if(a.isContiguous() && b.isContiguous()){
+                       final double[] avals = a.values(0);
+                       final double[] bvals = b.values(0);
+                       for(int i = 0; i < avals.length; i++)
+                               lnnz += (avals[i] += bvals[i]) == 0 ? 0 : 1;
+               }
+               else{
                        for(int r = 0; r < rlen; r++) {
-                               int aix = a.pos(r), bix = b.pos(r);
-                               double[] avals = a.values(r), bvals = 
b.values(r);
+                               final int aix = a.pos(r), bix = b.pos(r);
+                               final double[] avals = a.values(r), bvals = 
b.values(r);
                                LibMatrixMult.vectAdd(bvals, avals, bix, aix, 
clen);
                                lnnz += UtilFunctions.computeNnz(avals, aix, 
clen);
                        }
                }
-               else {
-                       for(int r = 0; r < rlen; r++) {
-                               double[] avals = a.values(r), bvals = 
b.values(r);
-                               for(int c = 0, ix = a.pos(r); c < clen; c++, 
ix++) {
-                                       double tmp = op.fn.execute(avals[ix], 
bvals[ix]);
-                                       lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
-                               }
+               m1ret.setNonZeros(lnnz);
+       }
+
+       private static void safeBinaryInPlaceDenseGeneric(MatrixBlock m1ret, 
MatrixBlock m2, BinaryOperator op) {
+               DenseBlock a = m1ret.getDenseBlock();
+               DenseBlock b = m2.getDenseBlock();
+               final int rlen = m1ret.rlen;
+               final int clen = m1ret.clen;
+               long lnnz = 0;
+               for(int r = 0; r < rlen; r++) {
+                       double[] avals = a.values(r), bvals = b.values(r);
+                       for(int c = 0, ix = a.pos(r); c < clen; c++, ix++) {
+                               double tmp = op.fn.execute(avals[ix], 
bvals[ix]);
+                               lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
                        }
                }
-
                m1ret.setNonZeros(lnnz);
        }
 
+
        private static void safeBinaryInPlaceDenseConst(MatrixBlock m1ret, 
double m2, BinaryOperator op) {
                // prepare outputs
                m1ret.allocateDenseBlock();
@@ -1984,8 +2091,7 @@ public class LibMatrixBincell {
                        }
        }
        
-       private static void unsafeBinaryInPlace(MatrixBlock m1ret, MatrixBlock 
m2, BinaryOperator op)
-       {
+       private static void unsafeBinaryInPlace(MatrixBlock m1ret, MatrixBlock 
m2, BinaryOperator op){
                int rlen = m1ret.rlen;
                int clen = m1ret.clen;
                BinaryAccessType atype = getBinaryAccessType(m1ret, m2);
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index 6fda33ad09..3fa91bb31e 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -364,7 +364,7 @@ public class LibMatrixMult
         * The parameter k (k&gt;=1) determines the max parallelism k' with 
k'=min(k, vcores, m1.rlen).
         * 
         * NOTE: This multi-threaded mmchain operation has additional memory 
requirements of k*ncol(X)*8bytes 
-        * for partial aggregation. Current max memory: 256KB; otherwise 
redirectly to sequential execution.
+        * for partial aggregation. Current max memory: 256KB; otherwise 
redirect to sequential execution.
         * 
         * @param mX X matrix
         * @param mV v matrix

Reply via email to