[SYSTEMML-1361] Extended code generator for existing cellwise fused ops This patch adds the existing cell-wise fused unary operators selp, sprop, sigmoid, log_nz, as well as ternary operators +* and -* in order to prevent these operators from breaking fusion boundaries.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/9b69f36a Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/9b69f36a Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/9b69f36a Branch: refs/heads/master Commit: 9b69f36a9aad831b8e78c7e45be3d7c80386ec01 Parents: 6aab005 Author: Matthias Boehm <[email protected]> Authored: Wed Mar 1 22:01:41 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Wed Mar 1 22:01:41 2017 -0800 ---------------------------------------------------------------------- .../sysml/hops/codegen/cplan/CNodeBinary.java | 5 +- .../sysml/hops/codegen/cplan/CNodeTernary.java | 136 +++++++++++++++++++ .../sysml/hops/codegen/cplan/CNodeUnary.java | 33 +++-- .../sysml/hops/codegen/template/CellTpl.java | 34 ++++- .../hops/codegen/template/TemplateUtils.java | 11 +- .../sysml/hops/rewrite/HopRewriteUtils.java | 9 +- 6 files changed, 205 insertions(+), 23 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java index 1bfaab4..d11d0ac 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -34,10 +34,10 @@ public class CNodeBinary extends CNode LESS, LESSEQUAL, GREATER, GREATEREQUAL, EQUAL,NOTEQUAL, MIN, MAX, AND, OR, LOG, POW, MINUS1_MULT; - + public static boolean contains(String value) { for( BinType bt : values() ) - if( bt.toString().equals(value) ) + if( bt.name().equals(value) ) return true; return false; } @@ -188,6 +188,7 @@ public class CNodeBinary extends CNode } } + @Override public void setOutputDims() { switch(_type) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java new file mode 100644 index 0000000..9fdae3b --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.cplan; + +import java.util.Arrays; + +import org.apache.sysml.parser.Expression.DataType; + + +public class CNodeTernary extends CNode +{ + public enum TernaryType { + PLUS_MULT, MINUS_MULT; + + public static boolean contains(String value) { + for( TernaryType tt : values() ) + if( tt.name().equals(value) ) + return true; + return false; + } + + public String getTemplate(boolean sparse) { + switch (this) { + case PLUS_MULT: + return " double %TMP% = %IN1% + %IN2% * %IN3%;\n" ; + + case MINUS_MULT: + return " double %TMP% = %IN1% - %IN2% * %IN3%;\n;\n" ; + + default: + throw new RuntimeException("Invalid ternary type: "+this.toString()); + } + } + } + + private final TernaryType _type; + + public CNodeTernary( CNode in1, CNode in2, CNode in3, TernaryType type ) { + _inputs.add(in1); + _inputs.add(in2); + _inputs.add(in3); + _type = type; + setOutputDims(); + } + + public TernaryType getType() { + return _type; + } + + @Override + public String codegen(boolean sparse) { + if( _generated ) + return ""; + + StringBuilder sb = new StringBuilder(); + + //generate children + sb.append(_inputs.get(0).codegen(sparse)); + sb.append(_inputs.get(1).codegen(sparse)); + sb.append(_inputs.get(2).codegen(sparse)); + + //generate binary operation + String var = createVarname(); + String tmp = _type.getTemplate(sparse); + tmp = tmp.replaceAll("%TMP%", var); + for( int j=1; j<=3; j++ ) { + String varj = _inputs.get(j-1).getVarname(); + tmp = tmp.replaceAll("%IN"+j+"%", varj ); + } + sb.append(tmp); + + //mark as generated + _generated = true; + + return sb.toString(); + } + + @Override + public String toString() { + switch(_type) { + case PLUS_MULT: return "t(+*)"; + case MINUS_MULT: return "t(-*)"; + default: + return super.toString(); + } + } + + @Override + public void setOutputDims() { + switch(_type) { + case PLUS_MULT: + case MINUS_MULT: + _rows = 0; + _cols = 0; + _dataType= DataType.SCALAR; + break; + } + } + + @Override + public int hashCode() { + if( _hash == 0 ) { + int h1 = super.hashCode(); + int h2 = _type.hashCode(); + _hash = Arrays.hashCode(new int[]{h1,h2}); + } + return _hash; + } + + @Override + public boolean equals(Object o) { + if( !(o instanceof CNodeTernary) ) + return false; + + CNodeTernary that = (CNodeTernary) o; + return super.equals(that) + && _type == that._type; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index f08769e..dd8431b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -27,22 +27,21 @@ import org.apache.sysml.parser.Expression.DataType; public class CNodeUnary extends CNode { public enum UnaryType { - ROW_SUMS, LOOKUP, LOOKUP0, + ROW_SUMS, LOOKUP, LOOKUP0, //codegen specific EXP, POW2, MULT2, SQRT, LOG, - ABS, ROUND, CEIL,FLOOR, SIGN, + ABS, ROUND, CEIL, FLOOR, SIGN, SIN, COS, TAN, ASIN, ACOS, ATAN, - IQM, STOP, - DOTPRODUCT_ROW_SUMS; //row sums via dot product for debugging purposes + SELP, SPROP, SIGMOID, LOG_NZ; public static boolean contains(String value) { for( UnaryType ut : values() ) - if( ut.toString().equals(value) ) + if( ut.name().equals(value) ) return true; return false; } public String getTemplate(boolean sparse) { - switch (this) { + switch( this ) { case ROW_SUMS: return sparse ? " double %TMP% = LibSpoofPrimitives.vectSum( %IN1v%, %IN1i%, %POS1%, %LEN%);\n": " double %TMP% = LibSpoofPrimitives.vectSum( %IN1%, %POS1%, %LEN%);\n"; @@ -82,8 +81,17 @@ public class CNodeUnary extends CNode return " double %TMP% = Math.ceil(%IN1%);\n"; case FLOOR: return " double %TMP% = Math.floor(%IN1%);\n"; + case SELP: + return " double %TMP% = (%IN1%>0) ? %IN1% : 0;\n"; + case SPROP: + return " double %TMP% = %IN1% * (1 - %IN1%);\n"; + case SIGMOID: + return " double %TMP% = 1 / (1 + FastMath.exp(-%IN1%));\n"; + case LOG_NZ: + return " double %TMP% = (%IN1%==0) ? 0 : FastMath.log(%IN1%);\n"; + default: - throw new RuntimeException("Invalid binary type: "+this.toString()); + throw new RuntimeException("Invalid unary type: "+this.toString()); } } } @@ -150,15 +158,14 @@ public class CNodeUnary extends CNode @Override public void setOutputDims() { - switch(_type) - { + switch(_type) { case ROW_SUMS: case EXP: case LOOKUP: case LOOKUP0: case POW2: case MULT2: - case ABS: + case ABS: case SIN: case COS: case TAN: @@ -169,10 +176,12 @@ public class CNodeUnary extends CNode case SQRT: case LOG: case ROUND: - case IQM: - case STOP: case CEIL: case FLOOR: + case SELP: + case SPROP: + case SIGMOID: + case LOG_NZ: _rows = 0; _cols = 0; _dataType= DataType.SCALAR; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java index 0c841e8..4d95686 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java @@ -33,6 +33,7 @@ import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeBinary; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; @@ -41,6 +42,9 @@ import org.apache.sysml.hops.codegen.cplan.CNodeData; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; +import org.apache.sysml.hops.codegen.cplan.CNodeTernary; +import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; import org.apache.sysml.runtime.matrix.data.Pair; @@ -67,16 +71,17 @@ public class CellTpl extends BaseTpl return false; //re-assign initialHop to fuse the sum/rowsums (before checking for chains) - for (Hop h : _initialHop.getParent()) + //TODO add aggbinary (vector tsmm) as potential head for cellwise operation + for (Hop h : _initialHop.getParent()) { if( h instanceof AggUnaryOp && ((AggUnaryOp) h).getOp() == AggOp.SUM && ((AggUnaryOp) h).getDirection()!= Direction.Col ) { _initialHop = h; } + } //unary matrix && endHop found && endHop is not direct child of the initialHop (i.e., chain of operators) if(_endHop != null && _endHop != _initialHop) { - // if final hop is unary add its child to the input if(_endHop instanceof UnaryOp) _matrixInputs.add(_endHop.getInput().get(0)); @@ -199,7 +204,6 @@ public class CellTpl extends BaseTpl if( TemplateUtils.isColVector(cdata2) ) cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP); - if( bop.getOp()==OpOp2.POW && cdata2.isLiteral() && cdata2.getVarname().equals("2") ) out = new CNodeUnary(cdata1, UnaryType.POW2); else if( bop.getOp()==OpOp2.MULT && cdata2.isLiteral() && cdata2.getVarname().equals("2") ) @@ -207,6 +211,24 @@ public class CellTpl extends BaseTpl else //default binary out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); } + else if(hop instanceof TernaryOp) + { + TernaryOp top = (TernaryOp) hop; + CNode cdata1 = cnodeData.get(0); + CNode cdata2 = cnodeData.get(1); + CNode cdata3 = cnodeData.get(2); + + //cdata1 is vector + if( TemplateUtils.isColVector(cdata1) ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP); + //cdata3 is vector + if( TemplateUtils.isColVector(cdata3) ) + cdata3 = new CNodeUnary(cdata3, UnaryType.LOOKUP); + + //construct ternary cnode, primitive operation derived from OpOp3 + out = new CNodeTernary(cdata1, cdata2, cdata3, + TernaryType.valueOf(top.getOp().toString())); + } else if (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM && (((AggUnaryOp) hop).getDirection() == Direction.RowCol || ((AggUnaryOp) hop).getDirection() == Direction.Row) && root == hop) @@ -283,7 +305,11 @@ public class CellTpl extends BaseTpl && TemplateUtils.isVectorOrScalar(hop.getInput().get(1)) && !TemplateUtils.isBinaryMatrixRowVector(hop)) ||(TemplateUtils.isVectorOrScalar( hop.getInput().get(0)) && hop.getInput().get(1).getDataType() == DataType.MATRIX && !TemplateUtils.isBinaryMatrixRowVector(hop)) ); + boolean isTernaryVectorScalarVector = hop instanceof TernaryOp && hop.getInput().size()==3 && hop.dimsKnown() + && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX) + && TemplateUtils.isVector(hop.getInput().get(0)) && TemplateUtils.isVector(hop.getInput().get(2)); + return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) - && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector); + && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isTernaryVectorScalarVector); } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index fd8a960..32a4d80 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -33,6 +33,7 @@ import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.ReorgOp; +import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.ReOrgOp; @@ -45,6 +46,7 @@ import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; +import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; import org.apache.sysml.runtime.matrix.data.Pair; @@ -231,11 +233,12 @@ public class TemplateUtils public static boolean isOperationSupported(Hop h) { if(h instanceof UnaryOp) - return UnaryType.contains(((UnaryOp)h).getOp().toString()); + return UnaryType.contains(((UnaryOp)h).getOp().name()); else if(h instanceof BinaryOp) - return BinType.contains(((BinaryOp)h).getOp().toString()); - else - return false; + return BinType.contains(((BinaryOp)h).getOp().name()); + else if(h instanceof TernaryOp) + return TernaryType.contains(((TernaryOp)h).getOp().name()); + return false; } private static void rfindChildren(Hop hop, HashSet<Hop> children ) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9b69f36a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 802a382..e69ecbf 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -719,7 +719,7 @@ public class HopRewriteUtils return ret; } - + public static boolean isTransposeOperation(Hop hop) { return (hop instanceof ReorgOp && ((ReorgOp)hop).getOp()==ReOrgOp.TRANSPOSE); } @@ -778,6 +778,13 @@ public class HopRewriteUtils return false; } + + public static boolean checkInputDataTypes(Hop hop, DataType... dt) { + for( int i=0; i<hop.getInput().size(); i++ ) + if( hop.getInput().get(i).getDataType() != dt[i] ) + return false; + return true; + } public static boolean isFullColumnIndexing(LeftIndexingOp hop) {
