This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new acd0f69 [SYSTEMDS-2856] Multi-threaded binary matrix-matrix,
matrix-scalar ops
acd0f69 is described below
commit acd0f6905c6c556725421794f4010af17f2a75c5
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Feb 10 23:01:31 2021 +0100
[SYSTEMDS-2856] Multi-threaded binary matrix-matrix, matrix-scalar ops
This patch is a first step towards extended multi-threaded operations
support. So far binary operations were not multi-threaded because output
allocation dominates the runtime for many operations. With parallel
allocators, future in-place updates, increasing degree of parallelism,
and somewhat inefficient sparse-unsafe code paths this changes. In this
first step, we parallelize matrix-matrix unsafe operations, and
matrix-scalar safe operations which did not have a lot of special case
handling and thus could simply parallelize over row partitions.
On a scenario of a 1M x 1050 input matrix (mostly dense except one
one-hot encoded column), this patch improved the Kmeans runtime w/ 50
centroids, 1 run, MKL matrix multiply, and ~60 iterations from 177s to
109s (and relevant binary ops for <= and -2* from 87s to 15s).
---
src/main/java/org/apache/sysds/hops/BinaryOp.java | 15 +-
src/main/java/org/apache/sysds/lops/Binary.java | 27 ++-
src/main/java/org/apache/sysds/lops/Unary.java | 4 +-
.../instructions/cp/BinaryCPInstruction.java | 2 +-
.../cp/BinaryMatrixMatrixCPInstruction.java | 7 +-
.../cp/BinaryMatrixScalarCPInstruction.java | 6 +
.../runtime/matrix/data/LibMatrixBincell.java | 218 +++++++++++++++++----
.../sysds/runtime/matrix/data/LibMatrixNative.java | 4 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 38 +++-
.../runtime/matrix/operators/BinaryOperator.java | 9 +
.../matrix/operators/LeftScalarOperator.java | 2 +-
.../matrix/operators/RightScalarOperator.java | 2 +-
.../runtime/matrix/operators/ScalarOperator.java | 16 +-
13 files changed, 282 insertions(+), 68 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index cc5d58d..10e1c8d 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -414,12 +414,13 @@ public class BinaryOp extends MultiThreadedHop
(op==OpOp2.MULT &&
HopRewriteUtils.isLiteralOfValue(right, 2d)) ? OpOp1.MULT2 : null;
Lop tmp = null;
if( ot != null ) {
- tmp = new
Unary(getInput().get(0).constructLops(),
- getInput().get(1).constructLops(), ot,
getDataType(), getValueType(), et);
+ tmp = new Unary(getInput(0).constructLops(),
getInput(1).constructLops(),
+ ot, getDataType(), getValueType(), et);
}
else { //general case
- tmp = new
Binary(getInput().get(0).constructLops(),
- getInput().get(1).constructLops(), op,
getDataType(), getValueType(), et);
+ tmp = new Binary(getInput(0).constructLops(),
getInput(1).constructLops(),
+ op, getDataType(), getValueType(), et,
+
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
}
setOutputDimensions(tmp);
setLineNumbers(tmp);
@@ -458,9 +459,9 @@ public class BinaryOp extends MultiThreadedHop
getDataType(), getValueType(),
et, OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
}
else
- binary = new
Binary(getInput().get(0).constructLops(),
-
getInput().get(1).constructLops(), op,
- getDataType(), getValueType(),
et);
+ binary = new
Binary(getInput(0).constructLops(), getInput(1).constructLops(),
+ op, getDataType(),
getValueType(), et,
+
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
setOutputDimensions(binary);
setLineNumbers(binary);
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java
b/src/main/java/org/apache/sysds/lops/Binary.java
index 9ebe551..5fba53d 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -34,6 +34,7 @@ import org.apache.sysds.common.Types.ValueType;
public class Binary extends Lop
{
private OpOp2 operation;
+ private final int _numThreads;
/**
* Constructor to perform a binary operation.
@@ -45,9 +46,15 @@ public class Binary extends Lop
* @param vt value type
* @param et exec type
*/
+
public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType
vt, ExecType et) {
+ this(input1, input2, op, dt, vt, et, 1);
+ }
+
+ public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType
vt, ExecType et, int k) {
super(Lop.Type.Binary, dt, vt);
init(input1, input2, op, dt, vt, et);
+ _numThreads = k;
}
private void init(Lop input1, Lop input2, OpOp2 op, DataType dt,
ValueType vt, ExecType et) {
@@ -74,10 +81,20 @@ public class Binary extends Lop
@Override
public String getInstructions(String input1, String input2, String
output) {
- return InstructionUtils.concatOperands(
- getExecType().toString(), getOpcode(),
- getInputs().get(0).prepInputOperand(input1),
- getInputs().get(1).prepInputOperand(input2),
- prepOutputOperand(output));
+ if( getExecType() == ExecType.CP ) {
+ return InstructionUtils.concatOperands(
+ getExecType().name(), getOpcode(),
+ getInputs().get(0).prepInputOperand(input1),
+ getInputs().get(1).prepInputOperand(input2),
+ prepOutputOperand(output),
+ String.valueOf(_numThreads));
+ }
+ else {
+ return InstructionUtils.concatOperands(
+ getExecType().name(), getOpcode(),
+ getInputs().get(0).prepInputOperand(input1),
+ getInputs().get(1).prepInputOperand(input2),
+ prepOutputOperand(output));
+ }
}
}
diff --git a/src/main/java/org/apache/sysds/lops/Unary.java
b/src/main/java/org/apache/sysds/lops/Unary.java
index aa51477..0e34ba2 100644
--- a/src/main/java/org/apache/sysds/lops/Unary.java
+++ b/src/main/java/org/apache/sysds/lops/Unary.java
@@ -127,7 +127,9 @@ public class Unary extends Lop
|| op==OpOp1.CUMSUMPROD
|| op==OpOp1.EXP
|| op==OpOp1.LOG
- || op==OpOp1.SIGMOID;
+ || op==OpOp1.SIGMOID
+ || op==OpOp1.POW2
+ || op==OpOp1.MULT2;
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
index 2f0aad4..188b2ac 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
@@ -65,7 +65,7 @@ public abstract class BinaryCPInstruction extends
ComputationCPInstruction {
protected static String parseBinaryInstruction(String instr, CPOperand
in1, CPOperand in2, CPOperand out) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(instr);
- InstructionUtils.checkNumFields ( parts, 3 );
+ InstructionUtils.checkNumFields ( parts, 3, 4 );
String opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
index 20ddfb1..abe815a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
@@ -25,6 +25,7 @@ import
org.apache.sysds.runtime.compress.AbstractCompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -32,10 +33,14 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
public class BinaryMatrixMatrixCPInstruction extends BinaryCPInstruction {
private static final Log LOG =
LogFactory.getLog(BinaryMatrixMatrixCPInstruction.class.getName());
-
+
protected BinaryMatrixMatrixCPInstruction(Operator op, CPOperand in1,
CPOperand in2, CPOperand out,
String opcode, String istr) {
super(CPType.Binary, op, in1, in2, out, opcode, istr);
+ if( op instanceof BinaryOperator ) {
+ String[] parts =
InstructionUtils.getInstructionParts(istr);
+
((BinaryOperator)op).setNumThreads(Integer.parseInt(parts[parts.length-1]));
+ }
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixScalarCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixScalarCPInstruction.java
index 6b00759..04932ad 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixScalarCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixScalarCPInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.cp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
@@ -30,6 +31,11 @@ public class BinaryMatrixScalarCPInstruction extends
BinaryCPInstruction {
protected BinaryMatrixScalarCPInstruction(Operator op, CPOperand in1,
CPOperand in2, CPOperand out,
String opcode, String istr) {
super(CPType.Binary, op, in1, in2, out, opcode, istr);
+ if( op instanceof ScalarOperator ) {
+ String[] parts =
InstructionUtils.getInstructionParts(istr);
+ if( parts.length > 4 )
+
((ScalarOperator)op).setNumThreads(Integer.parseInt(parts[parts.length-1]));
+ }
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
index c287d4f..249d0e3 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
@@ -19,7 +19,13 @@
package org.apache.sysds.runtime.matrix.data;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.DenseBlock;
@@ -48,6 +54,7 @@ import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.SortUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -61,6 +68,7 @@ import org.apache.sysds.runtime.util.UtilFunctions;
*/
public class LibMatrixBincell
{
+ private static final long PAR_NUMCELL_THRESHOLD2 = 16*1024; //Min 16K
elements
public enum BinaryAccessType {
MATRIX_MATRIX,
@@ -94,7 +102,7 @@ public class LibMatrixBincell
//execute binary cell operations
if(op.sparseSafe)
- safeBinaryScalar(m1, ret, op);
+ safeBinaryScalar(m1, ret, op, 0, m1.rlen);
else
unsafeBinaryScalar(m1, ret, op);
@@ -104,6 +112,48 @@ public class LibMatrixBincell
ret.examSparsity();
}
+ public static void bincellOp(MatrixBlock m1, MatrixBlock ret,
ScalarOperator op, int k) {
+ //check internal assumptions
+ if( (op.sparseSafe &&
m1.isInSparseFormat()!=ret.isInSparseFormat())
+ ||(!op.sparseSafe && ret.isInSparseFormat()) ) {
+ throw new DMLRuntimeException("Wrong output
representation for safe="+op.sparseSafe+": "+m1.isInSparseFormat()+",
"+ret.isInSparseFormat());
+ }
+
+ //fallback to singlet-threaded for special cases
+ if( m1.isEmpty() || !op.sparseSafe
+ || ret.getLength() < PAR_NUMCELL_THRESHOLD2 ) {
+ bincellOp(m1, ret, op);
+ return;
+ }
+
+ //preallocate dense/sparse block for multi-threaded operations
+ ret.allocateBlock();
+
+ try {
+ //execute binary cell operations
+ ExecutorService pool = CommonThreadPool.get(k);
+ ArrayList<BincellScalarTask> tasks = new ArrayList<>();
+ ArrayList<Integer> blklens =
UtilFunctions.getBalancedBlockSizesDefault(ret.rlen, k, false);
+ for( int i=0, lb=0; i<blklens.size();
lb+=blklens.get(i), i++ )
+ tasks.add(new BincellScalarTask(m1, ret, op,
lb, lb+blklens.get(i)));
+ List<Future<Long>> taskret = pool.invokeAll(tasks);
+
+ //aggregate non-zeros
+ ret.nonZeros = 0; //reset after execute
+ for( Future<Long> task : taskret )
+ ret.nonZeros += task.get();
+ pool.shutdown();
+ }
+ catch(InterruptedException | ExecutionException ex) {
+ throw new DMLRuntimeException(ex);
+ }
+
+ //ensure empty results sparse representation
+ //(no additional memory requirements)
+ if( ret.isEmptyBlock(false) )
+ ret.examSparsity();
+ }
+
/**
* matrix-matrix binary operations, MM, MV
*
@@ -117,7 +167,7 @@ public class LibMatrixBincell
if(op.sparseSafe || isSparseSafeDivide(op, m2))
safeBinary(m1, m2, ret, op);
else
- unsafeBinary(m1, m2, ret, op);
+ unsafeBinary(m1, m2, ret, op, 0, m1.rlen);
//ensure empty results sparse representation
//(no additional memory requirements)
@@ -125,6 +175,45 @@ public class LibMatrixBincell
ret.examSparsity();
}
+ public static void bincellOp(MatrixBlock m1, MatrixBlock m2,
MatrixBlock ret, BinaryOperator op, int k) {
+ //fallback to sequential computation for specialized operations
+ //TODO parallel support for all sparse safe operations
+ if( op.sparseSafe || isSparseSafeDivide(op, m2)
+ || ret.getLength() < PAR_NUMCELL_THRESHOLD2
+ || getBinaryAccessType(m1, m2) ==
BinaryAccessType.OUTER_VECTOR_VECTOR)
+ {
+ bincellOp(m1, m2, ret, op);
+ return;
+ }
+
+ //preallocate dense/sparse block for multi-threaded operations
+ ret.allocateBlock();
+
+ try {
+ //execute binary cell operations
+ ExecutorService pool = CommonThreadPool.get(k);
+ ArrayList<BincellTask> tasks = new ArrayList<>();
+ ArrayList<Integer> blklens =
UtilFunctions.getBalancedBlockSizesDefault(ret.rlen, k, false);
+ for( int i=0, lb=0; i<blklens.size();
lb+=blklens.get(i), i++ )
+ tasks.add(new BincellTask(m1, m2, ret, op, lb,
lb+blklens.get(i)));
+ List<Future<Long>> taskret = pool.invokeAll(tasks);
+
+ //aggregate non-zeros
+ ret.nonZeros = 0; //reset after execute
+ for( Future<Long> task : taskret )
+ ret.nonZeros += task.get();
+ pool.shutdown();
+ }
+ catch(InterruptedException | ExecutionException ex) {
+ throw new DMLRuntimeException(ex);
+ }
+
+ //ensure empty results sparse representation
+ //(no additional memory requirements)
+ if( ret.isEmptyBlock(false) )
+ ret.examSparsity();
+ }
+
/**
* NOTE: operations in place always require m1 and m2 to be of equal
dimensions
*
@@ -919,42 +1008,40 @@ public class LibMatrixBincell
ret.examSparsity();
}
- private static void unsafeBinary(MatrixBlock m1, MatrixBlock m2,
MatrixBlock ret, BinaryOperator op) {
- int rlen = m1.rlen;
+ private static void unsafeBinary(MatrixBlock m1, MatrixBlock m2,
MatrixBlock ret, BinaryOperator op, int rl, int ru) {
int clen = m1.clen;
BinaryAccessType atype = getBinaryAccessType(m1, m2);
- if( atype == BinaryAccessType.MATRIX_COL_VECTOR ) //MATRIX -
COL_VECTOR
- {
- for(int r=0; r<rlen; r++) {
+ long lnnz = 0;
+ if( atype == BinaryAccessType.MATRIX_COL_VECTOR ) { //MATRIX -
COL_VECTOR
+ for(int r=rl; r<ru; r++) {
double v2 = m2.quickGetValue(r, 0);
for(int c=0; c<clen; c++) {
double v1 = m1.quickGetValue(r, c);
double v = op.fn.execute( v1, v2 );
- ret.appendValue(r, c, v);
+ ret.appendValuePlain(r, c, v);
+ lnnz += (v!=0) ? 1 : 0;
}
}
}
- else if( atype == BinaryAccessType.MATRIX_ROW_VECTOR ) //MATRIX
- ROW_VECTOR
- {
- for(int r=0; r<rlen; r++)
+ else if( atype == BinaryAccessType.MATRIX_ROW_VECTOR ) {
//MATRIX - ROW_VECTOR
+ for(int r=rl; r<ru; r++)
for(int c=0; c<clen; c++) {
double v1 = m1.quickGetValue(r, c);
double v2 = m2.quickGetValue(0, c);
double v = op.fn.execute( v1, v2 );
- ret.appendValue(r, c, v);
+ ret.appendValuePlain(r, c, v);
+ lnnz += (v!=0) ? 1 : 0;
}
}
- else if( atype == BinaryAccessType.OUTER_VECTOR_VECTOR )
//VECTOR - VECTOR
- {
+ else if( atype == BinaryAccessType.OUTER_VECTOR_VECTOR ) {
//VECTOR - VECTOR
int clen2 = m2.clen;
-
if(LibMatrixOuterAgg.isCompareOperator(op)
&& m2.getNumColumns()>16 &&
SortUtils.isSorted(m2)) {
performBinOuterOperation(m1, m2, ret, op);
}
else {
- for(int r=0; r<rlen; r++) {
+ for(int r=rl; r<ru; r++) {
double v1 = m1.quickGetValue(r, 0);
for(int c=0; c<clen2; c++) {
double v2 = m2.quickGetValue(0,
c);
@@ -974,28 +1061,30 @@ public class LibMatrixBincell
double[] a = m1.getDenseBlockValues();
double[] b = m2.getDenseBlockValues();
double[] c = ret.getDenseBlockValues();
- int lnnz = 0;
- for( int i=0; i<rlen; i++ ) {
+ for( int i=rl; i<ru; i++ ) {
c[i] = op.fn.execute( a[i], b[i] );
lnnz += (c[i] != 0) ? 1 : 0;
}
- ret.nonZeros = lnnz;
}
//general case
- else
- {
- for(int r=0; r<rlen; r++)
+ else {
+ for(int r=rl; r<ru; r++)
for(int c=0; c<clen; c++) {
double v1 = m1.quickGetValue(r,
c);
double v2 = m2.quickGetValue(r,
c);
double v = op.fn.execute( v1,
v2 );
- ret.appendValue(r, c, v);
+ ret.appendValuePlain(r, c, v);
+ lnnz += (v!=0) ? 1 : 0;
}
}
}
+
+ //avoid false sharing in multi-threaded ops, while
+ //correctly setting the nnz for single-threaded ops
+ ret.nonZeros = lnnz;
}
- private static void safeBinaryScalar(MatrixBlock m1, MatrixBlock ret,
ScalarOperator op) {
+ private static void safeBinaryScalar(MatrixBlock m1, MatrixBlock ret,
ScalarOperator op, int rl, int ru) {
//early abort possible since sparsesafe
if( m1.isEmptyBlock(false) ) {
return;
@@ -1016,10 +1105,9 @@ public class LibMatrixBincell
ret.allocateSparseRowsBlock();
SparseBlock a = m1.sparseBlock;
SparseBlock c = ret.sparseBlock;
- int rlen = Math.min(m1.rlen, a.numRows());
long nnz = 0;
- for(int r=0; r<rlen; r++) {
+ for(int r=rl; r<ru; r++) {
if( a.isEmpty(r) ) continue;
int apos = a.pos(r);
@@ -1053,7 +1141,7 @@ public class LibMatrixBincell
ret.nonZeros = nnz;
}
else { //DENSE <- DENSE
- denseBinaryScalar(m1, ret, op);
+ denseBinaryScalar(m1, ret, op, rl, ru);
}
}
@@ -1078,14 +1166,15 @@ public class LibMatrixBincell
if( ret.sparse )
throw new DMLRuntimeException("Unsupported unsafe
binary scalar operations over sparse output representation.");
+ int m = m1.rlen;
+ int n = m1.clen;
+
if( m1.sparse ) //SPARSE MATRIX
{
ret.allocateDenseBlock();
SparseBlock a = m1.sparseBlock;
DenseBlock dc = ret.getDenseBlock();
- int m = m1.rlen;
- int n = m1.clen;
//init dense result with unsafe 0-value
double val0 = op.executeScalar(0);
@@ -1115,26 +1204,27 @@ public class LibMatrixBincell
ret.nonZeros = nnz;
}
else { //DENSE MATRIX
- denseBinaryScalar(m1, ret, op);
+ denseBinaryScalar(m1, ret, op, 0, m);
}
}
- private static void denseBinaryScalar(MatrixBlock m1, MatrixBlock ret,
ScalarOperator op) {
+ private static void denseBinaryScalar(MatrixBlock m1, MatrixBlock ret,
ScalarOperator op, int rl, int ru) {
//allocate dense block (if necessary), incl clear nnz
ret.allocateDenseBlock(true);
DenseBlock da = m1.getDenseBlock();
DenseBlock dc = ret.getDenseBlock();
+ int clen = m1.clen;
//compute scalar operation, incl nnz maintenance
long nnz = 0;
- for( int bi=0; bi<da.numBlocks(); bi++) {
- double[] a = da.valuesAt(bi);
- double[] c = dc.valuesAt(bi);
- int limit = da.size(bi);
- for( int i=0; i<limit; i++ ) {
- c[i] = op.executeScalar( a[i] );
- nnz += (c[i] != 0) ? 1 : 0;
+ for(int i=rl; i<ru; i++) {
+ double[] a = da.values(i);
+ double[] c = dc.values(i);
+ int apos = da.pos(i), cpos = dc.pos(i);
+ for(int j=0; j<clen; j++) {
+ c[cpos+j] = op.executeScalar( a[apos+j] );
+ nnz += (c[cpos+j] != 0) ? 1 : 0;
}
}
ret.nonZeros = nnz;
@@ -1408,4 +1498,56 @@ public class LibMatrixBincell
if( zero )
c.compact(r);
}
+
+ private static class BincellTask implements Callable<Long> {
+ private final MatrixBlock _m1;
+ private final MatrixBlock _m2;
+ private final MatrixBlock _ret;
+ private final BinaryOperator _bop;
+ private final int _rl;
+ private final int _ru;
+
+ protected BincellTask( MatrixBlock m1, MatrixBlock m2,
MatrixBlock ret, BinaryOperator bop, int rl, int ru ) {
+ _m1 = m1;
+ _m2 = m2;
+ _ret = ret;
+ _bop = bop;
+ _rl = rl;
+ _ru = ru;
+ }
+
+ @Override
+ public Long call() {
+ //execute binary operation on row partition
+ unsafeBinary(_m1, _m2, _ret, _bop, _rl, _ru);
+
+ //maintain block nnz (upper bounds inclusive)
+ return _ret.recomputeNonZeros(_rl, _ru-1);
+ }
+ }
+
+ private static class BincellScalarTask implements Callable<Long> {
+ private final MatrixBlock _m1;
+ private final MatrixBlock _ret;
+ private final ScalarOperator _sop;
+ private final int _rl;
+ private final int _ru;
+
+ protected BincellScalarTask( MatrixBlock m1, MatrixBlock ret,
ScalarOperator sop, int rl, int ru ) {
+ _m1 = m1;
+ _ret = ret;
+ _sop = sop;
+ _rl = rl;
+ _ru = ru;
+ }
+
+ @Override
+ public Long call() {
+ //execute binary operation on row partition
+ safeBinaryScalar(_m1, _ret, _sop, _rl, _ru);
+
+ //maintain block nnz (upper bounds inclusive)
+ return _ret.recomputeNonZeros(_rl, _ru-1);
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
index d64a936..6e7ba49 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
@@ -116,7 +116,9 @@ public class LibMatrixNative
Statistics.incrementNativeFailuresCounter();
}
//fallback to default java implementation
- LOG.warn("matrixMult: Native mat mult failed. Falling back to
java version.");
+ LOG.warn("matrixMult: Native mat mult failed. Falling back to
java version ("
+ + "loaded=" + NativeHelper.isNativeLibraryLoaded()
+ + ", sparse=" + (m1.isInSparseFormat() |
m2.isInSparseFormat()) + ")");
if (k == 1)
LibMatrixMult.matrixMult(m1, m2, ret, !examSparsity);
else
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index ed090c3..5d4b869 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -694,8 +694,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
if( v == 0 )
return;
- if( !sparse ) //DENSE
- {
+ if( !sparse ) { //DENSE
//allocate on demand (w/o overwriting nnz)
allocateDenseBlock(false);
@@ -703,8 +702,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
denseBlock.set(r, c, v);
nonZeros++;
}
- else //SPARSE
- {
+ else { //SPARSE
//allocation on demand (w/o overwriting nnz)
allocateSparseRowsBlock(false);
sparseBlock.allocate(r, estimatedNNzsPerRow, clen);
@@ -715,6 +713,28 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
}
}
+ public void appendValuePlain(int r, int c, double v) {
+ //early abort (append guarantees no overwrite)
+ if( v == 0 )
+ return;
+
+ if( !sparse ) { //DENSE
+ //allocate on demand (w/o overwriting nnz)
+ allocateDenseBlock(false);
+
+ //set value and maintain nnz
+ denseBlock.set(r, c, v);
+ }
+ else { //SPARSE
+ //allocation on demand (w/o overwriting nnz)
+ allocateSparseRowsBlock(false);
+ sparseBlock.allocate(r, estimatedNNzsPerRow, clen);
+
+ //set value and maintain nnz
+ sparseBlock.append(r, c, v);
+ }
+ }
+
public void appendRow(int r, SparseRow row) {
appendRow(r, row, true);
}
@@ -2659,7 +2679,10 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
ret.reset(rlen, clen, sp, this.nonZeros);
//core scalar operations
- LibMatrixBincell.bincellOp(this, ret, op);
+ if( op.getNumThreads() > 1 )
+ LibMatrixBincell.bincellOp(this, ret, op,
op.getNumThreads());
+ else
+ LibMatrixBincell.bincellOp(this, ret, op);
return ret;
}
@@ -2842,7 +2865,10 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
ret.reset(rows, cols, resultSparse.sparse,
resultSparse.estimatedNonZeros);
//core binary cell operation
- LibMatrixBincell.bincellOp( this, that, ret, op );
+ if( op.getNumThreads() > 1 )
+ LibMatrixBincell.bincellOp( this, that, ret, op,
op.getNumThreads() );
+ else
+ LibMatrixBincell.bincellOp( this, that, ret, op );
return ret;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index bc4cdd0..7579046 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -57,6 +57,7 @@ public class BinaryOperator extends Operator implements
Serializable
public final ValueFunction fn;
public final boolean commutative;
+ private int _k = 1; // num threads
public BinaryOperator(ValueFunction p) {
//binaryop is sparse-safe iff (0 op 0) == 0
@@ -70,6 +71,14 @@ public class BinaryOperator extends Operator implements
Serializable
|| p instanceof And || p instanceof Or || p instanceof
Xor;
}
+ public void setNumThreads(int k) {
+ _k = k;
+ }
+
+ public int getNumThreads() {
+ return _k;
+ }
+
/**
* Method for getting the hop binary operator type for a given function
object.
* This is used in order to use a common code path for consistency
between
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/LeftScalarOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/LeftScalarOperator.java
index 7a40a3f..abca742 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/LeftScalarOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/LeftScalarOperator.java
@@ -58,7 +58,7 @@ public class LeftScalarOperator extends ScalarOperator
@Override
public ScalarOperator setConstant(double cst) {
- return new LeftScalarOperator(fn, cst);
+ return new LeftScalarOperator(fn, cst, getNumThreads());
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/RightScalarOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/RightScalarOperator.java
index a55ed66..fe821e0 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/RightScalarOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/RightScalarOperator.java
@@ -56,7 +56,7 @@ public class RightScalarOperator extends ScalarOperator
@Override
public ScalarOperator setConstant(double cst) {
- return new RightScalarOperator(fn, cst);
+ return new RightScalarOperator(fn, cst, getNumThreads());
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
index 8f27209..d33bbae 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
@@ -44,7 +44,7 @@ public abstract class ScalarOperator extends Operator
public final ValueFunction fn;
protected final double _constant;
- private final int k; //num threads
+ private int _k; //num threads
public ScalarOperator(ValueFunction p, double cst) {
this(p, cst, false);
@@ -63,13 +63,21 @@ public abstract class ScalarOperator extends Operator
|| (p instanceof Builtin &&
((Builtin)p).getBuiltinCode()==BuiltinCode.MIN && cst>=0));
fn = p;
_constant = cst;
- k = numThreads;
+ _k = numThreads;
}
public double getConstant() {
return _constant;
}
+ public void setNumThreads(int k) {
+ _k = k;
+ }
+
+ public int getNumThreads() {
+ return _k;
+ }
+
public abstract ScalarOperator setConstant(double cst);
public abstract ScalarOperator setConstant(double cst, int numThreads);
@@ -94,8 +102,4 @@ public abstract class ScalarOperator extends Operator
|| fn instanceof Builtin &&
((Builtin)fn).getBuiltinCode()==BuiltinCode.LOG_NZ)
|| fn instanceof BitwShiftL || fn instanceof BitwShiftR;
}
-
- public int getNumThreads() {
- return k;
- }
}