Repository: systemml Updated Branches: refs/heads/master c95019fd9 -> f14255f46
[SYSTEMML-2108] Performance CP ternary +* and -* operations Since the introduction of the general ternary operation framework for ifelse (which also subsumed the specific +* and -* operations), the +* and -* operations showed non-negligible overhead, especially for sparse-dense combinations. Hence, this patch adds a special case for matrix-scalar-matrix and matrix-matrix-scalar operations that routes these operations to the binary operation framework. On lenet over mnist, +* and -* consumed 28% execution time - this patch then reduced the runtime of these operations by more than 2x. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f14255f4 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f14255f4 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f14255f4 Branch: refs/heads/master Commit: f14255f464017a0f3dea1d335160b25810fe20a3 Parents: c95019f Author: Matthias Boehm <[email protected]> Authored: Fri Feb 2 22:59:52 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Fri Feb 2 22:59:52 2018 -0800 ---------------------------------------------------------------------- .../runtime/functionobjects/MinusMultiply.java | 22 ++++++++++++++++++-- .../runtime/functionobjects/PlusMultiply.java | 22 ++++++++++++++++++-- .../functionobjects/TernaryValueFunction.java | 5 +++++ .../sysml/runtime/matrix/data/MatrixBlock.java | 14 ++++++++++--- .../matrix/operators/BinaryOperator.java | 3 +++ 5 files changed, 59 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java index 794571f..1e3d093 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java @@ -21,14 +21,23 @@ package org.apache.sysml.runtime.functionobjects; import java.io.Serializable; -public class MinusMultiply extends TernaryValueFunction implements Serializable +import org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant; +import org.apache.sysml.runtime.matrix.operators.BinaryOperator; + +public class MinusMultiply extends TernaryValueFunction implements ValueFunctionWithConstant, Serializable { private static final long serialVersionUID = 2801982061205871665L; private static MinusMultiply singleObj = null; + private final double _cnt; + private MinusMultiply() { - // nothing to do here + _cnt = 1; + } + + private MinusMultiply(double cnt) { + _cnt = cnt; } public static MinusMultiply getFnObject() { @@ -41,4 +50,13 @@ public class MinusMultiply extends TernaryValueFunction implements Serializable public double execute(double in1, double in2, double in3) { return in1 - in2 * in3; } + + public BinaryOperator setOp2Constant(double cnt) { + return new BinaryOperator(new MinusMultiply(cnt)); + } + + @Override + public double execute(double in1, double in2) { + return in1 - _cnt * in2; + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java index cb821f5..041527f 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java @@ -21,14 +21,23 @@ package org.apache.sysml.runtime.functionobjects; import java.io.Serializable; -public class PlusMultiply extends TernaryValueFunction implements Serializable +import org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant; +import org.apache.sysml.runtime.matrix.operators.BinaryOperator; + +public class PlusMultiply extends TernaryValueFunction implements ValueFunctionWithConstant, Serializable { private static final long serialVersionUID = 2801982061205871665L; private static PlusMultiply singleObj = null; + private final double _cnt; + private PlusMultiply() { - // nothing to do here + _cnt = 1; + } + + private PlusMultiply(double cnt) { + _cnt = cnt; } public static PlusMultiply getFnObject() { @@ -41,4 +50,13 @@ public class PlusMultiply extends TernaryValueFunction implements Serializable public double execute(double in1, double in2, double in3) { return in1 + in2 * in3; } + + public BinaryOperator setOp2Constant(double cnt) { + return new BinaryOperator(new PlusMultiply(cnt)); + } + + @Override + public double execute(double in1, double in2) { + return in1 + _cnt * in2; + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java b/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java index c317010..9629746 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java @@ -22,6 +22,7 @@ package org.apache.sysml.runtime.functionobjects; import java.io.Serializable; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.matrix.operators.BinaryOperator; public abstract class TernaryValueFunction extends ValueFunction implements Serializable { @@ -29,4 +30,8 @@ public abstract class TernaryValueFunction extends ValueFunction implements Seri public abstract double execute ( double in1, double in2, double in3 ) throws DMLRuntimeException; + + public interface ValueFunctionWithConstant { + public BinaryOperator setOp2Constant(double cnt); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index e06c8c1..654cf53 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -51,14 +51,17 @@ import org.apache.sysml.runtime.functionobjects.IfElse; import org.apache.sysml.runtime.functionobjects.KahanFunction; import org.apache.sysml.runtime.functionobjects.KahanPlus; import org.apache.sysml.runtime.functionobjects.KahanPlusSq; +import org.apache.sysml.runtime.functionobjects.MinusMultiply; import org.apache.sysml.runtime.functionobjects.Multiply; import org.apache.sysml.runtime.functionobjects.Plus; +import org.apache.sysml.runtime.functionobjects.PlusMultiply; import org.apache.sysml.runtime.functionobjects.ReduceAll; import org.apache.sysml.runtime.functionobjects.ReduceCol; import org.apache.sysml.runtime.functionobjects.ReduceRow; import org.apache.sysml.runtime.functionobjects.RevIndex; import org.apache.sysml.runtime.functionobjects.SortIndex; import org.apache.sysml.runtime.functionobjects.SwapIndex; +import org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant; import org.apache.sysml.runtime.instructions.cp.CM_COV_Object; import org.apache.sysml.runtime.instructions.cp.KahanObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; @@ -2803,9 +2806,8 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab //prepare result ret.reset(m, n, false); - if( op.fn instanceof IfElse && (s1 || nnz==0 || nnz==(long)m*n) ) - { - //special case for shallow-copy if-else + if( op.fn instanceof IfElse && (s1 || nnz==0 || nnz==(long)m*n) ) { + //SPECIAL CASE for shallow-copy if-else boolean expr = s1 ? (d1 != 0) : (nnz==(long)m*n); MatrixBlock tmp = expr ? m2 : m3; if( tmp.rlen==m && tmp.clen==n ) { @@ -2822,6 +2824,12 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab } } } + else if (s2 != s3 && (op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply) ) { + //SPECIAL CASE for sparse-dense combinations of common +* and -* + BinaryOperator bop = ((ValueFunctionWithConstant)op.fn) + .setOp2Constant(s2 ? d2 : d3); + LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, ret, bop); + } else { ret.allocateDenseBlock(); http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java index 5245db5..e3b9a06 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java @@ -39,12 +39,14 @@ import org.apache.sysml.runtime.functionobjects.IntegerDivide; import org.apache.sysml.runtime.functionobjects.LessThan; import org.apache.sysml.runtime.functionobjects.LessThanEquals; import org.apache.sysml.runtime.functionobjects.Minus; +import org.apache.sysml.runtime.functionobjects.MinusMultiply; import org.apache.sysml.runtime.functionobjects.MinusNz; import org.apache.sysml.runtime.functionobjects.Modulus; import org.apache.sysml.runtime.functionobjects.Multiply; import org.apache.sysml.runtime.functionobjects.NotEquals; import org.apache.sysml.runtime.functionobjects.Or; import org.apache.sysml.runtime.functionobjects.Plus; +import org.apache.sysml.runtime.functionobjects.PlusMultiply; import org.apache.sysml.runtime.functionobjects.Power; import org.apache.sysml.runtime.functionobjects.ValueFunction; import org.apache.sysml.runtime.functionobjects.Xor; @@ -58,6 +60,7 @@ public class BinaryOperator extends Operator implements Serializable public BinaryOperator(ValueFunction p) { //binaryop is sparse-safe iff (0 op 0) == 0 super (p instanceof Plus || p instanceof Multiply || p instanceof Minus + || p instanceof PlusMultiply || p instanceof MinusMultiply || p instanceof And || p instanceof Or || p instanceof Xor || p instanceof BitwAnd || p instanceof BitwOr || p instanceof BitwXor || p instanceof BitwShiftL || p instanceof BitwShiftR);
