Repository: systemml Updated Branches: refs/heads/master 0a0a40370 -> d03396f20
[SYSTEMML-1570] Removed unnecessary sel+ operator, improved max/min This patch removes the unnecessary sel+ (select positive values) operator and now compiles expressions like (X>0)*X to max(X,0) instead. For sparse, this required additional improvements to make max and min more efficient for these cases: specifically, we now (1) mark max and min as conditionally sparse-safe according to the given constant, and (2) we preallocate output sparse rows for min and max as well. On a scenario of 20 iterations over a 10K x 10K input matrix, this patch performs as follows compared to the removed sel+ operator: a) dense: 22.9s -> 10.7s b) sparse sp=0.1: 4.0s -> 3.7s (11.1s w/o max/min modifications) Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3737b9ad Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3737b9ad Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3737b9ad Branch: refs/heads/master Commit: 3737b9ad0e71a0b8f772fb75c7b2258932f4d8d4 Parents: 0a0a403 Author: Matthias Boehm <[email protected]> Authored: Tue Jan 16 13:21:49 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Tue Jan 16 16:08:09 2018 -0800 ---------------------------------------------------------------------- .../org/apache/sysml/hops/ConvolutionOp.java | 5 +- src/main/java/org/apache/sysml/hops/Hop.java | 4 +- .../java/org/apache/sysml/hops/UnaryOp.java | 16 ++---- .../sysml/hops/codegen/cplan/CNodeUnary.java | 7 +-- .../codegen/opt/PlanSelectionFuseCostBased.java | 9 ++- .../opt/PlanSelectionFuseCostBasedV2.java | 7 +-- .../RewriteAlgebraicSimplificationStatic.java | 59 ++++++-------------- src/main/java/org/apache/sysml/lops/Unary.java | 5 +- .../sysml/runtime/functionobjects/Builtin.java | 18 ++---- .../instructions/CPInstructionParser.java | 1 - .../instructions/GPUInstructionParser.java | 1 - .../instructions/MRInstructionParser.java | 1 - .../instructions/SPInstructionParser.java | 1 - .../gpu/MatrixBuiltinGPUInstruction.java | 2 - .../runtime/matrix/data/LibMatrixBincell.java | 9 ++- .../matrix/operators/LeftScalarOperator.java | 6 +- .../matrix/operators/RightScalarOperator.java | 6 +- .../runtime/matrix/operators/UnaryOperator.java | 3 +- .../unary/matrix/FullSelectPosTest.java | 10 ++-- 19 files changed, 62 insertions(+), 108 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/ConvolutionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java index bdd3fa8..4c7525d 100644 --- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java +++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops; import org.apache.sysml.api.DMLScript; import org.apache.sysml.hops.Hop.MultiThreadedHop; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.lops.ConvolutionTransform; import org.apache.sysml.lops.ConvolutionTransform.OperationTypes; import org.apache.sysml.lops.Lop; @@ -164,7 +165,9 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop } private static boolean isInputReLU(Hop input) { - return input instanceof UnaryOp && ((UnaryOp) input).getOp() == OpOp1.SELP; + return HopRewriteUtils.isBinary(input, OpOp2.MAX) + && (HopRewriteUtils.isLiteralOfValue(input.getInput().get(0), 0) + || HopRewriteUtils.isLiteralOfValue(input.getInput().get(1), 0)); } private static boolean isInputConv2d(Hop input) { http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index a23008e..ae8c559 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1056,8 +1056,7 @@ public abstract class Hop implements ParseInfo CUMSUM, CUMPROD, CUMMIN, CUMMAX, //fused ML-specific operators for performance SPROP, //sample proportion: P * (1 - P) - SIGMOID, //sigmoid function: 1 / (1 + exp(-X)) - SELP, //select positive: X * (X>0) + SIGMOID, //sigmoid function: 1 / (1 + exp(-X)) LOG_NZ, //sparse-safe log; ppred(X,0,"!=")*log(X) } @@ -1312,7 +1311,6 @@ public abstract class Hop implements ParseInfo HopsOpOp1LopsU.put(OpOp1.CAST_AS_MATRIX, org.apache.sysml.lops.Unary.OperationTypes.NOTSUPPORTED); HopsOpOp1LopsU.put(OpOp1.SPROP, org.apache.sysml.lops.Unary.OperationTypes.SPROP); HopsOpOp1LopsU.put(OpOp1.SIGMOID, org.apache.sysml.lops.Unary.OperationTypes.SIGMOID); - HopsOpOp1LopsU.put(OpOp1.SELP, org.apache.sysml.lops.Unary.OperationTypes.SELP); HopsOpOp1LopsU.put(OpOp1.LOG_NZ, org.apache.sysml.lops.Unary.OperationTypes.LOG_NZ); HopsOpOp1LopsU.put(OpOp1.CAST_AS_MATRIX, org.apache.sysml.lops.Unary.OperationTypes.CAST_AS_MATRIX); HopsOpOp1LopsU.put(OpOp1.CAST_AS_FRAME, org.apache.sysml.lops.Unary.OperationTypes.CAST_AS_FRAME); http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/UnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java b/src/main/java/org/apache/sysml/hops/UnaryOp.java index 2e03a2b..ff1b954 100644 --- a/src/main/java/org/apache/sysml/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java @@ -107,7 +107,7 @@ public class UnaryOp extends Hop implements MultiThreadedHop || (_op == OpOp1.CAST_AS_FRAME && getInput().get(0).getDataType()==DataType.SCALAR)); if(!isScalar) { switch(_op) { - case SELP:case EXP:case SQRT:case LOG:case ABS: + case EXP:case SQRT:case LOG:case ABS: case ROUND:case FLOOR:case CEIL: case SIN:case COS: case TAN: case ASIN:case ACOS:case ATAN: @@ -155,15 +155,7 @@ public class UnaryOp extends Hop implements MultiThreadedHop if( optype == null ) throw new HopsException("Unknown UnaryCP lop type for UnaryOp operation type '"+_op+"'"); - UnaryCP unary1 = null; - if((_op == Hop.OpOp1.NROW || _op == Hop.OpOp1.NCOL || _op == Hop.OpOp1.LENGTH) && - input instanceof UnaryOp && ((UnaryOp) input).getOp() == OpOp1.SELP) { - // Dimensions does not change during sel+ operation. - // This case is helpful to avoid unnecessary sel+ operation for fused maxpooling. - unary1 = new UnaryCP(input.getInput().get(0).constructLops(), optype, getDataType(), getValueType()); - } - else - unary1 = new UnaryCP(input.constructLops(), optype, getDataType(), getValueType()); + UnaryCP unary1 = new UnaryCP(input.constructLops(), optype, getDataType(), getValueType()); setOutputDimensions(unary1); setLineNumbers(unary1); @@ -606,12 +598,12 @@ public class UnaryOp extends Hop implements MultiThreadedHop || _op==OpOp1.ACOS || _op==OpOp1.ASIN || _op==OpOp1.ATAN || _op==OpOp1.COSH || _op==OpOp1.SINH || _op==OpOp1.TANH || _op==OpOp1.SQRT || _op==OpOp1.ROUND - || _op==OpOp1.SPROP || _op==OpOp1.SELP ) //sparsity preserving + || _op==OpOp1.SPROP ) //sparsity preserving { ret = new long[]{mc.getRows(), mc.getCols(), mc.getNonZeros()}; } else - ret = new long[]{mc.getRows(), mc.getCols(), -1}; + ret = new long[]{mc.getRows(), mc.getCols(), -1}; } return ret; http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index 891bfb9..83ddd28 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -39,7 +39,7 @@ public class CNodeUnary extends CNode EXP, POW2, MULT2, SQRT, LOG, LOG_NZ, ABS, ROUND, CEIL, FLOOR, SIGN, SIN, COS, TAN, ASIN, ACOS, ATAN, SINH, COSH, TANH, - SELP, SPROP, SIGMOID; + SPROP, SIGMOID; public static boolean contains(String value) { for( UnaryType ut : values() ) @@ -135,8 +135,6 @@ public class CNodeUnary extends CNode return " double %TMP% = FastMath.ceil(%IN1%);\n"; case FLOOR: return " double %TMP% = FastMath.floor(%IN1%);\n"; - case SELP: - return " double %TMP% = (%IN1%>0) ? %IN1% : 0;\n"; case SPROP: return " double %TMP% = %IN1% * (1 - %IN1%);\n"; case SIGMOID: @@ -174,7 +172,7 @@ public class CNodeUnary extends CNode public boolean isSparseSafeScalar() { return ArrayUtils.contains(new UnaryType[]{ POW2, MULT2, ABS, ROUND, CEIL, FLOOR, SIGN, - SIN, TAN, SELP, SPROP}, this); + SIN, TAN, SPROP}, this); } } @@ -337,7 +335,6 @@ public class CNodeUnary extends CNode case ROUND: case CEIL: case FLOOR: - case SELP: case SPROP: case SIGMOID: case LOG_NZ: http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java index 521ef61..9cddf57 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java @@ -632,17 +632,16 @@ public class PlanSelectionFuseCostBased extends PlanSelection double costs = 1; if( current instanceof UnaryOp ) { switch( ((UnaryOp)current).getOp() ) { - case ABS: + case ABS: case ROUND: case CEIL: case FLOOR: - case SIGN: - case SELP: costs = 1; break; + case SIGN: costs = 1; break; case SPROP: case SQRT: costs = 2; break; case EXP: costs = 18; break; case SIGMOID: costs = 21; break; - case LOG: + case LOG: case LOG_NZ: costs = 32; break; case NCOL: case NROW: @@ -684,7 +683,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection case LESS: case LESSEQUAL: case GREATER: - case GREATEREQUAL: + case GREATEREQUAL: case CBIND: case RBIND: costs = 1; break; case INTDIV: costs = 6; break; http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 00b1543..68718d3 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -924,17 +924,16 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection double costs = 1; if( current instanceof UnaryOp ) { switch( ((UnaryOp)current).getOp() ) { - case ABS: + case ABS: case ROUND: case CEIL: case FLOOR: - case SIGN: - case SELP: costs = 1; break; + case SIGN: costs = 1; break; case SPROP: case SQRT: costs = 2; break; case EXP: costs = 18; break; case SIGMOID: costs = 21; break; - case LOG: + case LOG: case LOG_NZ: costs = 32; break; case NCOL: case NROW: http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 963e578..d72560b 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -1116,8 +1116,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { UnaryOp uop = (UnaryOp) hi; //valid unary op if( uop.getOp()==OpOp1.ABS || uop.getOp()==OpOp1.SIGN - || uop.getOp()==OpOp1.SELP || uop.getOp()==OpOp1.CEIL - || uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND ) + || uop.getOp()==OpOp1.CEIL || uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND ) { //clear link unary-binary Hop input = uop.getInput().get(0); @@ -1293,67 +1292,41 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { BinaryOp bleft = (BinaryOp)left; Hop left1 = bleft.getInput().get(0); - Hop left2 = bleft.getInput().get(1); + Hop left2 = bleft.getInput().get(1); if( left2 instanceof LiteralOp && - HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 && - left1 == right && bleft.getOp() == OpOp2.GREATER ) + HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 && + left1 == right && (bleft.getOp() == OpOp2.GREATER || bleft.getOp() == OpOp2.GREATEREQUAL ) ) { - UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP); - HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); + BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX); + HopRewriteUtils.replaceChildReference(parent, bop, binary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); - hi = unary; + hi = binary; applied = true; - LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp1"); + LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a"); } - } + } if( !applied && right instanceof BinaryOp ) //X*(X>0) { BinaryOp bright = (BinaryOp)right; Hop right1 = bright.getInput().get(0); - Hop right2 = bright.getInput().get(1); + Hop right2 = bright.getInput().get(1); if( right2 instanceof LiteralOp && - HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 && - right1 == left && bright.getOp() == OpOp2.GREATER ) + HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 && + right1 == left && bright.getOp() == OpOp2.GREATER || bright.getOp() == OpOp2.GREATEREQUAL ) { - UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SELP); - HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); + BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX); + HopRewriteUtils.replaceChildReference(parent, bop, binary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); - hi = unary; + hi = binary; applied= true; - LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp2"); + LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b"); } } } - - //select positive (selp) operator; pattern: max(X,0) -> selp+ - if( !applied && bop.getOp() == OpOp2.MAX && left.getDataType()==DataType.MATRIX - && right instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right)==0 ) - { - UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SELP); - HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); - HopRewriteUtils.cleanupUnreferenced(bop); - hi = unary; - applied = true; - - LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp3"); - } - - //select positive (selp) operator; pattern: max(0,X) -> selp+ - if( !applied && bop.getOp() == OpOp2.MAX && right.getDataType()==DataType.MATRIX - && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==0 ) - { - UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP); - HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); - HopRewriteUtils.cleanupUnreferenced(bop); - hi = unary; - applied = true; - - LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp4"); - } } return hi; http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/lops/Unary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Unary.java b/src/main/java/org/apache/sysml/lops/Unary.java index 03096c2..880181c 100644 --- a/src/main/java/org/apache/sysml/lops/Unary.java +++ b/src/main/java/org/apache/sysml/lops/Unary.java @@ -43,7 +43,7 @@ public class Unary extends Lop AND, OR, XOR, BW_AND, BW_OR, BW_XOR, BW_SHIFTL, BW_SHIFTR, ROUND, CEIL, FLOOR, MR_IQM, INVERSE, CHOLESKY, CUMSUM, CUMPROD, CUMMIN, CUMMAX, - SPROP, SIGMOID, SELP, SUBTRACT_NZ, LOG_NZ, + SPROP, SIGMOID, SUBTRACT_NZ, LOG_NZ, CAST_AS_MATRIX, CAST_AS_FRAME, NOTSUPPORTED } @@ -313,9 +313,6 @@ public class Unary extends Lop case SIGMOID: return "sigmoid"; - case SELP: - return "sel+"; - case CAST_AS_MATRIX: return UnaryCP.CAST_AS_MATRIX_OPCODE; http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java index 12da2bd..41cb709 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java @@ -51,7 +51,7 @@ public class Builtin extends ValueFunction public enum BuiltinCode { SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, - STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP } + STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID } public BuiltinCode bFunc; private static final boolean FASTMATH = true; @@ -95,7 +95,6 @@ public class Builtin extends ValueFunction String2BuiltinCode.put( "inverse", BuiltinCode.INVERSE); String2BuiltinCode.put( "sprop", BuiltinCode.SPROP); String2BuiltinCode.put( "sigmoid", BuiltinCode.SIGMOID); - String2BuiltinCode.put( "sel+", BuiltinCode.SELP); } // We should create one object for every builtin function that we support @@ -104,7 +103,7 @@ public class Builtin extends ValueFunction private static Builtin absObj = null, signObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null, printfObj; private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj=null, floorObj=null; private static Builtin inverseObj=null, cumsumObj=null, cumprodObj=null, cumminObj=null, cummaxObj=null; - private static Builtin stopObj = null, spropObj = null, sigmoidObj = null, selpObj = null; + private static Builtin stopObj = null, spropObj = null, sigmoidObj = null; private Builtin(BuiltinCode bf) { bFunc = bf; @@ -113,6 +112,10 @@ public class Builtin extends ValueFunction public BuiltinCode getBuiltinCode() { return bFunc; } + + public static boolean isBuiltinCode(ValueFunction fn, BuiltinCode code) { + return (fn instanceof Builtin && ((Builtin)fn).getBuiltinCode() == code); + } public static boolean isBuiltinFnObject(String str) { return String2BuiltinCode.containsKey(str); @@ -279,11 +282,6 @@ public class Builtin extends ValueFunction if ( sigmoidObj == null ) sigmoidObj = new Builtin(BuiltinCode.SIGMOID); return sigmoidObj; - - case SELP: - if ( selpObj == null ) - selpObj = new Builtin(BuiltinCode.SELP); - return selpObj; default: // Unknown code --> return null @@ -332,10 +330,6 @@ public class Builtin extends ValueFunction //sigmoid: 1/(1+exp(-x)) return FASTMATH ? 1 / (1 + FastMath.exp(-in)) : 1 / (1 + Math.exp(-in)); - case SELP: - //select positive: x*(x>0) - return (in > 0) ? in : 0; - default: throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc); } http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index a7b9d6b..e2bb20a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -179,7 +179,6 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "cholesky",CPType.Unary); String2CPInstructionType.put( "sprop", CPType.Unary); String2CPInstructionType.put( "sigmoid", CPType.Unary); - String2CPInstructionType.put( "sel+", CPType.Unary); String2CPInstructionType.put( "printf" , CPType.BuiltinNary); String2CPInstructionType.put( "cbind" , CPType.BuiltinNary); http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java index e234d52..138b4f5 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java @@ -81,7 +81,6 @@ public class GPUInstructionParser extends InstructionParser String2GPUInstructionType.put( "-*", GPUINSTRUCTION_TYPE.ArithmeticBinary); // Unary Builtin functions - String2GPUInstructionType.put( "sel+", GPUINSTRUCTION_TYPE.BuiltinUnary); String2GPUInstructionType.put( "exp", GPUINSTRUCTION_TYPE.BuiltinUnary); String2GPUInstructionType.put( "log", GPUINSTRUCTION_TYPE.BuiltinUnary); String2GPUInstructionType.put( "abs", GPUINSTRUCTION_TYPE.BuiltinUnary); http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java index 0e25d83..debafe1 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java @@ -153,7 +153,6 @@ public class MRInstructionParser extends InstructionParser String2MRInstructionType.put( "floor", MRType.Unary); String2MRInstructionType.put( "sprop", MRType.Unary); String2MRInstructionType.put( "sigmoid", MRType.Unary); - String2MRInstructionType.put( "sel+", MRType.Unary); String2MRInstructionType.put( "!", MRType.Unary); // Specific UNARY Instruction Opcodes http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java index d625853..1ba7776 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java @@ -245,7 +245,6 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "floor" , SPType.Unary); String2SPInstructionType.put( "sprop", SPType.Unary); String2SPInstructionType.put( "sigmoid", SPType.Unary); - String2SPInstructionType.put( "sel+", SPType.Unary); // Parameterized Builtin Functions String2SPInstructionType.put( "groupedagg" , SPType.ParameterizedBuiltin); http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java index 04760ee..f54bc80 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java @@ -44,8 +44,6 @@ public class MatrixBuiltinGPUInstruction extends BuiltinUnaryGPUInstruction { ec.setMetaData(_output.getName(), mat.getNumRows(), mat.getNumColumns()); switch(opcode) { - case "sel+": - LibMatrixCuDNN.relu(ec, ec.getGPUContext(0), getExtendedOpcode(), mat, _output.getName()); break; case "exp": LibMatrixCUDA.exp(ec, ec.getGPUContext(0), getExtendedOpcode(), mat, _output.getName()); break; case "sqrt": http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java index 16c8d2d..3ed12ab 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java @@ -22,6 +22,7 @@ package org.apache.sysml.runtime.matrix.data; import java.util.Arrays; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.functionobjects.Builtin; import org.apache.sysml.runtime.functionobjects.Divide; import org.apache.sysml.runtime.functionobjects.Equals; import org.apache.sysml.runtime.functionobjects.GreaterThan; @@ -38,6 +39,7 @@ import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.functionobjects.PlusMultiply; import org.apache.sysml.runtime.functionobjects.Power2; import org.apache.sysml.runtime.functionobjects.ValueFunction; +import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysml.runtime.matrix.operators.BinaryOperator; import org.apache.sysml.runtime.matrix.operators.ScalarOperator; import org.apache.sysml.runtime.util.DataConverter; @@ -935,11 +937,12 @@ public class LibMatrixBincell throw new DMLRuntimeException("Unsupported safe binary scalar operations over different input/output representation: "+m1.sparse+" "+ret.sparse); boolean copyOnes = (op.fn instanceof NotEquals && op.getConstant()==0); - boolean allocExact = (op.fn instanceof Multiply - || op.fn instanceof Multiply2 || op.fn instanceof Power2); + boolean allocExact = (op.fn instanceof Multiply || op.fn instanceof Multiply2 + || op.fn instanceof Power2 || Builtin.isBuiltinCode(op.fn, BuiltinCode.MAX) + || Builtin.isBuiltinCode(op.fn, BuiltinCode.MIN)); if( m1.sparse ) //SPARSE <- SPARSE - { + { //allocate sparse row structure ret.allocateSparseRowsBlock(); SparseBlock a = m1.sparseBlock; http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/matrix/operators/LeftScalarOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/LeftScalarOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/LeftScalarOperator.java index b47fe12..64cd914 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/LeftScalarOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/LeftScalarOperator.java @@ -21,11 +21,13 @@ package org.apache.sysml.runtime.matrix.operators; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.functionobjects.Builtin; import org.apache.sysml.runtime.functionobjects.GreaterThan; import org.apache.sysml.runtime.functionobjects.GreaterThanEquals; import org.apache.sysml.runtime.functionobjects.LessThan; import org.apache.sysml.runtime.functionobjects.LessThanEquals; import org.apache.sysml.runtime.functionobjects.ValueFunction; +import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; /** * Scalar operator for scalar-matrix operations with scalar @@ -40,7 +42,9 @@ public class LeftScalarOperator extends ScalarOperator super(p, cst, (p instanceof GreaterThan && cst<=0) || (p instanceof GreaterThanEquals && cst<0) || (p instanceof LessThan && cst>=0) - || (p instanceof LessThanEquals && cst>0)); + || (p instanceof LessThanEquals && cst>0) + || (Builtin.isBuiltinCode(p, BuiltinCode.MAX) && cst<=0) + || (Builtin.isBuiltinCode(p, BuiltinCode.MIN) && cst>=0)); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java index 9458270..fbd478e 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java @@ -21,6 +21,7 @@ package org.apache.sysml.runtime.matrix.operators; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.functionobjects.Builtin; import org.apache.sysml.runtime.functionobjects.Divide; import org.apache.sysml.runtime.functionobjects.GreaterThan; import org.apache.sysml.runtime.functionobjects.GreaterThanEquals; @@ -28,6 +29,7 @@ import org.apache.sysml.runtime.functionobjects.LessThan; import org.apache.sysml.runtime.functionobjects.LessThanEquals; import org.apache.sysml.runtime.functionobjects.Power; import org.apache.sysml.runtime.functionobjects.ValueFunction; +import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; /** * Scalar operator for scalar-matrix operations with scalar @@ -43,7 +45,9 @@ public class RightScalarOperator extends ScalarOperator || (p instanceof LessThan && cst<=0) || (p instanceof LessThanEquals && cst<0) || (p instanceof Divide && cst!=0) - || (p instanceof Power && cst!=0)); + || (p instanceof Power && cst!=0) + || (Builtin.isBuiltinCode(p, BuiltinCode.MAX) && cst<=0) + || (Builtin.isBuiltinCode(p, BuiltinCode.MIN) && cst>=0)); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java index 7f48252..8b3888e 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java @@ -41,8 +41,7 @@ public class UnaryOperator extends Operator || ((Builtin)p).bFunc==Builtin.BuiltinCode.SINH || ((Builtin)p).bFunc==Builtin.BuiltinCode.TANH || ((Builtin)p).bFunc==Builtin.BuiltinCode.ROUND || ((Builtin)p).bFunc==Builtin.BuiltinCode.ABS || ((Builtin)p).bFunc==Builtin.BuiltinCode.SQRT || ((Builtin)p).bFunc==Builtin.BuiltinCode.SPROP - || ((Builtin)p).bFunc==Builtin.BuiltinCode.SELP || ((Builtin)p).bFunc==Builtin.BuiltinCode.LOG_NZ - || ((Builtin)p).bFunc==Builtin.BuiltinCode.SIGN) ); + || ((Builtin)p).bFunc==Builtin.BuiltinCode.LOG_NZ || ((Builtin)p).bFunc==Builtin.BuiltinCode.SIGN) ); fn = p; k = numThreads; } http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSelectPosTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSelectPosTest.java b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSelectPosTest.java index 6f01d8b..eaf714b 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSelectPosTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSelectPosTest.java @@ -199,13 +199,11 @@ public class FullSelectPosTest extends AutomatedTestBase TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); //check generated opcode - if( rewrites ){ - String expected_op = "sel+"; - if ( instType == ExecType.SPARK ) - expected_op = Instruction.SP_INST_PREFIX+"sel+"; - + if( rewrites ) { + String expected_op = (instType == ExecType.SPARK) ? Instruction.SP_INST_PREFIX+"max" : "max"; if(instType == ExecType.CP || instType == ExecType.SPARK) - Assert.assertTrue("Missing opcode: " + expected_op , Statistics.getCPHeavyHitterOpCodes().contains(expected_op) || Statistics.getCPHeavyHitterOpCodes().contains("gpu_sel+")); + Assert.assertTrue("Missing opcode: " + expected_op , Statistics.getCPHeavyHitterOpCodes().contains(expected_op) + || Statistics.getCPHeavyHitterOpCodes().contains("gpu_max")); } } finally
