[SYSTEMML-1424] Extended codegen operations and cost model

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/69d8b7c4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/69d8b7c4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/69d8b7c4

Branch: refs/heads/master
Commit: 69d8b7c4b53deb3a1d3e4eba99b8718366df1a86
Parents: 3547619
Author: Matthias Boehm <[email protected]>
Authored: Mon Apr 3 18:25:44 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Apr 3 18:25:44 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeBinary.java   | 54 ++++++++++++--------
 .../sysml/hops/codegen/cplan/CNodeUnary.java    |  6 +--
 .../template/PlanSelectionFuseCostBased.java    | 33 +++++++++---
 3 files changed, 60 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/69d8b7c4/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index 5ec7231..b6b6ce5 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -35,9 +35,9 @@ public class CNodeBinary extends CNode
                VECT_LESS_SCALAR, VECT_LESSEQUAL_SCALAR, VECT_GREATER_SCALAR, 
VECT_GREATEREQUAL_SCALAR,
                MULT, DIV, PLUS, MINUS, MODULUS, INTDIV, 
                LESS, LESSEQUAL, GREATER, GREATEREQUAL, EQUAL,NOTEQUAL,
-               MIN, MAX, AND, OR, LOG, POW,
-               MINUS1_MULT;
-
+               MIN, MAX, AND, OR, LOG, LOG_NZ, POW,
+               MINUS1_MULT, MINUS_NZ;
+               
                public static boolean contains(String value) {
                        for( BinType bt : values()  )
                                if( bt.name().equals(value) )
@@ -85,41 +85,45 @@ public class CNodeBinary extends CNode
                                
                                /*Can be replaced by function objects*/
                                case MULT:
-                                       return "    double %TMP% = %IN1% * 
%IN2%;\n" ;
+                                       return "    double %TMP% = %IN1% * 
%IN2%;\n";
                                
                                case DIV:
-                                       return "    double %TMP% = %IN1% / 
%IN2%;\n" ;
+                                       return "    double %TMP% = %IN1% / 
%IN2%;\n";
                                case PLUS:
-                                       return "    double %TMP% = %IN1% + 
%IN2%;\n" ;
+                                       return "    double %TMP% = %IN1% + 
%IN2%;\n";
                                case MINUS:
-                                       return "    double %TMP% = %IN1% - 
%IN2%;\n" ;
+                                       return "    double %TMP% = %IN1% - 
%IN2%;\n";
                                case MODULUS:
-                                       return "    double %TMP% = 
LibSpoofPrimitives.mod(%IN1%, %IN2%);\n" ;
+                                       return "    double %TMP% = 
LibSpoofPrimitives.mod(%IN1%, %IN2%);\n";
                                case INTDIV: 
-                                       return "    double %TMP% = 
LibSpoofPrimitives.intDiv(%IN1%, %IN2%);\n" ;
+                                       return "    double %TMP% = 
LibSpoofPrimitives.intDiv(%IN1%, %IN2%);\n";
                                case LESS:
-                                       return "    double %TMP% = (%IN1% < 
%IN2%) ? 1 : 0;\n" ;
+                                       return "    double %TMP% = (%IN1% < 
%IN2%) ? 1 : 0;\n";
                                case LESSEQUAL:
-                                       return "    double %TMP% = (%IN1% <= 
%IN2%) ? 1 : 0;\n" ;
+                                       return "    double %TMP% = (%IN1% <= 
%IN2%) ? 1 : 0;\n";
                                case GREATER:
-                                       return "    double %TMP% = (%IN1% > 
%IN2%) ? 1 : 0;\n" ;
+                                       return "    double %TMP% = (%IN1% > 
%IN2%) ? 1 : 0;\n";
                                case GREATEREQUAL: 
-                                       return "    double %TMP% = (%IN1% >= 
%IN2%) ? 1 : 0;\n" ;
+                                       return "    double %TMP% = (%IN1% >= 
%IN2%) ? 1 : 0;\n";
                                case EQUAL:
-                                       return "    double %TMP% = (%IN1% == 
%IN2%) ? 1 : 0;\n" ;
+                                       return "    double %TMP% = (%IN1% == 
%IN2%) ? 1 : 0;\n";
                                case NOTEQUAL: 
-                                       return "    double %TMP% = (%IN1% != 
%IN2%) ? 1 : 0;\n" ;
+                                       return "    double %TMP% = (%IN1% != 
%IN2%) ? 1 : 0;\n";
                                
                                case MIN:
-                                       return "    double %TMP% = (%IN1% <= 
%IN2%) ? %IN1% : %IN2%;\n" ;
+                                       return "    double %TMP% = (%IN1% <= 
%IN2%) ? %IN1% : %IN2%;\n";
                                case MAX:
-                                       return "    double %TMP% = (%IN1% >= 
%IN2%) ? %IN1% : %IN2%;\n" ;
+                                       return "    double %TMP% = (%IN1% >= 
%IN2%) ? %IN1% : %IN2%;\n";
                                case LOG:
-                                       return "    double %TMP% = 
FastMath.log(%IN1%)/FastMath.log(%IN2%);\n" ;
+                                       return "    double %TMP% = 
FastMath.log(%IN1%)/FastMath.log(%IN2%);\n";
+                               case LOG_NZ:
+                                       return "    double %TMP% = (%IN1% == 0) 
? 0 : FastMath.log(%IN1%)/FastMath.log(%IN2%);\n";      
                                case POW:
-                                       return "    double %TMP% = 
Math.pow(%IN1%, %IN2%);\n" ;
+                                       return "    double %TMP% = 
Math.pow(%IN1%, %IN2%);\n";
                                case MINUS1_MULT:
-                                       return "    double %TMP% = 1 - %IN1% * 
%IN2%;\n" ;
+                                       return "    double %TMP% = 1 - %IN1% * 
%IN2%;\n";
+                               case MINUS_NZ:
+                                       return "    double %TMP% = (%IN1% != 0) 
? %IN1% - %IN2% : 0;\n";
                                        
                                default: 
                                        throw new RuntimeException("Invalid 
binary type: "+this.toString());
@@ -225,6 +229,7 @@ public class CNodeBinary extends CNode
                        case DIV: return "b(/)";
                        case PLUS: return "b(+)";
                        case MINUS: return "b(-)";
+                       case POW: return "b(^)";
                        case MODULUS: return "b(%%)";
                        case INTDIV: return "b(%/%)";
                        case LESS: return "b(<)";
@@ -233,8 +238,11 @@ public class CNodeBinary extends CNode
                        case GREATEREQUAL: return "b(>=)";
                        case EQUAL: return "b(==)";
                        case NOTEQUAL: return "b(!=)";
+                       case OR: return "b(|)";
+                       case AND: return "b(&)";
                        case MINUS1_MULT: return "b(1-*)";
-                       default: return "b("+_type.name()+")";
+                       case MINUS_NZ: return "b(-nz)";
+                       default: return "b("+_type.name().toLowerCase()+")";
                }
        }
        
@@ -277,7 +285,8 @@ public class CNodeBinary extends CNode
                        case DIV: 
                        case PLUS: 
                        case MINUS: 
-                       case MINUS1_MULT:       
+                       case MINUS1_MULT:
+                       case MINUS_NZ:
                        case MODULUS: 
                        case INTDIV:    
                        //SCALAR Comparison
@@ -293,6 +302,7 @@ public class CNodeBinary extends CNode
                        case AND: 
                        case OR:                        
                        case LOG: 
+                       case LOG_NZ:    
                        case POW: 
                                _rows = 0;
                                _cols = 0;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/69d8b7c4/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index 75b2630..119dc8c 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -28,10 +28,10 @@ public class CNodeUnary extends CNode
 {
        public enum UnaryType {
                ROW_SUMS, LOOKUP_R, LOOKUP_RC, LOOKUP0, //codegen specific
-               EXP, POW2, MULT2, SQRT, LOG,
+               EXP, POW2, MULT2, SQRT, LOG, LOG_NZ,
                ABS, ROUND, CEIL, FLOOR, SIGN, 
                SIN, COS, TAN, ASIN, ACOS, ATAN,
-               SELP, SPROP, SIGMOID, LOG_NZ; 
+               SELP, SPROP, SIGMOID; 
                
                public static boolean contains(String value) {
                        for( UnaryType ut : values()  )
@@ -156,7 +156,7 @@ public class CNodeUnary extends CNode
                        case LOOKUP_R:  return "u(ixr)";
                        case LOOKUP_RC: return "u(ixrc)";
                        case LOOKUP0:   return "u(ix0)";
-                       default:                return "u("+_type.name()+")";
+                       default:                return 
"u("+_type.name().toLowerCase()+")";
                }
        }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/69d8b7c4/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
index 653f43b..47717c2 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
@@ -453,11 +453,13 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                                case CEIL:
                                case FLOOR:
                                case SIGN:
-                               case SELP:   costs = 1; break; 
+                               case SELP:    costs = 1; break; 
                                case SPROP:
-                               case SQRT:   costs = 2; break;
-                               case EXP:    costs = 18; break;
-                               case LOG:    costs = 32; break;
+                               case SQRT:    costs = 2; break;
+                               case EXP:     costs = 18; break;
+                               case SIGMOID: costs = 21; break;
+                               case LOG:    
+                               case LOG_NZ:  costs = 32; break;
                                case NCOL:
                                case NROW:
                                case PRINT:
@@ -466,6 +468,12 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                                case CAST_AS_INT:
                                case CAST_AS_MATRIX:
                                case CAST_AS_SCALAR: costs = 1; break;
+                               case SIN:     costs = 18; break;
+                               case COS:     costs = 22; break;
+                               case TAN:     costs = 42; break;
+                               case ASIN:    costs = 93; break;
+                               case ACOS:    costs = 103; break;
+                               case ATAN:    costs = 40; break;
                                case CUMSUM:
                                case CUMMIN:
                                case CUMMAX:
@@ -480,6 +488,10 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                                case MULT: 
                                case PLUS:
                                case MINUS:
+                               case MIN:
+                               case MAX: 
+                               case AND:
+                               case OR:
                                case EQUAL:
                                case NOTEQUAL:
                                case LESS:
@@ -487,11 +499,16 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                                case GREATER:
                                case GREATEREQUAL: 
                                case CBIND:
-                               case RBIND: costs = 1; break;
-                               case DIV:   costs = 22; break;
-                               case LOG:   costs = 32; break;
-                               case POW:   costs = 
(HopRewriteUtils.isLiteralOfValue(
+                               case RBIND:   costs = 1; break;
+                               case INTDIV:  costs = 6; break;
+                               case MODULUS: costs = 8; break;
+                               case DIV:    costs = 22; break;
+                               case LOG:
+                               case LOG_NZ: costs = 32; break;
+                               case POW:    costs = 
(HopRewriteUtils.isLiteralOfValue(
                                                current.getInput().get(1), 2) ? 
1 : 16); break;
+                               case MINUS_NZ:
+                               case MINUS1_MULT: costs = 2; break;
                                default:
                                        throw new RuntimeException("Cost model 
not "
                                                + "implemented yet for: 
"+((BinaryOp)current).getOp());

Reply via email to