Repository: systemml
Updated Branches:
  refs/heads/master c95019fd9 -> f14255f46


[SYSTEMML-2108] Performance CP ternary +* and -* operations

Since the introduction of the general ternary operation framework for
ifelse (which also subsumed the specific +* and -* operations), the +*
and -* operations showed non-negligible overhead, especially for
sparse-dense combinations. Hence, this patch adds a special case for
matrix-scalar-matrix and matrix-matrix-scalar operations that routes
these operations to the binary operation framework.

On lenet over mnist, +* and -* consumed 28% execution time - this patch
then reduced the runtime of these operations by more than 2x.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f14255f4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f14255f4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f14255f4

Branch: refs/heads/master
Commit: f14255f464017a0f3dea1d335160b25810fe20a3
Parents: c95019f
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri Feb 2 22:59:52 2018 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Feb 2 22:59:52 2018 -0800

----------------------------------------------------------------------
 .../runtime/functionobjects/MinusMultiply.java  | 22 ++++++++++++++++++--
 .../runtime/functionobjects/PlusMultiply.java   | 22 ++++++++++++++++++--
 .../functionobjects/TernaryValueFunction.java   |  5 +++++
 .../sysml/runtime/matrix/data/MatrixBlock.java  | 14 ++++++++++---
 .../matrix/operators/BinaryOperator.java        |  3 +++
 5 files changed, 59 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java 
b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
index 794571f..1e3d093 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
@@ -21,14 +21,23 @@ package org.apache.sysml.runtime.functionobjects;
 
 import java.io.Serializable;
 
-public class MinusMultiply extends TernaryValueFunction implements Serializable
+import 
org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+
+public class MinusMultiply extends TernaryValueFunction implements 
ValueFunctionWithConstant, Serializable
 {
        private static final long serialVersionUID = 2801982061205871665L;
        
        private static MinusMultiply singleObj = null;
 
+       private final double _cnt;
+       
        private MinusMultiply() {
-               // nothing to do here
+               _cnt = 1;
+       }
+       
+       private MinusMultiply(double cnt) {
+               _cnt = cnt;
        }
 
        public static MinusMultiply getFnObject() {
@@ -41,4 +50,13 @@ public class MinusMultiply extends TernaryValueFunction 
implements Serializable
        public double execute(double in1, double in2, double in3) {
                return in1 - in2 * in3;
        }
+       
+       public BinaryOperator setOp2Constant(double cnt) {
+               return new BinaryOperator(new MinusMultiply(cnt));
+       }
+       
+       @Override
+       public double execute(double in1, double in2) {
+               return in1 - _cnt * in2;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java 
b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
index cb821f5..041527f 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
@@ -21,14 +21,23 @@ package org.apache.sysml.runtime.functionobjects;
 
 import java.io.Serializable;
 
-public class PlusMultiply extends TernaryValueFunction implements Serializable
+import 
org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+
+public class PlusMultiply extends TernaryValueFunction implements 
ValueFunctionWithConstant, Serializable
 {
        private static final long serialVersionUID = 2801982061205871665L;
 
        private static PlusMultiply singleObj = null;
 
+       private final double _cnt;
+       
        private PlusMultiply() {
-               // nothing to do here
+               _cnt = 1;
+       }
+       
+       private PlusMultiply(double cnt) {
+               _cnt = cnt;
        }
 
        public static PlusMultiply getFnObject() {
@@ -41,4 +50,13 @@ public class PlusMultiply extends TernaryValueFunction 
implements Serializable
        public double execute(double in1, double in2, double in3) {
                return in1 + in2 * in3;
        }
+
+       public BinaryOperator setOp2Constant(double cnt) {
+               return new BinaryOperator(new PlusMultiply(cnt));
+       }
+       
+       @Override
+       public double execute(double in1, double in2) {
+               return in1 + _cnt * in2;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
 
b/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
index c317010..9629746 100644
--- 
a/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
@@ -22,6 +22,7 @@ package org.apache.sysml.runtime.functionobjects;
 import java.io.Serializable;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 
 public abstract class TernaryValueFunction extends ValueFunction implements 
Serializable
 {
@@ -29,4 +30,8 @@ public abstract class TernaryValueFunction extends 
ValueFunction implements Seri
        
        public abstract double execute ( double in1, double in2, double in3 )
                throws DMLRuntimeException;
+       
+       public interface ValueFunctionWithConstant {
+               public BinaryOperator setOp2Constant(double cnt);
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index e06c8c1..654cf53 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -51,14 +51,17 @@ import org.apache.sysml.runtime.functionobjects.IfElse;
 import org.apache.sysml.runtime.functionobjects.KahanFunction;
 import org.apache.sysml.runtime.functionobjects.KahanPlus;
 import org.apache.sysml.runtime.functionobjects.KahanPlusSq;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
 import org.apache.sysml.runtime.functionobjects.Multiply;
 import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.functionobjects.ReduceAll;
 import org.apache.sysml.runtime.functionobjects.ReduceCol;
 import org.apache.sysml.runtime.functionobjects.ReduceRow;
 import org.apache.sysml.runtime.functionobjects.RevIndex;
 import org.apache.sysml.runtime.functionobjects.SortIndex;
 import org.apache.sysml.runtime.functionobjects.SwapIndex;
+import 
org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
 import org.apache.sysml.runtime.instructions.cp.CM_COV_Object;
 import org.apache.sysml.runtime.instructions.cp.KahanObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObject;
@@ -2803,9 +2806,8 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                //prepare result
                ret.reset(m, n, false);
                
-               if( op.fn instanceof IfElse && (s1 || nnz==0 || nnz==(long)m*n) 
)
-               {
-                       //special case for shallow-copy if-else
+               if( op.fn instanceof IfElse && (s1 || nnz==0 || nnz==(long)m*n) 
) {
+                       //SPECIAL CASE for shallow-copy if-else
                        boolean expr = s1 ? (d1 != 0) : (nnz==(long)m*n);
                        MatrixBlock tmp = expr ? m2 : m3;
                        if( tmp.rlen==m && tmp.clen==n ) {
@@ -2822,6 +2824,12 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                                }
                        }
                }
+               else if (s2 != s3 && (op.fn instanceof PlusMultiply || op.fn 
instanceof MinusMultiply) ) {
+                       //SPECIAL CASE for sparse-dense combinations of common 
+* and -*
+                       BinaryOperator bop = ((ValueFunctionWithConstant)op.fn)
+                               .setOp2Constant(s2 ? d2 : d3);
+                       LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, ret, 
bop);
+               }
                else {
                        ret.allocateDenseBlock();
                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java 
b/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
index 5245db5..e3b9a06 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
@@ -39,12 +39,14 @@ import 
org.apache.sysml.runtime.functionobjects.IntegerDivide;
 import org.apache.sysml.runtime.functionobjects.LessThan;
 import org.apache.sysml.runtime.functionobjects.LessThanEquals;
 import org.apache.sysml.runtime.functionobjects.Minus;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
 import org.apache.sysml.runtime.functionobjects.MinusNz;
 import org.apache.sysml.runtime.functionobjects.Modulus;
 import org.apache.sysml.runtime.functionobjects.Multiply;
 import org.apache.sysml.runtime.functionobjects.NotEquals;
 import org.apache.sysml.runtime.functionobjects.Or;
 import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.functionobjects.Power;
 import org.apache.sysml.runtime.functionobjects.ValueFunction;
 import org.apache.sysml.runtime.functionobjects.Xor;
@@ -58,6 +60,7 @@ public class BinaryOperator  extends Operator implements 
Serializable
        public BinaryOperator(ValueFunction p) {
                //binaryop is sparse-safe iff (0 op 0) == 0
                super (p instanceof Plus || p instanceof Multiply || p 
instanceof Minus
+                       || p instanceof PlusMultiply || p instanceof 
MinusMultiply
                        || p instanceof And || p instanceof Or || p instanceof 
Xor
                        || p instanceof BitwAnd || p instanceof BitwOr || p 
instanceof BitwXor
                        || p instanceof BitwShiftL || p instanceof BitwShiftR);

Reply via email to