Repository: systemml
Updated Branches:
  refs/heads/master 58ab12761 -> f1bf97baf


[MINOR] Simplify and cleanup GPU-specific rewrites (rewrite utils)

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b429551d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b429551d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b429551d

Branch: refs/heads/master
Commit: b429551dbd9917746f0001c74c16afbdb8231592
Parents: 58ab127
Author: Matthias Boehm <[email protected]>
Authored: Thu Jul 12 15:54:29 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Jul 12 15:54:29 2018 -0700

----------------------------------------------------------------------
 .../hops/rewrite/RewriteGPUSpecificOps.java     | 38 ++++++++------------
 1 file changed, 15 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/b429551d/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index 987d9cd..1c00c6f 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -24,7 +24,6 @@ import java.util.HashMap;
 
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.AggUnaryOp;
-import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
@@ -35,8 +34,6 @@ import org.apache.sysml.hops.Hop.ReOrgOp;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.DnnOp;
 import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.hops.ReorgOp;
-import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
 
 /*
@@ -97,8 +94,7 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
                        return;
                
                //recursively process children
-               for( int i=0; i<hop.getInput().size(); i++)
-               {
+               for( int i=0; i<hop.getInput().size(); i++) {
                        Hop hi = hop.getInput().get(i);
                        
                        //process childs recursively first (to allow roll-up)
@@ -116,11 +112,11 @@ public class RewriteGPUSpecificOps extends HopRewriteRule 
{
        }
        
        private static boolean isBiasAdd(Hop h) {
-               return h instanceof DnnOp && ((DnnOp) h).getOp() == 
OpOpDnn.BIASADD;
+               return HopRewriteUtils.isDnn(h, OpOpDnn.BIASADD);
        }
        
        private static boolean isBiasMultiply(Hop h) {
-               return h instanceof DnnOp && ((DnnOp) h).getOp() == 
OpOpDnn.BIASMULT;
+               return HopRewriteUtils.isDnn(h, OpOpDnn.BIASMULT);
        }
        
        private static boolean fitsOnGPU(Hop h, double multiplier) {
@@ -168,24 +164,22 @@ public class RewriteGPUSpecificOps extends HopRewriteRule 
{
        }
        
        private static boolean isUnaryMinus(Hop h) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MINUS 
-                               && 
Hop.computeSizeInformation(h.getInput().get(0)) == 0;
+               return HopRewriteUtils.isBinary(h, OpOp2.MINUS)
+                       && 
HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0);
        }
        
        private static boolean isOneDivideBySqrt(Hop h) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.DIV 
-                               && h.getInput().get(1) instanceof UnaryOp
-                               && ((UnaryOp)h.getInput().get(1)).getOp() == 
OpOp1.SQRT
-                               && 
Hop.computeSizeInformation(h.getInput().get(0)) == 1;
+               return HopRewriteUtils.isBinary(h, OpOp2.DIV)
+                       && HopRewriteUtils.isUnary(h.getInput().get(1), 
OpOp1.SQRT)
+                       && 
HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 1);
        }
        
-       private static Hop channelSums(Hop parent, Hop hi, int pos) 
-       {
+       private static Hop channelSums(Hop parent, Hop hi, int pos) {
                if(hi instanceof AggUnaryOp) {
                        AggUnaryOp hop = (AggUnaryOp) hi;
                        // output = rowSums(matrix(colSums(x), 
rows=numChannels, cols=imgSize*imgSize))
-                       if(hop.getOp() == AggOp.SUM && hop.getDirection() == 
Direction.Row
-                               && hop.getInput().get(0) instanceof ReorgOp && 
((ReorgOp)hop.getInput().get(0)).getOp() == ReOrgOp.RESHAPE) {
+                       if( hop.getOp() == AggOp.SUM && hop.getDirection() == 
Direction.Row
+                               && 
HopRewriteUtils.isReorg(hop.getInput().get(0), ReOrgOp.RESHAPE) ) {
                                Hop colSumsInput = 
hop.getInput().get(0).getInput().get(0);
                                if(colSumsInput instanceof AggUnaryOp && 
((AggUnaryOp)colSumsInput).getOp() == AggOp.SUM && 
((AggUnaryOp)colSumsInput).getDirection() == Direction.Col) {
                                        ArrayList<Hop> inHops = new 
ArrayList<Hop>();
@@ -206,19 +200,18 @@ public class RewriteGPUSpecificOps extends HopRewriteRule 
{
                return hi;
        }
        
-       private static Hop batchNormTest(Hop parent, Hop hi, int pos) 
-       {               
+       private static Hop batchNormTest(Hop parent, Hop hi, int pos) {
                // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
                // hi = bias_add(bias_multiply(norm, gamma), beta)
                // 2x for input and output and 1x for overhead
-               if( isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && 
fitsOnGPU(hi, 3) ) {  
+               if( isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && 
fitsOnGPU(hi, 3) ) {
                        Hop norm = getFirstInput(getFirstInput(hi));
                        if(isBiasMultiply(norm) && 
isBiasAdd(getFirstInput(norm)) 
                                        && 
isUnaryMinus(getSecondInput(getFirstInput(norm)))
                                        && 
isOneDivideBySqrt(getSecondInput(norm))) {
                                double eps = 0;
                                Hop var = 
getFirstInput(getSecondInput(getSecondInput(norm)));
-                               if(var instanceof BinaryOp && ((BinaryOp) 
var).getOp() == OpOp2.PLUS &&
+                               if( HopRewriteUtils.isBinary(var, OpOp2.PLUS) &&
                                        (getFirstInput(var) instanceof 
LiteralOp || getSecondInput(var) instanceof LiteralOp)) {
                                        // eps + ema_var
                                        if(getFirstInput(var) instanceof 
LiteralOp) {
@@ -248,10 +241,9 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
                                                        
OpOpDnn.BATCH_NORM2D_TEST, inHops);
                                        return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
                                }
-                       }                       
+                       }
                }
                
                return hi;
        }
-
 }

Reply via email to