[SYSTEMML-1361] Extended code generator for existing cellwise fused ops

This patch adds the existing cell-wise fused unary operators selp,
sprop, sigmoid, log_nz, as well as ternary operators +* and -* in order
to prevent these operators from breaking fusion boundaries. 

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/9b69f36a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/9b69f36a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/9b69f36a

Branch: refs/heads/master
Commit: 9b69f36a9aad831b8e78c7e45be3d7c80386ec01
Parents: 6aab005
Author: Matthias Boehm <[email protected]>
Authored: Wed Mar 1 22:01:41 2017 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Wed Mar 1 22:01:41 2017 -0800

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeBinary.java   |   5 +-
 .../sysml/hops/codegen/cplan/CNodeTernary.java  | 136 +++++++++++++++++++
 .../sysml/hops/codegen/cplan/CNodeUnary.java    |  33 +++--
 .../sysml/hops/codegen/template/CellTpl.java    |  34 ++++-
 .../hops/codegen/template/TemplateUtils.java    |  11 +-
 .../sysml/hops/rewrite/HopRewriteUtils.java     |   9 +-
 6 files changed, 205 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index 1bfaab4..d11d0ac 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -34,10 +34,10 @@ public class CNodeBinary extends CNode
                LESS, LESSEQUAL, GREATER, GREATEREQUAL, EQUAL,NOTEQUAL,
                MIN, MAX, AND, OR, LOG, POW,
                MINUS1_MULT;
-               
+
                public static boolean contains(String value) {
                        for( BinType bt : values()  )
-                               if( bt.toString().equals(value) )
+                               if( bt.name().equals(value) )
                                        return true;
                        return false;
                }
@@ -188,6 +188,7 @@ public class CNodeBinary extends CNode
                }
        }
        
+       @Override
        public void setOutputDims()
        {
                switch(_type) {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
new file mode 100644
index 0000000..9fdae3b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.cplan;
+
+import java.util.Arrays;
+
+import org.apache.sysml.parser.Expression.DataType;
+
+
+public class CNodeTernary extends CNode
+{
+       public enum TernaryType {
+               PLUS_MULT, MINUS_MULT;
+               
+               public static boolean contains(String value) {
+                       for( TernaryType tt : values()  )
+                               if( tt.name().equals(value) )
+                                       return true;
+                       return false;
+               }
+               
+               public String getTemplate(boolean sparse) {
+                       switch (this) {
+                               case PLUS_MULT:
+                                       return "    double %TMP% = %IN1% + 
%IN2% * %IN3%;\n" ;
+                               
+                               case MINUS_MULT:
+                                       return "    double %TMP% = %IN1% - 
%IN2% * %IN3%;\n;\n" ;
+                                       
+                               default: 
+                                       throw new RuntimeException("Invalid 
ternary type: "+this.toString());
+                       }
+               }
+       }
+       
+       private final TernaryType _type;
+       
+       public CNodeTernary( CNode in1, CNode in2, CNode in3, TernaryType type 
) {
+               _inputs.add(in1);
+               _inputs.add(in2);
+               _inputs.add(in3);
+               _type = type;
+               setOutputDims();
+       }
+
+       public TernaryType getType() {
+               return _type;
+       }
+       
+       @Override
+       public String codegen(boolean sparse) {
+               if( _generated )
+                       return "";
+                       
+               StringBuilder sb = new StringBuilder();
+               
+               //generate children
+               sb.append(_inputs.get(0).codegen(sparse));
+               sb.append(_inputs.get(1).codegen(sparse));
+               sb.append(_inputs.get(2).codegen(sparse));
+               
+               //generate binary operation
+               String var = createVarname();
+               String tmp = _type.getTemplate(sparse);
+               tmp = tmp.replaceAll("%TMP%", var);
+               for( int j=1; j<=3; j++ ) {
+                       String varj = _inputs.get(j-1).getVarname();
+                       tmp = tmp.replaceAll("%IN"+j+"%", varj );
+               }
+               sb.append(tmp);
+               
+               //mark as generated
+               _generated = true;
+               
+               return sb.toString();
+       }
+       
+       @Override
+       public String toString() {
+               switch(_type) {
+                       case PLUS_MULT: return "t(+*)";
+                       case MINUS_MULT: return "t(-*)";
+                       default:
+                               return super.toString();        
+               }
+       }
+       
+       @Override
+       public void setOutputDims() {
+               switch(_type) {
+                       case PLUS_MULT: 
+                       case MINUS_MULT:
+                               _rows = 0;
+                               _cols = 0;
+                               _dataType= DataType.SCALAR;
+                               break;
+               }
+       }
+       
+       @Override
+       public int hashCode() {
+               if( _hash == 0 ) {
+                       int h1 = super.hashCode();
+                       int h2 = _type.hashCode();
+                       _hash = Arrays.hashCode(new int[]{h1,h2});
+               }
+               return _hash;
+       }
+       
+       @Override 
+       public boolean equals(Object o) {
+               if( !(o instanceof CNodeTernary) )
+                       return false;
+               
+               CNodeTernary that = (CNodeTernary) o;
+               return super.equals(that)
+                       && _type == that._type;
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index f08769e..dd8431b 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -27,22 +27,21 @@ import org.apache.sysml.parser.Expression.DataType;
 public class CNodeUnary extends CNode
 {
        public enum UnaryType {
-               ROW_SUMS, LOOKUP, LOOKUP0,
+               ROW_SUMS, LOOKUP, LOOKUP0, //codegen specific
                EXP, POW2, MULT2, SQRT, LOG,
-               ABS, ROUND, CEIL,FLOOR, SIGN, 
+               ABS, ROUND, CEIL, FLOOR, SIGN, 
                SIN, COS, TAN, ASIN, ACOS, ATAN,
-               IQM, STOP,
-               DOTPRODUCT_ROW_SUMS; //row sums via dot product for debugging 
purposes
+               SELP, SPROP, SIGMOID, LOG_NZ; 
                
                public static boolean contains(String value) {
                        for( UnaryType ut : values()  )
-                               if( ut.toString().equals(value) )
+                               if( ut.name().equals(value) )
                                        return true;
                        return false;
                }
                
                public String getTemplate(boolean sparse) {
-                       switch (this) {
+                       switch( this ) {
                                case ROW_SUMS:
                                        return sparse ? "    double %TMP% = 
LibSpoofPrimitives.vectSum( %IN1v%, %IN1i%, %POS1%, %LEN%);\n": 
                                                                        "    
double %TMP% = LibSpoofPrimitives.vectSum( %IN1%, %POS1%,  %LEN%);\n"; 
@@ -82,8 +81,17 @@ public class CNodeUnary extends CNode
                                        return "    double %TMP% = 
Math.ceil(%IN1%);\n";
                                case FLOOR:
                                        return "    double %TMP% = 
Math.floor(%IN1%);\n";
+                               case SELP:
+                                       return "    double %TMP% = (%IN1%>0) ? 
%IN1% : 0;\n";
+                               case SPROP:
+                                       return "    double %TMP% = %IN1% * (1 - 
%IN1%);\n";
+                               case SIGMOID:
+                                       return "    double %TMP% = 1 / (1 + 
FastMath.exp(-%IN1%));\n";
+                               case LOG_NZ:
+                                       return "    double %TMP% = (%IN1%==0) ? 
0 : FastMath.log(%IN1%);\n";
+                                       
                                default: 
-                                       throw new RuntimeException("Invalid 
binary type: "+this.toString());
+                                       throw new RuntimeException("Invalid 
unary type: "+this.toString());
                        }
                }
        }
@@ -150,15 +158,14 @@ public class CNodeUnary extends CNode
 
        @Override
        public void setOutputDims() {
-               switch(_type)
-               {
+               switch(_type) {
                        case ROW_SUMS:
                        case EXP:
                        case LOOKUP:
                        case LOOKUP0:   
                        case POW2:
                        case MULT2:     
-                       case ABS:
+                       case ABS:  
                        case SIN:
                        case COS: 
                        case TAN:
@@ -169,10 +176,12 @@ public class CNodeUnary extends CNode
                        case SQRT:
                        case LOG:
                        case ROUND: 
-                       case IQM:
-                       case STOP:
                        case CEIL:
                        case FLOOR:
+                       case SELP:      
+                       case SPROP:
+                       case SIGMOID:
+                       case LOG_NZ:
                                _rows = 0;
                                _cols = 0;
                                _dataType= DataType.SCALAR;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java
index 0c841e8..4d95686 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java
@@ -33,6 +33,7 @@ import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.codegen.cplan.CNode;
 import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
@@ -41,6 +42,9 @@ import org.apache.sysml.hops.codegen.cplan.CNodeData;
 import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
 import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
+import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
+import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;
 import org.apache.sysml.runtime.matrix.data.Pair;
@@ -67,16 +71,17 @@ public class CellTpl extends BaseTpl
                        return false;
                        
                //re-assign initialHop to fuse the sum/rowsums (before checking 
for chains)
-               for (Hop h : _initialHop.getParent())
+               //TODO add aggbinary (vector tsmm) as potential head for 
cellwise operation
+               for (Hop h : _initialHop.getParent()) {
                        if( h instanceof AggUnaryOp && ((AggUnaryOp) h).getOp() 
== AggOp.SUM 
                                && ((AggUnaryOp) h).getDirection()!= 
Direction.Col ) {
                                _initialHop = h;  
                        }
+               }
                
                //unary matrix && endHop found && endHop is not direct child of 
the initialHop (i.e., chain of operators)
                if(_endHop != null && _endHop != _initialHop)
                {
-                       
                        // if final hop is unary add its child to the input 
                        if(_endHop instanceof UnaryOp)
                                _matrixInputs.add(_endHop.getInput().get(0));
@@ -199,7 +204,6 @@ public class CellTpl extends BaseTpl
                                        if( TemplateUtils.isColVector(cdata2) )
                                                cdata2 = new CNodeUnary(cdata2, 
UnaryType.LOOKUP);
                                        
-                                       
                                        if( bop.getOp()==OpOp2.POW && 
cdata2.isLiteral() && cdata2.getVarname().equals("2") )
                                                out = new CNodeUnary(cdata1, 
UnaryType.POW2);
                                        else if( bop.getOp()==OpOp2.MULT && 
cdata2.isLiteral() && cdata2.getVarname().equals("2") )
@@ -207,6 +211,24 @@ public class CellTpl extends BaseTpl
                                        else //default binary   
                                                out = new CNodeBinary(cdata1, 
cdata2, BinType.valueOf(primitiveOpName));
                                }
+                               else if(hop instanceof TernaryOp) 
+                               {
+                                       TernaryOp top = (TernaryOp) hop;
+                                       CNode cdata1 = cnodeData.get(0);
+                                       CNode cdata2 = cnodeData.get(1);
+                                       CNode cdata3 = cnodeData.get(2);
+                                       
+                                       //cdata1 is vector
+                                       if( TemplateUtils.isColVector(cdata1) )
+                                               cdata1 = new CNodeUnary(cdata1, 
UnaryType.LOOKUP);
+                                       //cdata3 is vector
+                                       if( TemplateUtils.isColVector(cdata3) )
+                                               cdata3 = new CNodeUnary(cdata3, 
UnaryType.LOOKUP);
+                                       
+                                       //construct ternary cnode, primitive 
operation derived from OpOp3
+                                       out = new CNodeTernary(cdata1, cdata2, 
cdata3, 
+                                                       
TernaryType.valueOf(top.getOp().toString()));
+                               }
                                else if (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getOp() == AggOp.SUM
                                        && (((AggUnaryOp) hop).getDirection() 
== Direction.RowCol 
                                        || ((AggUnaryOp) hop).getDirection() == 
Direction.Row) && root == hop)
@@ -283,7 +305,11 @@ public class CellTpl extends BaseTpl
                                && 
TemplateUtils.isVectorOrScalar(hop.getInput().get(1)) && 
!TemplateUtils.isBinaryMatrixRowVector(hop)) 
                        ||(TemplateUtils.isVectorOrScalar( 
hop.getInput().get(0))  
                                && hop.getInput().get(1).getDataType() == 
DataType.MATRIX && !TemplateUtils.isBinaryMatrixRowVector(hop)) );
+               boolean isTernaryVectorScalarVector = hop instanceof TernaryOp 
&& hop.getInput().size()==3 && hop.dimsKnown()
+                               && HopRewriteUtils.checkInputDataTypes(hop, 
DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)
+                               && 
TemplateUtils.isVector(hop.getInput().get(0)) && 
TemplateUtils.isVector(hop.getInput().get(2));
+               
                return hop.getDataType() == DataType.MATRIX && 
TemplateUtils.isOperationSupported(hop)
-                       && (hop instanceof UnaryOp || isBinaryMatrixScalar || 
isBinaryMatrixVector);    
+                       && (hop instanceof UnaryOp || isBinaryMatrixScalar || 
isBinaryMatrixVector || isTernaryVectorScalarVector);     
        }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index fd8a960..32a4d80 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -33,6 +33,7 @@ import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.ReOrgOp;
@@ -45,6 +46,7 @@ import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct;
 import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
 import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
+import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;
 import org.apache.sysml.runtime.matrix.data.Pair;
@@ -231,11 +233,12 @@ public class TemplateUtils
 
        public static boolean isOperationSupported(Hop h) {
                if(h instanceof  UnaryOp)
-                       return 
UnaryType.contains(((UnaryOp)h).getOp().toString());
+                       return UnaryType.contains(((UnaryOp)h).getOp().name());
                else if(h instanceof BinaryOp)
-                       return 
BinType.contains(((BinaryOp)h).getOp().toString());
-               else
-                       return false;
+                       return BinType.contains(((BinaryOp)h).getOp().name());
+               else if(h instanceof TernaryOp)
+                       return 
TernaryType.contains(((TernaryOp)h).getOp().name());
+               return false;
        }
 
        private static void rfindChildren(Hop hop, HashSet<Hop> children ) {    
        

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 802a382..e69ecbf 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -719,7 +719,7 @@ public class HopRewriteUtils
                
                return ret;
        }
-       
+
        public static boolean isTransposeOperation(Hop hop) {
                return (hop instanceof ReorgOp && 
((ReorgOp)hop).getOp()==ReOrgOp.TRANSPOSE);
        }
@@ -778,6 +778,13 @@ public class HopRewriteUtils
                
                return false;
        }
+
+       public static boolean checkInputDataTypes(Hop hop, DataType... dt) {
+               for( int i=0; i<hop.getInput().size(); i++ )
+                       if( hop.getInput().get(i).getDataType() != dt[i] )
+                               return false;
+               return true;
+       }
        
        public static boolean isFullColumnIndexing(LeftIndexingOp hop)
        {

Reply via email to