http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java index acf2e48..53d368b 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java @@ -20,811 +20,288 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; +import java.util.function.Function; -import org.apache.sysml.conf.ConfigurationManager; -import org.apache.sysml.hops.AggUnaryOp; -import org.apache.sysml.hops.BinaryOp; -import org.apache.sysml.hops.FunctionOp; import org.apache.sysml.hops.Hop; -import org.apache.sysml.hops.FunctionOp.FunctionType; -import org.apache.sysml.hops.Hop.AggOp; -import org.apache.sysml.hops.Hop.DataOpTypes; -import org.apache.sysml.hops.Hop.Direction; -import org.apache.sysml.hops.Hop.OpOp1; -import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.Hop.OpOpDnn; -import org.apache.sysml.hops.Hop.ReOrgOp; -import org.apache.sysml.hops.DataOp; -import org.apache.sysml.hops.LiteralOp; -import org.apache.sysml.hops.DnnOp; -import org.apache.sysml.hops.OptimizerUtils; -import org.apache.sysml.hops.ReorgOp; -import org.apache.sysml.hops.UnaryOp; -import org.apache.sysml.parser.DMLProgram; -import org.apache.sysml.parser.Expression.DataType; -import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; +import static org.apache.sysml.hops.rewrite.HopDagPatternMatcher.*; +import static org.apache.sysml.parser.Expression.DataType.MATRIX; +import static org.apache.sysml.parser.Expression.DataType.SCALAR; + /* - * This class contains GPU-specific rewrites for following patterns: + * ------------------------------------------------------------------------- + * Design documentation for hop rewrite rules that use HopDagPatternMatcher: + * ------------------------------------------------------------------------- + * + * Typical (but not all) hop rewrite rules have following structure: + * 1. Rules are grouped together in different Java classes and added in org.apache.sysml.hops.rewrite.ProgramRewriter. + * + * 2. Each rule class inherits from HopRewriteRule and implements rewriteHopDAG method. Other class of rewrite rules are StatementBlockRewriteRule and are not covered by this approach. + * + * 3. The structure of rewriteHopDAG is common across HopRewriteRule subclasses and usually have following pattern: + * if(root of the given HOP DAG matches certain pattern) { + * HopRewriteUtils.rewireAllParentChildReferences(root, newRoot) + * } + * else root + * + * 4. To avoid redundancy, the above logic is implemented in the abstract class HopRewriteRuleWithPatternMatcher: + * ArrayList<HopPatternRewriter> patternRewriters = getPatternRewriter(); + * for(HopPatternRewriter patternRewriter : patternRewriters) { + * hi = patternRewriter.rewrite(hi); + * } + * + * 5. The developer has to inherit from HopRewriteRuleWithPatternMatcher that implements the above logic + * and write code for getPatternRewriter() that returns ArrayList<HopPatternRewriter> * - * 1. batchNormTest: applied when mode="test" in batch normalization nn layer. - * norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps)) - * hi = bias_add(bias_multiply(norm, gamma), beta) + * 6. Since the HOP pattern donot change during execution, it is convenient to implement them into a static variable: + * ArrayList<HopPatternRewriter> _rewriters * - * 2. channelSum: - * output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize)) + * 7. The replacement part in each entry of patternMatcher invokes the helper methods in HopRewriteUtils to create a newRoot. For example: HopRewriteUtils.createDnnOp * - * 3. batchNormTrain: applied when mode="train" in batch normalization nn layer. - * This rewrite is only enabled if none of the outputs are persistent writes as it assumes that - * FunctionOp will introduce a transient writes. This rewrite replaces the existing outputs of the matched pattern with transient reads. + * 8. The below DSL is more readable if implemented with Scala's operator overloading, but it adds an dependency on scala library + * (in specific, scala uses scala.Function1 for implementing operator overloading). + * Hence, to minimize the dependency, the DSL is implemented using static methods in HopDagPatternMatcher class. + * We can revisit this if we plan to add scala as hard dependency in SystemML. + * + * 9. The matcher part in each entry of patternMatcher uses the DSL implemented in HopDagPatternMatcher to improve readability. + * - The DSL mentioned above follows DML syntax that makes it convenient for an external contributer to understand and modify the HOP rewrites. + * - It is important to note that the developer has to add the same scoping rules as SystemML. + * - To create a newRoot HOP, it is important to have a mechanism to extract leaves of the matched pattern. This is implemented + * by using leaf() method. + * - Often, it is important to create a new HOP only if it it can fit into memory. For GPU, one can use the fitsOnGPU(multiplier) helper method. * */ -public class RewriteGPUSpecificOps extends HopRewriteRule { - - private static int _seq = 1; - - @Override - public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) { - if( roots == null ) - return roots; - - //one pass rewrite-descend (rewrite created pattern) - for( int i = 0; i < roots.size(); i++ ) - rule_GPUKernels(roots, roots.get(i), false ); - Hop.resetVisitStatus(roots, true); - - //one pass descend-rewrite (for rollup) - for( int i = 0; i < roots.size(); i++ ) - rule_GPUKernels(roots, roots.get(i), true ); - Hop.resetVisitStatus(roots, true); +public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher { + // ------------------------------------------------------------------------------------------- + + private static HopDagPatternMatcher util_channel_sums(HopDagPatternMatcher X, HopDagPatternMatcher C, HopDagPatternMatcher HW) { + // rowSums(matrix(colSums(X), rows=C, cols=HW)) + return rowSums(matrix( colSums(X), C, HW)); + } + + // Pattern 1: + private static final HopDagPatternMatcher _batchNormdX; + static { + HopDagPatternMatcher C = leaf("C", SCALAR); + HopDagPatternMatcher HW = leaf("HW", SCALAR); + HopDagPatternMatcher CHW = leaf("CHW", SCALAR); + HopDagPatternMatcher cache_inv_var = leaf("cache_inv_var", MATRIX); + HopDagPatternMatcher dout = leaf("dout", MATRIX); + HopDagPatternMatcher gamma = leaf("gamma", MATRIX); + HopDagPatternMatcher X = leaf("X", MATRIX); + HopDagPatternMatcher mean = leaf("mean", MATRIX); - return roots; - } - - @Override - public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { - if( root == null ) - return root; - - //one pass rewrite-descend (rewrite created pattern) - rule_GPUKernels(null, root, false ); - - root.resetVisitStatus(); - - //one pass descend-rewrite (for rollup) - rule_GPUKernels(null, root, true ); + HopDagPatternMatcher centered = bias_add(X, unaryMinus(mean)); - return root; - } - - /** - * Fuse the kernel - * - * @param roots root operators - * @param hop high-level operator - * @param descendFirst true if recursively process children first - */ - private void rule_GPUKernels(ArrayList<Hop> roots, Hop hop, boolean descendFirst) - { - if(hop.isVisited()) - return; - - //recursively process children - for( int i=0; i<hop.getInput().size(); i++) { - Hop hi = hop.getInput().get(i); - - //process childs recursively first (to allow roll-up) - if( descendFirst ) - rule_GPUKernels(roots, hi, descendFirst); //see below - - if(roots != null) { - //hi = batchNormTrain(roots, hop, hi, i); - } - hi = batchNormTest(hop, hi, i); - hi = channelSums(hop, hi, i); - hi = updateNesterovX(hop, hi, i); - - if( !descendFirst ) - rule_GPUKernels(roots, hi, descendFirst); - } - - hop.setVisited(); - } - - private static boolean isBiasAdd(Hop h) { - return HopRewriteUtils.isDnn(h, OpOpDnn.BIASADD); - } - - private static boolean isBiasMultiply(Hop h) { - return HopRewriteUtils.isDnn(h, OpOpDnn.BIASMULT); - } - - private static boolean fitsOnGPU(Hop h, double multiplier) { - double memEst = multiplier*h.getMemEstimate(); - return ConfigurationManager.isGPU() && h.dimsKnown() && OptimizerUtils.isMemoryBasedOptLevel() && - memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget(); - } - - private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean isFirstSameSizeAsOutput) { - return fitsOnGPU(inputHops, isFirstSameSizeAsOutput, 0); - } - - private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean isFirstSameSizeAsOutput, long additionalBytes) { - double memEst = additionalBytes; - boolean isFirst = true; - for(Hop h : inputHops) { - double est = h.getMemEstimate(); - if(est == OptimizerUtils.INVALID_SIZE) { - return false; - } - else if(isFirst && isFirstSameSizeAsOutput) { - isFirst = false; - memEst += 2*est; - } - else { - memEst += est; - } - } - return ConfigurationManager.isGPU() && OptimizerUtils.isMemoryBasedOptLevel() && - memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget(); - } - - private static boolean hasFirstInput(Hop h) { - return !(h == null || h.getInput() == null || h.getInput().size() < 1); - } - - private static Hop getFirstInput(Hop h) { - if(h == null || h.getInput() == null || h.getInput().size() < 1) { - throw new RuntimeException("No input available for " + h); - } - return h.getInput().get(0); - } - - private static boolean hasSecondInput(Hop h) { - return !(h == null || h.getInput() == null || h.getInput().size() < 2); - } - - private static Hop getSecondInput(Hop h) { - if(h == null || h.getInput() == null || h.getInput().size() < 2) { - throw new RuntimeException("Expected atleast two inputs for " + h); + // dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win) + HopDagPatternMatcher dnorm = bias_multiply(dout, gamma); + // dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win) + HopDagPatternMatcher dmean_norm_branch = util_channel_sums(bias_multiply(dnorm, unaryMinus(cache_inv_var)), C, HW) ; + // dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm, + // C, Hin, Win) # shape (C, 1) + HopDagPatternMatcher dvar = util_channel_sums(mult(mult(-0.5, bias_multiply(centered, pow(cache_inv_var, 3))), dnorm), C, HW); + // dmean_var_branch = util::channel_sums((-2*oneByN*oneByHW) * centered, C, Hin, Win) * dvar + HopDagPatternMatcher dmean_var_branch = + mult(util_channel_sums(mult(leaf("const3", SCALAR), centered), C, HW), dvar); + // dX_norm_branch = bias_multiply(dnorm, cache_inv_var) + HopDagPatternMatcher dX_norm_branch = bias_multiply(dnorm, cache_inv_var); + // dX_mean_branch = (oneByN*oneByHW) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean) + HopDagPatternMatcher dX_mean_branch = mult(leaf("const1", SCALAR), bias_add(matrix(0, 1, CHW), + plus(dmean_norm_branch, dmean_var_branch) )); + // dX_var_branch = (2*oneByN*oneByHW) * bias_multiply(centered, dvar) + HopDagPatternMatcher dX_var_branch = mult(leaf("const2", SCALAR), bias_multiply(centered, dvar)); + _batchNormdX = plus(plus(dX_norm_branch, dX_mean_branch), dX_var_branch).fitsOnGPU(2); + } + private static final Function<Hop, Hop> _batchNormdXReplacer = hi -> { + // double CHW = _batchNormdX.getLiteralValue("CHW"); + double HW = _batchNormdX.getLiteralValue("HW"); + double C = _batchNormdX.getLiteralValue("C"); + double const1 = _batchNormdX.getLiteralValue("const1"); // (oneByN*oneByHW) + double const2 = _batchNormdX.getLiteralValue("const2"); // (2*oneByN*oneByHW) + double const3 = _batchNormdX.getLiteralValue("const3"); // (-2*oneByN*oneByHW) + if(2*const1 == const2 && const3 == -const2 && + hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("mean")) && + hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("cache_inv_var")) && + _batchNormdX.getMatchedHop("X").getDim2() == C*HW && + checkDimensions(_batchNormdX.getMatchedHop("gamma"), (long)C, 1)) { + LOG.debug("Applied batchNormdX rewrite."); + Hop newHop = HopRewriteUtils.createDnnOp(_batchNormdX, OpOpDnn.BATCH_NORM2D_BACKWARD_DX, + "X", "dout", "gamma", "mean", "cache_inv_var"); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); } - return h.getInput().get(1); - } - - private static Hop getThirdInput(Hop h) { - if(h == null || h.getInput() == null || h.getInput().size() < 3) { - throw new RuntimeException("Expected atleast three inputs for " + h); - } - return h.getInput().get(2); - } - - private static boolean isUnaryMinus(Hop h) { - return HopRewriteUtils.isBinary(h, OpOp2.MINUS) - && HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0); - } - - private static boolean isOneDivideBySqrt(Hop h) { - return HopRewriteUtils.isBinary(h, OpOp2.DIV) - && HopRewriteUtils.isUnary(h.getInput().get(1), OpOp1.SQRT) - && HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 1); - } - - private static Hop channelSums(Hop parent, Hop hi, int pos) { - if(hi instanceof AggUnaryOp) { - AggUnaryOp hop = (AggUnaryOp) hi; - // output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize)) - if( hop.getOp() == AggOp.SUM && hop.getDirection() == Direction.Row - && HopRewriteUtils.isReorg(hop.getInput().get(0), ReOrgOp.RESHAPE) ) { - Hop colSumsInput = hop.getInput().get(0).getInput().get(0); - if(colSumsInput instanceof AggUnaryOp && ((AggUnaryOp)colSumsInput).getOp() == AggOp.SUM && ((AggUnaryOp)colSumsInput).getDirection() == Direction.Col) { - ArrayList<Hop> inHops = new ArrayList<Hop>(); - inHops.add(colSumsInput.getInput().get(0)); - long numChannels = Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(1)); - long HW = Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(2)); - if(numChannels > 0 && HW > 0 && fitsOnGPU(inHops, false, numChannels*8)) { - inHops.add(new LiteralOp(numChannels)); - inHops.add(new LiteralOp(HW)); - LOG.debug("Applied channelSums rewrite."); - Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(), - OpOpDnn.CHANNEL_SUMS, inHops); - return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); - } - } - } + else if(DEBUG_REWRITES) { + System.out.println("Couldnot apply batchNormdX rewrite."); + System.out.println((2*const1) + " == " + const2 + " && " + const3 + "== -" + const2 + + " && " + hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("mean")) + " && " + + hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("cache_inv_var")) + " && " + + _batchNormdX.getMatchedHop("X").getDim2() + " == " + C + "*" + HW + " && " + + checkDimensions(_batchNormdX.getMatchedHop("gamma"), (long)C, 1)); } return hi; - } - - private static boolean isRowMeans(Hop h) { - return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row; - } - - private static boolean isRowVars(Hop h) { - return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row; - } - - private static boolean isRowVars(Hop h, Hop childHop) { - return isRowVars(h) && getFirstInput(h) == childHop; - } - - private static boolean isColMeans(Hop h) { - return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col; - } - - private static boolean isColVars(Hop h) { - return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col; - } - - private static boolean isReshape(Hop h) { - return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE; - } - - private static boolean isReshape(Hop h, long expectedRows, long expectedCols) { - return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE && - Hop.computeSizeInformation(getSecondInput(h)) == expectedRows && - Hop.computeSizeInformation(getThirdInput(h)) == expectedCols; - } - - private static boolean isBinaryAdd(Hop h) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS; - } - - private static boolean isBinaryMSAdd(Hop h, double expectedValue) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS - && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR - && OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue; - } - - private static boolean isBinaryMMAdd(Hop h) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS - && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX; - } - - private static boolean isBinaryMMMinus(Hop h) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS - && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX; - } - - private static boolean isBinaryMSMult(Hop h, double expectedValue) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT - && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR - && OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue; - } - - private static boolean isBinarySSMinus(Hop h) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS - && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR; - } - - private static boolean isBinarySSDiv(Hop h) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV - && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR; - } - - private static boolean isBinarySMDiv(Hop h, double expectedValue) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV - && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.MATRIX - && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>()) == expectedValue; - } - - private static boolean isAnyBinaryAdd(ArrayList<Hop> hops) { - if(hops != null) { - for(Hop h : hops) { - if(h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS) - return true; - } - } - return false; - } - - private static boolean isBinaryMSMult(Hop h) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT - && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR; - } - - private static boolean isBinarySMMult(Hop h) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT - && getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR; - } - - private static boolean isBinarySMMult(Hop h, double expectedVal) { - return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT - && getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR - && getValue(getFirstInput(h)) == expectedVal; - } - - private static double getValue(Hop h) { - return OptimizerUtils.rEvalSimpleDoubleExpression(h, new HashMap<>()); - } - - /** - * Checks if the "mean" hop is a moving average of mean in batch normalization layer. - * - * @param mean hop to check against - * @param X input data - * @return true if the "mean" hop is a moving average of mean in batch normalization layer. - */ - private static boolean isBatchNormTrainMean(Hop mean, Hop X) { - // subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win) - // mean = rowMeans(subgrp_means) - return isRowMeans(mean) && isReshape(getFirstInput(mean)) && isColMeans(getFirstInput(getFirstInput(mean))) - && getFirstInput(getFirstInput(getFirstInput(mean))) == X; - } - - /** - * Checks for nrow(X) pattern - * - * @param expr hop to be matched - * @param X input X - * @return true if expr is nrow(X) else false - */ - private static boolean isNrowOfX(Hop expr, Hop X) { - return expr instanceof UnaryOp && ((UnaryOp)expr).getOp() == OpOp1.NROW && getFirstInput(expr) == X; - } - - /** - * Checks for the colVars(X) * ((N-1)/N) pattern - * - * @param expr hop to be matched - * @param X input X - * @param ignoreCorrectionTerm whether to ignore the correction term ((N-1)/N). - * @return true if expr is colVars(X) * ((N-1)/N) else false - */ - private static boolean isCorrectedColVars(Hop expr, Hop X, boolean ignoreCorrectionTerm) { - // colVars(X) * ((N-1)/N) - if(isColVars(expr) && getFirstInput(expr) == X) { - // Support no correction as well in this rewrite - return true; - } - else if(X.rowsKnown()) { - return isBinaryMSMult(expr, ((double)X.getDim1()-1)/X.getDim1()) && - isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X; + }; + + + + // Pattern 2: + // subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) + // var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) + private static final HopDagPatternMatcher _batchNormUpdatedVar; + static { + HopDagPatternMatcher subgrp_vars = + matrix( + mult(colVars(leaf("X", MATRIX).fitsOnGPU(2)), leaf("varConst1", SCALAR)), // colVars(X) * ((N-1)/N) + leaf("C", SCALAR), // rows=C + leaf("HW", SCALAR)); // cols=Hin*Win + _batchNormUpdatedVar = + mm_plus( + rowMeans(subgrp_vars), + mult(rowVars(leaf("subgrp_means", MATRIX)), leaf("varConst2", SCALAR))); // rowVars(subgrp_means)*varConst2 + } + private static final Function<Hop, Hop> _batchNormUpdatedVarReplacer = hi -> { + double HW = _batchNormUpdatedVar.getLiteralValue("HW"); + if(_batchNormUpdatedVar.getLiteralValue("varConst2") == ((HW-1)/HW)) { + LOG.debug("Applied batchNormUpdatedVar rewrite."); + Hop newHop = HopRewriteUtils.createDnnOp(_batchNormUpdatedVar, OpOpDnn.UPDATE_EMA_VAR, + // varConst1 => ((N-1)/N) + "subgrp_means", "X", "C", "HW", "varConst1"); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); } - else if(isBinaryMSMult(expr) && - isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X) { - if(ignoreCorrectionTerm) { - return true; - } - Hop tmp = getSecondInput(expr); - // ((N-1)/N) - boolean isNMinus1Pattern = isBinarySSDiv(tmp) && isBinarySSMinus(getFirstInput(tmp)) && - getFirstInput(getFirstInput(tmp)) == getSecondInput(tmp) && - OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(getFirstInput(tmp)), new HashMap<>()) == 1; - boolean ret = isNMinus1Pattern && isNrowOfX(getSecondInput(tmp), X); - if(LOG.isDebugEnabled()) { - LOG.debug("Is the corrected column variance pattern for batch_norm_train rewrite when number of rows of X unknown matched:" + ret); - } - return ret; - } - return false; - } - - /** - * Checks if the "var" hop is a moving average of variance in batch normalization layer. - * - * @param mean previously matched mean hop - * @param var the hop to check against - * @param X input data hop - * @param subgrpMeans mean for subgroup mean - * @param ignoreCorrectionTerm whether to incore the correct term (see isCorrectedColVars method in this class) - * @return true if the "var" hop is a moving average of variance in batch normalization layer. - */ - private static boolean isBatchNormTrainVar(Hop mean, Hop var, Hop X, Hop subgrpMeans, boolean ignoreCorrectionTerm) { - long numChannels = Hop.computeSizeInformation(getSecondInput(getFirstInput(mean))); - long HW = Hop.computeSizeInformation(getThirdInput(getFirstInput(mean))); - // subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) - // var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) - return numChannels > 0 && HW > 0 && isBinaryMMAdd(var) && isRowMeans(getFirstInput(var)) && - // matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) - isReshape(getFirstInput(getFirstInput(var)), numChannels, HW) && - isCorrectedColVars(getFirstInput(getFirstInput(getFirstInput(var))), X, ignoreCorrectionTerm) && - // rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) - isBinaryMSMult(getSecondInput(var), ((((double)HW)-1)/HW)) && - isRowVars(getFirstInput(getSecondInput(var)), subgrpMeans); - } - - /** - * Checks and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean - * - * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean - * @param mu value of mu - * @return an array [ema_mean_upd, ema_mean] if expression matched, else null - */ - private static Hop [] getUpdatedMovingAverageExpressions(Hop rhsTimesOp, double mu) { - if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || - !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0))) - return null; - - // Check (1-mu)*mean - double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>()); - Hop plusOp = rhsTimesOp.getParent().get(0); - Hop lhsTimesOp = null; - if(plusOp.getInput().get(0) == rhsTimesOp) { - lhsTimesOp = plusOp.getInput().get(1); - } - else { - lhsTimesOp = plusOp.getInput().get(0); - } - - if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 && - isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) { - return new Hop[] { - plusOp.getParent().get(0), - getSecondInput(lhsTimesOp), - getSecondInput(rhsTimesOp) - }; - } - return null; - } - - /** - * Checks (if exactly one of rhsTimesOps) and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean - * - * @param rhsTimesOps array list of hop representing BinaryOp of expression (1-mu)*mean - * @param mu value of mu - * @return an array [ema_mean_upd, ema_mean] if any of the expression matched, else null - */ - private static Hop [] getUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps, double mu) { - if(rhsTimesOps == null || rhsTimesOps.size() == 0) - return null; - - Hop [] ret = null; - for(Hop h : rhsTimesOps) { - boolean matched = isUpdatedMovingAverageExpression(h, mu); - if(matched && ret != null) { - return null; // Multiple matches, cannot decide which one to fuse - } - else if(matched) { - ret = getUpdatedMovingAverageExpressions(h, mu); - } - } - - return ret; - } - - /** - * Checks and returns the mu in the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean - * - * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean - * @return value of mu if the expression matched else null - */ - private static Double getMuFromUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps) { - if(rhsTimesOps == null || rhsTimesOps.size() == 0) - return null; - - Double ret = null; - for(Hop h : rhsTimesOps) { - boolean matched = isUpdatedMovingAverageExpression(h); - if(matched && ret != null) { - return null; // Multiple matches, cannot decide which one to fuse - } - else if(matched) { - ret = -(OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>())-1); - } - } - return ret; - } - - /** - * Checks for the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean - * - * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean - * @return true if expression matched - */ - private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp) { - if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || - !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0))) - return false; - - // Check (1-mu)*mean - Hop plusOp = rhsTimesOp.getParent().get(0); - Hop lhsTimesOp = null; - if(plusOp.getInput().get(0) == rhsTimesOp) { - lhsTimesOp = plusOp.getInput().get(1); - } - else { - lhsTimesOp = plusOp.getInput().get(0); - } - - if(plusOp.getParent() != null && plusOp.getParent().size() == 1 && isBinarySMMult(lhsTimesOp)) { - return true; - } - return false; - } + return hi; + }; - // ema_mean_upd = mu*ema_mean + (1-mu)*mean - // Returns true if expression matched, else false - private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp, double mu) { - if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || - !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0))) - return false; - - // Check (1-mu)*mean - double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>()); - Hop plusOp = rhsTimesOp.getParent().get(0); - Hop lhsTimesOp = null; - if(plusOp.getInput().get(0) == rhsTimesOp) { - lhsTimesOp = plusOp.getInput().get(1); - } - else { - lhsTimesOp = plusOp.getInput().get(0); - } - if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 && - isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) { - return true; - } - return false; - } - - /** - * Checks for the expression 1/sqrt(denom) - * - * @param denom denominator of the expression to be matched - * @return true if the expression 1/sqrt(denom) matched else false - */ - private static boolean isOneBySqrt(Hop denom) { - return denom.getParent() != null && denom.getParent().get(0) instanceof UnaryOp && - ((UnaryOp)denom.getParent().get(0)).getOp() == OpOp1.SQRT && - denom.getParent().get(0).getParent() != null && denom.getParent().get(0).getParent().size() == 1 && - isBinarySMDiv(denom.getParent().get(0).getParent().get(0), 1); - } - - /** - * Checks for the batch norm (mode="train") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar - * and returns a new FunctionOp if matched - * - * @param roots root hops of the given statement block - * @param parent parent of the input - * @param hi input to be matched - * @param pos position - * @return a new FunctionOp or hi - */ - @SuppressWarnings("unused") - private static Hop batchNormTrain(ArrayList<Hop> roots, Hop parent, Hop hi, int pos) - { + // Pattern 3: + private static final HopDagPatternMatcher _batchNormTest; + static { // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps)) + HopDagPatternMatcher norm = + bias_multiply( + bias_add(leaf("X", MATRIX), unaryMinus(leaf("mean", MATRIX))), // bias_add(X, -mean) + div(1, sqrt(plus(leaf("var", MATRIX), leaf("eps", SCALAR))))); // 1/sqrt(var+eps) // hi = bias_add(bias_multiply(norm, gamma), beta) - // 2x for input and output and 1x for overhead - // fitsOnGPU(hi, 3) - if( hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) ) { - Hop norm = getFirstInput(getFirstInput(hi)); - if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) - && hasSecondInput(getFirstInput(norm)) && isUnaryMinus(getSecondInput(getFirstInput(norm))) - && isOneDivideBySqrt(getSecondInput(norm))) { - double eps = 0; - Hop var = getFirstInput(getSecondInput(getSecondInput(norm))); - if(isBinaryAdd(var) && (getFirstInput(var) instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) { - // eps + ema_var - if(getFirstInput(var) instanceof LiteralOp) { - eps = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>()); - var = getSecondInput(var); - } - else { - eps = OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new HashMap<>()); - var = getFirstInput(var); - } - } - // Generate batch norm test op - Hop X = getFirstInput(getFirstInput(norm)); - Hop mean = getSecondInput(getSecondInput(getFirstInput(norm))); - - if(hasFirstInput(mean) && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), false) && - mean.getParent() != null && mean.getParent().size() >= 2 && - var.getParent() != null && var.getParent().size() == 2) { - Hop gamma = getSecondInput(getFirstInput(hi)); - Hop beta = getSecondInput(hi); - - // Always get mu from variance as it will have exactly one match of fusion pattern - Double potentialMu = getMuFromUpdatedMovingAverageExpressions(var.getParent()); - if(potentialMu == null) - return hi; - double mu = potentialMu; - - Hop [] means = getUpdatedMovingAverageExpressions(mean.getParent(), mu); - Hop [] vars = getUpdatedMovingAverageExpressions(var.getParent(), mu); - if(means == null || vars == null) - return hi; - - Hop varPlusEps = null; - boolean isFirstBinaryAddOp = isAnyBinaryAdd(var.getParent().get(0).getParent()); - boolean isSecondBinaryAddOp = isAnyBinaryAdd(var.getParent().get(1).getParent()); - if(isFirstBinaryAddOp && !isSecondBinaryAddOp) { - varPlusEps = var.getParent().get(1); - } - else if(!isFirstBinaryAddOp && isSecondBinaryAddOp) { - varPlusEps = var.getParent().get(0); - } - if(varPlusEps != null && isBinaryMSAdd(varPlusEps, eps) && isOneBySqrt(varPlusEps)) { - - Hop cache_var = varPlusEps.getParent().get(0).getParent().get(0); - Hop ema_mean_upd = means[0]; - Hop ema_var_upd = vars[0]; - Hop ema_mean = means[1]; - Hop ema_var = vars[1]; - Hop cache_mean = means[2]; - - - ArrayList<Hop> inHops = new ArrayList<Hop>(); - inHops.add(X); - inHops.add(gamma); - inHops.add(beta); - inHops.add(ema_mean); - inHops.add(ema_var); - inHops.add(new LiteralOp(eps)); - inHops.add(new LiteralOp(mu)); - Hop [] oldHops = {hi, ema_mean_upd, ema_var_upd, cache_mean, cache_var}; - - // Since FunctionOp adds transientwrite explicitly, persistent writes are not supported - if(!isAnyPersistentWrite(oldHops)) { - LOG.debug("Applied batchNormTrain rewrite."); - ArrayList<Hop> outputs = getMultiOutputHops(roots, oldHops); - FunctionOp ret = new FunctionOp(FunctionType.MULTIRETURN_BUILTIN, DMLProgram.INTERNAL_NAMESPACE, "batch_norm2d_train", - null, inHops, outputs.stream().map(h -> h.getName()).toArray(String[]::new), outputs); - Collections.reverse(roots); - roots.add(ret); - Collections.reverse(roots); - return ret; - } - } - - } + _batchNormTest = + bias_add( + bias_multiply(norm, leaf("gamma", MATRIX)), + leaf("beta", MATRIX)) + .fitsOnGPU(3); + } + private static final Function<Hop, Hop> _batchNormTestReplacer = hi -> { + LOG.debug("Applied batchNormTest rewrite."); + Hop newHop = HopRewriteUtils.createDnnOp(_batchNormTest, OpOpDnn.BATCH_NORM2D_TEST, "X", "gamma", "beta", "mean", "var", "eps"); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); + }; + + // Pattern 4: + // rowSums(matrix(colSums(X), rows=C, cols=HW)) + private static final HopDagPatternMatcher _channelSums = util_channel_sums(leaf("X", MATRIX).fitsOnGPU(2), leaf("C", SCALAR), leaf("HW", SCALAR));; + private static final Function<Hop, Hop> _channelSumsReplacer = hi -> { + LOG.debug("Applied channelSums rewrite."); + Hop newHop = HopRewriteUtils.createDnnOp(_channelSums, OpOpDnn.CHANNEL_SUMS, "X", "C", "HW"); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); + }; + + // Pattern 5: + // (X - mu*v_prev) + (1+mu)*v + private static final HopDagPatternMatcher _updateNesterovX = + mm_plus( + minus( // X - mu*v_prev + leaf("X", MATRIX), + mult( // mu*v_prev + leaf("mu", SCALAR), + leaf("v_prev", MATRIX))), + mult( // (1+mu)*v + leaf("onePlusMu", SCALAR), + leaf("v", MATRIX))) + .fitsOnGPU(3); + private static final Function<Hop, Hop> _updateNesterovXReplacer = hi -> { + if((1+_updateNesterovX.getLiteralValue("mu")) == _updateNesterovX.getLiteralValue("onePlusMu")) { + Hop X = _updateNesterovX.getMatchedHop("X"); + Hop v = _updateNesterovX.getMatchedHop("v"); + Hop v_prev = _updateNesterovX.getMatchedHop("v_prev"); + if(hasSameDimensions(X, v) && hasSameDimensions(X, v_prev)) { + LOG.debug("Applied updateNesterovX rewrite."); + Hop newHop = HopRewriteUtils.createDnnOp(_updateNesterovX, OpOpDnn.UPDATE_NESTEROV_X, "X", "v", "v_prev", "mu"); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); } } - return hi; - } - - // ------------------------------------------------------------ - /** - * Checks if any of the given output hop is a persistent write. - * - * @param outputHops output hops to check - * @return true if any of the hop is a persistent write else false. - */ - private static boolean isAnyPersistentWrite(Hop [] outputHops) { - for(Hop outHop : outputHops) { - if(HopRewriteUtils.isData(outHop, DataOpTypes.PERSISTENTWRITE)) - return true; + }; + + // Pattern 6: + // matrix(colMeans(X), rows=C, cols=Hin*Win) + // This avoids unnecessary copy by the reshape operator + private static final HopDagPatternMatcher _reshapeColMeans = + matrix( + colMeans(leaf("X", MATRIX).fitsOnGPU(2)), // colMeans(X) + leaf("C", SCALAR), + leaf("HW", SCALAR)); + private static final Function<Hop, Hop> _reshapeColMeansReplacer = hi -> { + LOG.debug("Applied reshapeColMeans rewrite."); + Hop newHop = HopRewriteUtils.createDnnOp(_reshapeColMeans, OpOpDnn.RESHAPE_COLMEANS, "X", "C", "HW"); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); + }; + + // Pattern 7: + // mu*ema_mean + (1-mu)*mean + private static final HopDagPatternMatcher _updateEMA = + mm_plus( + mult( // mu*ema_mean + leaf("mu", SCALAR), + leaf("ema_mean", MATRIX)), + mult( // (1-mu)*mean + leaf("oneMinusMu", SCALAR), + leaf("mean", MATRIX))) + .fitsOnGPU(3); + private static final Function<Hop, Hop> _updateEMAReplacer = hi -> { + if((1-_updateEMA.getLiteralValue("mu")) == _updateEMA.getLiteralValue("oneMinusMu")) { + LOG.debug("Applied updateEMA rewrite."); + Hop newHop = HopRewriteUtils.createDnnOp(_updateEMA, OpOpDnn.UPDATE_EMA, "ema_mean", "mean", "mu"); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); } - return false; - } - - /** - * Returns output hop for a multi-output FunctionOp to be created by rewrite. - * - * @param roots root hops of statement block - * @param oldHops old output hops of the pattern - * @return new output hops that should be passed to FunctionOp - */ - private static ArrayList<Hop> getMultiOutputHops(ArrayList<Hop> roots, Hop [] oldHops) { - ArrayList<Hop> ret = new ArrayList<>(); - for(int i = 0; i < oldHops.length; i++) { - // Create a transient read as FunctionOp will add a transient write. - if(HopRewriteUtils.isData(oldHops[i], DataOpTypes.PERSISTENTWRITE)) - throw new RuntimeException("Persistent write is not supported as output for the given rewrite." + oldHops[i]); - // Generate a new name if the old output was not transient write. - String name = HopRewriteUtils.isData(oldHops[i], DataOpTypes.TRANSIENTWRITE) ? oldHops[i].getName() : "_genGPU" + (_seq++); - DataOp tRead = HopRewriteUtils.createTransientRead(name, oldHops[i]); - HopRewriteUtils.rewireAllParentChildReferences(oldHops[i], tRead); - ret.add(tRead); - // Remove old output from roots to avoid unnecessary computation. - if(roots.contains(oldHops[i])) { - roots.remove(oldHops[i]); - } + return hi; + }; + + // Pattern 8: + // 1/sqrt(var+epsilon) + private static final HopDagPatternMatcher _invVar = + div(1, + sqrt( // var+epsilon + plus( leaf("var", MATRIX), + leaf("eps", SCALAR) ))) + .fitsOnGPU(2); + private static final Function<Hop, Hop> _invVarReplacer = hi -> { + LOG.debug("Applied computeInverseVariance rewrite."); + Hop newHop = HopRewriteUtils.createDnnOp(_invVar, OpOpDnn.INV_VAR, "var", "eps"); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); + }; + + + private static ArrayList<HopPatternRewriter> _rewriters = null; + public ArrayList<HopPatternRewriter> getPatternRewriter() { + if(_rewriters == null) { + ArrayList<HopPatternRewriter> rewriters = new ArrayList<>(); + rewriters.add(new HopPatternRewriter("batchNormdX", _batchNormdX, _batchNormdXReplacer)); + rewriters.add(new HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, _batchNormUpdatedVarReplacer)); + rewriters.add(new HopPatternRewriter("batchNormTest", _batchNormTest, _batchNormTestReplacer)); + rewriters.add(new HopPatternRewriter("channelSums", _channelSums, _channelSumsReplacer)); + rewriters.add(new HopPatternRewriter("updateNesterovX", _updateNesterovX, _updateNesterovXReplacer)); + rewriters.add(new HopPatternRewriter("reshapeColMeans", _reshapeColMeans, _reshapeColMeansReplacer)); + rewriters.add(new HopPatternRewriter("updateEMA", _updateEMA, _updateEMAReplacer)); + rewriters.add(new HopPatternRewriter("invVar", _invVar, _invVarReplacer)); + _rewriters = rewriters; } - return ret; + return _rewriters; } - // ------------------------------------------------------------ - /** - * Checks for the nesterov_update_x pattern (X = X - mu*v_prev + (1+mu)*v) - * and returns a new DnnOp if matched - * - * @param parent parent of the input - * @param hi input to be matched - * @param pos position - * @return a new DnnOp or hi - */ - private static Hop updateNesterovX(Hop parent, Hop hi, int pos) { - if(fitsOnGPU(hi, 4) && isBinaryMMAdd(hi) && isBinaryMMMinus(getFirstInput(hi)) - && isBinarySMMult(getSecondInput(getFirstInput(hi))) - && isBinarySMMult(getSecondInput(hi))) { - Hop onePlusMu = getFirstInput(getSecondInput(hi)); - Hop tmp = getSecondInput(getFirstInput(hi)); - Hop mu = getFirstInput(tmp); - if(isOnePlusMu(onePlusMu, mu)) { - Hop v_prev = getSecondInput(tmp); - Hop v = getSecondInput(getSecondInput(hi)); - Hop X = getFirstInput(getFirstInput(hi)); - if(hasSameDimensions(X, v) && hasSameDimensions(X, v_prev)) { - ArrayList<Hop> inHops = new ArrayList<Hop>(); - inHops.add(X); - inHops.add(v); - inHops.add(v_prev); - inHops.add(mu); - LOG.debug("Applied updateNesterovX rewrite."); - Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(), - OpOpDnn.UPDATE_NESTEROV_X, inHops); - return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); - } - } - } - return hi; - } + + // ------------------------------------------------------------------------------------------- private static boolean hasSameDimensions(Hop x, Hop y) { return x.dimsKnown() && y.dimsKnown() && (x.getDim1() == y.getDim1()) && (x.getDim2() == y.getDim2()); } - private static boolean isOnePlusMu(Hop onePlusMu, Hop mu) { - return (isBinarySMMult(onePlusMu, 1.0) && getSecondInput(onePlusMu) == mu) || - getValue(onePlusMu) == getValue(mu) + 1; - } - - /** - * Checks for the batch norm (mode="test") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar - * and returns a new DnnOp if matched - * - * @param parent parent of the input - * @param hi input to be matched - * @param pos position - * @return a new DnnOp or hi - */ - private static Hop batchNormTest(Hop parent, Hop hi, int pos) { - // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps)) - // hi = bias_add(bias_multiply(norm, gamma), beta) - // 2x for input and output and 1x for overhead - if(hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) { - Hop norm = getFirstInput(getFirstInput(hi)); - if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) - && isUnaryMinus(getSecondInput(getFirstInput(norm))) - && isOneDivideBySqrt(getSecondInput(norm))) { - double eps = 0; - Hop var = getFirstInput(getSecondInput(getSecondInput(norm))); - if( HopRewriteUtils.isBinary(var, OpOp2.PLUS) && - (getFirstInput(var) instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) { - // eps + ema_var - if(getFirstInput(var) instanceof LiteralOp) { - eps = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>()); - var = getSecondInput(var); - } - else { - eps = OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new HashMap<>()); - var = getFirstInput(var); - } - } - // Generate batch norm test op - Hop X = getFirstInput(getFirstInput(norm)); - Hop mean = getSecondInput(getSecondInput(getFirstInput(norm))); - - // This guard disallows eager fusion of train batch normalization into test batch normalization - boolean potentialForBatchNormTrain = !X.rowsKnown() && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), true); - if(!potentialForBatchNormTrain) { - Hop gamma = getSecondInput(getFirstInput(hi)); - Hop beta = getSecondInput(hi); - ArrayList<Hop> inHops = new ArrayList<Hop>(); - inHops.add(X); - inHops.add(gamma); - inHops.add(beta); - inHops.add(mean); - inHops.add(var); - inHops.add(new LiteralOp(eps)); - if(fitsOnGPU(inHops, true)) { - LOG.debug("Applied batchNormTest rewrite."); - Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(), - OpOpDnn.BATCH_NORM2D_TEST, inHops); - return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); - } - } - else { - LOG.debug("Skipping batchNormTest rewrite as there is potential for batch normalization train rewrite after recompilation."); - } - } - } - - return hi; + private static boolean checkDimensions(Hop x, long dim1, long dim2) { + return x.dimsKnown() && (x.getDim1() == dim1) && (x.getDim2() == dim2); } -} +} \ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/lops/DnnTransform.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java b/src/main/java/org/apache/sysml/lops/DnnTransform.java index 3183b5f..2d2d5f1 100644 --- a/src/main/java/org/apache/sysml/lops/DnnTransform.java +++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java @@ -32,7 +32,8 @@ public class DnnTransform extends Lop RELU_MAX_POOLING, RELU_MAX_POOLING_BACKWARD, RELU_BACKWARD, CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST, - UPDATE_NESTEROV_X + UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR, + BATCH_NORM2D_BACKWARD_DX } private OperationTypes operation; @@ -167,11 +168,26 @@ public class DnnTransform extends Lop case CHANNEL_SUMS: return "channel_sums"; + case INV_VAR: + return "inv_var"; + case UPDATE_NESTEROV_X: return "update_nesterov_x"; case BATCH_NORM2D_TEST: return "batch_norm2d_test"; + + case BATCH_NORM2D_BACKWARD_DX: + return "batch_norm2d_bwd_dx"; + + case UPDATE_EMA_VAR: + return "update_ema_var"; + + case UPDATE_EMA: + return "update_ema"; + + case RESHAPE_COLMEANS: + return "reshape_colmeans"; default: throw new UnsupportedOperationException(this.printErrorLocation() + "Instruction is not defined for Transform operation " + operation); @@ -181,7 +197,8 @@ public class DnnTransform extends Lop @Override public String getInstructions(String input, String bias, String output) { - if(operation == OperationTypes.BIAS_ADD || operation == OperationTypes.BIAS_MULTIPLY || operation == OperationTypes.RELU_BACKWARD) { + if(operation == OperationTypes.BIAS_ADD || operation == OperationTypes.BIAS_MULTIPLY || operation == OperationTypes.RELU_BACKWARD + || operation == OperationTypes.INV_VAR) { StringBuilder sb = new StringBuilder(); sb.append( getExecType() ); @@ -190,7 +207,7 @@ public class DnnTransform extends Lop sb.append( OPERAND_DELIMITOR ); sb.append( getInputs().get(0).prepInputOperand(input)); sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(0).prepInputOperand(bias)); + sb.append( getInputs().get(1).prepInputOperand(bias)); //output sb.append( OPERAND_DELIMITOR ); sb.append( this.prepOutputOperand(output)); @@ -212,7 +229,7 @@ public class DnnTransform extends Lop @Override public String getInstructions(String input, String C, String HW, String output) { - if(operation == OperationTypes.CHANNEL_SUMS) { + if(operation == OperationTypes.CHANNEL_SUMS || operation == OperationTypes.RESHAPE_COLMEANS || operation == OperationTypes.UPDATE_EMA) { StringBuilder sb = new StringBuilder(); sb.append( getExecType() ); @@ -306,6 +323,34 @@ public class DnnTransform extends Lop throw new LopsException("The operation is not supported with six operands:" + operation.name()); } } + + public String getInstructions(String input1, String input2, String input3, String input4, String input5, String output) { + if(operation == OperationTypes.UPDATE_EMA_VAR || operation == OperationTypes.BATCH_NORM2D_BACKWARD_DX) { + StringBuilder sb = new StringBuilder(); + sb.append( getExecType() ); + + sb.append( OPERAND_DELIMITOR ); + sb.append( getOpcode() ); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(0).prepInputOperand(input1)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(1).prepInputOperand(input2)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(2).prepInputOperand(input3)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(3).prepInputOperand(input4)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(4).prepInputOperand(input5)); + //output + sb.append( OPERAND_DELIMITOR ); + sb.append( this.prepOutputOperand(output)); + + return sb.toString(); + } + else { + throw new LopsException("The operation is not supported with six operands:" + operation.name()); + } + } public void appendOpcode(StringBuilder sb) { sb.append( getExecType() ); http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index 9f3a1e2..fe86dc8 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -270,58 +270,6 @@ public class BuiltinFunctionExpression extends DataIdentifier setDimensions(dc0, getFifthExpr()); break; } - case BATCH_NORM2D: - { - // Input: image, scale, bias, runningMean, runningVar, mode, epsilon, exponentialAverageFactor - checkNumParameters(8); - checkMatrixParam(getFirstExpr()); - checkMatrixParam(getSecondExpr()); - checkMatrixParam(getThirdExpr()); - checkMatrixParam(getFourthExpr()); - checkMatrixParam(getFifthExpr()); - - // Output: ret, retRunningMean, retRunningVar, resultSaveMean, resultSaveInvVariance - // setup output properties - if(getOutputs().length != 5) - raiseValidateError("batch_norm2d has 5 outputs", false); - - DataIdentifier ret = (DataIdentifier) getOutputs()[0]; - DataIdentifier retRunningMean = (DataIdentifier) getOutputs()[1]; - DataIdentifier retRunningVar = (DataIdentifier) getOutputs()[2]; - DataIdentifier resultSaveMean = (DataIdentifier) getOutputs()[3]; - DataIdentifier resultSaveInvVariance = (DataIdentifier) getOutputs()[4]; - - setDimensions(ret, getFirstExpr()); - setDimensions(retRunningMean, getFourthExpr()); - setDimensions(retRunningVar, getFourthExpr()); - setDimensions(resultSaveMean, getFourthExpr()); - setDimensions(resultSaveInvVariance, getFourthExpr()); - break; - } - case BATCH_NORM2D_BACKWARD: - { - // Input: image, dout, scale, epsilon, savedMean, savedInvVariance - checkNumParameters(6); - checkMatrixParam(getFirstExpr()); - checkMatrixParam(getSecondExpr()); - checkMatrixParam(getThirdExpr()); - checkMatrixParam(getFifthExpr()); - checkMatrixParam(getSixthExpr()); - - // Output: dX, dScale, dBias - // setup output properties - if(getOutputs().length != 3) - raiseValidateError("batch_norm2d_backward has 3 outputs", false); - - DataIdentifier dX = (DataIdentifier) getOutputs()[0]; - DataIdentifier dScale = (DataIdentifier) getOutputs()[1]; - DataIdentifier dBias = (DataIdentifier) getOutputs()[2]; - - setDimensions(dX, getFirstExpr()); - setDimensions(dScale, getThirdExpr()); - setDimensions(dBias, getThirdExpr()); - break; - } case EIGEN: checkNumParameters(1); checkMatrixParam(getFirstExpr()); @@ -1451,8 +1399,7 @@ public class BuiltinFunctionExpression extends DataIdentifier // always unconditional (because unsupported operation) BuiltinFunctionOp op = getOpCode(); if( op==BuiltinFunctionOp.EIGEN || op==BuiltinFunctionOp.LU || op==BuiltinFunctionOp.QR || op==BuiltinFunctionOp.SVD - || op==BuiltinFunctionOp.LSTM || op==BuiltinFunctionOp.LSTM_BACKWARD - || op==BuiltinFunctionOp.BATCH_NORM2D || op==BuiltinFunctionOp.BATCH_NORM2D_BACKWARD) + || op==BuiltinFunctionOp.LSTM || op==BuiltinFunctionOp.LSTM_BACKWARD) raiseValidateError("Function "+op+" needs to be called with multi-return assignment.", false, LanguageErrorCodes.INVALID_PARAMETERS); else raiseValidateError("Unsupported function "+op, false, LanguageErrorCodes.INVALID_PARAMETERS); @@ -1535,8 +1482,6 @@ public class BuiltinFunctionExpression extends DataIdentifier case EIGEN: case LSTM: case LSTM_BACKWARD: - case BATCH_NORM2D: - case BATCH_NORM2D_BACKWARD: case SVD: return true; default: @@ -1956,10 +1901,6 @@ public class BuiltinFunctionExpression extends DataIdentifier bifop = Expression.BuiltinFunctionOp.LSTM; else if (functionName.equals("lstm_backward")) bifop = Expression.BuiltinFunctionOp.LSTM_BACKWARD; - else if (functionName.equals("batch_norm2d")) - bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D; - else if (functionName.equals("batch_norm2d_backward")) - bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D_BACKWARD; else if (functionName.equals("conv2d")) bifop = Expression.BuiltinFunctionOp.CONV2D; else if (functionName.equals("bias_add")) http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index e9b643e..e3db435 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2271,8 +2271,6 @@ public class DMLTranslator case EIGEN: case LSTM: case LSTM_BACKWARD: - case BATCH_NORM2D: - case BATCH_NORM2D_BACKWARD: case SVD: // Number of outputs = size of targetList = #of identifiers in source.getOutputs http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index 46e6442..33fca66 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -93,7 +93,7 @@ public abstract class Expression implements ParseInfo EXISTS, CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIASADD, BIASMULT, MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD, AVG_POOL_BACKWARD, - LSTM, LSTM_BACKWARD, BATCH_NORM2D, BATCH_NORM2D_BACKWARD, + LSTM, LSTM_BACKWARD, EXP, FLOOR, IFELSE, http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java index f4122d9..3480504 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java @@ -59,11 +59,13 @@ public class GPUInstructionParser extends InstructionParser String2GPUInstructionType.put( "channel_sums", GPUINSTRUCTION_TYPE.Dnn); String2GPUInstructionType.put( "lstm", GPUINSTRUCTION_TYPE.Dnn); String2GPUInstructionType.put( "lstm_backward", GPUINSTRUCTION_TYPE.Dnn); - String2GPUInstructionType.put( "batch_norm2d", GPUINSTRUCTION_TYPE.Dnn); - String2GPUInstructionType.put( "batch_norm2d_backward", GPUINSTRUCTION_TYPE.Dnn); String2GPUInstructionType.put( "batch_norm2d_test", GPUINSTRUCTION_TYPE.Dnn); - String2GPUInstructionType.put( "batch_norm2d_train", GPUINSTRUCTION_TYPE.Dnn); - String2GPUInstructionType.put( "update_nesterov_x", GPUINSTRUCTION_TYPE.Dnn); + String2GPUInstructionType.put( "update_nesterov_x", GPUINSTRUCTION_TYPE.Dnn); + String2GPUInstructionType.put( "update_ema_var", GPUINSTRUCTION_TYPE.Dnn); + String2GPUInstructionType.put( "update_ema", GPUINSTRUCTION_TYPE.Dnn); + String2GPUInstructionType.put( "reshape_colmeans", GPUINSTRUCTION_TYPE.Dnn); + String2GPUInstructionType.put( "inv_var", GPUINSTRUCTION_TYPE.Dnn); + String2GPUInstructionType.put( "batch_norm2d_bwd_dx", GPUINSTRUCTION_TYPE.Dnn); // Matrix Multiply Operators String2GPUInstructionType.put( "ba+*", GPUINSTRUCTION_TYPE.AggregateBinary);
