Repository: systemml Updated Branches: refs/heads/master 58ab12761 -> f1bf97baf
[MINOR] Simplify and cleanup GPU-specific rewrites (rewrite utils) Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b429551d Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b429551d Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b429551d Branch: refs/heads/master Commit: b429551dbd9917746f0001c74c16afbdb8231592 Parents: 58ab127 Author: Matthias Boehm <[email protected]> Authored: Thu Jul 12 15:54:29 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Jul 12 15:54:29 2018 -0700 ---------------------------------------------------------------------- .../hops/rewrite/RewriteGPUSpecificOps.java | 38 ++++++++------------ 1 file changed, 15 insertions(+), 23 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/b429551d/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java index 987d9cd..1c00c6f 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java @@ -24,7 +24,6 @@ import java.util.HashMap; import org.apache.sysml.api.DMLScript; import org.apache.sysml.hops.AggUnaryOp; -import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; @@ -35,8 +34,6 @@ import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.DnnOp; import org.apache.sysml.hops.OptimizerUtils; -import org.apache.sysml.hops.ReorgOp; -import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; /* @@ -97,8 +94,7 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { return; //recursively process children - for( int i=0; i<hop.getInput().size(); i++) - { + for( int i=0; i<hop.getInput().size(); i++) { Hop hi = hop.getInput().get(i); //process childs recursively first (to allow roll-up) @@ -116,11 +112,11 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { } private static boolean isBiasAdd(Hop h) { - return h instanceof DnnOp && ((DnnOp) h).getOp() == OpOpDnn.BIASADD; + return HopRewriteUtils.isDnn(h, OpOpDnn.BIASADD); } private static boolean isBiasMultiply(Hop h) { - return h instanceof DnnOp && ((DnnOp) h).getOp() == OpOpDnn.BIASMULT; + return HopRewriteUtils.isDnn(h, OpOpDnn.BIASMULT); } private static boolean fitsOnGPU(Hop h, double multiplier) { @@ -168,24 +164,22 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { } private static boolean isUnaryMinus(Hop h) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS - && Hop.computeSizeInformation(h.getInput().get(0)) == 0; + return HopRewriteUtils.isBinary(h, OpOp2.MINUS) + && HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0); } private static boolean isOneDivideBySqrt(Hop h) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV - && h.getInput().get(1) instanceof UnaryOp - && ((UnaryOp)h.getInput().get(1)).getOp() == OpOp1.SQRT - && Hop.computeSizeInformation(h.getInput().get(0)) == 1; + return HopRewriteUtils.isBinary(h, OpOp2.DIV) + && HopRewriteUtils.isUnary(h.getInput().get(1), OpOp1.SQRT) + && HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 1); } - private static Hop channelSums(Hop parent, Hop hi, int pos) - { + private static Hop channelSums(Hop parent, Hop hi, int pos) { if(hi instanceof AggUnaryOp) { AggUnaryOp hop = (AggUnaryOp) hi; // output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize)) - if(hop.getOp() == AggOp.SUM && hop.getDirection() == Direction.Row - && hop.getInput().get(0) instanceof ReorgOp && ((ReorgOp)hop.getInput().get(0)).getOp() == ReOrgOp.RESHAPE) { + if( hop.getOp() == AggOp.SUM && hop.getDirection() == Direction.Row + && HopRewriteUtils.isReorg(hop.getInput().get(0), ReOrgOp.RESHAPE) ) { Hop colSumsInput = hop.getInput().get(0).getInput().get(0); if(colSumsInput instanceof AggUnaryOp && ((AggUnaryOp)colSumsInput).getOp() == AggOp.SUM && ((AggUnaryOp)colSumsInput).getDirection() == Direction.Col) { ArrayList<Hop> inHops = new ArrayList<Hop>(); @@ -206,19 +200,18 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { return hi; } - private static Hop batchNormTest(Hop parent, Hop hi, int pos) - { + private static Hop batchNormTest(Hop parent, Hop hi, int pos) { // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps)) // hi = bias_add(bias_multiply(norm, gamma), beta) // 2x for input and output and 1x for overhead - if( isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) { + if( isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) { Hop norm = getFirstInput(getFirstInput(hi)); if(isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) && isUnaryMinus(getSecondInput(getFirstInput(norm))) && isOneDivideBySqrt(getSecondInput(norm))) { double eps = 0; Hop var = getFirstInput(getSecondInput(getSecondInput(norm))); - if(var instanceof BinaryOp && ((BinaryOp) var).getOp() == OpOp2.PLUS && + if( HopRewriteUtils.isBinary(var, OpOp2.PLUS) && (getFirstInput(var) instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) { // eps + ema_var if(getFirstInput(var) instanceof LiteralOp) { @@ -248,10 +241,9 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { OpOpDnn.BATCH_NORM2D_TEST, inHops); return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); } - } + } } return hi; } - }
