This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit af4f3d7683cf5dcd62d45858fb0290a607e66dfe
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Feb 13 01:31:57 2021 +0100

    [SYSTEMDS-2856] Extended multi-threading binary and ternary operations
    
    This patch generalized the multi-threading of binary (sparse-unsafe) to
    binary (sparse-unsafe and sparse-safe matrix) and ternary operations,
    where the latter often calls binary sparse-safe matrix operations. For
    an mnist lenet parameter server scenario, this patch improved end-to-end
    performance from 205s to 168s. It also slightly improved other
    algorithms like KMeans.
---
 src/main/java/org/apache/sysds/hops/TernaryOp.java |  16 +-
 src/main/java/org/apache/sysds/lops/Ternary.java   |  11 +-
 .../runtime/instructions/InstructionUtils.java     |   6 +-
 .../cp/ParamservBuiltinCPInstruction.java          |   6 +-
 .../instructions/cp/TernaryCPInstruction.java      |   3 +-
 .../runtime/matrix/data/LibMatrixBincell.java      | 241 ++++++++++++---------
 .../runtime/matrix/data/LibMatrixTercell.java      | 132 +++++++++++
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  25 +--
 .../runtime/matrix/operators/TernaryOperator.java  |  10 +-
 9 files changed, 321 insertions(+), 129 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java 
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index f8c369c..47e42bb 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -57,9 +57,8 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
  *
  * CTABLE op takes 2 extra inputs with target dimensions for padding and 
pruning.
  */
-public class TernaryOp extends Hop 
+public class TernaryOp extends MultiThreadedHop
 {
-       
        public static boolean ALLOW_CTABLE_SEQUENCE_REWRITES = true;
        
        private OpOp3 _op = null;
@@ -147,6 +146,13 @@ public class TernaryOp extends Hop
        }
        
        @Override
+       public boolean isMultiThreadedOpType() {
+               return _op == OpOp3.IFELSE
+                       || _op == OpOp3.MINUS_MULT
+                       || _op == OpOp3.PLUS_MULT;
+       }
+       
+       @Override
        public Lop constructLops() 
        {
                //return already created lops
@@ -324,13 +330,17 @@ public class TernaryOp extends Hop
 
        private void constructLopsTernaryDefault() {
                ExecType et = optFindExecType();
+               int k = 1;
                if( getInput().stream().allMatch(h -> 
h.getDataType().isScalar()) )
                        et = ExecType.CP; //always CP for pure scalar operations
+               else
+                       k= OptimizerUtils.getConstrainedNumThreads( 
_maxNumThreads );
+               
                Ternary plusmult = new Ternary(_op,
                        getInput().get(0).constructLops(),
                        getInput().get(1).constructLops(),
                        getInput().get(2).constructLops(), 
-                       getDataType(),getValueType(), et );
+                       getDataType(),getValueType(), et, k );
                setOutputDimensions(plusmult);
                setLineNumbers(plusmult);
                setLops(plusmult);
diff --git a/src/main/java/org/apache/sysds/lops/Ternary.java 
b/src/main/java/org/apache/sysds/lops/Ternary.java
index 8faea5d..a1a2d53 100644
--- a/src/main/java/org/apache/sysds/lops/Ternary.java
+++ b/src/main/java/org/apache/sysds/lops/Ternary.java
@@ -33,10 +33,12 @@ import org.apache.sysds.common.Types.ValueType;
 public class Ternary extends Lop 
 {
        private final OpOp3 _op;
-               
-       public Ternary(OpOp3 op, Lop input1, Lop input2, Lop input3, DataType 
dt, ValueType vt, ExecType et) {
+       private final int _numThreads;
+       
+       public Ternary(OpOp3 op, Lop input1, Lop input2, Lop input3, DataType 
dt, ValueType vt, ExecType et, int numThreads) {
                super(Lop.Type.Ternary, dt, vt);
                _op = op;
+               _numThreads = numThreads;
                init(input1, input2, input3, et);
        }
 
@@ -71,6 +73,11 @@ public class Ternary extends Lop
                sb.append( OPERAND_DELIMITOR );
                sb.append( prepOutputOperand(output) );
                
+               if( getExecType() == ExecType.CP && getDataType().isMatrix() ) {
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( _numThreads );
+               }
+               
                return sb.toString();
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 49c3452..9245132 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -595,8 +595,12 @@ public class InstructionUtils
        }
        
        public static TernaryOperator parseTernaryOperator(String opcode) {
+               return parseTernaryOperator(opcode, 1);
+       }
+       
+       public static TernaryOperator parseTernaryOperator(String opcode, int 
numThreads) {
                return new TernaryOperator(opcode.equals("+*") ? 
PlusMultiply.getFnObject() :
-                       opcode.equals("-*") ? MinusMultiply.getFnObject() : 
IfElse.getFnObject());
+                       opcode.equals("-*") ? MinusMultiply.getFnObject() : 
IfElse.getFnObject(), numThreads);
        }
        
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 0fa5297..a99e8ee 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -320,7 +320,8 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
                // Create the local workers
                List<LocalPSWorker> workers = IntStream.range(0, workerNum)
-                       .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, 
getEpochs(), getBatchSize(), workerECs.get(i), ps))
+                       .mapToObj(i -> new LocalPSWorker(i, updFunc, freq,
+                               getEpochs(), getBatchSize(), workerECs.get(i), 
ps, workerNum==1))
                        .collect(Collectors.toList());
 
                // Do data partition
@@ -497,7 +498,8 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        private void partitionLocally(PSScheme scheme, ExecutionContext ec, 
List<LocalPSWorker> workers) {
                MatrixObject features = 
ec.getMatrixObject(getParam(PS_FEATURES));
                MatrixObject labels = ec.getMatrixObject(getParam(PS_LABELS));
-               DataPartitionLocalScheme.Result result = new 
LocalDataPartitioner(scheme).doPartitioning(workers.size(), 
features.acquireReadAndRelease(), labels.acquireReadAndRelease());
+               DataPartitionLocalScheme.Result result = new 
LocalDataPartitioner(scheme)
+                       .doPartitioning(workers.size(), 
features.acquireReadAndRelease(), labels.acquireReadAndRelease());
                List<MatrixObject> pfs = result.pFeatures;
                List<MatrixObject> pls = result.pLabels;
                if (pfs.size() < workers.size()) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
index 9d4232b..14d0090 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
@@ -38,7 +38,8 @@ public class TernaryCPInstruction extends 
ComputationCPInstruction {
                CPOperand operand2 = new CPOperand(parts[2]);
                CPOperand operand3 = new CPOperand(parts[3]);
                CPOperand outOperand = new CPOperand(parts[4]);
-               TernaryOperator op = 
InstructionUtils.parseTernaryOperator(opcode);
+               int numThreads = parts.length>5 ? Integer.parseInt(parts[5]) : 
1;
+               TernaryOperator op = 
InstructionUtils.parseTernaryOperator(opcode, numThreads);
                return new TernaryCPInstruction(op, operand1, operand2, 
operand3, outOperand, opcode,str);
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
index 249d0e3..29363b0 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
@@ -163,11 +163,19 @@ public class LibMatrixBincell
         * @param op binary operator
         */
        public static void bincellOp(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op) {
+               BinaryAccessType atype = getBinaryAccessType(m1, m2);
+               
+               //preallocate for consistency
+               if( atype == BinaryAccessType.MATRIX_MATRIX )
+                       ret.allocateBlock(); //chosen outside
+               
                //execute binary cell operations
+               long nnz = 0;
                if(op.sparseSafe || isSparseSafeDivide(op, m2))
-                       safeBinary(m1, m2, ret, op);
+                       nnz = safeBinary(m1, m2, ret, op, atype, 0, m1.rlen);
                else
-                       unsafeBinary(m1, m2, ret, op, 0, m1.rlen);
+                       nnz = unsafeBinary(m1, m2, ret, op, 0, m1.rlen);
+               ret.setNonZeros(nnz);
                
                //ensure empty results sparse representation 
                //(no additional memory requirements)
@@ -176,18 +184,20 @@ public class LibMatrixBincell
        }
        
        public static void bincellOp(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op, int k) {
+               BinaryAccessType atype = getBinaryAccessType(m1, m2);
+               
                //fallback to sequential computation for specialized operations
-               //TODO parallel support for all sparse safe operations
-               if( op.sparseSafe || isSparseSafeDivide(op, m2)
-                       || ret.getLength() < PAR_NUMCELL_THRESHOLD2
-                       || getBinaryAccessType(m1, m2) == 
BinaryAccessType.OUTER_VECTOR_VECTOR)
+               if( m1.isEmpty() || m2.isEmpty()
+                       || ret.getLength() < PAR_NUMCELL_THRESHOLD2
+                       || ((op.sparseSafe || isSparseSafeDivide(op, m2))
+                               && atype != BinaryAccessType.MATRIX_MATRIX))
                {
                        bincellOp(m1, m2, ret, op);
                        return;
                }
                
                //preallocate dense/sparse block for multi-threaded operations
-               ret.allocateBlock();
+               ret.allocateBlock(); //chosen outside
                
                try {
                        //execute binary cell operations
@@ -195,7 +205,7 @@ public class LibMatrixBincell
                        ArrayList<BincellTask> tasks = new ArrayList<>();
                        ArrayList<Integer> blklens = 
UtilFunctions.getBalancedBlockSizesDefault(ret.rlen, k, false);
                        for( int i=0, lb=0; i<blklens.size(); 
lb+=blklens.get(i), i++ )
-                               tasks.add(new BincellTask(m1, m2, ret, op, lb, 
lb+blklens.get(i)));
+                               tasks.add(new BincellTask(m1, m2, ret, op, 
atype, lb, lb+blklens.get(i)));
                        List<Future<Long>> taskret = pool.invokeAll(tasks);
                        
                        //aggregate non-zeros
@@ -286,7 +296,11 @@ public class LibMatrixBincell
        // private sparse-safe/sparse-unsafe implementations
        ///////////////////////////////////
 
-       private static void safeBinary(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op) {
+       private static long safeBinary(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op,
+               BinaryAccessType atype, int rl, int ru)
+       {
+               //NOTE: multi-threaded over rl-ru only applied for 
matrix-matrix, non-empty
+               
                boolean skipEmpty = (op.fn instanceof Multiply 
                        || isSparseSafeDivide(op, m2) );
                boolean copyLeftRightEmpty = (op.fn instanceof Plus || op.fn 
instanceof Minus 
@@ -297,12 +311,11 @@ public class LibMatrixBincell
                if( m1.isEmptyBlock(false) && m2.isEmptyBlock(false) 
                        || skipEmpty && (m1.isEmptyBlock(false) || 
m2.isEmptyBlock(false)) ) 
                {
-                       return;
+                       return 0;
                }
                
-               BinaryAccessType atype = getBinaryAccessType(m1, m2);
                if( atype == BinaryAccessType.MATRIX_COL_VECTOR //MATRIX - 
VECTOR
-                       || atype == BinaryAccessType.MATRIX_ROW_VECTOR)  
+                       || atype == BinaryAccessType.MATRIX_ROW_VECTOR)
                {
                        //note: m2 vector and hence always dense
                        if( !m1.sparse && !m2.sparse && !ret.sparse ) //DENSE 
all
@@ -318,7 +331,7 @@ public class LibMatrixBincell
                                safeBinaryMVDenseSparseMult(m1, m2, ret, op);
                        else //generic combinations
                                safeBinaryMVGeneric(m1, m2, ret, op);
-               }       
+               }
                else if( atype == BinaryAccessType.OUTER_VECTOR_VECTOR ) 
//VECTOR - VECTOR
                {
                        safeBinaryVVGeneric(m1, m2, ret, op);
@@ -334,25 +347,27 @@ public class LibMatrixBincell
                                ret.copyShallow(m2);
                        }
                        else if(m1.sparse && m2.sparse) {
-                               safeBinaryMMSparseSparse(m1, m2, ret, op);
+                               return safeBinaryMMSparseSparse(m1, m2, ret, 
op, rl, ru);
                        }
                        else if( !ret.sparse && (m1.sparse || m2.sparse) &&
                                (op.fn instanceof Plus || op.fn instanceof 
Minus ||
                                op.fn instanceof PlusMultiply || op.fn 
instanceof MinusMultiply ||
                                (op.fn instanceof Multiply && !m2.sparse ))) {
-                               safeBinaryMMSparseDenseDense(m1, m2, ret, op);
+                               return safeBinaryMMSparseDenseDense(m1, m2, 
ret, op, rl, ru);
                        }
                        else if( !ret.sparse && !m1.sparse && !m2.sparse 
                                && m1.denseBlock!=null && m2.denseBlock!=null ) 
{
-                               safeBinaryMMDenseDenseDense(m1, m2, ret, op);
+                               return safeBinaryMMDenseDenseDense(m1, m2, ret, 
op, rl, ru);
                        }
                        else if( skipEmpty && (m1.sparse || m2.sparse) ) {
-                               safeBinaryMMSparseDenseSkip(m1, m2, ret, op);
+                               return safeBinaryMMSparseDenseSkip(m1, m2, ret, 
op, rl, ru);
                        }
                        else { //generic case
-                               safeBinaryMMGeneric(m1, m2, ret, op);
+                               return safeBinaryMMGeneric(m1, m2, ret, op, rl, 
ru);
                        }
                }
+               //default catch all
+               return ret.getNonZeros();
        }
 
        private static void safeBinaryMVDense(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op) {
@@ -737,12 +752,11 @@ public class LibMatrixBincell
                //no need to recomputeNonZeros since maintained in append value
        }
        
-       private static void safeBinaryMMSparseSparse(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, BinaryOperator op) {
-               final int rlen = m1.rlen;
-               if(ret.sparse)
-                       ret.allocateSparseRowsBlock();
-               
+       private static long safeBinaryMMSparseSparse(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret,
+               BinaryOperator op, int rl, int ru)
+       {
                //both sparse blocks existing
+               long lnnz = 0;
                if(m1.sparseBlock!=null && m2.sparseBlock!=null)
                {
                        SparseBlock lsblock = m1.sparseBlock;
@@ -751,7 +765,7 @@ public class LibMatrixBincell
                        if( ret.sparse && lsblock.isAligned(rsblock) )
                        {
                                SparseBlock c = ret.sparseBlock;
-                               for(int r=0; r<rlen; r++) 
+                               for(int r=rl; r<ru; r++) 
                                        if( !lsblock.isEmpty(r) ) {
                                                int alen = lsblock.size(r);
                                                int apos = lsblock.pos(r);
@@ -763,12 +777,12 @@ public class LibMatrixBincell
                                                        double tmp = 
op.fn.execute(avals[j], bvals[j]);
                                                        c.append(r, aix[j], 
tmp);
                                                }
-                                               ret.nonZeros += c.size(r);
+                                               lnnz += c.size(r);
                                        }
                        }
                        else //general case
                        {
-                               for(int r=0; r<rlen; r++) {
+                               for(int r=rl; r<ru; r++) {
                                        if( !lsblock.isEmpty(r) && 
!rsblock.isEmpty(r) ) {
                                                mergeForSparseBinary(op, 
lsblock.values(r), lsblock.indexes(r), lsblock.pos(r), lsblock.size(r),
                                                        rsblock.values(r), 
rsblock.indexes(r), rsblock.pos(r), rsblock.size(r), r, ret);
@@ -782,6 +796,7 @@ public class LibMatrixBincell
                                                        lsblock.pos(r), 
lsblock.size(r), 0, r, ret);
                                        }
                                        // do nothing if both not existing
+                                       lnnz += ret.recomputeNonZeros(r, r);
                                }
                        }
                }
@@ -789,55 +804,63 @@ public class LibMatrixBincell
                else if( m2.sparseBlock!=null )
                {
                        SparseBlock rsblock = m2.sparseBlock;
-                       for(int r=0; r<Math.min(rlen, rsblock.numRows()); r++) {
+                       for(int r=rl; r<Math.min(ru, rsblock.numRows()); r++) {
                                if( rsblock.isEmpty(r) ) continue;
                                appendRightForSparseBinary(op, 
rsblock.values(r), rsblock.indexes(r), 
                                        rsblock.pos(r), rsblock.size(r), 0, r, 
ret);
+                               lnnz += ret.recomputeNonZeros(r, r);
                        }
                }
                //left sparse block existing
                else
                {
                        SparseBlock lsblock = m1.sparseBlock;
-                       for(int r=0; r<rlen; r++) {
+                       for(int r=rl; r<ru; r++) {
                                if( lsblock.isEmpty(r) ) continue;
                                appendLeftForSparseBinary(op, 
lsblock.values(r), lsblock.indexes(r), 
                                        lsblock.pos(r), lsblock.size(r), 0, r, 
ret);
+                               lnnz += ret.recomputeNonZeros(r, r);
                        }
                }
+               return lnnz;
        }
        
-       private static void safeBinaryMMSparseDenseDense(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, BinaryOperator op) {
+       private static long safeBinaryMMSparseDenseDense(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret,
+               BinaryOperator op, int rl, int ru)
+       {
                //specific case in order to prevent binary search on sparse 
inputs (see quickget and quickset)
-               ret.allocateDenseBlock();
                final int n = ret.clen;
                DenseBlock dc = ret.getDenseBlock();
                
                //1) process left input: assignment
-               
                if( m1.sparse && m1.sparseBlock != null ) //SPARSE left
                {
                        SparseBlock a = m1.sparseBlock;
-                       for( int bi=0; bi<dc.numBlocks(); bi++ ) {
-                               double[] c = dc.valuesAt(bi);
-                               int blen = dc.blockSize(bi);
-                               int off = bi * dc.blockSize();
-                               for( int i=0, ix=0; i<blen; i++, ix+=n ) {
-                                       int ai = off + i;
-                                       if( a.isEmpty(ai) ) continue;
-                                       int apos = a.pos(ai);
-                                       int alen = a.size(ai);
-                                       int[] aix = a.indexes(ai);
-                                       double[] avals = a.values(ai);
-                                       for(int k = apos; k < apos+alen; k++) 
-                                               c[ix+aix[k]] = avals[k];
-                               }
+                       for(int i=rl; i<ru; i++) {
+                               double[] c = dc.values(i);
+                               int cpos = dc.pos(i);
+                               if( a.isEmpty(i) ) continue;
+                               int apos = a.pos(i);
+                               int alen = a.size(i);
+                               int[] aix = a.indexes(i);
+                               double[] avals = a.values(i);
+                               for(int k = apos; k < apos+alen; k++) 
+                                       c[cpos+aix[k]] = avals[k];
                        }
                }
                else if( !m1.sparse ) //DENSE left
                {
-                       if( !m1.isEmptyBlock(false) ) 
-                               dc.set(m1.getDenseBlock());
+                       if( !m1.isEmptyBlock(false) ) {
+                               int rlbix = dc.index(rl);
+                               int rubix = dc.index(ru-1);
+                               DenseBlock da = m1.getDenseBlock();
+                               if( rlbix == rubix )
+                                       System.arraycopy(da.valuesAt(rlbix), 
da.pos(rl), dc.valuesAt(rlbix), dc.pos(rl), (ru-rl)*n);
+                               else {
+                                       for(int i=rl; i<ru; i++)
+                                               System.arraycopy(da.values(i), 
da.pos(i), dc.values(i), dc.pos(i), n);
+                               }
+                       }
                        else
                                dc.set(0);
                }
@@ -847,35 +870,32 @@ public class LibMatrixBincell
                if( m2.sparse && m2.sparseBlock!=null ) //SPARSE right
                {
                        SparseBlock a = m2.sparseBlock;
-                       for( int bi=0; bi<dc.numBlocks(); bi++ ) {
-                               double[] c = dc.valuesAt(bi);
-                               int blen = dc.blockSize(bi);
-                               int off = bi * dc.blockSize();
-                               for( int i=0, ix=0; i<blen; i++, ix+=n ) {
-                                       int ai = off + i;
-                                       if( !a.isEmpty(ai) ) {
-                                               int apos = a.pos(ai);
-                                               int alen = a.size(ai);
-                                               int[] aix = a.indexes(ai);
-                                               double[] avals = a.values(ai);
-                                               for(int k = apos; k < 
apos+alen; k++) 
-                                                       c[ix+aix[k]] = 
op.fn.execute(c[ix+aix[k]], avals[k]);
-                                       }
-                                       //exploit temporal locality of rows
-                                       lnnz += ret.recomputeNonZeros(ai, ai, 
0, n-1);
+                       for(int i=rl; i<ru; i++) {
+                               double[] c = dc.values(i);
+                               int cpos = dc.pos(i);
+                               if( !a.isEmpty(i) ) {
+                                       int apos = a.pos(i);
+                                       int alen = a.size(i);
+                                       int[] aix = a.indexes(i);
+                                       double[] avals = a.values(i);
+                                       for(int k = apos; k < apos+alen; k++) 
+                                               c[cpos+aix[k]] = 
op.fn.execute(c[cpos+aix[k]], avals[k]);
                                }
+                               //exploit temporal locality of rows
+                               lnnz += ret.recomputeNonZeros(i, i);
                        }
                }
                else if( !m2.sparse ) //DENSE right
                {
                        if( !m2.isEmptyBlock(false) ) {
-                               for( int bi=0; bi<dc.numBlocks(); bi++ ) {
-                                       double[] a = 
m2.getDenseBlock().valuesAt(bi);
-                                       double[] c = dc.valuesAt(bi);
-                                       int len = dc.size(bi);
-                                       for( int i=0; i<len; i++ ) {
-                                               c[i] = op.fn.execute(c[i], 
a[i]);
-                                               lnnz += (c[i]!=0) ? 1 : 0;
+                               DenseBlock da = m2.getDenseBlock();
+                               for( int i=rl; i<ru; i++ ) {
+                                       double[] a = da.values(i);
+                                       double[] c = dc.values(i);
+                                       int apos = da.pos(i);
+                                       for( int j = apos; j<apos+n; j++ ) {
+                                               c[j] = op.fn.execute(c[j], 
a[j]);
+                                               lnnz += (c[j]!=0) ? 1 : 0;
                                        }
                                }
                        }
@@ -886,41 +906,45 @@ public class LibMatrixBincell
                }
                
                //3) recompute nnz
-               ret.setNonZeros(lnnz);
+               return lnnz;
        }
        
-       private static void safeBinaryMMDenseDenseDense(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, BinaryOperator op) {
-               ret.allocateDenseBlock();
+       private static long safeBinaryMMDenseDenseDense(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret,
+               BinaryOperator op, int rl, int ru)
+       {
                DenseBlock da = m1.getDenseBlock();
                DenseBlock db = m2.getDenseBlock();
                DenseBlock dc = ret.getDenseBlock();
                ValueFunction fn = op.fn;
+               int clen = m1.clen;
                
                //compute dense-dense binary, maintain nnz on-the-fly
                long lnnz = 0;
-               for( int bi=0; bi<da.numBlocks(); bi++ ) {
-                       double[] a = da.valuesAt(bi);
-                       double[] b = db.valuesAt(bi);
-                       double[] c = dc.valuesAt(bi);
-                       int len = da.size(bi);
-                       for( int i=0; i<len; i++ ) {
-                               c[i] = fn.execute(a[i], b[i]);
-                               lnnz += (c[i]!=0)? 1 : 0;
+               for(int i=rl; i<ru; i++) {
+                       double[] a = da.values(i);
+                       double[] b = db.values(i);
+                       double[] c = dc.values(i);
+                       int pos = da.pos(i);
+                       for(int j=pos; j<pos+clen; j++) {
+                               c[j] = fn.execute(a[j], b[j]);
+                               lnnz += (c[j]!=0)? 1 : 0;
                        }
                }
-               ret.setNonZeros(lnnz);
+               return lnnz;
        }
        
-       private static void safeBinaryMMSparseDenseSkip(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, BinaryOperator op) {
+       private static long safeBinaryMMSparseDenseSkip(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret,
+               BinaryOperator op, int rl, int ru)
+       {
                SparseBlock a = m1.sparse ? m1.sparseBlock : m2.sparseBlock;
                if( a == null )
-                       return;
+                       return 0;
                
                //prepare second input and allocate output
                MatrixBlock b = m1.sparse ? m2 : m1;
-               ret.allocateBlock();
                
-               for( int i=0; i<a.numRows(); i++ ) {
+               long lnnz = 0;
+               for( int i=rl; i<Math.min(ru, a.numRows()); i++ ) {
                        if( a.isEmpty(i) ) continue;
                        int apos = a.pos(i);
                        int alen = a.size(i);
@@ -932,22 +956,28 @@ public class LibMatrixBincell
                                double in2 = b.quickGetValue(i, aix[k]);
                                if( in2==0 ) continue;
                                double val = op.fn.execute(avals[k], in2);
-                               ret.appendValue(i, aix[k], val);
+                               lnnz += (val != 0) ? 1 : 0;
+                               ret.appendValuePlain(i, aix[k], val);
                        }
                }
+               return lnnz;
        }
        
-       private static void safeBinaryMMGeneric(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op) {
-               int rlen = m1.rlen;
+       private static long safeBinaryMMGeneric(MatrixBlock m1, MatrixBlock m2,
+               MatrixBlock ret, BinaryOperator op, int rl, int ru)
+       {
                int clen = m2.clen;
-               for(int r=0; r<rlen; r++)
+               long lnnz = 0;
+               for(int r=rl; r<ru; r++)
                        for(int c=0; c<clen; c++) {
                                double in1 = m1.quickGetValue(r, c);
                                double in2 = m2.quickGetValue(r, c);
                                if( in1==0 && in2==0) continue;
                                double val = op.fn.execute(in1, in2);
-                               ret.appendValue(r, c, val);
+                               lnnz += (val != 0) ? 1 : 0;
+                               ret.appendValuePlain(r, c, val);
                        }
+               return lnnz;
        }
        
        /**
@@ -960,7 +990,7 @@ public class LibMatrixBincell
         * @param bOp binary operator
         * 
         */
-       private static void performBinOuterOperation(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, BinaryOperator bOp) {
+       private static long performBinOuterOperation(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, BinaryOperator bOp) {
                int rlen = m1.rlen;
                int clen = ret.clen;
                double b[] = DataConverter.convertToDoubleVector(m2);
@@ -1006,9 +1036,10 @@ public class LibMatrixBincell
                }
                ret.setNonZeros(lnnz);
                ret.examSparsity();
+               return lnnz;
        }
 
-       private static void unsafeBinary(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op, int rl, int ru) {
+       private static long unsafeBinary(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op, int rl, int ru) {
                int clen = m1.clen;
                BinaryAccessType atype = getBinaryAccessType(m1, m2);
                
@@ -1038,7 +1069,7 @@ public class LibMatrixBincell
                        int clen2 = m2.clen; 
                        if(LibMatrixOuterAgg.isCompareOperator(op) 
                                && m2.getNumColumns()>16 && 
SortUtils.isSorted(m2)) {
-                               performBinOuterOperation(m1, m2, ret, op);
+                               lnnz = performBinOuterOperation(m1, m2, ret, 
op);
                        } 
                        else {
                                for(int r=rl; r<ru; r++) {
@@ -1046,7 +1077,8 @@ public class LibMatrixBincell
                                        for(int c=0; c<clen2; c++) {
                                                double v2 = m2.quickGetValue(0, 
c);
                                                double v = op.fn.execute( v1, 
v2 );
-                                               ret.appendValue(r, c, v);
+                                               lnnz += (v != 0) ? 1 : 0;
+                                               ret.appendValuePlain(r, c, v);
                                        }
                                }
                        }
@@ -1079,9 +1111,7 @@ public class LibMatrixBincell
                        }
                }
                
-               //avoid false sharing in multi-threaded ops, while
-               //correctly setting the nnz for single-threaded ops
-               ret.nonZeros = lnnz;
+               return lnnz;
        }
 
        private static void safeBinaryScalar(MatrixBlock m1, MatrixBlock ret, 
ScalarOperator op, int rl, int ru) {
@@ -1504,25 +1534,28 @@ public class LibMatrixBincell
                private final MatrixBlock _m2;
                private final MatrixBlock _ret;
                private final BinaryOperator _bop;
+               BinaryAccessType _atype;
                private final int _rl;
                private final int _ru;
 
-               protected BincellTask( MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator bop, int rl, int ru ) {
+               protected BincellTask( MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator bop, BinaryAccessType atype, int rl, int ru ) {
                        _m1 = m1;
                        _m2 = m2;
                        _ret = ret;
                        _bop = bop;
+                       _atype = atype;
                        _rl = rl;
                        _ru = ru;
                }
                
                @Override
                public Long call() {
-                       //execute binary operation on row partition
-                       unsafeBinary(_m1, _m2, _ret, _bop, _rl, _ru);
-                       
-                       //maintain block nnz (upper bounds inclusive)
-                       return _ret.recomputeNonZeros(_rl, _ru-1);
+                       // execute binary operation on row partition
+                       // (including nnz maintenance)
+                       if(_bop.sparseSafe || isSparseSafeDivide(_bop, _m2))
+                               return safeBinary(_m1, _m2, _ret, _bop, _atype, 
_rl, _ru);
+                       else
+                               return unsafeBinary(_m1, _m2, _ret, _bop, _rl, 
_ru);
                }
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixTercell.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixTercell.java
new file mode 100644
index 0000000..cb48a1f
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixTercell.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.apache.sysds.runtime.util.UtilFunctions;
+
+/**
+ * Library for ternary cellwise operations.
+ * 
+ */
+public class LibMatrixTercell 
+{
+       private static final long PAR_NUMCELL_THRESHOLD = 8*1024;
+       
+       private LibMatrixTercell() {
+               //prevent instantiation via private constructor
+       }
+       
+       public static void tercellOp(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock m3, MatrixBlock ret, TernaryOperator op)
+       {
+               final boolean s1 = (m1.rlen==1 && m1.clen==1);
+               final boolean s2 = (m2.rlen==1 && m2.clen==1);
+               final boolean s3 = (m3.rlen==1 && m3.clen==1);
+               final double d1 = s1 ? m1.quickGetValue(0, 0) : Double.NaN;
+               final double d2 = s2 ? m2.quickGetValue(0, 0) : Double.NaN;
+               final double d3 = s3 ? m3.quickGetValue(0, 0) : Double.NaN;
+               
+               //allocate dense/sparse output
+               ret.allocateBlock();
+               
+               //execute ternary cell operations
+               if( op.getNumThreads() > 1 && ret.getLength() > 
PAR_NUMCELL_THRESHOLD) {
+                       try {
+                               //execute binary cell operations
+                               ExecutorService pool = 
CommonThreadPool.get(op.getNumThreads());
+                               ArrayList<TercellTask> tasks = new 
ArrayList<>();
+                               ArrayList<Integer> blklens = UtilFunctions
+                                       .getBalancedBlockSizesDefault(ret.rlen, 
op.getNumThreads(), false);
+                               for( int i=0, lb=0; i<blklens.size(); 
lb+=blklens.get(i), i++ )
+                                       tasks.add(new TercellTask(m1, m2, m3, 
ret, op, s1, s2, s3, d1, d2, d3, lb, lb+blklens.get(i)));
+                               List<Future<Long>> taskret = 
pool.invokeAll(tasks);
+                               
+                               //aggregate non-zeros
+                               ret.nonZeros = 0; //reset after execute
+                               for( Future<Long> task : taskret )
+                                       ret.nonZeros += task.get();
+                               pool.shutdown();
+                       }
+                       catch(InterruptedException | ExecutionException ex) {
+                               throw new DMLRuntimeException(ex);
+                       }
+               }
+               else {
+                       unsafeTernary(m1, m2, m3, ret, op, s1, s2, s3, d1, d2, 
d3, 0, ret.rlen);
+               }
+       }
+       
+       private static void unsafeTernary(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock m3, MatrixBlock ret,
+               TernaryOperator op, boolean s1, boolean s2, boolean s3, double 
d1, double d2, double d3, int rl, int ru)
+       {
+               //basic ternary operations (all combinations sparse/dense)
+               int n = ret.clen;
+               long lnnz = 0;
+               for( int i=rl; i<ru; i++ )
+                       for( int j=0; j<n; j++ ) {
+                               double in1 = s1 ? d1 : m1.quickGetValue(i, j);
+                               double in2 = s2 ? d2 : m2.quickGetValue(i, j);
+                               double in3 = s3 ? d3 : m3.quickGetValue(i, j);
+                               double val = op.fn.execute(in1, in2, in3);
+                               lnnz += (val != 0) ? 1 : 0;
+                               ret.appendValuePlain(i, j, val);
+                       }
+               
+               //set global output nnz once
+               ret.nonZeros = lnnz;
+       }
+       
+       private static class TercellTask implements Callable<Long> {
+               private final MatrixBlock _m1, _m2, _m3;
+               private final boolean _s1, _s2, _s3;
+               private final double _d1, _d2, _d3;
+               private final MatrixBlock _ret;
+               private final TernaryOperator _op;
+               private final int _rl, _ru;
+
+               protected TercellTask(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock m3, MatrixBlock ret, TernaryOperator op,
+                       boolean s1, boolean s2, boolean s3, double d1, double 
d2, double d3, int rl, int ru) {
+                       _m1 = m1; _m2 = m2; _m3 = m3;
+                       _s1 = s1; _s2 = s2; _s3 = s3;
+                       _d1 = d1; _d2 = d2; _d3 = d3;
+                       _ret = ret;
+                       _op = op;
+                       _rl = rl; _ru = ru;
+               }
+               
+               @Override
+               public Long call() {
+                       //execute binary operation on row partition
+                       unsafeTernary(_m1, _m2, _m3, _ret, _op, _s1, _s2, _s3, 
_d1, _d2, _d3, _rl, _ru);
+                       
+                       //maintain block nnz (upper bounds inclusive)
+                       return _ret.recomputeNonZeros(_rl, _ru-1);
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 5d4b869..1695465 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2911,7 +2911,10 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                }
                
                //prepare result
-               ret.reset(m, n, false);
+               boolean sparseOutput = (op.fn instanceof PlusMultiply || op.fn 
instanceof MinusMultiply)?
+                       evalSparseFormatInMemory(m, n, 
(s1?m*n*(d1!=0?1:0):getNonZeros())
+                               + Math.min(s2?m*n:m2.getNonZeros(), 
s3?m*n:m3.getNonZeros())) : false;
+               ret.reset(m, n, sparseOutput);
                
                if( op.fn instanceof IfElse && (s1 || nnz==0 || nnz==(long)m*n) 
) {
                        //SPECIAL CASE for shallow-copy if-else
@@ -2933,21 +2936,15 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                }
                else if (s2 != s3 && (op.fn instanceof PlusMultiply || op.fn 
instanceof MinusMultiply) ) {
                        //SPECIAL CASE for sparse-dense combinations of common 
+* and -*
-                       BinaryOperator bop = ((ValueFunctionWithConstant)op.fn)
-                               .setOp2Constant(s2 ? d2 : d3);
-                       LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, ret, 
bop);
+                       BinaryOperator bop = 
((ValueFunctionWithConstant)op.fn).setOp2Constant(s2 ? d2 : d3);
+                       if( op.getNumThreads() > 1 )
+                               LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, 
ret, bop, op.getNumThreads());
+                       else
+                               LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, 
ret, bop);
                }
                else {
-                       ret.allocateDenseBlock();
-                       
-                       //basic ternary operations
-                       for( int i=0; i<m; i++ )
-                               for( int j=0; j<n; j++ ) {
-                                       double in1 = s1 ? d1 : quickGetValue(i, 
j);
-                                       double in2 = s2 ? d2 : 
m2.quickGetValue(i, j);
-                                       double in3 = s3 ? d3 : 
m3.quickGetValue(i, j);
-                                       ret.appendValue(i, j, 
op.fn.execute(in1, in2, in3));
-                               }
+                       //DEFAULT CASE
+                       LibMatrixTercell.tercellOp(this, m2, m3, ret, op);
                        
                        //ensure correct output representation
                        ret.examSparsity();
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java
index 1caacd7..6ff8e89 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java
@@ -32,11 +32,17 @@ public class TernaryOperator  extends Operator implements 
Serializable
        private static final long serialVersionUID = 3456088891054083634L;
        
        public final TernaryValueFunction fn;
-       
-       public TernaryOperator(TernaryValueFunction p) {
+       private final int _k; // num threads
+
+       public TernaryOperator(TernaryValueFunction p, int numThreads) {
                //ternaryop is sparse-safe iff (op 0 0 0) == 0
                super (p instanceof PlusMultiply || p instanceof MinusMultiply 
|| p instanceof IfElse);
                fn = p;
+               _k = numThreads;
+       }
+       
+       public int getNumThreads() {
+               return _k;
        }
        
        @Override

Reply via email to