This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 2aad571  [SYSTEMDS-3234] Multi-threaded covariance/central-moment 
operations
2aad571 is described below

commit 2aad571d062a040c71ec3492ee1fdea6d7d4206c
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Sat Dec 18 16:31:32 2021 +0100

    [SYSTEMDS-3234] Multi-threaded covariance/central-moment operations
    
    Inspired by performance issues in SYSTEMDS-3233, this patch
    introduces multi-threaded cov/cm operations which were still
    single-threaded. These operations are mostly executed in parfor
    contexts, but if large memory requirements force a lower degree of
    parallelism in parfor, we should distributed the remaining
    parallelism to intra-operation parallelism like many other ops.
    
    Furthermore, this patch also cleans up the instruction construction
    parsing, and core cov/cm operations in order to share a common code
    path in LibMatrixAgg.
    
    On the scenario of SYSTEMDS-3233 this patch improved end-to-end
    performance from 261s to 144s eliminating cov/cm as top-2 heavy
    hitters (now right indexing due to column indexing on sparse matrix).
    On 100M row (800MB) input vectors and 100 operations, the total
    runtime improved as follows (server with 32 vcores):
    * 100x cov(100M, 100M): 105s -> 7.7s (13.6x)
    * 100x cm(100M):        109s -> 9.1s (12x)
---
 src/main/java/org/apache/sysds/hops/BinaryOp.java  |  24 ++--
 src/main/java/org/apache/sysds/hops/TernaryOp.java |  24 ++--
 .../java/org/apache/sysds/lops/CentralMoment.java  |  37 +++--
 .../java/org/apache/sysds/lops/CoVariance.java     |  43 ++----
 .../apache/sysds/runtime/functionobjects/CM.java   |   4 +
 .../apache/sysds/runtime/functionobjects/COV.java  |   1 +
 .../runtime/functionobjects/FunctionObject.java    |   4 +
 .../runtime/instructions/InstructionUtils.java     |   9 +-
 .../cp/CentralMomentCPInstruction.java             |  43 ++----
 .../instructions/cp/CovarianceCPInstruction.java   |  47 ++----
 .../sysds/runtime/matrix/data/LibMatrixAgg.java    | 159 ++++++++++++++++++++-
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 119 +--------------
 .../sysds/runtime/matrix/operators/CMOperator.java |  17 ++-
 .../runtime/matrix/operators/COVOperator.java      |  12 +-
 14 files changed, 289 insertions(+), 254 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java 
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 370442b..de6ccdc 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -59,7 +59,7 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
  *             Semantic: align indices (sort), then perform operation
  */
 
-public class BinaryOp extends MultiThreadedHop{
+public class BinaryOp extends MultiThreadedHop {
        // private static final Log LOG =  
LogFactory.getLog(BinaryOp.class.getName());
 
        //we use the full remote memory budget (but reduced by sort buffer), 
@@ -179,7 +179,9 @@ public class BinaryOp extends MultiThreadedHop{
        
        @Override
        public boolean isMultiThreadedOpType() {
-               return !getDataType().isScalar();
+               return !getDataType().isScalar()
+                       || getOp() == OpOp2.COV
+                       || getOp() == OpOp2.MOMENT;
        }
        
        @Override
@@ -279,26 +281,26 @@ public class BinaryOp extends MultiThreadedHop{
                setLops(pick);
        }
        
-       private void constructLopsCentralMoment(ExecType et) 
-       {
+       private void constructLopsCentralMoment(ExecType et) {
                // The output data type is a SCALAR if central moment 
                // gets computed in CP/SPARK, and it will be MATRIX otherwise.
                DataType dt = DataType.SCALAR;
+               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                CentralMoment cm = new CentralMoment(
-                               getInput().get(0).constructLops(), 
-                               getInput().get(1).constructLops(),
-                               dt, getValueType(), et);
-
+                       getInput().get(0).constructLops(), 
+                       getInput().get(1).constructLops(),
+                       dt, getValueType(), k, et);
                setLineNumbers(cm);
                cm.getOutputParameters().setDimensions(0, 0, 0, -1);
                setLops(cm);
        }
 
        private void constructLopsCovariance(ExecType et) {
+               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                CoVariance cov = new CoVariance(
-                               getInput().get(0).constructLops(), 
-                               getInput().get(1).constructLops(), 
-                               getDataType(), getValueType(), et);
+                       getInput().get(0).constructLops(), 
+                       getInput().get(1).constructLops(), 
+                       getDataType(), getValueType(), k, et);
                cov.getOutputParameters().setDimensions(0, 0, 0, -1);
                setLineNumbers(cov);
                setLops(cov);
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java 
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index 07e8f45..706a319 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -203,18 +203,17 @@ public class TernaryOp extends MultiThreadedHop
        /**
         * Method to construct LOPs when op = CENTRAILMOMENT.
         */
-       private void constructLopsCentralMoment()
-       {       
+       private void constructLopsCentralMoment() {
                if ( _op != OpOp3.MOMENT )
                        throw new HopsException("Unexpected operation: " + _op 
+ ", expecting " + OpOp3.MOMENT );
                
                ExecType et = optFindExecType();
-               
+               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                CentralMoment cm = new CentralMoment(
-                               getInput().get(0).constructLops(),
-                               getInput().get(1).constructLops(),
-                               getInput().get(2).constructLops(),
-                               getDataType(), getValueType(), et);
+                       getInput().get(0).constructLops(),
+                       getInput().get(1).constructLops(),
+                       getInput().get(2).constructLops(),
+                       getDataType(), getValueType(), k, et);
                cm.getOutputParameters().setDimensions(0, 0, 0, -1);
                setLineNumbers(cm);
                setLops(cm);
@@ -228,13 +227,12 @@ public class TernaryOp extends MultiThreadedHop
                        throw new HopsException("Unexpected operation: " + _op 
+ ", expecting " + OpOp3.COV );
                
                ExecType et = optFindExecType();
-               
-               
+               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                CoVariance cov = new CoVariance(
-                               getInput().get(0).constructLops(), 
-                               getInput().get(1).constructLops(), 
-                               getInput().get(2).constructLops(), 
-                               getDataType(), getValueType(), et);
+                       getInput().get(0).constructLops(),
+                       getInput().get(1).constructLops(),
+                       getInput().get(2).constructLops(),
+                       getDataType(), getValueType(), k, et);
                cov.getOutputParameters().setDimensions(0, 0, 0, -1);
                setLineNumbers(cov);
                setLops(cov);
diff --git a/src/main/java/org/apache/sysds/lops/CentralMoment.java 
b/src/main/java/org/apache/sysds/lops/CentralMoment.java
index b8c89e5..78ed543 100644
--- a/src/main/java/org/apache/sysds/lops/CentralMoment.java
+++ b/src/main/java/org/apache/sysds/lops/CentralMoment.java
@@ -30,6 +30,18 @@ import org.apache.sysds.common.Types.ValueType;
  */
 public class CentralMoment extends Lop 
 {
+       private final int _numThreads;
+       
+       public CentralMoment(Lop input1, Lop input2, DataType dt, ValueType vt, 
int numThreads, ExecType et) {
+               this(input1, input2, null, dt, vt, numThreads, et);
+       }
+
+       public CentralMoment(Lop input1, Lop input2, Lop input3, DataType dt, 
ValueType vt, int numThreads, ExecType et) {
+               super(Lop.Type.CentralMoment, dt, vt);
+               init(input1, input2, input3, et);
+               _numThreads = numThreads;
+       }
+       
        /**
         * Constructor to perform central moment.
         * input1 <- data (weighted or unweighted)
@@ -54,15 +66,6 @@ public class CentralMoment extends Lop
                lps.setProperties(inputs, et);
        }
 
-       public CentralMoment(Lop input1, Lop input2, DataType dt, ValueType vt, 
ExecType et) {
-               this(input1, input2, null, dt, vt, et);
-       }
-
-       public CentralMoment(Lop input1, Lop input2, Lop input3, DataType dt, 
ValueType vt, ExecType et) {
-               super(Lop.Type.CentralMoment, dt, vt);
-               init(input1, input2, input3, et);
-       }
-
        @Override
        public String toString() {
                return "Operation = CentralMoment";
@@ -77,21 +80,27 @@ public class CentralMoment extends Lop
         */
        @Override
        public String getInstructions(String input1, String input2, String 
input3, String output) {
+               StringBuilder sb = new StringBuilder();
                if( input3 == null ) {
-                       return InstructionUtils.concatOperands(
+                       sb.append(InstructionUtils.concatOperands(
                                getExecType().toString(), "cm",
                                getInputs().get(0).prepInputOperand(input1),
                                
getInputs().get((input3!=null)?2:1).prepScalarInputOperand(getExecType()),
-                               prepOutputOperand(output));
+                               prepOutputOperand(output)));
                }
                else {
-                       return InstructionUtils.concatOperands(
+                       sb.append(InstructionUtils.concatOperands(
                                getExecType().toString(), "cm",
                                getInputs().get(0).prepInputOperand(input1),
                                getInputs().get(1).prepInputOperand(input2),
                                
getInputs().get((input3!=null)?2:1).prepScalarInputOperand(getExecType()),
-                               prepOutputOperand(output));
+                               prepOutputOperand(output)));
+               }
+               if( getExecType() == ExecType.CP ) {
+                       sb.append(OPERAND_DELIMITOR);
+                       sb.append(String.valueOf(_numThreads));
                }
+               return sb.toString();
        }
        
        /**
@@ -104,4 +113,4 @@ public class CentralMoment extends Lop
        public String getInstructions(String input1, String input2, String 
output) {
                return getInstructions(input1, input2, null, output);
        }
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysds/lops/CoVariance.java 
b/src/main/java/org/apache/sysds/lops/CoVariance.java
index 2129b0f..dc5427d 100644
--- a/src/main/java/org/apache/sysds/lops/CoVariance.java
+++ b/src/main/java/org/apache/sysds/lops/CoVariance.java
@@ -29,44 +29,31 @@ import org.apache.sysds.common.Types.ValueType;
  */
 public class CoVariance extends Lop 
 {
-
-       public CoVariance(Lop input1, DataType dt, ValueType vt, ExecType et) {
-               super(Lop.Type.CoVariance, dt, vt);
-               init(input1, null, null, et);
-       }
+       private final int _numThreads;
        
-       public CoVariance(Lop input1, Lop input2, DataType dt, ValueType vt, 
ExecType et) {
-               this(input1, input2, null, dt, vt, et);
+       public CoVariance(Lop input1, Lop input2, DataType dt, ValueType vt, 
int numThreads, ExecType et) {
+               this(input1, input2, null, dt, vt, numThreads, et);
        }
        
-       public CoVariance(Lop input1, Lop input2, Lop input3, DataType dt, 
ValueType vt, ExecType et) {
+       public CoVariance(Lop input1, Lop input2, Lop input3, DataType dt, 
ValueType vt, int numThreads, ExecType et) {
                super(Lop.Type.CoVariance, dt, vt);
                init(input1, input2, input3, et);
+               _numThreads = numThreads;
        }
 
        private void init(Lop input1, Lop input2, Lop input3, ExecType et) {
-               /*
-                * When et = MR: covariance lop will have a single input lop, 
which
-                * denote the combined input data -- output of combinebinary, 
if unweighed;
-                * and output combineteriaty (if weighted).
-                * 
-                * When et = CP: covariance lop must have at least two input 
lops, which
-                * denote the two input columns on which covariance is 
computed. It also
-                * takes an optional third arguments, when weighted covariance 
is computed.
-                */
+               if ( input2 == null )
+                       throw new LopsException(this.printErrorLocation() + 
"Invalid inputs to covariance lop.");
+       
                addInput(input1);
                input1.addOutput(this);
-
-               if ( input2 == null ) {
-                       throw new LopsException(this.printErrorLocation() + 
"Invalid inputs to covariance lop.");
-               }
                addInput(input2);
                input2.addOutput(this);
-               
                if ( input3 != null ) {
                        addInput(input3);
                        input3.addOutput(this);
                }
+               
                lps.setProperties(inputs, et);
        }
 
@@ -102,19 +89,17 @@ public class CoVariance extends Lop
 
                sb.append( getInputs().get(0).prepInputOperand(input1));
                sb.append( OPERAND_DELIMITOR );
-
-               if( input2 != null ) {
-                       sb.append( getInputs().get(1).prepInputOperand(input2));
-                       sb.append( OPERAND_DELIMITOR );
-               }
-               
+               sb.append( getInputs().get(1).prepInputOperand(input2));
+               sb.append( OPERAND_DELIMITOR );
                if( input3 != null ) {
                        sb.append( getInputs().get(2).prepInputOperand(input3));
                        sb.append( OPERAND_DELIMITOR );
                }
                
                sb.append( prepOutputOperand(output));
+               sb.append( OPERAND_DELIMITOR );
+               sb.append(_numThreads);
                
                return sb.toString();
        }
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java 
b/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
index 4c47d59..ae3e718 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
@@ -74,6 +74,10 @@ public class CM extends ValueFunction
                //execution due to state in cm object (buff2, buff3)    
                return new CM( type ); 
        }
+       
+       public static CM getCMFnObject(CM fn) {
+               return getCMFnObject(fn._type);
+       }
 
        public AggregateOperationTypes getAggOpType() {
                return _type;
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/COV.java 
b/src/main/java/org/apache/sysds/runtime/functionobjects/COV.java
index 6eb85e4..89d8b23 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/COV.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/COV.java
@@ -59,6 +59,7 @@ public class COV extends ValueFunction
         * @param w2 ?
         * @return result
         */
+       @Override
        public Data execute(Data in1, double u, double v, double w2) 
        {
                CM_COV_Object cov1=(CM_COV_Object) in1;
diff --git 
a/src/main/java/org/apache/sysds/runtime/functionobjects/FunctionObject.java 
b/src/main/java/org/apache/sysds/runtime/functionobjects/FunctionObject.java
index d3f3e7d..26e99ff 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/FunctionObject.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/FunctionObject.java
@@ -81,6 +81,10 @@ public abstract class FunctionObject implements Serializable
                throw new DMLRuntimeException("execute(): should not be invoked 
from base class.");
        }
        
+       public Data execute(Data in1, double in2, double in3, double in4) {
+               throw new DMLRuntimeException("execute(): should not be invoked 
from base class.");
+       }
+       
        public Data execute(Data in1, Data in2) {
                throw new DMLRuntimeException("execute(): should not be invoked 
from base class.");
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index a13768d..f9d9ab3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -1063,8 +1063,7 @@ public class InstructionUtils
         * @return the instruction string with the given inputs concatenated
         */
        public static String concatOperands(String... inputs) {
-               concatBaseOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
-               return _strBuilders.get().toString();
+               return concatBaseOperandsWithDelim(Lop.OPERAND_DELIMITOR, 
inputs);
        }
 
        /**
@@ -1073,11 +1072,10 @@ public class InstructionUtils
         * @return concatenated input parts
         */
        public static String concatOperandParts(String... inputs) {
-               concatBaseOperandsWithDelim(Instruction.VALUETYPE_PREFIX, 
inputs);
-               return _strBuilders.get().toString();
+               return 
concatBaseOperandsWithDelim(Instruction.VALUETYPE_PREFIX, inputs);
        }
 
-       private static void concatBaseOperandsWithDelim(String delim, String... 
inputs){
+       private static String concatBaseOperandsWithDelim(String delim, 
String... inputs){
                StringBuilder sb = _strBuilders.get();
                sb.setLength(0); //reuse allocated space
                for( int i=0; i<inputs.length-1; i++ ) {
@@ -1085,6 +1083,7 @@ public class InstructionUtils
                        sb.append(delim);
                }
                sb.append(inputs[inputs.length-1]);
+               return sb.toString();
        }
        
        public static String concatStrings(String... inputs) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CentralMomentCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CentralMomentCPInstruction.java
index d787dde..2aeba90 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CentralMomentCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CentralMomentCPInstruction.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.functionobjects.CM;
@@ -37,11 +35,6 @@ public class CentralMomentCPInstruction extends 
AggregateUnaryCPInstruction {
        }
 
        public static CentralMomentCPInstruction parseInstruction(String str) {
-               CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-               CPOperand in2 = null; 
-               CPOperand in3 = null; 
-               CPOperand out = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-               
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0]; 
                
@@ -49,40 +42,30 @@ public class CentralMomentCPInstruction extends 
AggregateUnaryCPInstruction {
                if( !opcode.equalsIgnoreCase("cm") ) {
                        throw new DMLRuntimeException("Unsupported opcode 
"+opcode);
                }
-                       
-               if ( parts.length == 4 ) {
-                       // Example: CP.cm.mVar0.Var1.mVar2; (without weights)
-                       in2 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-                       parseUnaryInstruction(str, in1, in2, out);
-               }
-               else if ( parts.length == 5) {
-                       // CP.cm.mVar0.mVar1.Var2.mVar3; (with weights)
-                       in2 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-                       in3 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-                       parseUnaryInstruction(str, in1, in2, in3, out);
-               }
-       
+               
+               InstructionUtils.checkNumFields(str, 4, 5); //w/o opcode
+               CPOperand in1 = new CPOperand(parts[1]); //data
+               CPOperand in2 = new CPOperand(parts[2]); //scalar
+               CPOperand in3 = (parts.length==5) ? null : new 
CPOperand(parts[3]); //weights
+               CPOperand out = new CPOperand(parts[parts.length-2]);
+               int numThreads = Integer.parseInt(parts[parts.length-1]);
+
                /* 
                 * Exact order of the central moment MAY NOT be known at 
compilation time.
                 * We first try to parse the second argument as an integer, and 
if we fail, 
                 * we simply pass -1 so that getCMAggOpType() picks up 
AggregateOperationTypes.INVALID.
                 * It must be updated at run time in processInstruction() 
method.
                 */
-               
                int cmOrder;
                try {
-                       if ( in3 == null ) {
-                               cmOrder = Integer.parseInt(in2.getName());
-                       }
-                       else {
-                               cmOrder = Integer.parseInt(in3.getName());
-                       }
-               } catch(NumberFormatException e) {
+                       cmOrder = Integer.parseInt((in3==null) ? in2.getName() 
: in3.getName());
+               }
+               catch(NumberFormatException e) {
                        cmOrder = -1; // unknown at compilation time
                }
-               
+
                AggregateOperationTypes opType = 
CMOperator.getCMAggOpType(cmOrder);
-               CMOperator cm = new CMOperator(CM.getCMFnObject(opType), 
opType);
+               CMOperator cm = new CMOperator(CM.getCMFnObject(opType), 
opType, numThreads);
                return new CentralMomentCPInstruction(cm, in1, in2, in3, out, 
opcode, str);
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CovarianceCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CovarianceCPInstruction.java
index a6758b3..ae495f1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CovarianceCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CovarianceCPInstruction.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.functionobjects.COV;
@@ -30,45 +28,30 @@ import 
org.apache.sysds.runtime.matrix.operators.COVOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public class CovarianceCPInstruction extends BinaryCPInstruction {
-
-       private CovarianceCPInstruction(Operator op, CPOperand in1, CPOperand 
in2, CPOperand out, String opcode,
-                       String istr) {
-               super(CPType.AggregateBinary, op, in1, in2, out, opcode, istr);
-       }
-
-       private CovarianceCPInstruction(Operator op, CPOperand in1, CPOperand 
in2, CPOperand in3, CPOperand out,
-                       String opcode, String istr) {
+       
+       private CovarianceCPInstruction(Operator op, CPOperand in1,
+               CPOperand in2, CPOperand in3, CPOperand out, String opcode, 
String istr)
+       {
                super(CPType.AggregateBinary, op, in1, in2, in3, out, opcode, 
istr);
        }
 
        public static CovarianceCPInstruction parseInstruction( String str )
        {
-               CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-               CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-               CPOperand in3 = null;
-               CPOperand out = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
 
-               if( !opcode.equalsIgnoreCase("cov") ) {
+               if( !opcode.equalsIgnoreCase("cov") )
                        throw new 
DMLRuntimeException("CovarianceCPInstruction.parseInstruction():: Unknown 
opcode " + opcode);
-               }
                
-               COVOperator cov = new COVOperator(COV.getCOMFnObject());
-               if ( parts.length == 4 ) {
-                       // CP.cov.mVar0.mVar1.mVar2
-                       parseBinaryInstruction(str, in1, in2, out);
-                       return new CovarianceCPInstruction(cov, in1, in2, out, 
opcode, str);
-               } else if ( parts.length == 5 ) {
-                       // CP.cov.mVar0.mVar1.mVar2.mVar3
-                       in3 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-                       parseBinaryInstruction(str, in1, in2, in3, out);
-                       return new CovarianceCPInstruction(cov, in1, in2, in3, 
out, opcode, str);
-               }
-               else {
-                       throw new DMLRuntimeException("Invalid number of 
arguments in Instruction: " + str);
-               }
+               InstructionUtils.checkNumFields(parts, 4, 5); //w/o opcode
+               CPOperand in1 = new CPOperand(parts[1]);
+               CPOperand in2 = new CPOperand(parts[2]);
+               CPOperand in3 = (parts.length==5) ? null : new 
CPOperand(parts[3]);
+               CPOperand out = new CPOperand(parts[parts.length-2]);
+               int numThreads = Integer.parseInt(parts[parts.length-1]);
+               
+               COVOperator cov = new COVOperator(COV.getCOMFnObject(), 
numThreads);
+               return new CovarianceCPInstruction(cov, in1, in2, in3, out, 
opcode, str);
        }
        
        @Override
@@ -78,7 +61,7 @@ public class CovarianceCPInstruction extends 
BinaryCPInstruction {
                MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
                String output_name = output.getName(); 
                COVOperator cov_op = (COVOperator)_optr;
-               CM_COV_Object covobj = new CM_COV_Object();
+               CM_COV_Object covobj = null;
                
                if ( input3 == null ) {
                        // Unweighted: cov.mvar0.mvar1.out
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index 0d3c007..864c7f2 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -81,7 +81,8 @@ import org.apache.sysds.runtime.util.UtilFunctions;
  * ak+, uak+, uark+, uack+, uasqk+, uarsqk+, uacsqk+,
  * uamin, uarmin, uacmin, uamax, uarmax, uacmax,
  * ua*, uamean, uarmean, uacmean, uavar, uarvar, uacvar,
- * uarimax, uaktrace, cumk+, cummin, cummax, cum*, tak+.
+ * uarimax, uaktrace, cumk+, cummin, cummax, cum*, tak+,
+ * cm, cov
  * 
  * TODO next opcode extensions: a+, colindexmax
  */
@@ -418,6 +419,44 @@ public class LibMatrixAgg
                return out;
        }
 
+       public static CM_COV_Object aggregateCmCov(MatrixBlock in1, MatrixBlock 
in2, MatrixBlock in3, ValueFunction fn) {
+               CM_COV_Object cmobj = new CM_COV_Object();
+               
+               // empty block handling (important for result corretness, 
otherwise
+               // we get a NaN due to 0/0 on reading out the required result)
+               if( in1.isEmptyBlock(false) && fn instanceof CM ) {
+                       fn.execute(cmobj, 0.0, in1.getNumRows());
+                       return cmobj;
+               }
+               
+               return aggregateCmCov(in1, in2, in3, fn, 0, in1.getNumRows());
+       }
+       
+       public static CM_COV_Object aggregateCmCov(MatrixBlock in1, MatrixBlock 
in2, MatrixBlock in3, ValueFunction fn, int k) {
+               if( in1.isEmptyBlock(false) || 
!satisfiesMultiThreadingConstraints(in1, k) )
+                       return aggregateCmCov(in1, in2, in3, fn);
+               
+               CM_COV_Object ret = new CM_COV_Object();
+               
+               try {
+                       ExecutorService pool = CommonThreadPool.get(k);
+                       ArrayList<AggCmCovTask> tasks = new ArrayList<>();
+                       ArrayList<Integer> blklens = 
UtilFunctions.getBalancedBlockSizesDefault(in1.rlen, k, false);
+                       for( int i=0, lb=0; i<blklens.size(); 
lb+=blklens.get(i), i++ )
+                               tasks.add(new AggCmCovTask(in1, in2, in3, fn, 
lb, lb+blklens.get(i)));
+                       List<Future<CM_COV_Object>> rtasks = 
pool.invokeAll(tasks);
+                       pool.shutdown();
+                       //aggregate partial results and error handling
+                       for( int i=1; i<rtasks.size(); i++ )
+                               fn.execute(ret, rtasks.get(i).get());
+               }
+               catch(Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               }
+               
+               return ret;
+       }
+       
        public static MatrixBlock aggregateTernary(MatrixBlock in1, MatrixBlock 
in2, MatrixBlock in3, MatrixBlock ret, AggregateTernaryOperator op) {
                //early abort if any block is empty
                if( in1.isEmptyBlock(false) || in2.isEmptyBlock(false) || 
in3!=null&&in3.isEmptyBlock(false) ) {
@@ -571,6 +610,12 @@ public class LibMatrixAgg
                        && in.nonZeros > (sharedTP ? PAR_NUMCELL_THRESHOLD2 : 
PAR_NUMCELL_THRESHOLD1);
        }
        
+       public static boolean satisfiesMultiThreadingConstraints(MatrixBlock 
in,int k) {
+               boolean sharedTP = 
(InfrastructureAnalyzer.getLocalParallelism() == k);
+               return k > 1 && in.rlen > (sharedTP ? k/8 : k/2)
+                       && in.nonZeros > (sharedTP ? PAR_NUMCELL_THRESHOLD2 : 
PAR_NUMCELL_THRESHOLD1);
+       }
+       
        /**
         * Recompute outputs (e.g., maxindex or minindex) according to block 
indexes from MR.
         * TODO: this should not be part of block operations but of the MR 
instruction.
@@ -681,6 +726,94 @@ public class LibMatrixAgg
                else
                        out.binaryOperationsInPlace(laop.increOp, partout);
        }
+       
+       private static CM_COV_Object aggregateCmCov(MatrixBlock in1, 
MatrixBlock in2, MatrixBlock in3, ValueFunction fn, int rl, int ru) {
+               CM_COV_Object ret = new CM_COV_Object();
+               
+               if( in2 == null && in3 == null ) { //CM
+                       int nzcount = 0;
+                       if(in1.sparse && in1.sparseBlock!=null) { //SPARSE
+                               int ru2 = Math.min(ru, 
in1.sparseBlock.numRows());
+                               for(int r = rl; r < ru2; r++) {
+                                       SparseBlock a = in1.sparseBlock;
+                                       if(a.isEmpty(r)) 
+                                               continue;
+                                       int apos = a.pos(r);
+                                       int alen = a.size(r);
+                                       double[] avals = a.values(r);
+                                       for(int i=apos; i<apos+alen; i++) {
+                                               fn.execute(ret, avals[i]);
+                                               nzcount++;
+                                       }
+                               }
+                               // account for zeros in the vector
+                               fn.execute(ret, 0.0, ru2-rl-nzcount);
+                       }
+                       else if(in1.denseBlock!=null) { //DENSE
+                               //always vector (see check above)
+                               double[] a = in1.getDenseBlockValues();
+                               for(int i=rl; i<ru; i++)
+                                       fn.execute(ret, a[i]);
+                       }
+               }
+               else if( in3 == null ) { //CM w/ weights, COV
+                       if (in1.sparse && in1.sparseBlock!=null) { //SPARSE
+                               for(int i = rl; i < ru; i++) { 
+                                       fn.execute(ret,
+                                               in1.quickGetValue(i,0),
+                                               in2.quickGetValue(i,0));
+                               }
+                       }
+                       else if(in1.denseBlock!=null) //DENSE
+                       {
+                               //always vectors (see check above)
+                               double[] a = in1.getDenseBlockValues();
+                               if( !in2.sparse ) {
+                                       if(in2.denseBlock!=null) {
+                                               double[] w = 
in2.getDenseBlockValues();
+                                               for( int i = rl; i < ru; i++ )
+                                                       fn.execute(ret, a[i], 
w[i]);
+                                       }
+                               }
+                               else {
+                                       for(int i = rl; i < ru; i++) 
+                                               fn.execute(ret, a[i], 
in2.quickGetValue(i,0) );
+                               }
+                       }
+               }
+               else { // COV w/ weights
+                       if(in1.sparse && in1.sparseBlock!=null) { //SPARSE
+                               for(int i = rl; i < ru; i++ ) {
+                                       fn.execute(ret,
+                                               in1.quickGetValue(i,0),
+                                               in2.quickGetValue(i,0),
+                                               in3.quickGetValue(i,0));
+                               }
+                       }
+                       else if(in1.denseBlock!=null) { //DENSE
+                               //always vectors (see check above)
+                               double[] a = in1.getDenseBlockValues();
+                               
+                               if( !in2.sparse && !in3.sparse ) {
+                                       double[] w = in3.getDenseBlockValues();
+                                       if(in2.denseBlock!=null) {
+                                               double[] b = 
in2.getDenseBlockValues();
+                                               for( int i=rl; i<ru; i++ )
+                                                       fn.execute(ret, a[i], 
b[i], w[i]);
+                                       }
+                               }
+                               else {
+                                       for(int i = rl; i < ru; i++) {
+                                               fn.execute(ret, a[i],
+                                                       in2.quickGetValue(i,0),
+                                                       in3.quickGetValue(i,0));
+                                       }
+                               }
+                       }
+               }
+
+               return ret;
+       }
 
        private static void aggregateTernaryDense(MatrixBlock in1, MatrixBlock 
in2, MatrixBlock in3, MatrixBlock ret, IndexFunction ixFn, int rl, int ru)
        {
@@ -3444,4 +3577,28 @@ public class LibMatrixAgg
                        return null;
                }
        }
+       
+       private static class AggCmCovTask implements Callable<CM_COV_Object> {
+               private final MatrixBlock _in1, _in2, _in3;
+               private final ValueFunction _fn;
+               private final int _rl, _ru;
+
+               protected AggCmCovTask(MatrixBlock in1, MatrixBlock in2, 
MatrixBlock in3, ValueFunction fn, int rl, int ru) {
+                       _in1 = in1;
+                       _in2 = in2;
+                       _in3 = in3;
+                       _fn = fn;
+                       _rl = rl;
+                       _ru = ru;
+               }
+               
+               @Override
+               public CM_COV_Object call() {
+                       //deep copy stateful CM function (has Kahan objects 
inside)
+                       //for correctness and to avoid cache thrashing among 
threads
+                       ValueFunction fn = (_fn instanceof CM) ? 
CM.getCMFnObject((CM)_fn) : _fn;
+                       //execute aggregate for row partition
+                       return aggregateCmCov(_in1, _in2, _in3, fn, _rl, _ru);
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 65af775..f208174 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -4726,46 +4726,11 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
        public CM_COV_Object cmOperations(CMOperator op) {
                // dimension check for input column vectors
                if ( this.getNumColumns() != 1) {
-                       throw new DMLRuntimeException("Central Moment can not 
be computed on [" 
+                       throw new DMLRuntimeException("Central Moment cannot be 
computed on [" 
                                        + this.getNumRows() + "," + 
this.getNumColumns() + "] matrix.");
                }
                
-               CM_COV_Object cmobj = new CM_COV_Object();
-               
-               // empty block handling (important for result corretness, 
otherwise
-               // we get a NaN due to 0/0 on reading out the required result)
-               if( isEmptyBlock(false) ) {
-                       op.fn.execute(cmobj, 0.0, getNumRows());
-                       return cmobj;
-               }
-               
-               int nzcount = 0;
-               if(sparse && sparseBlock!=null) //SPARSE
-               {
-                       for(int r=0; r<Math.min(rlen, sparseBlock.numRows()); 
r++)
-                       {
-                               if(sparseBlock.isEmpty(r)) 
-                                       continue;
-                               int apos = sparseBlock.pos(r);
-                               int alen = sparseBlock.size(r);
-                               double[] avals = sparseBlock.values(r);
-                               for(int i=apos; i<apos+alen; i++) {
-                                       op.fn.execute(cmobj, avals[i]);
-                                       nzcount++;
-                               }
-                       }
-                       // account for zeros in the vector
-                       op.fn.execute(cmobj, 0.0, this.getNumRows()-nzcount);
-               }
-               else if(denseBlock!=null)  //DENSE
-               {
-                       //always vector (see check above)
-                       double[] a = getDenseBlockValues();
-                       for(int i=0; i<rlen; i++)
-                               op.fn.execute(cmobj, a[i]);
-               }
-
-               return cmobj;
+               return LibMatrixAgg.aggregateCmCov(this, null, null, op.fn, 
op.getNumThreads());
        }
                
        public CM_COV_Object cmOperations(CMOperator op, MatrixBlock weights) {
@@ -4779,31 +4744,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                                        + weights.getNumRows() + "," + 
weights.getNumColumns() +"]");
                }
                
-               CM_COV_Object cmobj = new CM_COV_Object();
-               if (sparse && sparseBlock!=null) //SPARSE
-               {
-                       for(int i=0; i < rlen; i++) 
-                               op.fn.execute(cmobj, this.quickGetValue(i,0), 
weights.quickGetValue(i,0));
-               }
-               else if(denseBlock!=null) //DENSE
-               {
-                       //always vectors (see check above)
-                       double[] a = getDenseBlockValues();
-                       if( !weights.sparse )
-                       {
-                               double[] w = weights.getDenseBlockValues();
-                               if(weights.denseBlock!=null)
-                                       for( int i=0; i<rlen; i++ )
-                                               op.fn.execute(cmobj, a[i], 
w[i]);
-                       }
-                       else
-                       {
-                               for(int i=0; i<rlen; i++) 
-                                       op.fn.execute(cmobj, a[i], 
weights.quickGetValue(i,0) );
-                       }
-               }
-               
-               return cmobj;
+               return LibMatrixAgg.aggregateCmCov(this, weights, null, op.fn, 
op.getNumThreads());
        }
        
        public CM_COV_Object covOperations(COVOperator op, MatrixBlock that) {
@@ -4817,30 +4758,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                                        + that.getNumRows() + "," + 
that.getNumColumns() +"]");
                }
                
-               CM_COV_Object covobj = new CM_COV_Object();
-               if(sparse && sparseBlock!=null) //SPARSE
-               {
-                       for(int i=0; i < rlen; i++ ) 
-                               op.fn.execute(covobj, this.quickGetValue(i,0), 
that.quickGetValue(i,0));
-               }
-               else if(denseBlock!=null) //DENSE
-               {
-                       //always vectors (see check above)
-                       double[] a = getDenseBlockValues();
-                       if( !that.sparse ) {
-                               if(that.denseBlock!=null) {
-                                       double[] b = that.getDenseBlockValues();
-                                       for( int i=0; i<rlen; i++ )
-                                               op.fn.execute(covobj, a[i], 
b[i]);
-                               }
-                       }
-                       else {
-                               for(int i=0; i<rlen; i++)
-                                       op.fn.execute(covobj, a[i], 
that.quickGetValue(i,0));
-                       }
-               }
-               
-               return covobj;
+               return LibMatrixAgg.aggregateCmCov(this, that, null, op.fn, 
op.getNumThreads());
        }
        
        public CM_COV_Object covOperations(COVOperator op, MatrixBlock that, 
MatrixBlock weights) {
@@ -4859,34 +4777,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                                        + weights.getNumRows() + "," + 
weights.getNumColumns() +"]");
                }
                
-               CM_COV_Object covobj = new CM_COV_Object();
-               if(sparse && sparseBlock!=null) //SPARSE
-               {
-                       for(int i=0; i < rlen; i++ ) 
-                               op.fn.execute(covobj, this.quickGetValue(i,0), 
that.quickGetValue(i,0), weights.quickGetValue(i,0));
-               }
-               else if(denseBlock!=null) //DENSE
-               {
-                       //always vectors (see check above)
-                       double[] a = getDenseBlockValues();
-                       
-                       if( !that.sparse && !weights.sparse )
-                       {
-                               double[] w = weights.getDenseBlockValues();
-                               if(that.denseBlock!=null) {
-                                       double[] b = that.getDenseBlockValues();
-                                       for( int i=0; i<rlen; i++ )
-                                               op.fn.execute(covobj, a[i], 
b[i], w[i]);
-                               }
-                       }
-                       else
-                       {
-                               for(int i=0; i<rlen; i++)
-                                       op.fn.execute(covobj, a[i], 
that.quickGetValue(i,0), weights.quickGetValue(i,0));
-                       }
-               }
-               
-               return covobj;
+               return LibMatrixAgg.aggregateCmCov(this, that, weights, op.fn, 
op.getNumThreads());
        }
 
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
index fb7088f..579e681 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
@@ -25,7 +25,6 @@ import org.apache.sysds.runtime.functionobjects.ValueFunction;
 
 public class CMOperator extends Operator 
 {
-       
        private static final long serialVersionUID = 4126894676505115420L;
        
        // supported aggregates
@@ -40,23 +39,33 @@ public class CMOperator extends Operator
                INVALID
        }
 
-       public ValueFunction fn;
-       public AggregateOperationTypes aggOpType;
+       public final ValueFunction fn;
+       public final AggregateOperationTypes aggOpType;
+       public final int k;
 
        public CMOperator(ValueFunction op, AggregateOperationTypes agg) {
+               this(op, agg, 1);
+       }
+       
+       public CMOperator(ValueFunction op, AggregateOperationTypes agg, int 
numThreads) {
                super(true);
                fn = op;
                aggOpType = agg;
+               k = numThreads;
        }
 
        public AggregateOperationTypes getAggOpType() {
                return aggOpType;
        }
        
+       public int getNumThreads() {
+               return k;
+       }
+       
        public CMOperator setCMAggOp(int order) {
                AggregateOperationTypes agg = getCMAggOpType(order);
                ValueFunction fn = CM.getCMFnObject(aggOpType);
-               return new CMOperator(fn, agg);
+               return new CMOperator(fn, agg, k);
        }
        
        public static AggregateOperationTypes getCMAggOpType ( int order ) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
index 1ed9a3f..9d288db 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
@@ -27,9 +27,19 @@ public class COVOperator extends Operator
        private static final long serialVersionUID = -8404264552880694469L;
 
        public final COV fn;
+       public final int k;
        
        public COVOperator(COV op) {
+               this(op, 1);
+       }
+       
+       public COVOperator(COV op, int numThreads) {
                super(true);
                fn = op;
+               k = numThreads;
+       }
+       
+       public int getNumThreads() {
+               return k;
        }
-}
\ No newline at end of file
+}

Reply via email to