Repository: systemml
Updated Branches:
  refs/heads/master 0a0a40370 -> d03396f20


[SYSTEMML-1570] Removed unnecessary sel+ operator, improved max/min

This patch removes the unnecessary sel+ (select positive values)
operator and now compiles expressions like (X>0)*X to max(X,0) instead.
For sparse, this required additional improvements to make max and min
more efficient for these cases: specifically, we now (1) mark max and
min as conditionally sparse-safe according to the given constant, and
(2) we preallocate output sparse rows for min and max as well.

On a scenario of 20 iterations over a 10K x 10K input matrix, this patch
performs as follows compared to the removed sel+ operator:
a) dense: 22.9s -> 10.7s
b) sparse sp=0.1: 4.0s -> 3.7s (11.1s w/o max/min modifications)


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3737b9ad
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3737b9ad
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3737b9ad

Branch: refs/heads/master
Commit: 3737b9ad0e71a0b8f772fb75c7b2258932f4d8d4
Parents: 0a0a403
Author: Matthias Boehm <[email protected]>
Authored: Tue Jan 16 13:21:49 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Tue Jan 16 16:08:09 2018 -0800

----------------------------------------------------------------------
 .../org/apache/sysml/hops/ConvolutionOp.java    |  5 +-
 src/main/java/org/apache/sysml/hops/Hop.java    |  4 +-
 .../java/org/apache/sysml/hops/UnaryOp.java     | 16 ++----
 .../sysml/hops/codegen/cplan/CNodeUnary.java    |  7 +--
 .../codegen/opt/PlanSelectionFuseCostBased.java |  9 ++-
 .../opt/PlanSelectionFuseCostBasedV2.java       |  7 +--
 .../RewriteAlgebraicSimplificationStatic.java   | 59 ++++++--------------
 src/main/java/org/apache/sysml/lops/Unary.java  |  5 +-
 .../sysml/runtime/functionobjects/Builtin.java  | 18 ++----
 .../instructions/CPInstructionParser.java       |  1 -
 .../instructions/GPUInstructionParser.java      |  1 -
 .../instructions/MRInstructionParser.java       |  1 -
 .../instructions/SPInstructionParser.java       |  1 -
 .../gpu/MatrixBuiltinGPUInstruction.java        |  2 -
 .../runtime/matrix/data/LibMatrixBincell.java   |  9 ++-
 .../matrix/operators/LeftScalarOperator.java    |  6 +-
 .../matrix/operators/RightScalarOperator.java   |  6 +-
 .../runtime/matrix/operators/UnaryOperator.java |  3 +-
 .../unary/matrix/FullSelectPosTest.java         | 10 ++--
 19 files changed, 62 insertions(+), 108 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java 
b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index bdd3fa8..4c7525d 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops;
 
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.Hop.MultiThreadedHop;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.lops.ConvolutionTransform;
 import org.apache.sysml.lops.ConvolutionTransform.OperationTypes;
 import org.apache.sysml.lops.Lop;
@@ -164,7 +165,9 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
        }
        
        private static boolean isInputReLU(Hop input) {
-               return input instanceof UnaryOp && ((UnaryOp) input).getOp() == 
OpOp1.SELP;
+               return HopRewriteUtils.isBinary(input, OpOp2.MAX)
+                       && 
(HopRewriteUtils.isLiteralOfValue(input.getInput().get(0), 0)
+                       || 
HopRewriteUtils.isLiteralOfValue(input.getInput().get(1), 0));
        }
        
        private static boolean isInputConv2d(Hop input) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index a23008e..ae8c559 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1056,8 +1056,7 @@ public abstract class Hop implements ParseInfo
                CUMSUM, CUMPROD, CUMMIN, CUMMAX,
                //fused ML-specific operators for performance 
                SPROP, //sample proportion: P * (1 - P)
-               SIGMOID, //sigmoid function: 1 / (1 + exp(-X)) 
-               SELP, //select positive: X * (X>0)
+               SIGMOID, //sigmoid function: 1 / (1 + exp(-X))
                LOG_NZ, //sparse-safe log; ppred(X,0,"!=")*log(X)
        }
 
@@ -1312,7 +1311,6 @@ public abstract class Hop implements ParseInfo
                HopsOpOp1LopsU.put(OpOp1.CAST_AS_MATRIX, 
org.apache.sysml.lops.Unary.OperationTypes.NOTSUPPORTED);
                HopsOpOp1LopsU.put(OpOp1.SPROP, 
org.apache.sysml.lops.Unary.OperationTypes.SPROP);
                HopsOpOp1LopsU.put(OpOp1.SIGMOID, 
org.apache.sysml.lops.Unary.OperationTypes.SIGMOID);
-               HopsOpOp1LopsU.put(OpOp1.SELP, 
org.apache.sysml.lops.Unary.OperationTypes.SELP);
                HopsOpOp1LopsU.put(OpOp1.LOG_NZ, 
org.apache.sysml.lops.Unary.OperationTypes.LOG_NZ);
                HopsOpOp1LopsU.put(OpOp1.CAST_AS_MATRIX, 
org.apache.sysml.lops.Unary.OperationTypes.CAST_AS_MATRIX);
                HopsOpOp1LopsU.put(OpOp1.CAST_AS_FRAME, 
org.apache.sysml.lops.Unary.OperationTypes.CAST_AS_FRAME);

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/UnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java 
b/src/main/java/org/apache/sysml/hops/UnaryOp.java
index 2e03a2b..ff1b954 100644
--- a/src/main/java/org/apache/sysml/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java
@@ -107,7 +107,7 @@ public class UnaryOp extends Hop implements MultiThreadedHop
                                || (_op == OpOp1.CAST_AS_FRAME && 
getInput().get(0).getDataType()==DataType.SCALAR));
                if(!isScalar) {
                        switch(_op) {
-                               case SELP:case EXP:case SQRT:case LOG:case ABS:
+                               case EXP:case SQRT:case LOG:case ABS:
                                case ROUND:case FLOOR:case CEIL:
                                case SIN:case COS: case TAN:
                                case ASIN:case ACOS:case ATAN:
@@ -155,15 +155,7 @@ public class UnaryOp extends Hop implements 
MultiThreadedHop
                                        if( optype == null )
                                                throw new 
HopsException("Unknown UnaryCP lop type for UnaryOp operation type '"+_op+"'");
                                        
-                                       UnaryCP unary1 = null;
-                                       if((_op == Hop.OpOp1.NROW || _op == 
Hop.OpOp1.NCOL || _op == Hop.OpOp1.LENGTH) &&
-                                               input instanceof UnaryOp && 
((UnaryOp) input).getOp() == OpOp1.SELP) {
-                                               // Dimensions does not change 
during sel+ operation.
-                                               // This case is helpful to 
avoid unnecessary sel+ operation for fused maxpooling.
-                                               unary1 = new 
UnaryCP(input.getInput().get(0).constructLops(), optype, getDataType(), 
getValueType());
-                                       }
-                                       else
-                                               unary1 = new 
UnaryCP(input.constructLops(), optype, getDataType(), getValueType());
+                                       UnaryCP unary1 = new 
UnaryCP(input.constructLops(), optype, getDataType(), getValueType());
                                        setOutputDimensions(unary1);
                                        setLineNumbers(unary1);
 
@@ -606,12 +598,12 @@ public class UnaryOp extends Hop implements 
MultiThreadedHop
                                || _op==OpOp1.ACOS || _op==OpOp1.ASIN || 
_op==OpOp1.ATAN  
                                || _op==OpOp1.COSH || _op==OpOp1.SINH || 
_op==OpOp1.TANH 
                                || _op==OpOp1.SQRT || _op==OpOp1.ROUND  
-                               || _op==OpOp1.SPROP || _op==OpOp1.SELP ) 
//sparsity preserving
+                               || _op==OpOp1.SPROP ) //sparsity preserving
                        {
                                ret = new long[]{mc.getRows(), mc.getCols(), 
mc.getNonZeros()};
                        }
                        else 
-                               ret = new long[]{mc.getRows(), mc.getCols(), 
-1};       
+                               ret = new long[]{mc.getRows(), mc.getCols(), 
-1};
                }
                
                return ret;

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index 891bfb9..83ddd28 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -39,7 +39,7 @@ public class CNodeUnary extends CNode
                EXP, POW2, MULT2, SQRT, LOG, LOG_NZ,
                ABS, ROUND, CEIL, FLOOR, SIGN, 
                SIN, COS, TAN, ASIN, ACOS, ATAN, SINH, COSH, TANH,
-               SELP, SPROP, SIGMOID; 
+               SPROP, SIGMOID; 
                
                public static boolean contains(String value) {
                        for( UnaryType ut : values()  )
@@ -135,8 +135,6 @@ public class CNodeUnary extends CNode
                                        return "    double %TMP% = 
FastMath.ceil(%IN1%);\n";
                                case FLOOR:
                                        return "    double %TMP% = 
FastMath.floor(%IN1%);\n";
-                               case SELP:
-                                       return "    double %TMP% = (%IN1%>0) ? 
%IN1% : 0;\n";
                                case SPROP:
                                        return "    double %TMP% = %IN1% * (1 - 
%IN1%);\n";
                                case SIGMOID:
@@ -174,7 +172,7 @@ public class CNodeUnary extends CNode
                public boolean isSparseSafeScalar() {
                        return ArrayUtils.contains(new UnaryType[]{
                                POW2, MULT2, ABS, ROUND, CEIL, FLOOR, SIGN, 
-                               SIN, TAN, SELP, SPROP}, this);
+                               SIN, TAN, SPROP}, this);
                }
        }
        
@@ -337,7 +335,6 @@ public class CNodeUnary extends CNode
                        case ROUND:
                        case CEIL:
                        case FLOOR:
-                       case SELP:
                        case SPROP:
                        case SIGMOID:
                        case LOG_NZ:

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
index 521ef61..9cddf57 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
@@ -632,17 +632,16 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                double costs = 1;
                if( current instanceof UnaryOp ) {
                        switch( ((UnaryOp)current).getOp() ) {
-                               case ABS:   
+                               case ABS:
                                case ROUND:
                                case CEIL:
                                case FLOOR:
-                               case SIGN:
-                               case SELP:    costs = 1; break; 
+                               case SIGN:    costs = 1; break; 
                                case SPROP:
                                case SQRT:    costs = 2; break;
                                case EXP:     costs = 18; break;
                                case SIGMOID: costs = 21; break;
-                               case LOG:    
+                               case LOG:
                                case LOG_NZ:  costs = 32; break;
                                case NCOL:
                                case NROW:
@@ -684,7 +683,7 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                                case LESS:
                                case LESSEQUAL:
                                case GREATER:
-                               case GREATEREQUAL: 
+                               case GREATEREQUAL:
                                case CBIND:
                                case RBIND:   costs = 1; break;
                                case INTDIV:  costs = 6; break;

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 00b1543..68718d3 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -924,17 +924,16 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                double costs = 1;
                if( current instanceof UnaryOp ) {
                        switch( ((UnaryOp)current).getOp() ) {
-                               case ABS:   
+                               case ABS:
                                case ROUND:
                                case CEIL:
                                case FLOOR:
-                               case SIGN:
-                               case SELP:    costs = 1; break; 
+                               case SIGN:    costs = 1; break; 
                                case SPROP:
                                case SQRT:    costs = 2; break;
                                case EXP:     costs = 18; break;
                                case SIGMOID: costs = 21; break;
-                               case LOG:    
+                               case LOG:
                                case LOG_NZ:  costs = 32; break;
                                case NCOL:
                                case NROW:

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 963e578..d72560b 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1116,8 +1116,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                {
                        UnaryOp uop = (UnaryOp) hi; //valid unary op
                        if( uop.getOp()==OpOp1.ABS || uop.getOp()==OpOp1.SIGN
-                               || uop.getOp()==OpOp1.SELP || 
uop.getOp()==OpOp1.CEIL
-                               || uop.getOp()==OpOp1.FLOOR || 
uop.getOp()==OpOp1.ROUND )
+                               || uop.getOp()==OpOp1.CEIL || 
uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND )
                        {
                                //clear link unary-binary
                                Hop input = uop.getInput().get(0);
@@ -1293,67 +1292,41 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                                {
                                        BinaryOp bleft = (BinaryOp)left;
                                        Hop left1 = bleft.getInput().get(0);
-                                       Hop left2 = bleft.getInput().get(1);    
        
+                                       Hop left2 = bleft.getInput().get(1);
                                
                                        if( left2 instanceof LiteralOp &&
-                                               
HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 &&  
-                                               left1 == right && bleft.getOp() 
== OpOp2.GREATER  ) 
+                                               
HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 &&
+                                               left1 == right && 
(bleft.getOp() == OpOp2.GREATER || bleft.getOp() == OpOp2.GREATEREQUAL ) )
                                        {
-                                               UnaryOp unary = 
HopRewriteUtils.createUnary(right, OpOp1.SELP);
-                                               
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
+                                               BinaryOp binary = 
HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX);
+                                               
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
                                                
HopRewriteUtils.cleanupUnreferenced(bop, left);
-                                               hi = unary;
+                                               hi = binary;
                                                applied = true;
                                                
-                                               LOG.debug("Applied 
fuseBinarySubDAGToUnaryOperation-selp1");
+                                               LOG.debug("Applied 
fuseBinarySubDAGToUnaryOperation-max0a");
                                        }
-                               }                               
+                               }
                                if( !applied && right instanceof BinaryOp ) 
//X*(X>0)
                                {
                                        BinaryOp bright = (BinaryOp)right;
                                        Hop right1 = bright.getInput().get(0);
-                                       Hop right2 = bright.getInput().get(1);  
        
+                                       Hop right2 = bright.getInput().get(1);
                                
                                        if( right2 instanceof LiteralOp &&
-                                               
HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 && 
-                                               right1 == left && 
bright.getOp() == OpOp2.GREATER )
+                                               
HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 &&
+                                               right1 == left && 
bright.getOp() == OpOp2.GREATER || bright.getOp() == OpOp2.GREATEREQUAL )
                                        {
-                                               UnaryOp unary = 
HopRewriteUtils.createUnary(left, OpOp1.SELP);
-                                               
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
+                                               BinaryOp binary = 
HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX);
+                                               
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
                                                
HopRewriteUtils.cleanupUnreferenced(bop, left);
-                                               hi = unary;
+                                               hi = binary;
                                                applied= true;
                                                
-                                               LOG.debug("Applied 
fuseBinarySubDAGToUnaryOperation-selp2");
+                                               LOG.debug("Applied 
fuseBinarySubDAGToUnaryOperation-max0b");
                                        }
                                }
                        }
-                       
-                       //select positive (selp) operator; pattern: max(X,0) -> 
selp+
-                       if( !applied && bop.getOp() == OpOp2.MAX && 
left.getDataType()==DataType.MATRIX 
-                                       && right instanceof LiteralOp && 
HopRewriteUtils.getDoubleValue((LiteralOp)right)==0 )
-                       {
-                               UnaryOp unary = 
HopRewriteUtils.createUnary(left, OpOp1.SELP);
-                               HopRewriteUtils.replaceChildReference(parent, 
bop, unary, pos);
-                               HopRewriteUtils.cleanupUnreferenced(bop);
-                               hi = unary;
-                               applied = true;
-                               
-                               LOG.debug("Applied 
fuseBinarySubDAGToUnaryOperation-selp3");
-                       }
-                       
-                       //select positive (selp) operator; pattern: max(0,X) -> 
selp+
-                       if( !applied && bop.getOp() == OpOp2.MAX && 
right.getDataType()==DataType.MATRIX 
-                                       && left instanceof LiteralOp && 
HopRewriteUtils.getDoubleValue((LiteralOp)left)==0 )
-                       {
-                               UnaryOp unary = 
HopRewriteUtils.createUnary(right, OpOp1.SELP);
-                               HopRewriteUtils.replaceChildReference(parent, 
bop, unary, pos);
-                               HopRewriteUtils.cleanupUnreferenced(bop);
-                               hi = unary;
-                               applied = true;
-                               
-                               LOG.debug("Applied 
fuseBinarySubDAGToUnaryOperation-selp4");
-                       }
                }
                
                return hi;

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/lops/Unary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Unary.java 
b/src/main/java/org/apache/sysml/lops/Unary.java
index 03096c2..880181c 100644
--- a/src/main/java/org/apache/sysml/lops/Unary.java
+++ b/src/main/java/org/apache/sysml/lops/Unary.java
@@ -43,7 +43,7 @@ public class Unary extends Lop
                AND, OR, XOR, BW_AND, BW_OR, BW_XOR, BW_SHIFTL, BW_SHIFTR,
                ROUND, CEIL, FLOOR, MR_IQM, INVERSE, CHOLESKY,
                CUMSUM, CUMPROD, CUMMIN, CUMMAX,
-               SPROP, SIGMOID, SELP, SUBTRACT_NZ, LOG_NZ,
+               SPROP, SIGMOID, SUBTRACT_NZ, LOG_NZ,
                CAST_AS_MATRIX, CAST_AS_FRAME,
                NOTSUPPORTED
        }
@@ -313,9 +313,6 @@ public class Unary extends Lop
                case SIGMOID:
                        return "sigmoid";
                
-               case SELP:
-                       return "sel+";
-               
                case CAST_AS_MATRIX:
                        return UnaryCP.CAST_AS_MATRIX_OPCODE;
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java 
b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
index 12da2bd..41cb709 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
@@ -51,7 +51,7 @@ public class Builtin extends ValueFunction
        
        public enum BuiltinCode { SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, 
ATAN, LOG, LOG_NZ, MIN,
                MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, 
LENGTH, ROUND, MAXINDEX, MININDEX,
-               STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, 
SPROP, SIGMOID, SELP }
+               STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, 
SPROP, SIGMOID }
        public BuiltinCode bFunc;
        
        private static final boolean FASTMATH = true;
@@ -95,7 +95,6 @@ public class Builtin extends ValueFunction
                String2BuiltinCode.put( "inverse", BuiltinCode.INVERSE);
                String2BuiltinCode.put( "sprop",   BuiltinCode.SPROP);
                String2BuiltinCode.put( "sigmoid", BuiltinCode.SIGMOID);
-               String2BuiltinCode.put( "sel+",    BuiltinCode.SELP);
        }
        
        // We should create one object for every builtin function that we 
support
@@ -104,7 +103,7 @@ public class Builtin extends ValueFunction
        private static Builtin absObj = null, signObj = null, sqrtObj = null, 
expObj = null, plogpObj = null, printObj = null, printfObj;
        private static Builtin nrowObj = null, ncolObj = null, lengthObj = 
null, roundObj = null, ceilObj=null, floorObj=null; 
        private static Builtin inverseObj=null, cumsumObj=null, 
cumprodObj=null, cumminObj=null, cummaxObj=null;
-       private static Builtin stopObj = null, spropObj = null, sigmoidObj = 
null, selpObj = null;
+       private static Builtin stopObj = null, spropObj = null, sigmoidObj = 
null;
        
        private Builtin(BuiltinCode bf) {
                bFunc = bf;
@@ -113,6 +112,10 @@ public class Builtin extends ValueFunction
        public BuiltinCode getBuiltinCode() {
                return bFunc;
        }
+       
+       public static boolean isBuiltinCode(ValueFunction fn, BuiltinCode code) 
{
+               return (fn instanceof Builtin && ((Builtin)fn).getBuiltinCode() 
== code);
+       }
 
        public static boolean isBuiltinFnObject(String str) {
                return String2BuiltinCode.containsKey(str);
@@ -279,11 +282,6 @@ public class Builtin extends ValueFunction
                        if ( sigmoidObj == null )
                                sigmoidObj = new Builtin(BuiltinCode.SIGMOID);
                        return sigmoidObj;
-               
-               case SELP:
-                       if ( selpObj == null )
-                               selpObj = new Builtin(BuiltinCode.SELP);
-                       return selpObj;
 
                default:
                        // Unknown code --> return null
@@ -332,10 +330,6 @@ public class Builtin extends ValueFunction
                                //sigmoid: 1/(1+exp(-x))
                                return FASTMATH ? 1 / (1 + FastMath.exp(-in))  
: 1 / (1 + Math.exp(-in));
                        
-                       case SELP:
-                               //select positive: x*(x>0)
-                               return (in > 0) ? in : 0;
-                               
                        default:
                                throw new 
DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index a7b9d6b..e2bb20a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -179,7 +179,6 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "cholesky",CPType.Unary);
                String2CPInstructionType.put( "sprop", CPType.Unary);
                String2CPInstructionType.put( "sigmoid", CPType.Unary);
-               String2CPInstructionType.put( "sel+", CPType.Unary);
                
                String2CPInstructionType.put( "printf" , CPType.BuiltinNary);
                String2CPInstructionType.put( "cbind" , CPType.BuiltinNary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index e234d52..138b4f5 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -81,7 +81,6 @@ public class GPUInstructionParser  extends InstructionParser
                String2GPUInstructionType.put( "-*",   
GPUINSTRUCTION_TYPE.ArithmeticBinary);
                
                // Unary Builtin functions
-               String2GPUInstructionType.put( "sel+",  
GPUINSTRUCTION_TYPE.BuiltinUnary);
                String2GPUInstructionType.put( "exp",   
GPUINSTRUCTION_TYPE.BuiltinUnary);
                String2GPUInstructionType.put( "log",   
GPUINSTRUCTION_TYPE.BuiltinUnary);
                String2GPUInstructionType.put( "abs",   
GPUINSTRUCTION_TYPE.BuiltinUnary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
index 0e25d83..debafe1 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
@@ -153,7 +153,6 @@ public class MRInstructionParser extends InstructionParser
                String2MRInstructionType.put( "floor", MRType.Unary);
                String2MRInstructionType.put( "sprop", MRType.Unary);
                String2MRInstructionType.put( "sigmoid", MRType.Unary);
-               String2MRInstructionType.put( "sel+", MRType.Unary);
                String2MRInstructionType.put( "!", MRType.Unary);
                
                // Specific UNARY Instruction Opcodes

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index d625853..1ba7776 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -245,7 +245,6 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "floor" , SPType.Unary);
                String2SPInstructionType.put( "sprop", SPType.Unary);
                String2SPInstructionType.put( "sigmoid", SPType.Unary);
-               String2SPInstructionType.put( "sel+", SPType.Unary);
                
                // Parameterized Builtin Functions
                String2SPInstructionType.put( "groupedagg"       , 
SPType.ParameterizedBuiltin);

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java
index 04760ee..f54bc80 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java
@@ -44,8 +44,6 @@ public class MatrixBuiltinGPUInstruction extends 
BuiltinUnaryGPUInstruction {
                ec.setMetaData(_output.getName(), mat.getNumRows(), 
mat.getNumColumns());
 
                switch(opcode) {
-                       case "sel+":
-                               LibMatrixCuDNN.relu(ec, ec.getGPUContext(0), 
getExtendedOpcode(), mat, _output.getName()); break;
                        case "exp":
                                LibMatrixCUDA.exp(ec, ec.getGPUContext(0), 
getExtendedOpcode(), mat, _output.getName()); break;
                        case "sqrt":

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
index 16c8d2d..3ed12ab 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
@@ -22,6 +22,7 @@ package org.apache.sysml.runtime.matrix.data;
 import java.util.Arrays;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.functionobjects.Builtin;
 import org.apache.sysml.runtime.functionobjects.Divide;
 import org.apache.sysml.runtime.functionobjects.Equals;
 import org.apache.sysml.runtime.functionobjects.GreaterThan;
@@ -38,6 +39,7 @@ import org.apache.sysml.runtime.functionobjects.Plus;
 import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.functionobjects.Power2;
 import org.apache.sysml.runtime.functionobjects.ValueFunction;
+import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysml.runtime.matrix.operators.ScalarOperator;
 import org.apache.sysml.runtime.util.DataConverter;
@@ -935,11 +937,12 @@ public class LibMatrixBincell
                        throw new DMLRuntimeException("Unsupported safe binary 
scalar operations over different input/output representation: "+m1.sparse+" 
"+ret.sparse);
                
                boolean copyOnes = (op.fn instanceof NotEquals && 
op.getConstant()==0);
-               boolean allocExact = (op.fn instanceof Multiply 
-                       || op.fn instanceof Multiply2 || op.fn instanceof 
Power2);
+               boolean allocExact = (op.fn instanceof Multiply || op.fn 
instanceof Multiply2 
+                       || op.fn instanceof Power2 || 
Builtin.isBuiltinCode(op.fn, BuiltinCode.MAX)
+                       || Builtin.isBuiltinCode(op.fn, BuiltinCode.MIN));
                
                if( m1.sparse ) //SPARSE <- SPARSE
-               {       
+               {
                        //allocate sparse row structure
                        ret.allocateSparseRowsBlock();
                        SparseBlock a = m1.sparseBlock;

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/matrix/operators/LeftScalarOperator.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/operators/LeftScalarOperator.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/operators/LeftScalarOperator.java
index b47fe12..64cd914 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/operators/LeftScalarOperator.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/operators/LeftScalarOperator.java
@@ -21,11 +21,13 @@
 package org.apache.sysml.runtime.matrix.operators;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.functionobjects.Builtin;
 import org.apache.sysml.runtime.functionobjects.GreaterThan;
 import org.apache.sysml.runtime.functionobjects.GreaterThanEquals;
 import org.apache.sysml.runtime.functionobjects.LessThan;
 import org.apache.sysml.runtime.functionobjects.LessThanEquals;
 import org.apache.sysml.runtime.functionobjects.ValueFunction;
+import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
 
 /**
  * Scalar operator for scalar-matrix operations with scalar 
@@ -40,7 +42,9 @@ public class LeftScalarOperator extends ScalarOperator
                super(p, cst, (p instanceof GreaterThan && cst<=0)
                        || (p instanceof GreaterThanEquals && cst<0)
                        || (p instanceof LessThan && cst>=0)
-                       || (p instanceof LessThanEquals && cst>0));
+                       || (p instanceof LessThanEquals && cst>0)
+                       || (Builtin.isBuiltinCode(p, BuiltinCode.MAX) && cst<=0)
+                       || (Builtin.isBuiltinCode(p, BuiltinCode.MIN) && 
cst>=0));
        }
        
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java
index 9458270..fbd478e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java
@@ -21,6 +21,7 @@
 package org.apache.sysml.runtime.matrix.operators;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.functionobjects.Builtin;
 import org.apache.sysml.runtime.functionobjects.Divide;
 import org.apache.sysml.runtime.functionobjects.GreaterThan;
 import org.apache.sysml.runtime.functionobjects.GreaterThanEquals;
@@ -28,6 +29,7 @@ import org.apache.sysml.runtime.functionobjects.LessThan;
 import org.apache.sysml.runtime.functionobjects.LessThanEquals;
 import org.apache.sysml.runtime.functionobjects.Power;
 import org.apache.sysml.runtime.functionobjects.ValueFunction;
+import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
 
 /**
  * Scalar operator for scalar-matrix operations with scalar 
@@ -43,7 +45,9 @@ public class RightScalarOperator extends ScalarOperator
                        || (p instanceof LessThan && cst<=0)
                        || (p instanceof LessThanEquals && cst<0)
                        || (p instanceof Divide && cst!=0)
-                       || (p instanceof Power && cst!=0));
+                       || (p instanceof Power && cst!=0)
+                       || (Builtin.isBuiltinCode(p, BuiltinCode.MAX) && cst<=0)
+                       || (Builtin.isBuiltinCode(p, BuiltinCode.MIN) && 
cst>=0));
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java 
b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
index 7f48252..8b3888e 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
@@ -41,8 +41,7 @@ public class UnaryOperator extends Operator
                        || ((Builtin)p).bFunc==Builtin.BuiltinCode.SINH || 
((Builtin)p).bFunc==Builtin.BuiltinCode.TANH
                        || ((Builtin)p).bFunc==Builtin.BuiltinCode.ROUND || 
((Builtin)p).bFunc==Builtin.BuiltinCode.ABS
                        || ((Builtin)p).bFunc==Builtin.BuiltinCode.SQRT || 
((Builtin)p).bFunc==Builtin.BuiltinCode.SPROP
-                       || ((Builtin)p).bFunc==Builtin.BuiltinCode.SELP || 
((Builtin)p).bFunc==Builtin.BuiltinCode.LOG_NZ
-                       || ((Builtin)p).bFunc==Builtin.BuiltinCode.SIGN) );
+                       || ((Builtin)p).bFunc==Builtin.BuiltinCode.LOG_NZ || 
((Builtin)p).bFunc==Builtin.BuiltinCode.SIGN) );
                fn = p;
                k = numThreads;
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/3737b9ad/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSelectPosTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSelectPosTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSelectPosTest.java
index 6f01d8b..eaf714b 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSelectPosTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullSelectPosTest.java
@@ -199,13 +199,11 @@ public class FullSelectPosTest extends AutomatedTestBase
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
                        
                        //check generated opcode
-                       if( rewrites ){
-                               String expected_op = "sel+";
-                               if ( instType == ExecType.SPARK )
-                                       expected_op = 
Instruction.SP_INST_PREFIX+"sel+";
-                               
+                       if( rewrites ) {
+                               String expected_op = (instType == 
ExecType.SPARK) ? Instruction.SP_INST_PREFIX+"max" : "max";
                                if(instType == ExecType.CP || instType == 
ExecType.SPARK)
-                                       Assert.assertTrue("Missing opcode: " + 
expected_op , Statistics.getCPHeavyHitterOpCodes().contains(expected_op) || 
Statistics.getCPHeavyHitterOpCodes().contains("gpu_sel+"));
+                                       Assert.assertTrue("Missing opcode: " + 
expected_op , Statistics.getCPHeavyHitterOpCodes().contains(expected_op)
+                                               || 
Statistics.getCPHeavyHitterOpCodes().contains("gpu_max"));
                        }
                }
                finally

Reply via email to