[SYSTEMML-2055] New ifelse builtin function, ternary inst framework This patch introduces a vectorized D = ifelse(A, B, C) builtin function which output is defined as Dij = (Aij!=0) ? Bij : Cij for matrix inputs but it also supports all combinations of matrix-scalar operations. Similar to the binary cellwise instruction framework, this patch also introduces a new ternary instruction framework for all backends. A current restriction is that all matrix inputs need to be of equal shape - in the future we will extend this to matrix-vector operations as well. This new instruction framework now also hosts the exiting physical operators for plus_mult and minus_mult (axpy).
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ce9e42fe Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ce9e42fe Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ce9e42fe Branch: refs/heads/master Commit: ce9e42fefb6277b2b8c04db65658e955c7444e2d Parents: 5743810 Author: Matthias Boehm <[email protected]> Authored: Wed Jan 17 17:52:50 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Thu Jan 18 11:20:08 2018 -0800 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/Hop.java | 11 +- .../java/org/apache/sysml/hops/TernaryOp.java | 105 +++--- .../hops/cost/CostEstimatorStaticRuntime.java | 19 +- .../sysml/hops/rewrite/HopRewriteUtils.java | 8 + .../RewriteAlgebraicSimplificationDynamic.java | 10 +- src/main/java/org/apache/sysml/lops/Ctable.java | 2 +- src/main/java/org/apache/sysml/lops/Lop.java | 2 +- .../java/org/apache/sysml/lops/PlusMult.java | 118 ------- .../java/org/apache/sysml/lops/Ternary.java | 115 ++++++ .../java/org/apache/sysml/lops/compile/Dag.java | 6 +- .../sysml/parser/BuiltinFunctionExpression.java | 42 ++- .../org/apache/sysml/parser/DMLTranslator.java | 7 +- .../org/apache/sysml/parser/Expression.java | 1 + .../sysml/runtime/functionobjects/IfElse.java | 44 +++ .../runtime/functionobjects/MinusMultiply.java | 21 +- .../runtime/functionobjects/PlusMultiply.java | 23 +- .../functionobjects/TernaryValueFunction.java | 32 ++ .../ValueFunctionWithConstant.java | 36 -- .../instructions/CPInstructionParser.java | 30 +- .../runtime/instructions/InstructionUtils.java | 11 +- .../instructions/MRInstructionParser.java | 13 +- .../instructions/SPInstructionParser.java | 20 +- .../runtime/instructions/cp/CPInstruction.java | 2 +- .../runtime/instructions/cp/CPOperand.java | 11 + .../instructions/cp/PlusMultCPInstruction.java | 72 ---- .../instructions/cp/ScalarObjectFactory.java | 10 + .../instructions/cp/TernaryCPInstruction.java | 81 +++++ .../runtime/instructions/mr/MRInstruction.java | 2 +- .../instructions/mr/PlusMultInstruction.java | 61 ---- .../instructions/mr/TernaryInstruction.java | 105 ++++++ .../spark/PlusMultSPInstruction.java | 71 ---- .../instructions/spark/SPInstruction.java | 2 +- .../spark/TernarySPInstruction.java | 200 +++++++++++ .../runtime/matrix/MatrixCharacteristics.java | 4 + .../sysml/runtime/matrix/data/MatrixBlock.java | 49 +++ .../matrix/operators/BinaryOperator.java | 3 - .../matrix/operators/TernaryOperator.java | 46 +++ .../functions/ternary/FullIfElseTest.java | 349 +++++++++++++++++++ .../scripts/functions/ternary/TernaryIfElse.R | 45 +++ .../scripts/functions/ternary/TernaryIfElse.dml | 43 +++ .../functions/ternary/ZPackageSuite.java | 1 + 41 files changed, 1348 insertions(+), 485 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index ae8c559..3ed837e 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -40,6 +40,7 @@ import org.apache.sysml.lops.LopProperties.ExecType; import org.apache.sysml.lops.LopsException; import org.apache.sysml.lops.Nary; import org.apache.sysml.lops.ReBlock; +import org.apache.sysml.lops.Ternary; import org.apache.sysml.lops.Unary; import org.apache.sysml.lops.UnaryCP; import org.apache.sysml.parser.Expression.DataType; @@ -1074,7 +1075,7 @@ public abstract class Hop implements ParseInfo // Operations that require 3 operands public enum OpOp3 { - QUANTILE, INTERQUANTILE, CTABLE, CENTRALMOMENT, COVARIANCE, PLUS_MULT, MINUS_MULT + QUANTILE, INTERQUANTILE, CTABLE, CENTRALMOMENT, COVARIANCE, PLUS_MULT, MINUS_MULT, IFELSE } // Operations that require 4 operands @@ -1349,6 +1350,14 @@ public abstract class Hop implements ParseInfo HopsOpOp1LopsUS.put(OpOp1.STOP, org.apache.sysml.lops.UnaryCP.OperationTypes.STOP); } + protected static final HashMap<OpOp3, Ternary.OperationType> HopsOpOp3Lops; + static { + HopsOpOp3Lops = new HashMap<>(); + HopsOpOp3Lops.put(OpOp3.PLUS_MULT, Ternary.OperationType.PLUS_MULT); + HopsOpOp3Lops.put(OpOp3.MINUS_MULT, Ternary.OperationType.MINUS_MULT); + HopsOpOp3Lops.put(OpOp3.IFELSE, Ternary.OperationType.IFELSE); + } + /** * Maps from a multiple (variable number of operands) Hop operation type to * the corresponding Lop operation type. This is called in the MultipleOp http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/hops/TernaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/TernaryOp.java b/src/main/java/org/apache/sysml/hops/TernaryOp.java index dc40982..49f67b1 100644 --- a/src/main/java/org/apache/sysml/hops/TernaryOp.java +++ b/src/main/java/org/apache/sysml/hops/TernaryOp.java @@ -31,8 +31,7 @@ import org.apache.sysml.lops.Group; import org.apache.sysml.lops.Lop; import org.apache.sysml.lops.LopsException; import org.apache.sysml.lops.PickByCount; -import org.apache.sysml.lops.PlusMult; -import org.apache.sysml.lops.RepMat; +import org.apache.sysml.lops.Ternary; import org.apache.sysml.lops.SortKeys; import org.apache.sysml.lops.Ctable; import org.apache.sysml.lops.UnaryCP; @@ -136,6 +135,7 @@ public class TernaryOp extends Hop case CTABLE: case INTERQUANTILE: case QUANTILE: + case IFELSE: return false; case MINUS_MULT: case PLUS_MULT: @@ -175,7 +175,8 @@ public class TernaryOp extends Hop case PLUS_MULT: case MINUS_MULT: - constructLopsPlusMult(); + case IFELSE: + constructLopsTernaryDefault(); break; default: @@ -186,10 +187,10 @@ public class TernaryOp extends Hop catch(LopsException e) { throw new HopsException(this.printErrorLocation() + "error constructing Lops for TernaryOp Hop " , e); } - + //add reblock/checkpoint lops if necessary constructAndSetLopsDataFlowProperties(); - + return getLops(); } @@ -643,45 +644,49 @@ public class TernaryOp extends Hop } } - private void constructLopsPlusMult() + private void constructLopsTernaryDefault() throws HopsException, LopsException { - if ( _op != OpOp3.PLUS_MULT && _op != OpOp3.MINUS_MULT ) - throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.PLUS_MULT + " or" + OpOp3.MINUS_MULT); - ExecType et = optFindExecType(); - PlusMult plusmult = null; + if( getInput().stream().allMatch(h -> h.getDataType().isScalar()) ) + et = ExecType.CP; //always CP for pure scalar operations + + Ternary plusmult = null; if( et == ExecType.CP || et == ExecType.SPARK || et == ExecType.GPU ) { - plusmult = new PlusMult( - getInput().get(0).constructLops(), - getInput().get(1).constructLops(), - getInput().get(2).constructLops(), - _op, getDataType(),getValueType(), et ); + plusmult = new Ternary(HopsOpOp3Lops.get(_op), + getInput().get(0).constructLops(), + getInput().get(1).constructLops(), + getInput().get(2).constructLops(), + getDataType(),getValueType(), et ); } else { //MR - Hop left = getInput().get(0); - Hop right = getInput().get(2); - boolean requiresRep = BinaryOp.requiresReplication(left, right); + Hop first = getInput().get(0); + Hop second = getInput().get(1); + Hop third = getInput().get(2); - Lop rightLop = right.constructLops(); - if( requiresRep ) { - Lop offset = createOffsetLop(left, (right.getDim2()<=1)); //ncol of left input (determines num replicates) - rightLop = new RepMat(rightLop, offset, (right.getDim2()<=1), right.getDataType(), right.getValueType()); - setOutputDimensions(rightLop); - setLineNumbers(rightLop); + Lop firstLop = first.constructLops(); + if( first.getDataType().isMatrix() ) { + firstLop = new Group(firstLop, Group.OperationTypes.Sort, getDataType(), getValueType()); + setLineNumbers(firstLop); + setOutputDimensions(firstLop); + } + Lop secondLop = second.constructLops(); + if( second.getDataType().isMatrix() ) { + secondLop = new Group(secondLop, Group.OperationTypes.Sort, getDataType(), getValueType()); + setLineNumbers(secondLop); + setOutputDimensions(secondLop); + } + Lop thirdLop = third.constructLops(); + if( third.getDataType().isMatrix() ) { + thirdLop = new Group(thirdLop, Group.OperationTypes.Sort, getDataType(), getValueType()); + setLineNumbers(thirdLop); + setOutputDimensions(thirdLop); } - - Group group1 = new Group(left.constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType()); - setLineNumbers(group1); - setOutputDimensions(group1); - - Group group2 = new Group(rightLop, Group.OperationTypes.Sort, getDataType(), getValueType()); - setLineNumbers(group2); - setOutputDimensions(group2); - plusmult = new PlusMult(group1, getInput().get(1).constructLops(), - group2, _op, getDataType(),getValueType(), et ); + plusmult = new Ternary(HopsOpOp3Lops.get(_op), + firstLop, secondLop, thirdLop, + getDataType(),getValueType(), et ); } setOutputDimensions(plusmult); @@ -697,8 +702,7 @@ public class TernaryOp extends Hop } @Override - public boolean allowsAllExecTypes() - { + public boolean allowsAllExecTypes() { return true; } @@ -722,7 +726,8 @@ public class TernaryOp extends Hop // Output is a vector of length = #of quantiles to be computed, and it is likely to be dense. return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0); case PLUS_MULT: - case MINUS_MULT: { + case MINUS_MULT: + case IFELSE: { if (isGPUEnabled()) { // For the GPU, the input is converted to dense sparsity = 1.0; @@ -805,6 +810,11 @@ public class TernaryOp extends Hop if( mc[2].dimsKnown() ) return new long[]{mc[2].getRows(), 1, mc[2].getRows()}; break; + case IFELSE: + for(MatrixCharacteristics lmc : mc) + if( lmc.dimsKnown() && lmc.getRows() > 0 ) //known matrix + return new long[]{lmc.getRows(), lmc.getCols(), -1}; + break; case PLUS_MULT: case MINUS_MULT: //compute back NNz @@ -827,20 +837,17 @@ public class TernaryOp extends Hop ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR; - if( _etypeForced != null ) - { + if( _etypeForced != null ) { _etype = _etypeForced; } else - { + { if ( OptimizerUtils.isMemoryBasedOptLevel() ) { _etype = findExecTypeByMemEstimate(); } else if ( (getInput().get(0).areDimsBelowThreshold() && getInput().get(1).areDimsBelowThreshold() - && getInput().get(2).areDimsBelowThreshold()) - //|| (getInput().get(0).isVector() && getInput().get(1).isVector() && getInput().get(1).isVector() ) - ) + && getInput().get(2).areDimsBelowThreshold()) ) _etype = ExecType.CP; else _etype = REMOTE; @@ -908,17 +915,21 @@ public class TernaryOp extends Hop // This part of the code is executed only when a vector of quantiles are computed // Output is a vector of length = #of quantiles to be computed, and it is likely to be dense. // TODO qx1 - break; + break; + //default ternary operations + case IFELSE: case PLUS_MULT: case MINUS_MULT: - setDim1( getInput().get(0)._dim1 ); - setDim2( getInput().get(0)._dim2 ); + if( getDataType() == DataType.MATRIX ) { + setDim1( HopRewriteUtils.getMaxNrowInput(this) ); + setDim2( HopRewriteUtils.getMaxNcolInput(this) ); + } break; default: throw new RuntimeException("Size information for operation (" + _op + ") can not be updated."); } - } + } } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java b/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java index aebfa9d..e6ff4be 100644 --- a/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java +++ b/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java @@ -54,6 +54,7 @@ import org.apache.sysml.runtime.instructions.mr.MRInstruction; import org.apache.sysml.runtime.instructions.mr.MapMultChainInstruction; import org.apache.sysml.runtime.instructions.mr.PickByCountInstruction; import org.apache.sysml.runtime.instructions.mr.RemoveEmptyMRInstruction; +import org.apache.sysml.runtime.instructions.mr.TernaryInstruction; import org.apache.sysml.runtime.instructions.mr.CtableInstruction; import org.apache.sysml.runtime.instructions.mr.UnaryMRInstructionBase; import org.apache.sysml.runtime.instructions.mr.MRInstruction.MRType; @@ -395,7 +396,7 @@ public class CostEstimatorStaticRuntime extends CostEstimator vs[0] = stats[ binst.input1 ]; vs[1] = stats[ binst.input2 ]; vs[2] = stats[ binst.output ]; - + if( vs[0] == null ) //scalar input, vs[0] = _scalarStats; if( vs[1] == null ) //scalar input, @@ -408,6 +409,19 @@ public class CostEstimatorStaticRuntime extends CostEstimator attr = new String[]{rbinst.isRemoveRows()?"0":"1"}; } } + else if( mrinst instanceof TernaryInstruction ) { + TernaryInstruction tinst = (TernaryInstruction) mrinst; + byte[] ix = tinst.getAllIndexes(); + for(int i=0; i<ix.length-1; i++) + vs[0] = stats[ ix[i] ]; + vs[2] = stats[ ix[ix.length-1] ]; + if( vs[0] == null ) //scalar input, + vs[0] = _scalarStats; + if( vs[1] == null ) //scalar input, + vs[1] = _scalarStats; + if( vs[2] == null ) //scalar output + vs[2] = _scalarStats; + } else if( mrinst instanceof CtableInstruction ) { CtableInstruction tinst = (CtableInstruction) mrinst; @@ -884,6 +898,9 @@ public class CostEstimatorStaticRuntime extends CostEstimator else return d3m*d3n; + case Ternary: //opcodes: +*, -*, ifelse + return 2 * d1m * d1n; + case Ctable: //opcodes: ctable if( optype.equals("ctable") ){ if( leftSparse ) http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index a7373fb..d9d9120 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -1343,6 +1343,14 @@ public class HopRewriteUtils || sb instanceof ForStatementBlock); //incl parfor } + public static long getMaxNrowInput(Hop hop) { + return getMaxInputDim(hop, true); + } + + public static long getMaxNcolInput(Hop hop) { + return getMaxInputDim(hop, false); + } + public static long getMaxInputDim(Hop hop, boolean dim1) { return hop.getInput().stream().mapToLong( h -> (dim1?h.getDim1():h.getDim2())).max().orElse(-1); http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index e07f97c..d60b60f 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -2174,7 +2174,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule private static Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) { - //patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY + //patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY if( hi instanceof BinaryOp && !((BinaryOp) hi).isOuterVectorOperator() && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) ) { @@ -2186,6 +2186,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule //pattern (a) X + s*Y -> X +* sY if( bop.getOp() == OpOp2.PLUS && left.getDataType()==DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) + && HopRewriteUtils.isEqualSize(left, right) && right.getParent().size() == 1 ) //single consumer s*Y { Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); @@ -2197,18 +2198,19 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule //pattern (b) s*Y + X -> X +* sY else if( bop.getOp() == OpOp2.PLUS && right.getDataType()==DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(left) - && left.getParent().size() == 1 //single consumer s*Y - && HopRewriteUtils.isEqualSize(left, right)) //correctness matrix-vector + && HopRewriteUtils.isEqualSize(left, right) + && left.getParent().size() == 1 ) //single consumer s*Y { Hop smid = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); Hop mright = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? right : HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT); - LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")"); + LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")"); } //pattern (c) X - s*Y -> X -* sY else if( bop.getOp() == OpOp2.MINUS && left.getDataType()==DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) + && HopRewriteUtils.isEqualSize(left, right) && right.getParent().size() == 1 ) //single consumer s*Y { Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/lops/Ctable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Ctable.java b/src/main/java/org/apache/sysml/lops/Ctable.java index 7eec1bb..8aac934 100644 --- a/src/main/java/org/apache/sysml/lops/Ctable.java +++ b/src/main/java/org/apache/sysml/lops/Ctable.java @@ -53,7 +53,7 @@ public class Ctable extends Lop } public Ctable(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, boolean ignoreZeros, ExecType et) { - super(Lop.Type.Ternary, dt, vt); + super(Lop.Type.Ctable, dt, vt); init(inputLops, op, et); _ignoreZeros = ignoreZeros; } http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/lops/Lop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java index e105362..8d79224 100644 --- a/src/main/java/org/apache/sysml/lops/Lop.java +++ b/src/main/java/org/apache/sysml/lops/Lop.java @@ -55,7 +55,7 @@ public abstract class Lop FunctionCallCP, FunctionCallCPSingle, //CP function calls CumulativePartialAggregate, CumulativeSplitAggregate, CumulativeOffsetBinary, //MR cumsum/cumprod/cummin/cummax WeightedSquaredLoss, WeightedSigmoid, WeightedDivMM, WeightedCeMM, WeightedUMM, - SortKeys, PickValues, + SortKeys, PickValues, Ctable, Checkpoint, //Spark persist into storage level PlusMult, MinusMult, //CP SpoofFused, //CP/SP generated fused operator http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/lops/PlusMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/PlusMult.java b/src/main/java/org/apache/sysml/lops/PlusMult.java deleted file mode 100644 index 0649f49..0000000 --- a/src/main/java/org/apache/sysml/lops/PlusMult.java +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.lops; - -import org.apache.sysml.hops.Hop.OpOp3; -import org.apache.sysml.lops.LopProperties.ExecLocation; -import org.apache.sysml.lops.LopProperties.ExecType; -import org.apache.sysml.lops.compile.JobType; -import org.apache.sysml.parser.Expression.DataType; -import org.apache.sysml.parser.Expression.ValueType; - - -/** - * Lop to perform Sum of a matrix with another matrix multiplied by Scalar. - */ -public class PlusMult extends Lop -{ - public PlusMult(Lop input1, Lop input2, Lop input3, OpOp3 op, DataType dt, ValueType vt, ExecType et) { - super(Lop.Type.PlusMult, dt, vt); - if(op == OpOp3.MINUS_MULT) - type=Lop.Type.MinusMult; - init(input1, input2, input3, et); - } - - private void init(Lop input1, Lop input2, Lop input3, ExecType et) { - addInput(input1); - addInput(input2); - addInput(input3); - input1.addOutput(this); - input2.addOutput(this); - input3.addOutput(this); - - boolean breaksAlignment = false; - boolean aligner = false; - boolean definesMRJob = false; - - if ( et == ExecType.CP || et == ExecType.SPARK || et == ExecType.GPU ){ - lps.addCompatibility(JobType.INVALID); - lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob ); - } - else if( et == ExecType.MR ) { - lps.addCompatibility(JobType.GMR); - lps.addCompatibility(JobType.DATAGEN); - lps.addCompatibility(JobType.REBLOCK); - lps.setProperties( inputs, et, ExecLocation.Reduce, breaksAlignment, aligner, definesMRJob ); - } - } - - @Override - public String toString() { - return "Operation = PlusMult"; - } - - public String getOpString() { - return (type==Lop.Type.PlusMult) ? "+*" : "-*"; - } - - /** - * Function to generate CP/Spark axpy. - * - * input1: matrix1 - * input2: Scalar - * input3: matrix2 - */ - @Override - public String getInstructions(String input1, String input2, String input3, String output) - { - StringBuilder sb = new StringBuilder(); - - sb.append( getExecType() ); - - sb.append( OPERAND_DELIMITOR ); - sb.append( getOpString() ); - - //matrix 1 - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(0).prepInputOperand(input1) ); - - //scalar - sb.append( OPERAND_DELIMITOR ); - if( getExecType()==ExecType.MR ) - sb.append( getInputs().get(1).prepScalarLabel() ); - else - sb.append( getInputs().get(1).prepScalarInputOperand(input2) ); - - //matrix 2 - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(2).prepInputOperand(input3) ); - - sb.append( OPERAND_DELIMITOR ); - sb.append( prepOutputOperand(output) ); - - return sb.toString(); - } - - @Override - public String getInstructions(int input1, int input2, int input3, int output) { - return getInstructions(String.valueOf(input1), String.valueOf(input2), - String.valueOf(input3), String.valueOf(output)); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/lops/Ternary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Ternary.java b/src/main/java/org/apache/sysml/lops/Ternary.java new file mode 100644 index 0000000..be83cbd --- /dev/null +++ b/src/main/java/org/apache/sysml/lops/Ternary.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.lops; + +import org.apache.sysml.lops.LopProperties.ExecLocation; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.lops.compile.JobType; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; + + +/** + * Lop to perform Sum of a matrix with another matrix multiplied by Scalar. + */ +public class Ternary extends Lop +{ + public enum OperationType { + PLUS_MULT, + MINUS_MULT, + IFELSE, + } + + private final OperationType _type; + + public Ternary(OperationType op, Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, ExecType et) { + super(Lop.Type.Ternary, dt, vt); + _type = op; + init(input1, input2, input3, et); + } + + private void init(Lop input1, Lop input2, Lop input3, ExecType et) { + addInput(input1); + addInput(input2); + addInput(input3); + input1.addOutput(this); + input2.addOutput(this); + input3.addOutput(this); + + boolean breaksAlignment = false; + boolean aligner = false; + boolean definesMRJob = false; + + if ( et == ExecType.CP || et == ExecType.SPARK || et == ExecType.GPU ){ + lps.addCompatibility(JobType.INVALID); + lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob ); + } + else if( et == ExecType.MR ) { + lps.addCompatibility(JobType.GMR); + lps.setProperties( inputs, et, ExecLocation.Reduce, breaksAlignment, aligner, definesMRJob ); + } + } + + @Override + public String toString() { + return "Operation = t("+_type.name().toLowerCase()+")"; + } + + public String getOpString() { + switch( _type ) { + case PLUS_MULT: return "+*"; + case MINUS_MULT: return "-*"; + case IFELSE: return "ifelse"; + } + return null; + } + + @Override + public String getInstructions(String input1, String input2, String input3, String output) + { + StringBuilder sb = new StringBuilder(); + + sb.append( getExecType() ); + + sb.append( OPERAND_DELIMITOR ); + sb.append( getOpString() ); + + //process three operands + String[] inputs = new String[]{input1, input2, input3}; + for( int i=0; i<3; i++ ) { + sb.append( OPERAND_DELIMITOR ); + if( getExecType()==ExecType.MR && getInputs().get(i).getDataType().isScalar() ) + sb.append( getInputs().get(i).prepScalarLabel() ); + else + sb.append( getInputs().get(i).prepInputOperand(inputs[i]) ); + } + + sb.append( OPERAND_DELIMITOR ); + sb.append( prepOutputOperand(output) ); + + return sb.toString(); + } + + @Override + public String getInstructions(int input1, int input2, int input3, int output) { + return getInstructions(String.valueOf(input1), String.valueOf(input2), + String.valueOf(input3), String.valueOf(output)); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/lops/compile/Dag.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/compile/Dag.java b/src/main/java/org/apache/sysml/lops/compile/Dag.java index 28db4ea..d662124 100644 --- a/src/main/java/org/apache/sysml/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysml/lops/compile/Dag.java @@ -1402,7 +1402,7 @@ public class Dag<N extends Lop> node.getInputs().get(1).getOutputParameters().getLabel(), node.getOutputParameters().getLabel()); } - else if (node.getInputs().size() == 3 || node.getType() == Type.Ternary) { + else if (node.getInputs().size() == 3 || node.getType() == Type.Ctable) { inst_string = node.getInstructions( node.getInputs().get(0).getOutputParameters().getLabel(), node.getInputs().get(1).getOutputParameters().getLabel(), @@ -3119,11 +3119,11 @@ public class Dag<N extends Lop> } return output_index; - } else if (inputIndices.size() == 3 || node.getType() == Type.Ternary) { + } else if (inputIndices.size() == 3 || node.getType() == Type.Ctable) { int output_index = start_index[0]; start_index[0]++; - if (node.getType() == Type.Ternary ) { + if (node.getType() == Type.Ctable ) { // in case of CTABLE_TRANSFORM_SCALAR_WEIGHT: inputIndices.get(2) would be -1 otherInstructionsReducer.add(node.getInstructions( inputIndices.get(0), inputIndices.get(1), http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index e0f4e86..3cb882c 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -548,6 +548,11 @@ public class BuiltinFunctionExpression extends DataIdentifier output.setValueType(ValueType.BOOLEAN); break; + case IFELSE: + checkNumParameters(3); + setTernaryOutputProperties(output, conditional); + break; + case CBIND: case RBIND: //scalar string append (string concatenation with \n) @@ -1289,6 +1294,29 @@ public class BuiltinFunctionExpression extends DataIdentifier output.setBlockDimensions (dims[2], dims[3]); } + private void setTernaryOutputProperties(DataIdentifier output, boolean conditional) + throws LanguageException + { + DataType dt1 = getFirstExpr().getOutput().getDataType(); + DataType dt2 = getSecondExpr().getOutput().getDataType(); + DataType dt3 = getThirdExpr().getOutput().getDataType(); + DataType dtOut = (dt1.isMatrix() || dt2.isMatrix() || dt3.isMatrix()) ? + DataType.MATRIX : DataType.SCALAR; + if( dt1==DataType.MATRIX && dt2==DataType.MATRIX ) + checkMatchingDimensions(getFirstExpr(), getSecondExpr(), false, conditional); + if( dt1==DataType.MATRIX && dt3==DataType.MATRIX ) + checkMatchingDimensions(getFirstExpr(), getThirdExpr(), false, conditional); + if( dt2==DataType.MATRIX && dt3==DataType.MATRIX ) + checkMatchingDimensions(getSecondExpr(), getThirdExpr(), false, conditional); + long[] dims1 = getBinaryMatrixCharacteristics(getFirstExpr(), getSecondExpr()); + long[] dims2 = getBinaryMatrixCharacteristics(getSecondExpr(), getThirdExpr()); + output.setDataType(dtOut); + output.setValueType(dtOut==DataType.MATRIX ? ValueType.DOUBLE : + computeValueType(getSecondExpr(), getThirdExpr(), true)); + output.setDimensions(Math.max(dims1[0], dims2[0]), Math.max(dims1[1], dims2[1])); + output.setBlockDimensions (Math.max(dims1[2], dims2[2]), Math.max(dims1[3], dims2[3])); + } + private void expandArguments() { if ( _args == null ) { @@ -1517,13 +1545,15 @@ public class BuiltinFunctionExpression extends DataIdentifier } } - private void checkMatchingDimensions(Expression expr1, Expression expr2) - throws LanguageException - { + private void checkMatchingDimensions(Expression expr1, Expression expr2) throws LanguageException { checkMatchingDimensions(expr1, expr2, false); } - private void checkMatchingDimensions(Expression expr1, Expression expr2, boolean allowsMV) + private void checkMatchingDimensions(Expression expr1, Expression expr2, boolean allowsMV) throws LanguageException { + checkMatchingDimensions(expr1, expr2, allowsMV, false); + } + + private void checkMatchingDimensions(Expression expr1, Expression expr2, boolean allowsMV, boolean conditional) throws LanguageException { if (expr1 != null && expr2 != null) { @@ -1540,7 +1570,7 @@ public class BuiltinFunctionExpression extends DataIdentifier || (allowsMV && expr1.getOutput().getDim2() != expr2.getOutput().getDim2() && expr2.getOutput().getDim2() != 1) ) { raiseValidateError("Mismatch in matrix dimensions of parameters for function " - + this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS); + + this.getOpCode(), conditional, LanguageErrorCodes.INVALID_PARAMETERS); } } } @@ -1757,6 +1787,8 @@ public class BuiltinFunctionExpression extends DataIdentifier bifop = Expression.BuiltinFunctionOp.BITWISE_SHIFTL; else if ( functionName.equals("bitwShiftR") ) bifop = Expression.BuiltinFunctionOp.BITWISE_SHIFTR; + else if ( functionName.equals("ifelse") ) + bifop = Expression.BuiltinFunctionOp.IFELSE; else return null; http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 76ea23d..4422147 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2881,8 +2881,13 @@ public class DMLTranslator currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.MEDIAN, expr, expr2); } - break; + break; + case IFELSE: + currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), + Hop.OpOp3.IFELSE, expr, expr2, expr3); + break; + case SEQ: HashMap<String,Hop> randParams = new HashMap<>(); randParams.put(Statement.SEQ_FROM, expr); http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index 19d4d3e..8263206 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -90,6 +90,7 @@ public abstract class Expression implements ParseInfo MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD, EXP, FLOOR, + IFELSE, INTERQUANTILE, INVERSE, IQM, http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/functionobjects/IfElse.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/IfElse.java b/src/main/java/org/apache/sysml/runtime/functionobjects/IfElse.java new file mode 100644 index 0000000..da6e04e --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/IfElse.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.functionobjects; + +import java.io.Serializable; + +public class IfElse extends TernaryValueFunction implements Serializable +{ + private static final long serialVersionUID = -8660124936856173978L; + + private static IfElse singleObj = null; + + private IfElse() { + // nothing to do here + } + + public static IfElse getFnObject() { + if ( singleObj == null ) + singleObj = new IfElse(); + return singleObj; + } + + @Override + public double execute(double in1, double in2, double in3) { + return (in1 != 0) ? in2 : in3; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java index 9a98194..794571f 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java @@ -21,27 +21,24 @@ package org.apache.sysml.runtime.functionobjects; import java.io.Serializable; -public class MinusMultiply extends ValueFunctionWithConstant implements Serializable +public class MinusMultiply extends TernaryValueFunction implements Serializable { private static final long serialVersionUID = 2801982061205871665L; + private static MinusMultiply singleObj = null; + private MinusMultiply() { // nothing to do here } - public static MinusMultiply getMinusMultiplyFnObject() { - //create new object as the constant is modified and hence - //cannot be shared across multiple threads (e.g., in parfor) - return new MinusMultiply(); - } - - @Override - public double execute(double in1, double in2) { - return in1 - _constant * in2; + public static MinusMultiply getFnObject() { + if ( singleObj == null ) + singleObj = new MinusMultiply(); + return singleObj; } @Override - public double execute(long in1, long in2) { - return in1 - _constant * in2; + public double execute(double in1, double in2, double in3) { + return in1 - in2 * in3; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java index 5f9b6e7..cb821f5 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java @@ -21,27 +21,24 @@ package org.apache.sysml.runtime.functionobjects; import java.io.Serializable; -public class PlusMultiply extends ValueFunctionWithConstant implements Serializable +public class PlusMultiply extends TernaryValueFunction implements Serializable { private static final long serialVersionUID = 2801982061205871665L; - + + private static PlusMultiply singleObj = null; + private PlusMultiply() { // nothing to do here } - public static PlusMultiply getPlusMultiplyFnObject() { - //create new object as the constant is modified and hence - //cannot be shared across multiple threads (e.g., in parfor) - return new PlusMultiply(); - } - - @Override - public double execute(double in1, double in2) { - return in1 + _constant * in2; + public static PlusMultiply getFnObject() { + if ( singleObj == null ) + singleObj = new PlusMultiply(); + return singleObj; } @Override - public double execute(long in1, long in2) { - return in1 + _constant * in2; + public double execute(double in1, double in2, double in3) { + return in1 + in2 * in3; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java b/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java new file mode 100644 index 0000000..c317010 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.functionobjects; + +import java.io.Serializable; + +import org.apache.sysml.runtime.DMLRuntimeException; + +public abstract class TernaryValueFunction extends ValueFunction implements Serializable +{ + private static final long serialVersionUID = 4837616587192612216L; + + public abstract double execute ( double in1, double in2, double in3 ) + throws DMLRuntimeException; +} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java b/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java deleted file mode 100644 index f23c29a..0000000 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.functionobjects; - -import java.io.Serializable; - -public abstract class ValueFunctionWithConstant extends ValueFunction implements Serializable -{ - private static final long serialVersionUID = -4985988545393861058L; - protected double _constant; - - public void setConstant(double constant) { - _constant = constant; - } - - public double getConstant() { - return _constant; - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index 6b2e3d6..4928ac3 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -52,7 +52,7 @@ import org.apache.sysml.runtime.instructions.cp.MultiReturnBuiltinCPInstruction; import org.apache.sysml.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction; import org.apache.sysml.runtime.instructions.cp.PMMJCPInstruction; import org.apache.sysml.runtime.instructions.cp.ParameterizedBuiltinCPInstruction; -import org.apache.sysml.runtime.instructions.cp.PlusMultCPInstruction; +import org.apache.sysml.runtime.instructions.cp.TernaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.QuantilePickCPInstruction; import org.apache.sysml.runtime.instructions.cp.QuantileSortCPInstruction; import org.apache.sysml.runtime.instructions.cp.QuaternaryCPInstruction; @@ -122,8 +122,6 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "^2" , CPType.Binary); //special ^ case String2CPInstructionType.put( "*2" , CPType.Binary); //special * case String2CPInstructionType.put( "-nz" , CPType.Binary); //special - case - String2CPInstructionType.put( "+*" , CPType.Binary); - String2CPInstructionType.put( "-*" , CPType.Binary); // Boolean Instruction Opcodes String2CPInstructionType.put( "&&" , CPType.Binary); @@ -198,6 +196,11 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "transformmeta",CPType.ParameterizedBuiltin); String2CPInstructionType.put( "transformencode",CPType.MultiReturnParameterizedBuiltin); + // Ternary Instruction Opcodes + String2CPInstructionType.put( "+*", CPType.Ternary); + String2CPInstructionType.put( "-*", CPType.Ternary); + String2CPInstructionType.put( "ifelse", CPType.Ternary); + // Variable Instruction Opcodes String2CPInstructionType.put( "assignvar" , CPType.Variable); String2CPInstructionType.put( "cpvar" , CPType.Variable); @@ -319,26 +322,25 @@ public class CPInstructionParser extends InstructionParser case AggregateTernary: return AggregateTernaryCPInstruction.parseInstruction(str); - + + case Unary: + return UnaryCPInstruction.parseInstruction(str); + case Binary: - String opcode = InstructionUtils.getOpCode(str); - if( opcode.equals("+*") || opcode.equals("-*") ) - return PlusMultCPInstruction.parseInstruction(str); - else - return BinaryCPInstruction.parseInstruction(str); + return BinaryCPInstruction.parseInstruction(str); - case Ctable: - return CtableCPInstruction.parseInstruction(str); + case Ternary: + return TernaryCPInstruction.parseInstruction(str); case Quaternary: return QuaternaryCPInstruction.parseInstruction(str); - - case Unary: - return UnaryCPInstruction.parseInstruction(str); case BuiltinNary: return BuiltinNaryCPInstruction.parseInstruction(str); + case Ctable: + return CtableCPInstruction.parseInstruction(str); + case Reorg: return ReorgCPInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java index e6affa0..85a45ff 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java @@ -54,6 +54,7 @@ import org.apache.sysml.runtime.functionobjects.Divide; import org.apache.sysml.runtime.functionobjects.Equals; import org.apache.sysml.runtime.functionobjects.GreaterThan; import org.apache.sysml.runtime.functionobjects.GreaterThanEquals; +import org.apache.sysml.runtime.functionobjects.IfElse; import org.apache.sysml.runtime.functionobjects.IndexFunction; import org.apache.sysml.runtime.functionobjects.IntegerDivide; import org.apache.sysml.runtime.functionobjects.KahanPlus; @@ -95,6 +96,7 @@ import org.apache.sysml.runtime.matrix.operators.LeftScalarOperator; import org.apache.sysml.runtime.matrix.operators.Operator; import org.apache.sysml.runtime.matrix.operators.RightScalarOperator; import org.apache.sysml.runtime.matrix.operators.ScalarOperator; +import org.apache.sysml.runtime.matrix.operators.TernaryOperator; import org.apache.sysml.runtime.matrix.operators.UnaryOperator; @@ -580,14 +582,15 @@ public class InstructionUtils return new BinaryOperator(Builtin.getBuiltinFnObject("max")); else if ( opcode.equalsIgnoreCase("min") ) return new BinaryOperator(Builtin.getBuiltinFnObject("min")); - else if ( opcode.equalsIgnoreCase("+*") ) - return new BinaryOperator(PlusMultiply.getPlusMultiplyFnObject()); - else if ( opcode.equalsIgnoreCase("-*") ) - return new BinaryOperator(MinusMultiply.getMinusMultiplyFnObject()); throw new DMLRuntimeException("Unknown binary opcode " + opcode); } + public static TernaryOperator parseTernaryOperator(String opcode) { + return new TernaryOperator(opcode.equals("+*") ? PlusMultiply.getFnObject() : + opcode.equals("-*") ? MinusMultiply.getFnObject() : IfElse.getFnObject()); + } + /** * scalar-matrix operator * http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java index fae93f3..6777d51 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java @@ -64,7 +64,7 @@ import org.apache.sysml.runtime.instructions.mr.MatrixReshapeMRInstruction; import org.apache.sysml.runtime.instructions.mr.PMMJMRInstruction; import org.apache.sysml.runtime.instructions.mr.ParameterizedBuiltinMRInstruction; import org.apache.sysml.runtime.instructions.mr.PickByCountInstruction; -import org.apache.sysml.runtime.instructions.mr.PlusMultInstruction; +import org.apache.sysml.runtime.instructions.mr.TernaryInstruction; import org.apache.sysml.runtime.instructions.mr.QuaternaryInstruction; import org.apache.sysml.runtime.instructions.mr.RandInstruction; import org.apache.sysml.runtime.instructions.mr.RangeBasedReIndexInstruction; @@ -192,8 +192,6 @@ public class MRInstructionParser extends InstructionParser String2MRInstructionType.put( "bitwXor", MRType.Binary); String2MRInstructionType.put( "bitwShiftL", MRType.Binary); String2MRInstructionType.put( "bitwShiftR", MRType.Binary); - String2MRInstructionType.put( "+*" , MRType.Binary2); - String2MRInstructionType.put( "-*" , MRType.Binary2); String2MRInstructionType.put( "map+" , MRType.Binary); String2MRInstructionType.put( "map-" , MRType.Binary); @@ -220,6 +218,11 @@ public class MRInstructionParser extends InstructionParser String2MRInstructionType.put( "mapbitwShiftL", MRType.Binary); String2MRInstructionType.put( "mapbitwShiftR", MRType.Binary); + // Ternary Instruction Opcodes + String2MRInstructionType.put( "+*", MRType.Ternary); + String2MRInstructionType.put( "-*", MRType.Ternary); + String2MRInstructionType.put( "ifelse", MRType.Ternary); + String2MRInstructionType.put( "uaggouterchain", MRType.UaggOuterChain); // REORG Instruction Opcodes @@ -353,8 +356,8 @@ public class MRInstructionParser extends InstructionParser } } - case Binary2: - return PlusMultInstruction.parseInstruction(str); + case Ternary: + return TernaryInstruction.parseInstruction(str); case AggregateBinary: return AggregateBinaryInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java index 3098c10..a9155c9 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java @@ -63,7 +63,7 @@ import org.apache.sysml.runtime.instructions.spark.MatrixReshapeSPInstruction; import org.apache.sysml.runtime.instructions.spark.MultiReturnParameterizedBuiltinSPInstruction; import org.apache.sysml.runtime.instructions.spark.PMapmmSPInstruction; import org.apache.sysml.runtime.instructions.spark.ParameterizedBuiltinSPInstruction; -import org.apache.sysml.runtime.instructions.spark.PlusMultSPInstruction; +import org.apache.sysml.runtime.instructions.spark.TernarySPInstruction; import org.apache.sysml.runtime.instructions.spark.PmmSPInstruction; import org.apache.sysml.runtime.instructions.spark.QuantilePickSPInstruction; import org.apache.sysml.runtime.instructions.spark.QuaternarySPInstruction; @@ -162,8 +162,6 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "^" , SPType.Binary); String2SPInstructionType.put( "^2" , SPType.Binary); String2SPInstructionType.put( "*2" , SPType.Binary); - String2SPInstructionType.put( "+*" , SPType.Binary); - String2SPInstructionType.put( "-*" , SPType.Binary); String2SPInstructionType.put( "map+" , SPType.Binary); String2SPInstructionType.put( "map-" , SPType.Binary); String2SPInstructionType.put( "map*" , SPType.Binary); @@ -271,6 +269,11 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "ctable", SPType.Ctable); String2SPInstructionType.put( "ctableexpand", SPType.Ctable); + //ternary instruction opcodes + String2SPInstructionType.put( "+*", SPType.Ternary); + String2SPInstructionType.put( "-*", SPType.Ternary); + String2SPInstructionType.put( "ifelse", SPType.Ternary); + //quaternary instruction opcodes String2SPInstructionType.put( WeightedSquaredLoss.OPCODE, SPType.Quaternary); String2SPInstructionType.put( WeightedSquaredLossR.OPCODE, SPType.Quaternary); @@ -374,12 +377,11 @@ public class SPInstructionParser extends InstructionParser return ReorgSPInstruction.parseInstruction(str); case Binary: - String opcode = InstructionUtils.getOpCode(str); - if( opcode.equals("+*") || opcode.equals("-*") ) - return PlusMultSPInstruction.parseInstruction(str); - else - return BinarySPInstruction.parseInstruction(str); - + return BinarySPInstruction.parseInstruction(str); + + case Ternary: + return TernarySPInstruction.parseInstruction(str); + //ternary instructions case Ctable: return CtableSPInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java index 6b86029..7a2053b 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java @@ -31,7 +31,7 @@ public abstract class CPInstruction extends Instruction { public enum CPType { AggregateUnary, AggregateBinary, AggregateTernary, - Unary, Binary, Ctable, Quaternary, BuiltinNary, + Unary, Binary, Ternary, Quaternary, BuiltinNary, Ctable, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, Builtin, Reorg, Variable, External, Append, Rand, QSort, QPick, MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, Compression, SpoofFused, http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java index c39b617..2f9b64d 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java @@ -65,6 +65,10 @@ public class CPOperand return _dataType; } + public boolean isMatrix() { + return _dataType.isMatrix(); + } + public boolean isLiteral() { return _isLiteral; } @@ -87,6 +91,13 @@ public class CPOperand _valueType = ValueType.valueOf(opr[2]); _isLiteral = false; } + else if ( opr.length == 1 ) { + //note: for literals in MR instructions + _name = opr[0]; + _dataType = DataType.SCALAR; + _valueType = ValueType.DOUBLE; + _isLiteral = true; + } else { _name = opr[0]; _valueType = ValueType.valueOf(opr[1]); http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java deleted file mode 100644 index c74786c..0000000 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.instructions.cp; - -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.functionobjects.MinusMultiply; -import org.apache.sysml.runtime.functionobjects.PlusMultiply; -import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant; -import org.apache.sysml.runtime.instructions.InstructionUtils; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.operators.BinaryOperator; - -public class PlusMultCPInstruction extends ComputationCPInstruction { - - private PlusMultCPInstruction(BinaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, - String opcode, String str) { - super(CPType.Binary, op, in1, in2, in3, out, opcode, str); - } - - public static PlusMultCPInstruction parseInstruction(String str) - { - String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); - String opcode=parts[0]; - CPOperand operand1 = new CPOperand(parts[1]); - CPOperand operand2 = new CPOperand(parts[3]); //put the second matrix (parts[3]) in Operand2 to make using Binary matrix operations easier - CPOperand operand3 = new CPOperand(parts[2]); - CPOperand outOperand = new CPOperand(parts[4]); - BinaryOperator bOperator = new BinaryOperator(opcode.equals("+*") ? - PlusMultiply.getPlusMultiplyFnObject():MinusMultiply.getMinusMultiplyFnObject()); - return new PlusMultCPInstruction(bOperator,operand1, operand2, operand3, outOperand, opcode,str); - } - - @Override - public void processInstruction( ExecutionContext ec ) - throws DMLRuntimeException - { - String output_name = output.getName(); - - //get all the inputs - MatrixBlock matrix1 = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); - MatrixBlock matrix2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode()); - ScalarObject scalar = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()); - - //execution - ((ValueFunctionWithConstant) ((BinaryOperator)_optr).fn).setConstant(scalar.getDoubleValue()); - MatrixBlock out = (MatrixBlock) matrix1.binaryOperations((BinaryOperator) _optr, matrix2, new MatrixBlock()); - - //release the matrices - ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); - ec.releaseMatrixInput(input2.getName(), getExtendedOpcode()); - - ec.setMatrixOutput(output_name, out, getExtendedOpcode()); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarObjectFactory.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarObjectFactory.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarObjectFactory.java index ea2c169..ffad8ff 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarObjectFactory.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarObjectFactory.java @@ -46,6 +46,16 @@ public abstract class ScalarObjectFactory } } + public static ScalarObject createScalarObject(ValueType vt, double value) { + switch( vt ) { + case INT: return new IntObject(UtilFunctions.toLong(value)); + case DOUBLE: return new DoubleObject(value); + case BOOLEAN: return new BooleanObject(value != 0); + case STRING: return new StringObject(String.valueOf(value)); + default: throw new RuntimeException("Unsupported scalar value type: "+vt.name()); + } + } + public static ScalarObject createScalarObject(ValueType vt, ScalarObject so) { switch( vt ) { case DOUBLE: return new DoubleObject(so.getDoubleValue()); http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/cp/TernaryCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/TernaryCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/TernaryCPInstruction.java new file mode 100644 index 0000000..4f8aeb0 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/TernaryCPInstruction.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.cp; + +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.operators.TernaryOperator; + +public class TernaryCPInstruction extends ComputationCPInstruction { + + private TernaryCPInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) { + super(CPType.Ternary, op, in1, in2, in3, out, opcode, str); + } + + public static TernaryCPInstruction parseInstruction(String str) + { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode=parts[0]; + CPOperand operand1 = new CPOperand(parts[1]); + CPOperand operand2 = new CPOperand(parts[2]); + CPOperand operand3 = new CPOperand(parts[3]); + CPOperand outOperand = new CPOperand(parts[4]); + TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode); + return new TernaryCPInstruction(op, operand1, operand2, operand3, outOperand, opcode,str); + } + + @Override + public void processInstruction( ExecutionContext ec ) + throws DMLRuntimeException + { + if( input1.isMatrix() || input2.isMatrix() || input3.isMatrix() ) + { + //get all inputs as matrix blocks + MatrixBlock m1 = input1.isMatrix() ? ec.getMatrixInput(input1.getName()) : + new MatrixBlock(ec.getScalarInput(input1).getDoubleValue()); + MatrixBlock m2 = input2.isMatrix() ? ec.getMatrixInput(input2.getName()) : + new MatrixBlock(ec.getScalarInput(input2).getDoubleValue()); + MatrixBlock m3 = input3.isMatrix() ? ec.getMatrixInput(input3.getName()) : + new MatrixBlock(ec.getScalarInput(input3).getDoubleValue()); + + //execution + MatrixBlock out = m1.ternaryOperations((TernaryOperator)_optr, m2, m3, new MatrixBlock()); + + //release the inputs and output + if( input1.isMatrix() ) + ec.releaseMatrixInput(input1.getName()); + if( input2.isMatrix() ) + ec.releaseMatrixInput(input2.getName()); + if( input3.isMatrix() ) + ec.releaseMatrixInput(input3.getName()); + ec.setMatrixOutput(output.getName(), out); + } + else { //SCALARS + double value = ((TernaryOperator)_optr).fn.execute( + ec.getScalarInput(input1).getDoubleValue(), + ec.getScalarInput(input2).getDoubleValue(), + ec.getScalarInput(input3).getDoubleValue()); + ec.setScalarOutput(output.getName(), ScalarObjectFactory + .createScalarObject(output.getValueType(), value)); + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java index 355ccb4..0cc105c 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java @@ -30,7 +30,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator; public abstract class MRInstruction extends Instruction { public enum MRType { - Append, Aggregate, Binary, Binary2, AggregateBinary, AggregateUnary, Rand, + Append, Aggregate, Binary, Ternary, AggregateBinary, AggregateUnary, Rand, Seq, CSVReblock, CSVWrite, Reblock, Reorg, Replicate, Unary, CombineBinary, CombineUnary, CombineTernary, PickByCount, Partition, Ctable, Quaternary, CM_N_COV, MapGroupedAggregate, GroupedAggregate, RightIndex, ZeroOut, MMTSJ, PMMJ, MatrixReshape, ParameterizedBuiltin, Sort, MapMultChain, CumsumAggregate, CumsumSplit, http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java deleted file mode 100644 index ff9a426..0000000 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.instructions.mr; - -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant; -import org.apache.sysml.runtime.instructions.InstructionUtils; -import org.apache.sysml.runtime.matrix.data.MatrixValue; -import org.apache.sysml.runtime.matrix.mapred.CachedValueMap; -import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; -import org.apache.sysml.runtime.matrix.operators.BinaryOperator; -import org.apache.sysml.runtime.matrix.operators.Operator; - -public class PlusMultInstruction extends BinaryInstruction { - private PlusMultInstruction(Operator op, byte in1, byte in2, byte out, String istr) { - super(MRType.Binary2, op, in1, in2, out, istr); - } - - public static PlusMultInstruction parseInstruction ( String str ) - throws DMLRuntimeException - { - InstructionUtils.checkNumFields ( str, 4 ); - - String[] parts = InstructionUtils.getInstructionParts ( str ); - String opcode = parts[0]; - byte in1 = Byte.parseByte(parts[1]); - double scalar = Double.parseDouble(parts[2]); - byte in2 = Byte.parseByte(parts[3]); - byte out = Byte.parseByte(parts[4]); - - BinaryOperator bop = InstructionUtils.parseBinaryOperator(opcode); - ((ValueFunctionWithConstant) bop.fn).setConstant(scalar); - return new PlusMultInstruction(bop, in1, in2, out, str); - } - - @Override - public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, - IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) - throws DMLRuntimeException - { - //default binary mr instruction execution (custom logic encoded in operator) - super.processInstruction(valueClass, cachedValues, tempValue, zeroInput, blockRowFactor, blockColFactor); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/mr/TernaryInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/TernaryInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/TernaryInstruction.java new file mode 100644 index 0000000..d23c826 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/TernaryInstruction.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.mr; + +import java.util.Arrays; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.matrix.data.MatrixValue; +import org.apache.sysml.runtime.matrix.mapred.CachedValueMap; +import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; +import org.apache.sysml.runtime.matrix.operators.Operator; +import org.apache.sysml.runtime.matrix.operators.TernaryOperator; + +public class TernaryInstruction extends MRInstruction { + + private final CPOperand input1, input2, input3, output; + private final byte ixinput1, ixinput2, ixinput3, ixoutput; + private final MatrixBlock m1, m2, m3; + + private TernaryInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String istr) { + super(MRType.Ternary, op, Byte.parseByte(out.getName())); + instString = istr; + input1 = in1; input2 = in2; input3 = in3; output = out; + ixinput1 = input1.isMatrix() ? Byte.parseByte(input1.getName()) : -1; + ixinput2 = input2.isMatrix() ? Byte.parseByte(input2.getName()) : -1; + ixinput3 = input3.isMatrix() ? Byte.parseByte(input3.getName()) : -1; + ixoutput = output.isMatrix() ? Byte.parseByte(output.getName()) : -1; + m1 = input1.isMatrix() ? null :new MatrixBlock(Double.parseDouble(input1.getName())); + m2 = input2.isMatrix() ? null :new MatrixBlock(Double.parseDouble(input2.getName())); + m3 = input3.isMatrix() ? null :new MatrixBlock(Double.parseDouble(input3.getName())); + } + + public static TernaryInstruction parseInstruction ( String str ) + throws DMLRuntimeException + { + InstructionUtils.checkNumFields ( str, 4 ); + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand out = new CPOperand(parts[4]); + TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode); + return new TernaryInstruction(op, in1, in2, in3, out, str); + } + + @Override + public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, + IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) + throws DMLRuntimeException + { + MatrixBlock lm1 = input1.isMatrix() ? (MatrixBlock) cachedValues.getFirst(ixinput1).getValue() : m1; + MatrixBlock lm2 = input2.isMatrix() ? (MatrixBlock) cachedValues.getFirst(ixinput2).getValue() : m2; + MatrixBlock lm3 = input3.isMatrix() ? (MatrixBlock) cachedValues.getFirst(ixinput3).getValue() : m3; + MatrixIndexes ixin = input1.isMatrix() ? cachedValues.getFirst(ixinput1).getIndexes() : input2.isMatrix() ? + cachedValues.getFirst(ixinput2).getIndexes() : cachedValues.getFirst(ixinput3).getIndexes(); + + //prepare output + IndexedMatrixValue out = new IndexedMatrixValue(new MatrixIndexes(), new MatrixBlock()); + out.getIndexes().setIndexes(ixin); + + //process instruction + TernaryOperator op = (TernaryOperator)optr; + lm1.ternaryOperations(op, lm2, lm3, (MatrixBlock)out.getValue()); + + //put the output value in the cache + cachedValues.add(ixoutput, out); + } + + @Override + public byte[] getInputIndexes() { + byte[] tmp = getAllIndexes(); + return Arrays.copyOfRange(tmp, 0, tmp.length-1); + } + + @Override + public byte[] getAllIndexes() { + return ArrayUtils.toPrimitive( + Arrays.stream(new CPOperand[]{input1, input2, input3, output}) + .filter(in -> in.isMatrix()).map(in -> Byte.parseByte(in.getName())) + .toArray(Byte[]::new)); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java deleted file mode 100644 index e12cd3e..0000000 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.instructions.spark; - -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysml.runtime.functionobjects.MinusMultiply; -import org.apache.sysml.runtime.functionobjects.PlusMultiply; -import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant; -import org.apache.sysml.runtime.instructions.InstructionUtils; -import org.apache.sysml.runtime.instructions.cp.CPOperand; -import org.apache.sysml.runtime.instructions.cp.ScalarObject; -import org.apache.sysml.runtime.matrix.operators.BinaryOperator; - -public class PlusMultSPInstruction extends BinarySPInstruction { - private PlusMultSPInstruction(BinaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, - String opcode, String str) throws DMLRuntimeException { - super(SPType.Binary, op, in1, in2, out, opcode, str); - input3 = in3; - - // sanity check opcodes - if (!(opcode.equalsIgnoreCase("+*") || opcode.equalsIgnoreCase("-*"))) { - throw new DMLRuntimeException("Unknown opcode in PlusMultSPInstruction: " + toString()); - } - } - - public static PlusMultSPInstruction parseInstruction(String str) throws DMLRuntimeException - { - String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); - String opcode=parts[0]; - CPOperand operand1 = new CPOperand(parts[1]); - CPOperand operand2 = new CPOperand(parts[3]); //put the second matrix (parts[3]) in Operand2 to make using Binary matrix operations easier - CPOperand operand3 = new CPOperand(parts[2]); - CPOperand outOperand = new CPOperand(parts[4]); - BinaryOperator bOperator = new BinaryOperator(opcode.equals("+*") ? - PlusMultiply.getPlusMultiplyFnObject():MinusMultiply.getMinusMultiplyFnObject()); - return new PlusMultSPInstruction(bOperator,operand1, operand2, operand3, outOperand, opcode,str); - } - - @Override - public void processInstruction(ExecutionContext ec) - throws DMLRuntimeException - { - SparkExecutionContext sec = (SparkExecutionContext)ec; - - //pass the scalar - ScalarObject constant = (ScalarObject) ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()); - ((ValueFunctionWithConstant) ((BinaryOperator)_optr).fn).setConstant(constant.getDoubleValue()); - - super.processMatrixMatrixBinaryInstruction(sec); - - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java index ef7158c..e3d8dfd 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java @@ -31,7 +31,7 @@ public abstract class SPInstruction extends Instruction { public enum SPType { MAPMM, MAPMMCHAIN, CPMM, RMM, TSMM, TSMM2, PMM, ZIPMM, PMAPMM, //matrix multiplication instructions - MatrixIndexing, Reorg, Binary, + MatrixIndexing, Reorg, Binary, Ternary, AggregateUnary, AggregateTernary, Reblock, CSVReblock, Builtin, Unary, BuiltinNary, MultiReturnBuiltin, Checkpoint, Compression, Cast, CentralMoment, Covariance, QSort, QPick,
