Repository: incubator-systemml Updated Branches: refs/heads/master c22f239e3 -> b584aecf6
[SYSTEMML-766] Extended axpy compiler/runtime support (mr, hybrid) Incl fix rewrite 'fused binary operation chain' axpy. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b584aecf Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b584aecf Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b584aecf Branch: refs/heads/master Commit: b584aecf6b3a1eb96ff83b78cc3ad7c7c6d15baa Parents: c22f239 Author: Matthias Boehm <[email protected]> Authored: Mon Jul 18 19:46:55 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Tue Jul 19 10:58:49 2016 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/BinaryOp.java | 2 +- .../java/org/apache/sysml/hops/TernaryOp.java | 47 ++++++++++++- .../RewriteAlgebraicSimplificationStatic.java | 12 ++-- .../java/org/apache/sysml/lops/PlusMult.java | 58 ++++++++++++---- .../runtime/functionobjects/MinusMultiply.java | 18 +++-- .../runtime/functionobjects/PlusMultiply.java | 18 +++-- .../ValueFunctionWithConstant.java | 6 +- .../runtime/instructions/InstructionUtils.java | 6 ++ .../instructions/MRInstructionParser.java | 7 ++ .../instructions/cp/PlusMultCPInstruction.java | 17 +++-- .../runtime/instructions/mr/MRInstruction.java | 2 +- .../instructions/mr/PlusMultInstruction.java | 69 ++++++++++++++++++++ .../spark/PlusMultSPInstruction.java | 12 ++-- .../misc/RewriteFuseBinaryOpChainTest.java | 46 +++++++++---- 14 files changed, 249 insertions(+), 71 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/hops/BinaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java b/src/main/java/org/apache/sysml/hops/BinaryOp.java index 65e9232..edc327d 100644 --- a/src/main/java/org/apache/sysml/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java @@ -1335,7 +1335,7 @@ public class BinaryOp extends Hop * @param right * @return */ - private static boolean requiresReplication( Hop left, Hop right ) + public static boolean requiresReplication( Hop left, Hop right ) { return (!(left.getDim2()>=1 && right.getDim2()>=1) //cols of any input unknown ||(left.getDim2() > 1 && right.getDim2()==1 && left.getDim2()>=left.getColsInBlock() ) //col MV and more than 1 block http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/hops/TernaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/TernaryOp.java b/src/main/java/org/apache/sysml/hops/TernaryOp.java index 72e7624..626ad2c 100644 --- a/src/main/java/org/apache/sysml/hops/TernaryOp.java +++ b/src/main/java/org/apache/sysml/hops/TernaryOp.java @@ -31,6 +31,7 @@ import org.apache.sysml.lops.Lop; import org.apache.sysml.lops.LopsException; import org.apache.sysml.lops.PickByCount; import org.apache.sysml.lops.PlusMult; +import org.apache.sysml.lops.RepMat; import org.apache.sysml.lops.SortKeys; import org.apache.sysml.lops.Ternary; import org.apache.sysml.lops.UnaryCP; @@ -627,16 +628,58 @@ public class TernaryOp extends Hop } } } - private void constructLopsPlusMult() throws HopsException, LopsException { + + /** + * + * @throws HopsException + * @throws LopsException + */ + private void constructLopsPlusMult() + throws HopsException, LopsException + { if ( _op != OpOp3.PLUS_MULT && _op != OpOp3.MINUS_MULT ) throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.PLUS_MULT + " or" + OpOp3.MINUS_MULT); ExecType et = optFindExecType(); - PlusMult plusmult = new PlusMult(getInput().get(0).constructLops(),getInput().get(1).constructLops(),getInput().get(2).constructLops(), _op, getDataType(),getValueType(), et ); + PlusMult plusmult = null; + + if( et == ExecType.CP || et == ExecType.SPARK ) { + plusmult = new PlusMult( + getInput().get(0).constructLops(), + getInput().get(1).constructLops(), + getInput().get(2).constructLops(), + _op, getDataType(),getValueType(), et ); + } + else { //MR + Hop left = getInput().get(0); + Hop right = getInput().get(2); + boolean requiresRep = BinaryOp.requiresReplication(left, right); + + Lop rightLop = right.constructLops(); + if( requiresRep ) { + Lop offset = createOffsetLop(left, (right.getDim2()<=1)); //ncol of left input (determines num replicates) + rightLop = new RepMat(rightLop, offset, (right.getDim2()<=1), right.getDataType(), right.getValueType()); + setOutputDimensions(rightLop); + setLineNumbers(rightLop); + } + + Group group1 = new Group(left.constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType()); + setLineNumbers(group1); + setOutputDimensions(group1); + + Group group2 = new Group(rightLop, Group.OperationTypes.Sort, getDataType(), getValueType()); + setLineNumbers(group2); + setOutputDimensions(group2); + + plusmult = new PlusMult(group1, getInput().get(1).constructLops(), + group2, _op, getDataType(),getValueType(), et ); + } + setOutputDimensions(plusmult); setLineNumbers(plusmult); setLops(plusmult); } + @Override public String getOpString() { String s = new String(""); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 816b55a..9ef2c05 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -25,8 +25,6 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.sysml.api.DMLScript; -import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; @@ -34,7 +32,6 @@ import org.apache.sysml.hops.DataGenOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.IndexingOp; -import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.AggOp; @@ -1920,10 +1917,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule */ private Hop fuseBinaryOperationChain(Hop parent, Hop hi, int pos) { //pattern: X + lamda*Y -> X +* lambda Y - if( hi instanceof BinaryOp - && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) - && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1) instanceof BinaryOp - && (DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE || OptimizerUtils.isSparkExecutionMode()) ) + if( hi instanceof BinaryOp + && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) + && hi.getInput().get(0).getDataType()==DataType.MATRIX + && hi.getInput().get(1) instanceof BinaryOp + && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.MULT ) { //Check that the inner binary Op is a product of Scalar times Matrix or viceversa Hop innerBinaryOp = hi.getInput().get(1); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/lops/PlusMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/PlusMult.java b/src/main/java/org/apache/sysml/lops/PlusMult.java index 65e6440..8ee8625 100644 --- a/src/main/java/org/apache/sysml/lops/PlusMult.java +++ b/src/main/java/org/apache/sysml/lops/PlusMult.java @@ -34,9 +34,9 @@ public class PlusMult extends Lop { private void init(Lop input1, Lop input2, Lop input3, ExecType et) { - this.addInput(input1); - this.addInput(input2); - this.addInput(input3); + addInput(input1); + addInput(input2); + addInput(input3); input1.addOutput(this); input2.addOutput(this); input3.addOutput(this); @@ -47,7 +47,13 @@ public class PlusMult extends Lop if ( et == ExecType.CP || et == ExecType.SPARK ){ lps.addCompatibility(JobType.INVALID); - this.lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob ); + lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob ); + } + else if( et == ExecType.MR ) { + lps.addCompatibility(JobType.GMR); + lps.addCompatibility(JobType.DATAGEN); + lps.addCompatibility(JobType.REBLOCK); + lps.setProperties( inputs, et, ExecLocation.Reduce, breaksAlignment, aligner, definesMRJob ); } } @@ -60,13 +66,15 @@ public class PlusMult extends Lop @Override public String toString() { - return "Operation = PlusMult"; } + public String getOpString() { + return (type==Lop.Type.PlusMult) ? "+*" : "-*"; + } /** - * Function to generate CP Sum of a matrix with another matrix multiplied by Scalar. + * Function to generate CP/Spark axpy. * * input1: matrix1 * input2: Scalar @@ -75,23 +83,51 @@ public class PlusMult extends Lop @Override public String getInstructions(String input1, String input2, String input3, String output) { StringBuilder sb = new StringBuilder(); + sb.append( getExecType() ); sb.append( OPERAND_DELIMITOR ); - if(type==Lop.Type.PlusMult) - sb.append( "+*" ); - else - sb.append( "-*" ); + + sb.append(getOpString()); sb.append( OPERAND_DELIMITOR ); // Matrix1 sb.append( getInputs().get(0).prepInputOperand(input1) ); sb.append( OPERAND_DELIMITOR ); - // Matrix2 + // Scalar sb.append( getInputs().get(1).prepScalarInputOperand(input2) ); sb.append( OPERAND_DELIMITOR ); + // Matrix2 + sb.append( getInputs().get(2).prepInputOperand(input3)); + sb.append( OPERAND_DELIMITOR ); + + sb.append( prepOutputOperand(output)); + + return sb.toString(); + } + + @Override + public String getInstructions(int input1, int input2, int input3, int output) + throws LopsException + { + StringBuilder sb = new StringBuilder(); + + sb.append( getExecType() ); + sb.append( OPERAND_DELIMITOR ); + + sb.append(getOpString()); + sb.append( OPERAND_DELIMITOR ); + + // Matrix1 + sb.append( getInputs().get(0).prepInputOperand(input1) ); + sb.append( OPERAND_DELIMITOR ); + // Scalar + sb.append( getInputs().get(1).prepScalarLabel() ); + sb.append( OPERAND_DELIMITOR ); + + // Matrix2 sb.append( getInputs().get(2).prepInputOperand(input3)); sb.append( OPERAND_DELIMITOR ); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java index ee7a8fb..2036cf6 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java @@ -23,21 +23,25 @@ import java.io.Serializable; public class MinusMultiply extends ValueFunctionWithConstant implements Serializable { - private static final long serialVersionUID = 2801982061205871665L; - public MinusMultiply() { + private MinusMultiply() { // nothing to do here } + + public static MinusMultiply getMinusMultiplyFnObject() { + //create new object as the constant is modified and hence + //cannot be shared across multiple threads (e.g., in parfor) + return new MinusMultiply(); + } + public Object clone() throws CloneNotSupportedException { // cloning is not supported for singleton classes throw new CloneNotSupportedException(); } + @Override - public double execute(double in1, double in2) - { - return in1 - _constant*in2; - + public double execute(double in1, double in2) { + return in1 - _constant*in2; } - } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java index 87eb47b..2a1eea0 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java @@ -23,21 +23,25 @@ import java.io.Serializable; public class PlusMultiply extends ValueFunctionWithConstant implements Serializable { - private static final long serialVersionUID = 2801982061205871665L; - public PlusMultiply() { + private PlusMultiply() { // nothing to do here } + + public static PlusMultiply getPlusMultiplyFnObject() { + //create new object as the constant is modified and hence + //cannot be shared across multiple threads (e.g., in parfor) + return new PlusMultiply(); + } + public Object clone() throws CloneNotSupportedException { // cloning is not supported for singleton classes throw new CloneNotSupportedException(); } + @Override - public double execute(double in1, double in2) - { - return in1 + _constant*in2; - + public double execute(double in1, double in2) { + return in1 + _constant*in2; } - } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java b/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java index 2820875..f23c29a 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java @@ -26,13 +26,11 @@ public abstract class ValueFunctionWithConstant extends ValueFunction implements private static final long serialVersionUID = -4985988545393861058L; protected double _constant; - public void setConstant(double constant) - { + public void setConstant(double constant) { _constant = constant; } - public double getConstant() - { + public double getConstant() { return _constant; } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java index d2f477d..a3a7c08 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java @@ -56,6 +56,7 @@ import org.apache.sysml.runtime.functionobjects.LessThanEquals; import org.apache.sysml.runtime.functionobjects.Mean; import org.apache.sysml.runtime.functionobjects.Minus; import org.apache.sysml.runtime.functionobjects.Minus1Multiply; +import org.apache.sysml.runtime.functionobjects.MinusMultiply; import org.apache.sysml.runtime.functionobjects.MinusNz; import org.apache.sysml.runtime.functionobjects.Modulus; import org.apache.sysml.runtime.functionobjects.Multiply; @@ -63,6 +64,7 @@ import org.apache.sysml.runtime.functionobjects.Multiply2; import org.apache.sysml.runtime.functionobjects.NotEquals; import org.apache.sysml.runtime.functionobjects.Or; import org.apache.sysml.runtime.functionobjects.Plus; +import org.apache.sysml.runtime.functionobjects.PlusMultiply; import org.apache.sysml.runtime.functionobjects.Power; import org.apache.sysml.runtime.functionobjects.Power2; import org.apache.sysml.runtime.functionobjects.ReduceAll; @@ -626,6 +628,10 @@ public class InstructionUtils return new BinaryOperator(Builtin.getBuiltinFnObject("max")); else if ( opcode.equalsIgnoreCase("min") ) return new BinaryOperator(Builtin.getBuiltinFnObject("min")); + else if ( opcode.equalsIgnoreCase("+*") ) + return new BinaryOperator(PlusMultiply.getPlusMultiplyFnObject()); + else if ( opcode.equalsIgnoreCase("-*") ) + return new BinaryOperator(MinusMultiply.getMinusMultiplyFnObject()); throw new DMLRuntimeException("Unknown binary opcode " + opcode); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java index 894e7e9..0b9cb7d 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java @@ -63,6 +63,7 @@ import org.apache.sysml.runtime.instructions.mr.MatrixReshapeMRInstruction; import org.apache.sysml.runtime.instructions.mr.PMMJMRInstruction; import org.apache.sysml.runtime.instructions.mr.ParameterizedBuiltinMRInstruction; import org.apache.sysml.runtime.instructions.mr.PickByCountInstruction; +import org.apache.sysml.runtime.instructions.mr.PlusMultInstruction; import org.apache.sysml.runtime.instructions.mr.QuaternaryInstruction; import org.apache.sysml.runtime.instructions.mr.RandInstruction; import org.apache.sysml.runtime.instructions.mr.RangeBasedReIndexInstruction; @@ -182,6 +183,9 @@ public class MRInstructionParser extends InstructionParser String2MRInstructionType.put( "^2" , MRINSTRUCTION_TYPE.ArithmeticBinary); //special ^ case String2MRInstructionType.put( "*2" , MRINSTRUCTION_TYPE.ArithmeticBinary); //special * case String2MRInstructionType.put( "-nz" , MRINSTRUCTION_TYPE.ArithmeticBinary); //special - case + String2MRInstructionType.put( "+*" , MRINSTRUCTION_TYPE.ArithmeticBinary2); + String2MRInstructionType.put( "-*" , MRINSTRUCTION_TYPE.ArithmeticBinary2); + String2MRInstructionType.put( "map+" , MRINSTRUCTION_TYPE.ArithmeticBinary); String2MRInstructionType.put( "map-" , MRINSTRUCTION_TYPE.ArithmeticBinary); String2MRInstructionType.put( "map*" , MRINSTRUCTION_TYPE.ArithmeticBinary); @@ -333,6 +337,9 @@ public class MRInstructionParser extends InstructionParser } } + case ArithmeticBinary2: + return PlusMultInstruction.parseInstruction(str); + case AggregateBinary: return AggregateBinaryInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java index 212e0b7..12bc465 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java @@ -28,13 +28,15 @@ import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.operators.BinaryOperator; -public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction { +public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction +{ public PlusMultCPInstruction(BinaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) { super(op, in1, in2, out, opcode, str); input3=in3; } + public static PlusMultCPInstruction parseInstruction(String str) { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); @@ -43,14 +45,11 @@ public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction { CPOperand operand2 = new CPOperand(parts[3]); //put the second matrix (parts[3]) in Operand2 to make using Binary matrix operations easier CPOperand operand3 = new CPOperand(parts[2]); CPOperand outOperand = new CPOperand(parts[4]); - BinaryOperator bOperator = null; - if(opcode.equals("+*")) - bOperator = new BinaryOperator(new PlusMultiply()); - else if (opcode.equals("-*")) - bOperator = new BinaryOperator(new MinusMultiply()); + BinaryOperator bOperator = new BinaryOperator(opcode.equals("+*") ? + PlusMultiply.getPlusMultiplyFnObject():MinusMultiply.getMinusMultiplyFnObject()); return new PlusMultCPInstruction(bOperator,operand1, operand2, operand3, outOperand, opcode,str); - } + @Override public void processInstruction( ExecutionContext ec ) throws DMLRuntimeException @@ -60,10 +59,10 @@ public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction { //get all the inputs MatrixBlock matrix1 = ec.getMatrixInput(input1.getName()); MatrixBlock matrix2 = ec.getMatrixInput(input2.getName()); - ScalarObject lambda = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()); + ScalarObject scalar = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()); //execution - ((ValueFunctionWithConstant) ((BinaryOperator)_optr).fn).setConstant(lambda.getDoubleValue()); + ((ValueFunctionWithConstant) ((BinaryOperator)_optr).fn).setConstant(scalar.getDoubleValue()); MatrixBlock out = (MatrixBlock) matrix1.binaryOperations((BinaryOperator) _optr, matrix2, new MatrixBlock()); //release the matrices http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java index ea47e96..62762c1 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java @@ -31,7 +31,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator; public abstract class MRInstruction extends Instruction { - public enum MRINSTRUCTION_TYPE { INVALID, Append, Aggregate, ArithmeticBinary, AggregateBinary, AggregateUnary, + public enum MRINSTRUCTION_TYPE { INVALID, Append, Aggregate, ArithmeticBinary, ArithmeticBinary2, AggregateBinary, AggregateUnary, Rand, Seq, CSVReblock, CSVWrite, Transform, Reblock, Reorg, Replicate, Unary, CombineBinary, CombineUnary, CombineTernary, PickByCount, Partition, Ternary, Quaternary, CM_N_COV, Combine, MapGroupedAggregate, GroupedAggregate, RangeReIndex, ZeroOut, MMTSJ, PMMJ, MatrixReshape, ParameterizedBuiltin, Sort, MapMultChain, http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java new file mode 100644 index 0000000..95ae817 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.mr; + +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.MatrixValue; +import org.apache.sysml.runtime.matrix.mapred.CachedValueMap; +import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; +import org.apache.sysml.runtime.matrix.operators.BinaryOperator; +import org.apache.sysml.runtime.matrix.operators.Operator; + + +public class PlusMultInstruction extends BinaryInstruction +{ + public PlusMultInstruction(Operator op, byte in1, byte in2, byte out, String istr) { + super(op, in1, in2, out, istr); + } + + /** + * + * @param str + * @return + * @throws DMLRuntimeException + */ + public static PlusMultInstruction parseInstruction ( String str ) + throws DMLRuntimeException + { + InstructionUtils.checkNumFields ( str, 4 ); + + String[] parts = InstructionUtils.getInstructionParts ( str ); + String opcode = parts[0]; + byte in1 = Byte.parseByte(parts[1]); + double scalar = Double.parseDouble(parts[2]); + byte in2 = Byte.parseByte(parts[3]); + byte out = Byte.parseByte(parts[4]); + + BinaryOperator bop = InstructionUtils.parseBinaryOperator(opcode); + ((ValueFunctionWithConstant) bop.fn).setConstant(scalar); + return new PlusMultInstruction(bop, in1, in2, out, str); + } + + @Override + public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, + IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) + throws DMLRuntimeException + { + //default binary mr instruction execution (custom logic encoded in operator) + super.processInstruction(valueClass, cachedValues, tempValue, zeroInput, blockRowFactor, blockColFactor); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java index 4b73679..c93ed0a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java @@ -44,6 +44,7 @@ public class PlusMultSPInstruction extends ArithmeticBinarySPInstruction throw new DMLRuntimeException("Unknown opcode in PlusMultSPInstruction: " + toString()); } } + public static PlusMultSPInstruction parseInstruction(String str) throws DMLRuntimeException { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); @@ -52,15 +53,11 @@ public class PlusMultSPInstruction extends ArithmeticBinarySPInstruction CPOperand operand2 = new CPOperand(parts[3]); //put the second matrix (parts[3]) in Operand2 to make using Binary matrix operations easier CPOperand operand3 = new CPOperand(parts[2]); CPOperand outOperand = new CPOperand(parts[4]); - BinaryOperator bOperator = null; - if(opcode.equals("+*")) - bOperator = new BinaryOperator(new PlusMultiply()); - else if (opcode.equals("-*")) - bOperator = new BinaryOperator(new MinusMultiply()); + BinaryOperator bOperator = new BinaryOperator(opcode.equals("+*") ? + PlusMultiply.getPlusMultiplyFnObject():MinusMultiply.getMinusMultiplyFnObject()); return new PlusMultSPInstruction(bOperator,operand1, operand2, operand3, outOperand, opcode,str); } - @Override public void processInstruction(ExecutionContext ec) throws DMLRuntimeException @@ -74,5 +71,4 @@ public class PlusMultSPInstruction extends ArithmeticBinarySPInstruction super.processMatrixMatrixBinaryInstruction(sec); } - -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java index 7fec6b0..890a3b2 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java @@ -46,8 +46,6 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFuseBinaryOpChainTest.class.getSimpleName() + "/"; - //private static final int rows = 1234; - //private static final int cols = 567; private static final double eps = Math.pow(10, -10); @Override @@ -58,44 +56,64 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase } @Test - public void testFuseBinaryPlusNoRewrite() { + public void testFuseBinaryPlusNoRewriteCP() { testFuseBinaryChain( TEST_NAME1, false, ExecType.CP ); } @Test - public void testFuseBinaryPlusRewrite() { + public void testFuseBinaryPlusRewriteCP() { testFuseBinaryChain( TEST_NAME1, true, ExecType.CP); } @Test - public void testFuseBinaryMinusNoRewrite() { + public void testFuseBinaryMinusNoRewriteCP() { testFuseBinaryChain( TEST_NAME2, false, ExecType.CP ); } @Test - public void testFuseBinaryMinusRewrite() { + public void testFuseBinaryMinusRewriteCP() { testFuseBinaryChain( TEST_NAME2, true, ExecType.CP ); } @Test - public void testSpFuseBinaryPlusNoRewrite() { + public void testFuseBinaryPlusNoRewriteSP() { testFuseBinaryChain( TEST_NAME1, false, ExecType.SPARK ); } @Test - public void testSpFuseBinaryPlusRewrite() { + public void testFuseBinaryPlusRewriteSP() { testFuseBinaryChain( TEST_NAME1, true, ExecType.SPARK ); } @Test - public void testSpFuseBinaryMinusNoRewrite() { - testFuseBinaryChain( TEST_NAME2, false, ExecType.SPARK ); + public void testFuseBinaryMinusNoRewriteSP() { + testFuseBinaryChain( TEST_NAME2, false, ExecType.SPARK ); } @Test - public void testSpFuseBinaryMinusRewrite() { - testFuseBinaryChain( TEST_NAME2, true, ExecType.SPARK ); + public void testFuseBinaryMinusRewriteSP() { + testFuseBinaryChain( TEST_NAME2, true, ExecType.SPARK ); + } + + @Test + public void testFuseBinaryPlusNoRewriteMR() { + testFuseBinaryChain( TEST_NAME1, false, ExecType.MR ); + } + + @Test + public void testFuseBinaryPlusRewriteMR() { + testFuseBinaryChain( TEST_NAME1, true, ExecType.MR ); + } + + @Test + public void testFuseBinaryMinusNoRewriteMR() { + testFuseBinaryChain( TEST_NAME2, false, ExecType.MR ); + } + + @Test + public void testFuseBinaryMinusRewriteMR() { + testFuseBinaryChain( TEST_NAME2, true, ExecType.MR ); } @@ -111,7 +129,7 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase switch( instType ){ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; - default: rtplatform = RUNTIME_PLATFORM.SINGLE_NODE; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID; break; } boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; @@ -142,7 +160,7 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase Assert.assertTrue(TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R")); //check for applies rewrites - if( rewrites ) { + if( rewrites && instType!=ExecType.MR ) { String prefix = (instType==ExecType.SPARK) ? Instruction.SP_INST_PREFIX : ""; Assert.assertTrue("Rewrite not applied.",Statistics.getCPHeavyHitterOpCodes() .contains(testname.equals(TEST_NAME1) ? prefix+"+*" : prefix+"-*" ));
