Repository: incubator-systemml
Updated Branches:
  refs/heads/master c22f239e3 -> b584aecf6


[SYSTEMML-766] Extended axpy compiler/runtime support (mr, hybrid)

Incl fix rewrite 'fused binary operation chain' axpy.

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b584aecf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b584aecf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b584aecf

Branch: refs/heads/master
Commit: b584aecf6b3a1eb96ff83b78cc3ad7c7c6d15baa
Parents: c22f239
Author: Matthias Boehm <[email protected]>
Authored: Mon Jul 18 19:46:55 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Tue Jul 19 10:58:49 2016 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/BinaryOp.java    |  2 +-
 .../java/org/apache/sysml/hops/TernaryOp.java   | 47 ++++++++++++-
 .../RewriteAlgebraicSimplificationStatic.java   | 12 ++--
 .../java/org/apache/sysml/lops/PlusMult.java    | 58 ++++++++++++----
 .../runtime/functionobjects/MinusMultiply.java  | 18 +++--
 .../runtime/functionobjects/PlusMultiply.java   | 18 +++--
 .../ValueFunctionWithConstant.java              |  6 +-
 .../runtime/instructions/InstructionUtils.java  |  6 ++
 .../instructions/MRInstructionParser.java       |  7 ++
 .../instructions/cp/PlusMultCPInstruction.java  | 17 +++--
 .../runtime/instructions/mr/MRInstruction.java  |  2 +-
 .../instructions/mr/PlusMultInstruction.java    | 69 ++++++++++++++++++++
 .../spark/PlusMultSPInstruction.java            | 12 ++--
 .../misc/RewriteFuseBinaryOpChainTest.java      | 46 +++++++++----
 14 files changed, 249 insertions(+), 71 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/hops/BinaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java 
b/src/main/java/org/apache/sysml/hops/BinaryOp.java
index 65e9232..edc327d 100644
--- a/src/main/java/org/apache/sysml/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java
@@ -1335,7 +1335,7 @@ public class BinaryOp extends Hop
         * @param right
         * @return
         */
-       private static boolean requiresReplication( Hop left, Hop right )
+       public static boolean requiresReplication( Hop left, Hop right )
        {
                return (!(left.getDim2()>=1 && right.getDim2()>=1) //cols of 
any input unknown 
                                ||(left.getDim2() > 1 && right.getDim2()==1 && 
left.getDim2()>=left.getColsInBlock() ) //col MV and more than 1 block

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/hops/TernaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/TernaryOp.java 
b/src/main/java/org/apache/sysml/hops/TernaryOp.java
index 72e7624..626ad2c 100644
--- a/src/main/java/org/apache/sysml/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/TernaryOp.java
@@ -31,6 +31,7 @@ import org.apache.sysml.lops.Lop;
 import org.apache.sysml.lops.LopsException;
 import org.apache.sysml.lops.PickByCount;
 import org.apache.sysml.lops.PlusMult;
+import org.apache.sysml.lops.RepMat;
 import org.apache.sysml.lops.SortKeys;
 import org.apache.sysml.lops.Ternary;
 import org.apache.sysml.lops.UnaryCP;
@@ -627,16 +628,58 @@ public class TernaryOp extends Hop
                        }
                }
        }
-       private void constructLopsPlusMult() throws HopsException, 
LopsException {
+       
+       /**
+        * 
+        * @throws HopsException
+        * @throws LopsException
+        */
+       private void constructLopsPlusMult() 
+               throws HopsException, LopsException 
+       {
                if ( _op != OpOp3.PLUS_MULT && _op != OpOp3.MINUS_MULT )
                        throw new HopsException("Unexpected operation: " + _op 
+ ", expecting " + OpOp3.PLUS_MULT + " or" +  OpOp3.MINUS_MULT);
                
                ExecType et = optFindExecType();
-               PlusMult plusmult = new 
PlusMult(getInput().get(0).constructLops(),getInput().get(1).constructLops(),getInput().get(2).constructLops(),
 _op, getDataType(),getValueType(), et );
+               PlusMult plusmult = null;
+               
+               if( et == ExecType.CP || et == ExecType.SPARK ) {
+                       plusmult = new PlusMult(
+                                       getInput().get(0).constructLops(),
+                                       getInput().get(1).constructLops(),
+                                       getInput().get(2).constructLops(), 
+                                       _op, getDataType(),getValueType(), et 
);        
+               }
+               else { //MR
+                       Hop left = getInput().get(0);
+                       Hop right = getInput().get(2);
+                       boolean requiresRep = 
BinaryOp.requiresReplication(left, right);
+                       
+                       Lop rightLop = right.constructLops();
+                       if( requiresRep ) {
+                               Lop offset = createOffsetLop(left, 
(right.getDim2()<=1)); //ncol of left input (determines num replicates)
+                               rightLop = new RepMat(rightLop, offset, 
(right.getDim2()<=1), right.getDataType(), right.getValueType());
+                               setOutputDimensions(rightLop);
+                               setLineNumbers(rightLop);       
+                       }
+               
+                       Group group1 = new Group(left.constructLops(), 
Group.OperationTypes.Sort, getDataType(), getValueType());
+                       setLineNumbers(group1);
+                       setOutputDimensions(group1);
+               
+                       Group group2 = new Group(rightLop, 
Group.OperationTypes.Sort, getDataType(), getValueType());
+                       setLineNumbers(group2);
+                       setOutputDimensions(group2);
+                       
+                       plusmult = new PlusMult(group1, 
getInput().get(1).constructLops(), 
+                                       group2, _op, 
getDataType(),getValueType(), et );        
+               }
+               
                setOutputDimensions(plusmult);
                setLineNumbers(plusmult);
                setLops(plusmult);
        }
+       
        @Override
        public String getOpString() {
                String s = new String("");

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 816b55a..9ef2c05 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -25,8 +25,6 @@ import java.util.List;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
@@ -34,7 +32,6 @@ import org.apache.sysml.hops.DataGenOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.IndexingOp;
-import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
@@ -1920,10 +1917,11 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
         */
        private Hop fuseBinaryOperationChain(Hop parent, Hop hi, int pos) {
                //pattern: X + lamda*Y -> X +* lambda Y         
-               if( hi instanceof BinaryOp
-                               && (((BinaryOp)hi).getOp()==OpOp2.PLUS || 
((BinaryOp)hi).getOp()==OpOp2.MINUS) 
-                               && 
hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1) 
instanceof BinaryOp 
-                               && (DMLScript.rtplatform == 
RUNTIME_PLATFORM.SINGLE_NODE || OptimizerUtils.isSparkExecutionMode()) )
+               if( hi instanceof BinaryOp 
+                       && (((BinaryOp)hi).getOp()==OpOp2.PLUS || 
((BinaryOp)hi).getOp()==OpOp2.MINUS) 
+                       && hi.getInput().get(0).getDataType()==DataType.MATRIX 
+                       && hi.getInput().get(1) instanceof BinaryOp 
+                       && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.MULT 
)
                {
                        //Check that the inner binary Op is a product of Scalar 
times Matrix or viceversa
                        Hop innerBinaryOp =  hi.getInput().get(1);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/lops/PlusMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/PlusMult.java 
b/src/main/java/org/apache/sysml/lops/PlusMult.java
index 65e6440..8ee8625 100644
--- a/src/main/java/org/apache/sysml/lops/PlusMult.java
+++ b/src/main/java/org/apache/sysml/lops/PlusMult.java
@@ -34,9 +34,9 @@ public class PlusMult extends Lop
 {
        
        private void init(Lop input1, Lop input2, Lop input3, ExecType et) {
-               this.addInput(input1);
-               this.addInput(input2);
-               this.addInput(input3);
+               addInput(input1);
+               addInput(input2);
+               addInput(input3);
                input1.addOutput(this); 
                input2.addOutput(this); 
                input3.addOutput(this); 
@@ -47,7 +47,13 @@ public class PlusMult extends Lop
                
                if ( et == ExecType.CP ||  et == ExecType.SPARK ){
                        lps.addCompatibility(JobType.INVALID);
-                       this.lps.setProperties( inputs, et, 
ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+                       lps.setProperties( inputs, et, 
ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+               }
+               else if( et == ExecType.MR ) {
+                       lps.addCompatibility(JobType.GMR);
+                       lps.addCompatibility(JobType.DATAGEN);
+                       lps.addCompatibility(JobType.REBLOCK);
+                       lps.setProperties( inputs, et, ExecLocation.Reduce, 
breaksAlignment, aligner, definesMRJob );
                }
        }
        
@@ -60,13 +66,15 @@ public class PlusMult extends Lop
 
        @Override
        public String toString() {
-
                return "Operation = PlusMult";
        }
        
+       public String getOpString() {
+               return (type==Lop.Type.PlusMult) ? "+*" : "-*";
+       }
        
        /**
-        * Function to generate CP Sum of a matrix with another matrix 
multiplied by Scalar.
+        * Function to generate CP/Spark axpy.
         * 
         * input1: matrix1
         * input2: Scalar
@@ -75,23 +83,51 @@ public class PlusMult extends Lop
        @Override
        public String getInstructions(String input1, String input2, String 
input3, String output) {
                StringBuilder sb = new StringBuilder();
+               
                sb.append( getExecType() );
                sb.append( OPERAND_DELIMITOR );
-               if(type==Lop.Type.PlusMult)
-                       sb.append( "+*" );
-               else
-                       sb.append( "-*" );
+               
+               sb.append(getOpString());
                sb.append( OPERAND_DELIMITOR );
                
                // Matrix1
                sb.append( getInputs().get(0).prepInputOperand(input1) );
                sb.append( OPERAND_DELIMITOR );
                
-               // Matrix2
+               // Scalar
                sb.append( getInputs().get(1).prepScalarInputOperand(input2) );
                sb.append( OPERAND_DELIMITOR );
                
+               // Matrix2
+               sb.append( getInputs().get(2).prepInputOperand(input3));
+               sb.append( OPERAND_DELIMITOR );
+               
+               sb.append( prepOutputOperand(output));
+               
+               return sb.toString();
+       }
+       
+       @Override
+       public String getInstructions(int input1, int input2, int input3, int 
output) 
+               throws LopsException 
+       {
+               StringBuilder sb = new StringBuilder();
+               
+               sb.append( getExecType() );
+               sb.append( OPERAND_DELIMITOR );
+               
+               sb.append(getOpString());
+               sb.append( OPERAND_DELIMITOR );
+               
+               // Matrix1
+               sb.append( getInputs().get(0).prepInputOperand(input1) );
+               sb.append( OPERAND_DELIMITOR );
+               
                // Scalar
+               sb.append( getInputs().get(1).prepScalarLabel() );
+               sb.append( OPERAND_DELIMITOR );
+               
+               // Matrix2
                sb.append( getInputs().get(2).prepInputOperand(input3));
                sb.append( OPERAND_DELIMITOR );
                

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java 
b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
index ee7a8fb..2036cf6 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
@@ -23,21 +23,25 @@ import java.io.Serializable;
 
 public class MinusMultiply extends ValueFunctionWithConstant implements 
Serializable
 {
-
        private static final long serialVersionUID = 2801982061205871665L;
        
-       public MinusMultiply() {
+       private MinusMultiply() {
                // nothing to do here
        }
+
+       public static MinusMultiply getMinusMultiplyFnObject() {
+               //create new object as the constant is modified and hence 
+               //cannot be shared across multiple threads (e.g., in parfor)
+               return new MinusMultiply();
+       }
+       
        public Object clone() throws CloneNotSupportedException {
                // cloning is not supported for singleton classes
                throw new CloneNotSupportedException();
        }
+       
        @Override
-       public double execute(double in1, double in2)
-       {
-               return in1 - _constant*in2;
-               
+       public double execute(double in1, double in2) {
+               return in1 - _constant*in2;     
        }
-       
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java 
b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
index 87eb47b..2a1eea0 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
@@ -23,21 +23,25 @@ import java.io.Serializable;
 
 public class PlusMultiply extends ValueFunctionWithConstant implements 
Serializable
 {
-
        private static final long serialVersionUID = 2801982061205871665L;
        
-       public PlusMultiply() {
+       private PlusMultiply() {
                // nothing to do here
        }
+
+       public static PlusMultiply getPlusMultiplyFnObject() {
+               //create new object as the constant is modified and hence 
+               //cannot be shared across multiple threads (e.g., in parfor)
+               return new PlusMultiply();
+       }
+       
        public Object clone() throws CloneNotSupportedException {
                // cloning is not supported for singleton classes
                throw new CloneNotSupportedException();
        }
+       
        @Override
-       public double execute(double in1, double in2)
-       {
-               return in1 + _constant*in2;
-               
+       public double execute(double in1, double in2) {
+               return in1 + _constant*in2;     
        }
-       
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
 
b/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
index 2820875..f23c29a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
+++ 
b/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
@@ -26,13 +26,11 @@ public abstract class ValueFunctionWithConstant extends 
ValueFunction implements
        private static final long serialVersionUID = -4985988545393861058L;
        protected double _constant;
        
-       public void setConstant(double constant)
-       {
+       public void setConstant(double constant) {
                _constant = constant;
        }
        
-       public double getConstant()
-       {
+       public double getConstant() {
                return _constant;
        }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
index d2f477d..a3a7c08 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
@@ -56,6 +56,7 @@ import 
org.apache.sysml.runtime.functionobjects.LessThanEquals;
 import org.apache.sysml.runtime.functionobjects.Mean;
 import org.apache.sysml.runtime.functionobjects.Minus;
 import org.apache.sysml.runtime.functionobjects.Minus1Multiply;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
 import org.apache.sysml.runtime.functionobjects.MinusNz;
 import org.apache.sysml.runtime.functionobjects.Modulus;
 import org.apache.sysml.runtime.functionobjects.Multiply;
@@ -63,6 +64,7 @@ import org.apache.sysml.runtime.functionobjects.Multiply2;
 import org.apache.sysml.runtime.functionobjects.NotEquals;
 import org.apache.sysml.runtime.functionobjects.Or;
 import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.functionobjects.Power;
 import org.apache.sysml.runtime.functionobjects.Power2;
 import org.apache.sysml.runtime.functionobjects.ReduceAll;
@@ -626,6 +628,10 @@ public class InstructionUtils
                        return new 
BinaryOperator(Builtin.getBuiltinFnObject("max"));
                else if ( opcode.equalsIgnoreCase("min") ) 
                        return new 
BinaryOperator(Builtin.getBuiltinFnObject("min"));
+               else if ( opcode.equalsIgnoreCase("+*") )
+                       return new 
BinaryOperator(PlusMultiply.getPlusMultiplyFnObject());
+               else if ( opcode.equalsIgnoreCase("-*") )
+                       return new 
BinaryOperator(MinusMultiply.getMinusMultiplyFnObject());
                
                throw new DMLRuntimeException("Unknown binary opcode " + 
opcode);
        }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
index 894e7e9..0b9cb7d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
@@ -63,6 +63,7 @@ import 
org.apache.sysml.runtime.instructions.mr.MatrixReshapeMRInstruction;
 import org.apache.sysml.runtime.instructions.mr.PMMJMRInstruction;
 import 
org.apache.sysml.runtime.instructions.mr.ParameterizedBuiltinMRInstruction;
 import org.apache.sysml.runtime.instructions.mr.PickByCountInstruction;
+import org.apache.sysml.runtime.instructions.mr.PlusMultInstruction;
 import org.apache.sysml.runtime.instructions.mr.QuaternaryInstruction;
 import org.apache.sysml.runtime.instructions.mr.RandInstruction;
 import org.apache.sysml.runtime.instructions.mr.RangeBasedReIndexInstruction;
@@ -182,6 +183,9 @@ public class MRInstructionParser extends InstructionParser
                String2MRInstructionType.put( "^2"   , 
MRINSTRUCTION_TYPE.ArithmeticBinary); //special ^ case
                String2MRInstructionType.put( "*2"   , 
MRINSTRUCTION_TYPE.ArithmeticBinary); //special * case
                String2MRInstructionType.put( "-nz"  , 
MRINSTRUCTION_TYPE.ArithmeticBinary); //special - case
+               String2MRInstructionType.put( "+*"   , 
MRINSTRUCTION_TYPE.ArithmeticBinary2); 
+               String2MRInstructionType.put( "-*"   , 
MRINSTRUCTION_TYPE.ArithmeticBinary2); 
+               
                String2MRInstructionType.put( "map+"    , 
MRINSTRUCTION_TYPE.ArithmeticBinary);
                String2MRInstructionType.put( "map-"    , 
MRINSTRUCTION_TYPE.ArithmeticBinary);
                String2MRInstructionType.put( "map*"    , 
MRINSTRUCTION_TYPE.ArithmeticBinary);
@@ -333,6 +337,9 @@ public class MRInstructionParser extends InstructionParser
                                }
                        }
                        
+                       case ArithmeticBinary2:
+                               return 
PlusMultInstruction.parseInstruction(str);
+                       
                        case AggregateBinary:
                                return 
AggregateBinaryInstruction.parseInstruction(str);
                                

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
index 212e0b7..12bc465 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
@@ -28,13 +28,15 @@ import 
org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 
-public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction {
+public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction 
+{
        public PlusMultCPInstruction(BinaryOperator op, CPOperand in1, 
CPOperand in2, 
                        CPOperand in3, CPOperand out, String opcode, String 
str) 
        {
                super(op, in1, in2, out, opcode, str);
                input3=in3;
        }
+       
        public static PlusMultCPInstruction parseInstruction(String str)
        {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
@@ -43,14 +45,11 @@ public class PlusMultCPInstruction extends 
ArithmeticBinaryCPInstruction {
                CPOperand operand2 = new CPOperand(parts[3]); //put the second 
matrix (parts[3]) in Operand2 to make using Binary matrix operations easier
                CPOperand operand3 = new CPOperand(parts[2]); 
                CPOperand outOperand = new CPOperand(parts[4]);
-               BinaryOperator bOperator = null;
-               if(opcode.equals("+*"))
-                       bOperator = new BinaryOperator(new PlusMultiply());
-               else if (opcode.equals("-*"))
-                       bOperator = new BinaryOperator(new MinusMultiply());
+               BinaryOperator bOperator = new 
BinaryOperator(opcode.equals("+*") ? 
+                               
PlusMultiply.getPlusMultiplyFnObject():MinusMultiply.getMinusMultiplyFnObject());
                return new PlusMultCPInstruction(bOperator,operand1, operand2, 
operand3, outOperand, opcode,str);
-               
        }
+       
        @Override
        public void processInstruction( ExecutionContext ec )
                throws DMLRuntimeException
@@ -60,10 +59,10 @@ public class PlusMultCPInstruction extends 
ArithmeticBinaryCPInstruction {
                //get all the inputs
                MatrixBlock matrix1 = ec.getMatrixInput(input1.getName());
                MatrixBlock matrix2 = ec.getMatrixInput(input2.getName());
-               ScalarObject lambda = ec.getScalarInput(input3.getName(), 
input3.getValueType(), input3.isLiteral()); 
+               ScalarObject scalar = ec.getScalarInput(input3.getName(), 
input3.getValueType(), input3.isLiteral()); 
                
                //execution
-               ((ValueFunctionWithConstant) 
((BinaryOperator)_optr).fn).setConstant(lambda.getDoubleValue());
+               ((ValueFunctionWithConstant) 
((BinaryOperator)_optr).fn).setConstant(scalar.getDoubleValue());
                MatrixBlock out = (MatrixBlock) 
matrix1.binaryOperations((BinaryOperator) _optr, matrix2, new MatrixBlock());
                
                //release the matrices

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java
index ea47e96..62762c1 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java
@@ -31,7 +31,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator;
 public abstract class MRInstruction extends Instruction 
 {
        
-       public enum MRINSTRUCTION_TYPE { INVALID, Append, Aggregate, 
ArithmeticBinary, AggregateBinary, AggregateUnary, 
+       public enum MRINSTRUCTION_TYPE { INVALID, Append, Aggregate, 
ArithmeticBinary, ArithmeticBinary2, AggregateBinary, AggregateUnary, 
                Rand, Seq, CSVReblock, CSVWrite, Transform,
                Reblock, Reorg, Replicate, Unary, CombineBinary, CombineUnary, 
CombineTernary, PickByCount, Partition,
                Ternary, Quaternary, CM_N_COV, Combine, MapGroupedAggregate, 
GroupedAggregate, RangeReIndex, ZeroOut, MMTSJ, PMMJ, MatrixReshape, 
ParameterizedBuiltin, Sort, MapMultChain,

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java
new file mode 100644
index 0000000..95ae817
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.mr;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixValue;
+import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
+import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+
+
+public class PlusMultInstruction extends BinaryInstruction 
+{
+       public PlusMultInstruction(Operator op, byte in1, byte in2, byte out, 
String istr) {
+               super(op, in1, in2, out, istr);
+       }
+       
+       /**
+        * 
+        * @param str
+        * @return
+        * @throws DMLRuntimeException
+        */
+       public static PlusMultInstruction parseInstruction ( String str ) 
+               throws DMLRuntimeException 
+       {       
+               InstructionUtils.checkNumFields ( str, 4 );
+               
+               String[] parts = InstructionUtils.getInstructionParts ( str );
+               String opcode = parts[0];
+               byte in1 = Byte.parseByte(parts[1]);
+               double scalar = Double.parseDouble(parts[2]);
+               byte in2 = Byte.parseByte(parts[3]);
+               byte out = Byte.parseByte(parts[4]);
+               
+               BinaryOperator bop = 
InstructionUtils.parseBinaryOperator(opcode);
+               ((ValueFunctionWithConstant) bop.fn).setConstant(scalar);
+               return new PlusMultInstruction(bop, in1, in2, out, str);
+       }
+       
+       @Override
+       public void processInstruction(Class<? extends MatrixValue> valueClass, 
CachedValueMap cachedValues, 
+                       IndexedMatrixValue tempValue, IndexedMatrixValue 
zeroInput, int blockRowFactor, int blockColFactor)
+               throws DMLRuntimeException 
+       {
+               //default binary mr instruction execution (custom logic encoded 
in operator)
+               super.processInstruction(valueClass, cachedValues, tempValue, 
zeroInput, blockRowFactor, blockColFactor);
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
index 4b73679..c93ed0a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
@@ -44,6 +44,7 @@ public class PlusMultSPInstruction extends  
ArithmeticBinarySPInstruction
                        throw new DMLRuntimeException("Unknown opcode in 
PlusMultSPInstruction: " + toString());
                }               
        }
+       
        public static PlusMultSPInstruction parseInstruction(String str) throws 
DMLRuntimeException
        {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
@@ -52,15 +53,11 @@ public class PlusMultSPInstruction extends  
ArithmeticBinarySPInstruction
                CPOperand operand2 = new CPOperand(parts[3]);   //put the 
second matrix (parts[3]) in Operand2 to make using Binary matrix operations 
easier
                CPOperand operand3 = new CPOperand(parts[2]);
                CPOperand outOperand = new CPOperand(parts[4]);
-               BinaryOperator bOperator = null;
-               if(opcode.equals("+*"))
-                       bOperator = new BinaryOperator(new PlusMultiply());
-               else if (opcode.equals("-*"))
-                       bOperator = new BinaryOperator(new MinusMultiply());
+               BinaryOperator bOperator = new 
BinaryOperator(opcode.equals("+*") ? 
+                               
PlusMultiply.getPlusMultiplyFnObject():MinusMultiply.getMinusMultiplyFnObject());
                return new PlusMultSPInstruction(bOperator,operand1, operand2, 
operand3, outOperand, opcode,str);       
        }
        
-       
        @Override
        public void processInstruction(ExecutionContext ec) 
                throws DMLRuntimeException
@@ -74,5 +71,4 @@ public class PlusMultSPInstruction extends  
ArithmeticBinarySPInstruction
                super.processMatrixMatrixBinaryInstruction(sec);
        
        }
-
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
index 7fec6b0..890a3b2 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
@@ -46,8 +46,6 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
        private static final String TEST_DIR = "functions/misc/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteFuseBinaryOpChainTest.class.getSimpleName() + "/";
        
-       //private static final int rows = 1234;
-       //private static final int cols = 567;
        private static final double eps = Math.pow(10, -10);
        
        @Override
@@ -58,44 +56,64 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
        }
        
        @Test
-       public void testFuseBinaryPlusNoRewrite() {
+       public void testFuseBinaryPlusNoRewriteCP() {
                testFuseBinaryChain( TEST_NAME1, false, ExecType.CP );
        }
        
        
        @Test
-       public void testFuseBinaryPlusRewrite() {
+       public void testFuseBinaryPlusRewriteCP() {
                testFuseBinaryChain( TEST_NAME1, true, ExecType.CP);
        }
        
        @Test
-       public void testFuseBinaryMinusNoRewrite() {
+       public void testFuseBinaryMinusNoRewriteCP() {
                testFuseBinaryChain( TEST_NAME2, false, ExecType.CP );
        }
        
        @Test
-       public void testFuseBinaryMinusRewrite() {
+       public void testFuseBinaryMinusRewriteCP() {
                testFuseBinaryChain( TEST_NAME2, true, ExecType.CP );
        }
        
        @Test
-       public void testSpFuseBinaryPlusNoRewrite() {
+       public void testFuseBinaryPlusNoRewriteSP() {
                testFuseBinaryChain( TEST_NAME1, false, ExecType.SPARK );
        }
        
        @Test
-       public void testSpFuseBinaryPlusRewrite() {
+       public void testFuseBinaryPlusRewriteSP() {
                testFuseBinaryChain( TEST_NAME1, true, ExecType.SPARK );
        }
        
        @Test
-       public void testSpFuseBinaryMinusNoRewrite() {
-               testFuseBinaryChain( TEST_NAME2, false, ExecType.SPARK  );
+       public void testFuseBinaryMinusNoRewriteSP() {
+               testFuseBinaryChain( TEST_NAME2, false, ExecType.SPARK );
        }
        
        @Test
-       public void testSpFuseBinaryMinusRewrite() {
-               testFuseBinaryChain( TEST_NAME2, true, ExecType.SPARK  );
+       public void testFuseBinaryMinusRewriteSP() {
+               testFuseBinaryChain( TEST_NAME2, true, ExecType.SPARK );
+       }
+       
+       @Test
+       public void testFuseBinaryPlusNoRewriteMR() {
+               testFuseBinaryChain( TEST_NAME1, false, ExecType.MR );
+       }
+       
+       @Test
+       public void testFuseBinaryPlusRewriteMR() {
+               testFuseBinaryChain( TEST_NAME1, true, ExecType.MR );
+       }
+       
+       @Test
+       public void testFuseBinaryMinusNoRewriteMR() {
+               testFuseBinaryChain( TEST_NAME2, false, ExecType.MR );
+       }
+       
+       @Test
+       public void testFuseBinaryMinusRewriteMR() {
+               testFuseBinaryChain( TEST_NAME2, true, ExecType.MR );
        }
        
        
@@ -111,7 +129,7 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                switch( instType ){
                        case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
                        case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
-                       default: rtplatform = RUNTIME_PLATFORM.SINGLE_NODE; 
break;
+                       default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
                }
                
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
@@ -142,7 +160,7 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                        Assert.assertTrue(TestUtils.compareMatrices(dmlfile, 
rfile, eps, "Stat-DML", "Stat-R"));
                        
                        //check for applies rewrites
-                       if( rewrites ) {
+                       if( rewrites && instType!=ExecType.MR  ) {
                                String prefix = (instType==ExecType.SPARK) ? 
Instruction.SP_INST_PREFIX  : "";
                                Assert.assertTrue("Rewrite not 
applied.",Statistics.getCPHeavyHitterOpCodes()
                                                
.contains(testname.equals(TEST_NAME1) ? prefix+"+*" : prefix+"-*" ));

Reply via email to