[SYSTEMML-1424] Extended codegen operations and cost model Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/69d8b7c4 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/69d8b7c4 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/69d8b7c4
Branch: refs/heads/master Commit: 69d8b7c4b53deb3a1d3e4eba99b8718366df1a86 Parents: 3547619 Author: Matthias Boehm <[email protected]> Authored: Mon Apr 3 18:25:44 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Apr 3 18:25:44 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/cplan/CNodeBinary.java | 54 ++++++++++++-------- .../sysml/hops/codegen/cplan/CNodeUnary.java | 6 +-- .../template/PlanSelectionFuseCostBased.java | 33 +++++++++--- 3 files changed, 60 insertions(+), 33 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/69d8b7c4/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java index 5ec7231..b6b6ce5 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -35,9 +35,9 @@ public class CNodeBinary extends CNode VECT_LESS_SCALAR, VECT_LESSEQUAL_SCALAR, VECT_GREATER_SCALAR, VECT_GREATEREQUAL_SCALAR, MULT, DIV, PLUS, MINUS, MODULUS, INTDIV, LESS, LESSEQUAL, GREATER, GREATEREQUAL, EQUAL,NOTEQUAL, - MIN, MAX, AND, OR, LOG, POW, - MINUS1_MULT; - + MIN, MAX, AND, OR, LOG, LOG_NZ, POW, + MINUS1_MULT, MINUS_NZ; + public static boolean contains(String value) { for( BinType bt : values() ) if( bt.name().equals(value) ) @@ -85,41 +85,45 @@ public class CNodeBinary extends CNode /*Can be replaced by function objects*/ case MULT: - return " double %TMP% = %IN1% * %IN2%;\n" ; + return " double %TMP% = %IN1% * %IN2%;\n"; case DIV: - return " double %TMP% = %IN1% / %IN2%;\n" ; + return " double %TMP% = %IN1% / %IN2%;\n"; case PLUS: - return " double %TMP% = %IN1% + %IN2%;\n" ; + return " double %TMP% = %IN1% + %IN2%;\n"; case MINUS: - return " double %TMP% = %IN1% - %IN2%;\n" ; + return " double %TMP% = %IN1% - %IN2%;\n"; case MODULUS: - return " double %TMP% = LibSpoofPrimitives.mod(%IN1%, %IN2%);\n" ; + return " double %TMP% = LibSpoofPrimitives.mod(%IN1%, %IN2%);\n"; case INTDIV: - return " double %TMP% = LibSpoofPrimitives.intDiv(%IN1%, %IN2%);\n" ; + return " double %TMP% = LibSpoofPrimitives.intDiv(%IN1%, %IN2%);\n"; case LESS: - return " double %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n" ; + return " double %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n"; case LESSEQUAL: - return " double %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n" ; + return " double %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n"; case GREATER: - return " double %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n" ; + return " double %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n"; case GREATEREQUAL: - return " double %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n" ; + return " double %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n"; case EQUAL: - return " double %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n" ; + return " double %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n"; case NOTEQUAL: - return " double %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n" ; + return " double %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n"; case MIN: - return " double %TMP% = (%IN1% <= %IN2%) ? %IN1% : %IN2%;\n" ; + return " double %TMP% = (%IN1% <= %IN2%) ? %IN1% : %IN2%;\n"; case MAX: - return " double %TMP% = (%IN1% >= %IN2%) ? %IN1% : %IN2%;\n" ; + return " double %TMP% = (%IN1% >= %IN2%) ? %IN1% : %IN2%;\n"; case LOG: - return " double %TMP% = FastMath.log(%IN1%)/FastMath.log(%IN2%);\n" ; + return " double %TMP% = FastMath.log(%IN1%)/FastMath.log(%IN2%);\n"; + case LOG_NZ: + return " double %TMP% = (%IN1% == 0) ? 0 : FastMath.log(%IN1%)/FastMath.log(%IN2%);\n"; case POW: - return " double %TMP% = Math.pow(%IN1%, %IN2%);\n" ; + return " double %TMP% = Math.pow(%IN1%, %IN2%);\n"; case MINUS1_MULT: - return " double %TMP% = 1 - %IN1% * %IN2%;\n" ; + return " double %TMP% = 1 - %IN1% * %IN2%;\n"; + case MINUS_NZ: + return " double %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n"; default: throw new RuntimeException("Invalid binary type: "+this.toString()); @@ -225,6 +229,7 @@ public class CNodeBinary extends CNode case DIV: return "b(/)"; case PLUS: return "b(+)"; case MINUS: return "b(-)"; + case POW: return "b(^)"; case MODULUS: return "b(%%)"; case INTDIV: return "b(%/%)"; case LESS: return "b(<)"; @@ -233,8 +238,11 @@ public class CNodeBinary extends CNode case GREATEREQUAL: return "b(>=)"; case EQUAL: return "b(==)"; case NOTEQUAL: return "b(!=)"; + case OR: return "b(|)"; + case AND: return "b(&)"; case MINUS1_MULT: return "b(1-*)"; - default: return "b("+_type.name()+")"; + case MINUS_NZ: return "b(-nz)"; + default: return "b("+_type.name().toLowerCase()+")"; } } @@ -277,7 +285,8 @@ public class CNodeBinary extends CNode case DIV: case PLUS: case MINUS: - case MINUS1_MULT: + case MINUS1_MULT: + case MINUS_NZ: case MODULUS: case INTDIV: //SCALAR Comparison @@ -293,6 +302,7 @@ public class CNodeBinary extends CNode case AND: case OR: case LOG: + case LOG_NZ: case POW: _rows = 0; _cols = 0; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/69d8b7c4/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index 75b2630..119dc8c 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -28,10 +28,10 @@ public class CNodeUnary extends CNode { public enum UnaryType { ROW_SUMS, LOOKUP_R, LOOKUP_RC, LOOKUP0, //codegen specific - EXP, POW2, MULT2, SQRT, LOG, + EXP, POW2, MULT2, SQRT, LOG, LOG_NZ, ABS, ROUND, CEIL, FLOOR, SIGN, SIN, COS, TAN, ASIN, ACOS, ATAN, - SELP, SPROP, SIGMOID, LOG_NZ; + SELP, SPROP, SIGMOID; public static boolean contains(String value) { for( UnaryType ut : values() ) @@ -156,7 +156,7 @@ public class CNodeUnary extends CNode case LOOKUP_R: return "u(ixr)"; case LOOKUP_RC: return "u(ixrc)"; case LOOKUP0: return "u(ix0)"; - default: return "u("+_type.name()+")"; + default: return "u("+_type.name().toLowerCase()+")"; } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/69d8b7c4/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java index 653f43b..47717c2 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java @@ -453,11 +453,13 @@ public class PlanSelectionFuseCostBased extends PlanSelection case CEIL: case FLOOR: case SIGN: - case SELP: costs = 1; break; + case SELP: costs = 1; break; case SPROP: - case SQRT: costs = 2; break; - case EXP: costs = 18; break; - case LOG: costs = 32; break; + case SQRT: costs = 2; break; + case EXP: costs = 18; break; + case SIGMOID: costs = 21; break; + case LOG: + case LOG_NZ: costs = 32; break; case NCOL: case NROW: case PRINT: @@ -466,6 +468,12 @@ public class PlanSelectionFuseCostBased extends PlanSelection case CAST_AS_INT: case CAST_AS_MATRIX: case CAST_AS_SCALAR: costs = 1; break; + case SIN: costs = 18; break; + case COS: costs = 22; break; + case TAN: costs = 42; break; + case ASIN: costs = 93; break; + case ACOS: costs = 103; break; + case ATAN: costs = 40; break; case CUMSUM: case CUMMIN: case CUMMAX: @@ -480,6 +488,10 @@ public class PlanSelectionFuseCostBased extends PlanSelection case MULT: case PLUS: case MINUS: + case MIN: + case MAX: + case AND: + case OR: case EQUAL: case NOTEQUAL: case LESS: @@ -487,11 +499,16 @@ public class PlanSelectionFuseCostBased extends PlanSelection case GREATER: case GREATEREQUAL: case CBIND: - case RBIND: costs = 1; break; - case DIV: costs = 22; break; - case LOG: costs = 32; break; - case POW: costs = (HopRewriteUtils.isLiteralOfValue( + case RBIND: costs = 1; break; + case INTDIV: costs = 6; break; + case MODULUS: costs = 8; break; + case DIV: costs = 22; break; + case LOG: + case LOG_NZ: costs = 32; break; + case POW: costs = (HopRewriteUtils.isLiteralOfValue( current.getInput().get(1), 2) ? 1 : 16); break; + case MINUS_NZ: + case MINUS1_MULT: costs = 2; break; default: throw new RuntimeException("Cost model not " + "implemented yet for: "+((BinaryOp)current).getOp());
