http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index acf2e48..53d368b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -20,811 +20,288 @@
 package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
+import java.util.function.Function;
 
-import org.apache.sysml.conf.ConfigurationManager;
-import org.apache.sysml.hops.AggUnaryOp;
-import org.apache.sysml.hops.BinaryOp;
-import org.apache.sysml.hops.FunctionOp;
 import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.FunctionOp.FunctionType;
-import org.apache.sysml.hops.Hop.AggOp;
-import org.apache.sysml.hops.Hop.DataOpTypes;
-import org.apache.sysml.hops.Hop.Direction;
-import org.apache.sysml.hops.Hop.OpOp1;
-import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.Hop.OpOpDnn;
-import org.apache.sysml.hops.Hop.ReOrgOp;
-import org.apache.sysml.hops.DataOp;
-import org.apache.sysml.hops.LiteralOp;
-import org.apache.sysml.hops.DnnOp;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.hops.ReorgOp;
-import org.apache.sysml.hops.UnaryOp;
-import org.apache.sysml.parser.DMLProgram;
-import org.apache.sysml.parser.Expression.DataType;
-import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
+import static org.apache.sysml.hops.rewrite.HopDagPatternMatcher.*;
+import static org.apache.sysml.parser.Expression.DataType.MATRIX;
+import static org.apache.sysml.parser.Expression.DataType.SCALAR;
+
 
 /*
- * This class contains GPU-specific rewrites for following patterns:
+ * -------------------------------------------------------------------------
+ * Design documentation for hop rewrite rules that use HopDagPatternMatcher:
+ * -------------------------------------------------------------------------
+ * 
+ * Typical (but not all) hop rewrite rules have following structure:
+ * 1. Rules are grouped together in different Java classes and added in 
org.apache.sysml.hops.rewrite.ProgramRewriter.
+ * 
+ * 2. Each rule class inherits from HopRewriteRule and implements 
rewriteHopDAG method. Other class of rewrite rules are 
StatementBlockRewriteRule and are not covered by this approach.
+ * 
+ * 3. The structure of rewriteHopDAG is common across HopRewriteRule 
subclasses and usually have following pattern:
+ *  if(root of the given HOP DAG matches certain pattern) {
+ *    HopRewriteUtils.rewireAllParentChildReferences(root, newRoot)
+ *  }
+ *  else root
+ * 
+ * 4. To avoid redundancy, the above logic is implemented in the abstract 
class HopRewriteRuleWithPatternMatcher:
+ *  ArrayList<HopPatternRewriter> patternRewriters =  getPatternRewriter();
+ *    for(HopPatternRewriter patternRewriter : patternRewriters) {
+ *      hi = patternRewriter.rewrite(hi);
+ *  }
+ * 
+ * 5. The developer has to inherit from HopRewriteRuleWithPatternMatcher that 
implements the above logic
+ * and write code for getPatternRewriter() that returns 
ArrayList<HopPatternRewriter>  
  * 
- * 1. batchNormTest: applied when mode="test" in batch normalization nn layer.
- * norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
- * hi = bias_add(bias_multiply(norm, gamma), beta)
+ * 6. Since the HOP pattern donot change during execution, it is convenient to 
implement them into a static variable: 
+ * ArrayList<HopPatternRewriter> _rewriters
  * 
- * 2. channelSum:
- * output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))
+ * 7. The replacement part in each entry of patternMatcher invokes the helper 
methods in HopRewriteUtils to create a newRoot. For example: 
HopRewriteUtils.createDnnOp
  * 
- * 3. batchNormTrain: applied when mode="train" in batch normalization nn 
layer.
- * This rewrite is only enabled if none of the outputs are persistent writes 
as it assumes that 
- * FunctionOp will introduce a transient writes. This rewrite replaces the 
existing outputs of the matched pattern with transient reads.
+ * 8. The below DSL is more readable if implemented with Scala's operator 
overloading, but it adds an dependency on scala library 
+ * (in specific, scala uses scala.Function1 for implementing operator 
overloading).
+ * Hence, to minimize the dependency, the DSL is implemented using static 
methods in HopDagPatternMatcher class.
+ * We can revisit this if we plan to add scala as hard dependency in SystemML. 
+ * 
+ * 9. The matcher part in each entry of patternMatcher uses the DSL 
implemented in HopDagPatternMatcher to improve readability.
+ * - The DSL mentioned above follows DML syntax that makes it convenient for 
an external contributer to understand and modify the HOP rewrites.
+ * - It is important to note that the developer has to add the same scoping 
rules as SystemML.
+ * - To create a newRoot HOP, it is important to have a mechanism to extract 
leaves of the matched pattern. This is implemented
+ * by using leaf() method.
+ * - Often, it is important to create a new HOP only if it it can fit into 
memory. For GPU, one can use the fitsOnGPU(multiplier) helper method.
  * 
  */
-public class RewriteGPUSpecificOps extends HopRewriteRule {
-
-       private static int _seq = 1;
-       
-       @Override
-       public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) {
-               if( roots == null )
-                       return roots;
-
-               //one pass rewrite-descend (rewrite created pattern)
-               for( int i = 0; i < roots.size(); i++ )
-                       rule_GPUKernels(roots, roots.get(i), false );
-               Hop.resetVisitStatus(roots, true);
-
-               //one pass descend-rewrite (for rollup) 
-               for( int i = 0; i < roots.size(); i++ )
-                       rule_GPUKernels(roots, roots.get(i), true );
-               Hop.resetVisitStatus(roots, true);
+public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher {
+       // 
-------------------------------------------------------------------------------------------
+       
+       private static HopDagPatternMatcher 
util_channel_sums(HopDagPatternMatcher X, HopDagPatternMatcher C, 
HopDagPatternMatcher HW) {
+               // rowSums(matrix(colSums(X), rows=C, cols=HW))
+               return rowSums(matrix(  colSums(X), C, HW));
+       }
+       
+       // Pattern 1:
+       private static final HopDagPatternMatcher _batchNormdX;
+       static {
+               HopDagPatternMatcher C = leaf("C",  SCALAR);
+               HopDagPatternMatcher HW = leaf("HW",  SCALAR);
+               HopDagPatternMatcher CHW = leaf("CHW",  SCALAR);
+               HopDagPatternMatcher cache_inv_var = leaf("cache_inv_var", 
MATRIX);
+               HopDagPatternMatcher dout = leaf("dout", MATRIX);
+               HopDagPatternMatcher gamma = leaf("gamma", MATRIX);
+               HopDagPatternMatcher X = leaf("X", MATRIX);
+               HopDagPatternMatcher mean = leaf("mean", MATRIX);
                
-               return roots;
-       }
-
-       @Override
-       public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
-               if( root == null )
-                       return root;
-               
-               //one pass rewrite-descend (rewrite created pattern)
-               rule_GPUKernels(null, root, false );
-               
-               root.resetVisitStatus();
-               
-               //one pass descend-rewrite (for rollup) 
-               rule_GPUKernels(null, root, true );
+               HopDagPatternMatcher centered = bias_add(X, unaryMinus(mean));
                
-               return root;
-       }
-       
-       /**
-        * Fuse the kernel
-        * 
-        * @param roots root operators
-        * @param hop high-level operator
-        * @param descendFirst true if recursively process children first
-        */
-       private void rule_GPUKernels(ArrayList<Hop> roots, Hop hop, boolean 
descendFirst) 
-       {
-               if(hop.isVisited())
-                       return;
-               
-               //recursively process children
-               for( int i=0; i<hop.getInput().size(); i++) {
-                       Hop hi = hop.getInput().get(i);
-                       
-                       //process childs recursively first (to allow roll-up)
-                       if( descendFirst )
-                               rule_GPUKernels(roots, hi, descendFirst); //see 
below
-                       
-                       if(roots != null) {
-                               //hi = batchNormTrain(roots, hop, hi, i);
-                       }
-                       hi = batchNormTest(hop, hi, i); 
-                       hi = channelSums(hop, hi, i); 
-                       hi = updateNesterovX(hop, hi, i);
-       
-                       if( !descendFirst )
-                               rule_GPUKernels(roots, hi, descendFirst);
-               }
-
-               hop.setVisited();
-       }
-       
-       private static boolean isBiasAdd(Hop h) {
-               return HopRewriteUtils.isDnn(h, OpOpDnn.BIASADD);
-       }
-       
-       private static boolean isBiasMultiply(Hop h) {
-               return HopRewriteUtils.isDnn(h, OpOpDnn.BIASMULT);
-       }
-       
-       private static boolean fitsOnGPU(Hop h, double multiplier) {
-               double memEst = multiplier*h.getMemEstimate();
-               return ConfigurationManager.isGPU() && h.dimsKnown() && 
OptimizerUtils.isMemoryBasedOptLevel() &&
-                               memEst < OptimizerUtils.getLocalMemBudget() && 
memEst < GPUContextPool.initialGPUMemBudget();
-       }
-       
-       private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean 
isFirstSameSizeAsOutput) {
-               return fitsOnGPU(inputHops, isFirstSameSizeAsOutput, 0);
-       }
-       
-       private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean 
isFirstSameSizeAsOutput, long additionalBytes) {
-               double memEst = additionalBytes;
-               boolean isFirst = true;
-               for(Hop h : inputHops) {
-                       double est = h.getMemEstimate();
-                       if(est == OptimizerUtils.INVALID_SIZE) {
-                               return false;
-                       }
-                       else if(isFirst && isFirstSameSizeAsOutput) {
-                               isFirst = false;
-                               memEst += 2*est;
-                       }
-                       else {
-                               memEst += est;
-                       }
-               }
-               return ConfigurationManager.isGPU() && 
OptimizerUtils.isMemoryBasedOptLevel() &&
-                               memEst < OptimizerUtils.getLocalMemBudget() && 
memEst < GPUContextPool.initialGPUMemBudget();
-       }
-       
-       private static boolean hasFirstInput(Hop h) {
-               return !(h == null || h.getInput() == null || 
h.getInput().size() < 1);
-       }
-       
-       private static Hop getFirstInput(Hop h) {
-               if(h == null || h.getInput() == null || h.getInput().size() < 
1) {
-                       throw new RuntimeException("No input available for " + 
h);
-               }
-               return h.getInput().get(0);
-       }
-       
-       private static boolean hasSecondInput(Hop h) {
-               return !(h == null || h.getInput() == null || 
h.getInput().size() < 2);
-       }
-       
-       private static Hop getSecondInput(Hop h) {
-               if(h == null || h.getInput() == null || h.getInput().size() < 
2) {
-                       throw new RuntimeException("Expected atleast two inputs 
for " + h);
+               // dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
+               HopDagPatternMatcher dnorm = bias_multiply(dout, gamma);
+               // dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, 
-cache_inv_var), C, Hin, Win)
+               HopDagPatternMatcher dmean_norm_branch = 
util_channel_sums(bias_multiply(dnorm, unaryMinus(cache_inv_var)), C, HW) ;
+               // dvar = util::channel_sums((-1/2) * bias_multiply(centered, 
cache_inv_var^3) * dnorm,
+               //      C, Hin, Win)  # shape (C, 1)
+               HopDagPatternMatcher dvar = util_channel_sums(mult(mult(-0.5, 
bias_multiply(centered, pow(cache_inv_var, 3))), dnorm),  C, HW);
+               // dmean_var_branch = util::channel_sums((-2*oneByN*oneByHW) * 
centered, C, Hin, Win) * dvar
+               HopDagPatternMatcher dmean_var_branch =
+                       mult(util_channel_sums(mult(leaf("const3", SCALAR), 
centered), C, HW), dvar);
+               // dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
+               HopDagPatternMatcher dX_norm_branch = bias_multiply(dnorm, 
cache_inv_var);
+               // dX_mean_branch = (oneByN*oneByHW) * bias_add(matrix(0, 
rows=1, cols=C*Hin*Win), dmean)
+               HopDagPatternMatcher dX_mean_branch = mult(leaf("const1", 
SCALAR), bias_add(matrix(0, 1, CHW), 
+                               plus(dmean_norm_branch, dmean_var_branch) ));
+               // dX_var_branch = (2*oneByN*oneByHW) * bias_multiply(centered, 
dvar)
+               HopDagPatternMatcher dX_var_branch = mult(leaf("const2", 
SCALAR), bias_multiply(centered, dvar));
+               _batchNormdX = plus(plus(dX_norm_branch, dX_mean_branch), 
dX_var_branch).fitsOnGPU(2);
+       }
+       private static final Function<Hop, Hop> _batchNormdXReplacer = hi -> {
+               // double CHW = _batchNormdX.getLiteralValue("CHW");
+               double HW = _batchNormdX.getLiteralValue("HW");
+               double C = _batchNormdX.getLiteralValue("C");
+               double const1 = _batchNormdX.getLiteralValue("const1"); // 
(oneByN*oneByHW)
+               double const2 = _batchNormdX.getLiteralValue("const2"); // 
(2*oneByN*oneByHW)
+               double const3 = _batchNormdX.getLiteralValue("const3"); // 
(-2*oneByN*oneByHW)
+               if(2*const1 == const2 && const3 == -const2 && 
+                       hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), 
_batchNormdX.getMatchedHop("mean")) &&
+                       hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), 
_batchNormdX.getMatchedHop("cache_inv_var")) &&
+                       _batchNormdX.getMatchedHop("X").getDim2() == C*HW &&
+                       checkDimensions(_batchNormdX.getMatchedHop("gamma"), 
(long)C, 1)) {
+                       LOG.debug("Applied batchNormdX rewrite.");
+                       Hop newHop = HopRewriteUtils.createDnnOp(_batchNormdX, 
OpOpDnn.BATCH_NORM2D_BACKWARD_DX, 
+                                       "X", "dout", "gamma", "mean", 
"cache_inv_var");
+                       return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
                }
-               return h.getInput().get(1);
-       }
-       
-       private static Hop getThirdInput(Hop h) {
-               if(h == null || h.getInput() == null || h.getInput().size() < 
3) {
-                       throw new RuntimeException("Expected atleast three 
inputs for " + h);
-               }
-               return h.getInput().get(2);
-       }
-       
-       private static boolean isUnaryMinus(Hop h) {
-               return HopRewriteUtils.isBinary(h, OpOp2.MINUS)
-                       && 
HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0);
-       }
-       
-       private static boolean isOneDivideBySqrt(Hop h) {
-               return HopRewriteUtils.isBinary(h, OpOp2.DIV)
-                       && HopRewriteUtils.isUnary(h.getInput().get(1), 
OpOp1.SQRT)
-                       && 
HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 1);
-       }
-       
-       private static Hop channelSums(Hop parent, Hop hi, int pos) {
-               if(hi instanceof AggUnaryOp) {
-                       AggUnaryOp hop = (AggUnaryOp) hi;
-                       // output = rowSums(matrix(colSums(x), 
rows=numChannels, cols=imgSize*imgSize))
-                       if( hop.getOp() == AggOp.SUM && hop.getDirection() == 
Direction.Row
-                               && 
HopRewriteUtils.isReorg(hop.getInput().get(0), ReOrgOp.RESHAPE) ) {
-                               Hop colSumsInput = 
hop.getInput().get(0).getInput().get(0);
-                               if(colSumsInput instanceof AggUnaryOp && 
((AggUnaryOp)colSumsInput).getOp() == AggOp.SUM && 
((AggUnaryOp)colSumsInput).getDirection() == Direction.Col) {
-                                       ArrayList<Hop> inHops = new 
ArrayList<Hop>();
-                                       
inHops.add(colSumsInput.getInput().get(0));
-                                       long numChannels = 
Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(1));
-                                       long HW = 
Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(2));
-                                       if(numChannels > 0 && HW > 0 && 
fitsOnGPU(inHops, false, numChannels*8)) {
-                                               inHops.add(new 
LiteralOp(numChannels));
-                                               inHops.add(new LiteralOp(HW));
-                                               LOG.debug("Applied channelSums 
rewrite.");
-                                               Hop newHop = new 
DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
-                                                               
OpOpDnn.CHANNEL_SUMS, inHops);
-                                               return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
-                                       }
-                               }
-                       }
+               else if(DEBUG_REWRITES) {
+                       System.out.println("Couldnot apply batchNormdX 
rewrite."); 
+                       System.out.println((2*const1) + " == " + const2 + " && 
" + const3 + "== -" + const2 
+                       + " && " + 
hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), 
_batchNormdX.getMatchedHop("mean")) +  " && " + 
+                       hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), 
_batchNormdX.getMatchedHop("cache_inv_var")) + " && " +
+                       _batchNormdX.getMatchedHop("X").getDim2() + " == " + C 
+ "*" + HW  + " && " +
+                       checkDimensions(_batchNormdX.getMatchedHop("gamma"), 
(long)C, 1));
                }
                return hi;
-       }
-       
-       private static boolean isRowMeans(Hop h) {
-               return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row; 
-       }
-       
-       private static boolean isRowVars(Hop h) {
-               return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row; 
-       }
-       
-       private static boolean isRowVars(Hop h, Hop childHop) {
-               return isRowVars(h) && getFirstInput(h) == childHop; 
-       }
-       
-       private static boolean isColMeans(Hop h) {
-               return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col; 
-       }
-       
-       private static boolean isColVars(Hop h) {
-               return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col; 
-       }
-       
-       private static boolean isReshape(Hop h) {
-               return h instanceof ReorgOp && ((ReorgOp)h).getOp() == 
ReOrgOp.RESHAPE;
-       }
-       
-       private static boolean isReshape(Hop h, long expectedRows, long 
expectedCols) {
-               return h instanceof ReorgOp && ((ReorgOp)h).getOp() == 
ReOrgOp.RESHAPE &&
-                               Hop.computeSizeInformation(getSecondInput(h)) 
== expectedRows && 
-                               Hop.computeSizeInformation(getThirdInput(h)) == 
expectedCols;
-       }
-       
-       private static boolean isBinaryAdd(Hop h) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.PLUS;
-       }
-       
-       private static boolean isBinaryMSAdd(Hop h, double expectedValue) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.PLUS 
-                               && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
-                               && 
OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) 
== expectedValue;
-       }
-       
-       private static boolean isBinaryMMAdd(Hop h) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.PLUS 
-                               && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
-       }
-       
-       private static boolean isBinaryMMMinus(Hop h) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MINUS 
-                               && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
-       }
-       
-       private static boolean isBinaryMSMult(Hop h, double expectedValue) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MULT 
-                               && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
-                               && 
OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) 
== expectedValue;
-       }
-       
-       private static boolean isBinarySSMinus(Hop h) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MINUS 
-                               && getFirstInput(h).getDataType() == 
DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
-       }
-       
-       private static boolean isBinarySSDiv(Hop h) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.DIV 
-                               && getFirstInput(h).getDataType() == 
DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
-       }
-       
-       private static boolean isBinarySMDiv(Hop h, double expectedValue) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.DIV 
-                               && getFirstInput(h).getDataType() == 
DataType.SCALAR && getSecondInput(h).getDataType() == DataType.MATRIX 
-                               && 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>()) 
== expectedValue;
-       }
-       
-       private static boolean isAnyBinaryAdd(ArrayList<Hop> hops) {
-               if(hops != null) {
-                       for(Hop h : hops) {
-                               if(h instanceof BinaryOp && 
((BinaryOp)h).getOp() == OpOp2.PLUS)
-                                       return true;
-                       }
-               }
-               return false;
-       }
-       
-       private static boolean isBinaryMSMult(Hop h) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MULT 
-                               && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR;
-       }
-       
-       private static boolean isBinarySMMult(Hop h) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MULT 
-                               && getSecondInput(h).getDataType() == 
DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR;
-       }
-       
-       private static boolean isBinarySMMult(Hop h, double expectedVal) {
-               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MULT 
-                               && getSecondInput(h).getDataType() == 
DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR
-                               && getValue(getFirstInput(h)) == expectedVal;
-       }
-       
-       private static double getValue(Hop h) {
-               return OptimizerUtils.rEvalSimpleDoubleExpression(h, new 
HashMap<>());
-       }
-       
-       /**
-        * Checks if the "mean" hop is a moving average of mean in batch 
normalization layer.
-        *  
-        * @param mean hop to check against
-        * @param X input data
-        * @return true if the "mean" hop is a moving average of mean in batch 
normalization layer.
-        */
-       private static boolean isBatchNormTrainMean(Hop mean, Hop X) {
-               // subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
-               // mean = rowMeans(subgrp_means)
-               return isRowMeans(mean) && isReshape(getFirstInput(mean)) && 
isColMeans(getFirstInput(getFirstInput(mean)))
-                               && 
getFirstInput(getFirstInput(getFirstInput(mean))) == X;
-       }
-       
-       /**
-        * Checks for nrow(X) pattern
-        * 
-        * @param expr hop to be matched
-        * @param X input X
-        * @return true if expr is nrow(X) else false
-        */
-       private static boolean isNrowOfX(Hop expr, Hop X) {
-               return expr instanceof UnaryOp && ((UnaryOp)expr).getOp() == 
OpOp1.NROW && getFirstInput(expr) == X;
-       }
-       
-       /**
-        * Checks for the colVars(X) * ((N-1)/N) pattern
-        * 
-        * @param expr hop to be matched
-        * @param X input X
-        * @param ignoreCorrectionTerm whether to ignore the correction term 
((N-1)/N).
-        * @return true if expr is colVars(X) * ((N-1)/N) else false
-        */
-       private static boolean isCorrectedColVars(Hop expr, Hop X, boolean 
ignoreCorrectionTerm) {
-               // colVars(X) * ((N-1)/N)
-               if(isColVars(expr) && getFirstInput(expr) == X) {
-                       // Support no correction as well in this rewrite
-                       return true;
-               }
-               else if(X.rowsKnown()) {
-                       return isBinaryMSMult(expr, 
((double)X.getDim1()-1)/X.getDim1()) && 
-                                       isColVars(getFirstInput(expr)) && 
getFirstInput(getFirstInput(expr)) == X;
+       };
+       
+       
+       
+       // Pattern 2:
+       // subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
+       // var = rowMeans(subgrp_vars) + 
rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+       private static final HopDagPatternMatcher _batchNormUpdatedVar; 
+       static {
+               HopDagPatternMatcher subgrp_vars = 
+                       matrix( 
+                                       mult(colVars(leaf("X", 
MATRIX).fitsOnGPU(2)), leaf("varConst1", SCALAR)), // colVars(X) * ((N-1)/N)
+                                       leaf("C", SCALAR),      // rows=C
+                                       leaf("HW", SCALAR)); // cols=Hin*Win
+               _batchNormUpdatedVar = 
+                       mm_plus( 
+                                       rowMeans(subgrp_vars), 
+                                       mult(rowVars(leaf("subgrp_means", 
MATRIX)),  leaf("varConst2", SCALAR))); // rowVars(subgrp_means)*varConst2
+       }
+       private static final Function<Hop, Hop> _batchNormUpdatedVarReplacer = 
hi -> {
+               double HW = _batchNormUpdatedVar.getLiteralValue("HW");
+               if(_batchNormUpdatedVar.getLiteralValue("varConst2") == 
((HW-1)/HW)) {
+                       LOG.debug("Applied batchNormUpdatedVar rewrite.");
+                       Hop newHop = 
HopRewriteUtils.createDnnOp(_batchNormUpdatedVar, OpOpDnn.UPDATE_EMA_VAR, 
+                                       // varConst1 => ((N-1)/N)
+                                       "subgrp_means", "X", "C", "HW", 
"varConst1");
+                       return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
                }
-               else if(isBinaryMSMult(expr) && 
-                               isColVars(getFirstInput(expr)) && 
getFirstInput(getFirstInput(expr)) == X) {
-                       if(ignoreCorrectionTerm) {
-                               return true;
-                       }
-                       Hop tmp = getSecondInput(expr);
-                       // ((N-1)/N)
-                       boolean isNMinus1Pattern = isBinarySSDiv(tmp) && 
isBinarySSMinus(getFirstInput(tmp)) &&
-                                       getFirstInput(getFirstInput(tmp)) == 
getSecondInput(tmp) && 
-                                       
OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(getFirstInput(tmp)), 
new HashMap<>()) == 1;
-                       boolean ret = isNMinus1Pattern && 
isNrowOfX(getSecondInput(tmp), X);
-                       if(LOG.isDebugEnabled()) {
-                               LOG.debug("Is the corrected column variance 
pattern for batch_norm_train rewrite when number of rows of X unknown matched:" 
+ ret);
-                       }
-                       return ret;
-               }
-               return false;
-       }
-       
-       /**
-        * Checks if the "var" hop is a moving average of variance in batch 
normalization layer.
-        *  
-        * @param mean previously matched mean hop
-        * @param var the hop to check against
-        * @param X input data hop
-        * @param subgrpMeans mean for subgroup mean
-        * @param ignoreCorrectionTerm whether to incore the correct term  (see 
isCorrectedColVars method in this class)
-        * @return true if the "var" hop is a moving average of variance in 
batch normalization layer.
-        */
-       private static boolean isBatchNormTrainVar(Hop mean, Hop var, Hop  X, 
Hop subgrpMeans, boolean ignoreCorrectionTerm) {
-               long numChannels = 
Hop.computeSizeInformation(getSecondInput(getFirstInput(mean)));
-               long HW = 
Hop.computeSizeInformation(getThirdInput(getFirstInput(mean)));
-               // subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, 
cols=Hin*Win)
-               // var = rowMeans(subgrp_vars) + 
rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
-               return numChannels > 0 && HW > 0 && isBinaryMMAdd(var) && 
isRowMeans(getFirstInput(var)) &&  
-                               // matrix(colVars(X) * ((N-1)/N), rows=C, 
cols=Hin*Win)
-                               isReshape(getFirstInput(getFirstInput(var)), 
numChannels, HW) &&
-                               
isCorrectedColVars(getFirstInput(getFirstInput(getFirstInput(var))), X, 
ignoreCorrectionTerm) &&
-                               // 
rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
-                               isBinaryMSMult(getSecondInput(var), 
((((double)HW)-1)/HW)) && 
-                               isRowVars(getFirstInput(getSecondInput(var)), 
subgrpMeans);
-       }
-       
-       /**
-        * Checks and returns the matched hops for expression ema_mean_upd = 
mu*ema_mean + (1-mu)*mean  
-        * 
-        * @param rhsTimesOps hop representing BinaryOp of expression 
(1-mu)*mean 
-        * @param mu value of mu
-        * @return an array [ema_mean_upd, ema_mean] if expression matched, 
else null
-        */
-       private static Hop [] getUpdatedMovingAverageExpressions(Hop 
rhsTimesOp, double mu) {
-               if(rhsTimesOp == null || rhsTimesOp.getParent() == null || 
rhsTimesOp.getParent().size() != 1 || 
-                               !isBinarySMMult(rhsTimesOp) || 
!isBinaryAdd(rhsTimesOp.getParent().get(0)))
-                       return null;
-               
-               // Check (1-mu)*mean
-               double expectedOneMinusMu = 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new 
HashMap<>());
-               Hop plusOp = rhsTimesOp.getParent().get(0); 
-               Hop lhsTimesOp = null;
-               if(plusOp.getInput().get(0) == rhsTimesOp) {
-                       lhsTimesOp = plusOp.getInput().get(1); 
-               }
-               else {
-                       lhsTimesOp = plusOp.getInput().get(0);
-               }
-               
-               if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null 
&& plusOp.getParent().size() == 1 &&  
-                       isBinarySMMult(lhsTimesOp) && 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new 
HashMap<>()) == mu) {
-                       return new Hop[] {
-                               plusOp.getParent().get(0),
-                               getSecondInput(lhsTimesOp), 
-                               getSecondInput(rhsTimesOp)
-                       };
-               }
-               return null;
-       }
-       
-       /**
-        * Checks (if exactly one of rhsTimesOps) and returns the matched hops 
for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean  
-        * 
-        * @param rhsTimesOps array list of hop representing BinaryOp of 
expression (1-mu)*mean 
-        * @param mu value of mu
-        * @return an array [ema_mean_upd, ema_mean] if any of the expression 
matched, else null
-        */
-       private static Hop [] getUpdatedMovingAverageExpressions(ArrayList<Hop> 
rhsTimesOps, double mu) {
-               if(rhsTimesOps == null || rhsTimesOps.size() == 0)
-                       return null;
-               
-               Hop [] ret = null;
-               for(Hop h : rhsTimesOps) {
-                       boolean matched = isUpdatedMovingAverageExpression(h, 
mu);
-                       if(matched && ret != null) {
-                               return null; // Multiple matches, cannot decide 
which one to fuse
-                       }
-                       else if(matched) {
-                               ret = getUpdatedMovingAverageExpressions(h, mu);
-                       }
-               }
-               
-               return ret;
-       }
-       
-       /**
-        * Checks and returns the mu in the expression ema_mean_upd = 
mu*ema_mean + (1-mu)*mean
-        * 
-        * @param rhsTimesOps hop representing BinaryOp of expression 
(1-mu)*mean
-        * @return value of mu if the expression matched else null 
-        */
-       private static Double 
getMuFromUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps) {
-               if(rhsTimesOps == null || rhsTimesOps.size() == 0)
-                       return null;
-               
-               Double ret = null; 
-               for(Hop h : rhsTimesOps) {
-                       boolean matched = isUpdatedMovingAverageExpression(h);
-                       if(matched && ret != null) {
-                               return null; // Multiple matches, cannot decide 
which one to fuse
-                       }
-                       else if(matched) {
-                               ret = 
-(OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new 
HashMap<>())-1);
-                       }
-               }
-               return ret;
-       }
-       
-       /**
-        * Checks for the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
-        * 
-        * @param rhsTimesOps hop representing BinaryOp of expression 
(1-mu)*mean
-        * @return true if expression matched
-        */
-       private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp) 
{
-               if(rhsTimesOp == null || rhsTimesOp.getParent() == null || 
rhsTimesOp.getParent().size() != 1 || 
-                               !isBinarySMMult(rhsTimesOp) || 
!isBinaryAdd(rhsTimesOp.getParent().get(0)))
-                       return false;
-               
-               // Check (1-mu)*mean
-               Hop plusOp = rhsTimesOp.getParent().get(0); 
-               Hop lhsTimesOp = null;
-               if(plusOp.getInput().get(0) == rhsTimesOp) {
-                       lhsTimesOp = plusOp.getInput().get(1); 
-               }
-               else {
-                       lhsTimesOp = plusOp.getInput().get(0);
-               }
-               
-               if(plusOp.getParent() != null && plusOp.getParent().size() == 1 
&& isBinarySMMult(lhsTimesOp)) {
-                       return true;
-               }
-               return false;
-       }
+               return hi;
+       };
        
-       // ema_mean_upd = mu*ema_mean + (1-mu)*mean
-       // Returns true if expression matched, else false
-       private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp, 
double mu) {
-               if(rhsTimesOp == null || rhsTimesOp.getParent() == null || 
rhsTimesOp.getParent().size() != 1 || 
-                               !isBinarySMMult(rhsTimesOp) || 
!isBinaryAdd(rhsTimesOp.getParent().get(0)))
-                       return false;
-               
-               // Check (1-mu)*mean
-               double expectedOneMinusMu = 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new 
HashMap<>());
-               Hop plusOp = rhsTimesOp.getParent().get(0); 
-               Hop lhsTimesOp = null;
-               if(plusOp.getInput().get(0) == rhsTimesOp) {
-                       lhsTimesOp = plusOp.getInput().get(1); 
-               }
-               else {
-                       lhsTimesOp = plusOp.getInput().get(0);
-               }
                
-               if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null 
&& plusOp.getParent().size() == 1 &&  
-                       isBinarySMMult(lhsTimesOp) && 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new 
HashMap<>()) == mu) {
-                       return true;
-               }
-               return false;
-       }
-       
-       /**
-        * Checks for the expression 1/sqrt(denom)
-        * 
-        * @param denom denominator of the expression to be matched
-        * @return true if the expression 1/sqrt(denom) matched else false
-        */
-       private static boolean isOneBySqrt(Hop denom) {
-               return denom.getParent() != null && denom.getParent().get(0) 
instanceof UnaryOp &&
-                               ((UnaryOp)denom.getParent().get(0)).getOp() == 
OpOp1.SQRT &&
-                               denom.getParent().get(0).getParent() != null && 
denom.getParent().get(0).getParent().size() == 1 &&
-                               
isBinarySMDiv(denom.getParent().get(0).getParent().get(0), 1);
-       }
-       
-       /**
-        * Checks for the batch norm (mode="train") pattern using the helper 
isBatchNormTrainMean and isBatchNormTrainVar
-        * and returns a new FunctionOp if matched
-        * 
-        * @param roots root hops of the given statement block
-        * @param parent parent of the input
-        * @param hi input to be matched
-        * @param pos position
-        * @return a new FunctionOp or hi
-        */
-       @SuppressWarnings("unused")
-       private static Hop batchNormTrain(ArrayList<Hop> roots, Hop parent, Hop 
hi, int pos) 
-       {               
+       // Pattern 3:
+       private static final HopDagPatternMatcher _batchNormTest;
+       static {
                // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
+               HopDagPatternMatcher norm = 
+                       bias_multiply(
+                                       bias_add(leaf("X", MATRIX), 
unaryMinus(leaf("mean", MATRIX))), // bias_add(X, -mean)
+                                       div(1, sqrt(plus(leaf("var", MATRIX), 
leaf("eps", SCALAR))))); // 1/sqrt(var+eps)
                // hi = bias_add(bias_multiply(norm, gamma), beta)
-               // 2x for input and output and 1x for overhead
-               // fitsOnGPU(hi, 3)
-               if( hasFirstInput(hi) && isBiasAdd(hi) && 
isBiasMultiply(getFirstInput(hi)) ) { 
-                       Hop norm = getFirstInput(getFirstInput(hi));
-                       if(hasSecondInput(norm) && isBiasMultiply(norm) && 
isBiasAdd(getFirstInput(norm)) 
-                                       && hasSecondInput(getFirstInput(norm)) 
&& isUnaryMinus(getSecondInput(getFirstInput(norm)))
-                                       && 
isOneDivideBySqrt(getSecondInput(norm))) {
-                               double eps = 0;
-                               Hop var = 
getFirstInput(getSecondInput(getSecondInput(norm)));
-                               if(isBinaryAdd(var) && (getFirstInput(var) 
instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) {
-                                       // eps + ema_var
-                                       if(getFirstInput(var) instanceof 
LiteralOp) {
-                                               eps = 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>());
-                                               var = getSecondInput(var);
-                                       }
-                                       else {
-                                               eps = 
OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new 
HashMap<>());
-                                               var = getFirstInput(var);
-                                       }
-                               }
-                               // Generate batch norm test op
-                               Hop X = getFirstInput(getFirstInput(norm));
-                               Hop mean = 
getSecondInput(getSecondInput(getFirstInput(norm)));
-                               
-                               if(hasFirstInput(mean) && 
isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, 
getFirstInput(mean), false) &&
-                                       mean.getParent() != null && 
mean.getParent().size() >= 2 && 
-                                       var.getParent() != null && 
var.getParent().size() == 2) {
-                                       Hop gamma = 
getSecondInput(getFirstInput(hi));
-                                       Hop beta = getSecondInput(hi);
-                                       
-                                       // Always get mu from variance as it 
will have exactly one match of fusion pattern
-                                       Double potentialMu = 
getMuFromUpdatedMovingAverageExpressions(var.getParent());
-                                       if(potentialMu == null)
-                                               return hi;
-                                       double mu = potentialMu;
-                                       
-                                       Hop [] means = 
getUpdatedMovingAverageExpressions(mean.getParent(), mu);
-                                       Hop [] vars = 
getUpdatedMovingAverageExpressions(var.getParent(), mu);
-                                       if(means == null || vars == null)
-                                               return hi;
-                                       
-                                       Hop varPlusEps = null;
-                                       boolean isFirstBinaryAddOp = 
isAnyBinaryAdd(var.getParent().get(0).getParent());
-                    boolean isSecondBinaryAddOp = 
isAnyBinaryAdd(var.getParent().get(1).getParent());
-                    if(isFirstBinaryAddOp && !isSecondBinaryAddOp) {
-                            varPlusEps = var.getParent().get(1);
-                    }
-                    else if(!isFirstBinaryAddOp && isSecondBinaryAddOp) {
-                            varPlusEps = var.getParent().get(0);
-                    }
-                                       if(varPlusEps != null && 
isBinaryMSAdd(varPlusEps, eps) && isOneBySqrt(varPlusEps)) {
-                                               
-                                               Hop cache_var = 
varPlusEps.getParent().get(0).getParent().get(0);
-                                               Hop ema_mean_upd = means[0];
-                                               Hop ema_var_upd = vars[0];
-                                               Hop ema_mean = means[1];
-                                               Hop ema_var = vars[1];
-                                               Hop cache_mean = means[2];
-                                               
-                                               
-                                               ArrayList<Hop> inHops = new 
ArrayList<Hop>();
-                                               inHops.add(X);
-                                               inHops.add(gamma);
-                                               inHops.add(beta);
-                                               inHops.add(ema_mean);
-                                               inHops.add(ema_var);
-                                               inHops.add(new LiteralOp(eps));
-                                               inHops.add(new LiteralOp(mu));
-                                               Hop [] oldHops = {hi, 
ema_mean_upd, ema_var_upd, cache_mean, cache_var};
-                                               
-                                               // Since FunctionOp adds 
transientwrite explicitly, persistent writes are not supported
-                                               
if(!isAnyPersistentWrite(oldHops)) {
-                                                       LOG.debug("Applied 
batchNormTrain rewrite.");
-                                                       ArrayList<Hop> outputs 
= getMultiOutputHops(roots, oldHops);
-                                                       FunctionOp ret = new 
FunctionOp(FunctionType.MULTIRETURN_BUILTIN, DMLProgram.INTERNAL_NAMESPACE, 
"batch_norm2d_train", 
-                                                               null, inHops, 
outputs.stream().map(h -> h.getName()).toArray(String[]::new), outputs);
-                                                       
Collections.reverse(roots);
-                                                       roots.add(ret);
-                                                       
Collections.reverse(roots);
-                                                       return ret;
-                                               }
-                                       }
-                                       
-                               }
+               _batchNormTest = 
+                       bias_add(
+                                       bias_multiply(norm, leaf("gamma", 
MATRIX)), 
+                                       leaf("beta", MATRIX))
+                       .fitsOnGPU(3);
+       }
+       private static final Function<Hop, Hop> _batchNormTestReplacer = hi -> {
+               LOG.debug("Applied batchNormTest rewrite.");
+               Hop newHop = HopRewriteUtils.createDnnOp(_batchNormTest, 
OpOpDnn.BATCH_NORM2D_TEST, "X", "gamma", "beta", "mean", "var", "eps");
+               return HopRewriteUtils.rewireAllParentChildReferences(hi, 
newHop);
+       };
+       
+       // Pattern 4:
+       // rowSums(matrix(colSums(X), rows=C, cols=HW))
+       private static final HopDagPatternMatcher _channelSums = 
util_channel_sums(leaf("X", MATRIX).fitsOnGPU(2), leaf("C", SCALAR), leaf("HW", 
SCALAR));;
+       private static final Function<Hop, Hop> _channelSumsReplacer = hi -> {
+               LOG.debug("Applied channelSums rewrite.");
+               Hop newHop = HopRewriteUtils.createDnnOp(_channelSums, 
OpOpDnn.CHANNEL_SUMS, "X", "C", "HW");
+               return HopRewriteUtils.rewireAllParentChildReferences(hi, 
newHop);
+       };
+       
+       // Pattern 5:
+       // (X - mu*v_prev) + (1+mu)*v
+       private static final HopDagPatternMatcher _updateNesterovX = 
+               mm_plus(
+                               minus(  // X - mu*v_prev
+                                               leaf("X", MATRIX), 
+                                               mult(   // mu*v_prev
+                                                               leaf("mu", 
SCALAR), 
+                                                               leaf("v_prev", 
MATRIX))),
+                               mult(   // (1+mu)*v
+                                               leaf("onePlusMu", SCALAR), 
+                                               leaf("v", MATRIX)))             
                                
+               .fitsOnGPU(3);
+       private static final Function<Hop, Hop> _updateNesterovXReplacer = hi 
-> {
+               if((1+_updateNesterovX.getLiteralValue("mu")) == 
_updateNesterovX.getLiteralValue("onePlusMu")) {
+                       Hop X = _updateNesterovX.getMatchedHop("X");
+                       Hop v = _updateNesterovX.getMatchedHop("v");
+                       Hop v_prev = _updateNesterovX.getMatchedHop("v_prev");
+                       if(hasSameDimensions(X, v) && hasSameDimensions(X, 
v_prev)) {
+                               LOG.debug("Applied updateNesterovX rewrite.");
+                               Hop newHop = 
HopRewriteUtils.createDnnOp(_updateNesterovX, OpOpDnn.UPDATE_NESTEROV_X, "X", 
"v", "v_prev", "mu");
+                               return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
                        }
                }
-               
                return hi;
-       }
-       
-       // ------------------------------------------------------------
-       /**
-        * Checks if any of the given output hop is a persistent write.
-        * 
-        * @param outputHops output hops to check
-        * @return true if any of the hop is a persistent write else false.
-        */
-       private static boolean isAnyPersistentWrite(Hop [] outputHops) {
-               for(Hop outHop : outputHops) {
-                       if(HopRewriteUtils.isData(outHop, 
DataOpTypes.PERSISTENTWRITE))
-                               return true;
+       };
+       
+       // Pattern 6:
+       // matrix(colMeans(X), rows=C, cols=Hin*Win)
+       // This avoids unnecessary copy by the reshape operator
+       private static final HopDagPatternMatcher _reshapeColMeans = 
+               matrix(
+                               colMeans(leaf("X", MATRIX).fitsOnGPU(2)), // 
colMeans(X)
+                               leaf("C", SCALAR), 
+                               leaf("HW", SCALAR));
+       private static final Function<Hop, Hop> _reshapeColMeansReplacer = hi 
-> {
+               LOG.debug("Applied reshapeColMeans rewrite.");
+               Hop newHop = HopRewriteUtils.createDnnOp(_reshapeColMeans, 
OpOpDnn.RESHAPE_COLMEANS, "X", "C", "HW");
+               return HopRewriteUtils.rewireAllParentChildReferences(hi, 
newHop);
+       };
+       
+       // Pattern 7:
+       // mu*ema_mean + (1-mu)*mean
+       private static final HopDagPatternMatcher _updateEMA = 
+               mm_plus(        
+                               mult(   // mu*ema_mean
+                                               leaf("mu", SCALAR), 
+                                               leaf("ema_mean", MATRIX)), 
+                               mult(   // (1-mu)*mean
+                                               leaf("oneMinusMu", SCALAR), 
+                                               leaf("mean", MATRIX)))
+               .fitsOnGPU(3);
+       private static final Function<Hop, Hop> _updateEMAReplacer = hi -> {
+               if((1-_updateEMA.getLiteralValue("mu")) == 
_updateEMA.getLiteralValue("oneMinusMu")) {
+                       LOG.debug("Applied updateEMA rewrite.");
+                       Hop newHop = HopRewriteUtils.createDnnOp(_updateEMA, 
OpOpDnn.UPDATE_EMA, "ema_mean", "mean", "mu");
+                       return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
                }
-               return false;
-       }
-       
-       /**
-        * Returns output hop for a multi-output FunctionOp to be created by 
rewrite.
-        * 
-        * @param roots root hops of statement block
-        * @param oldHops old output hops of the pattern
-        * @return new output hops that should be passed to FunctionOp
-        */
-       private static ArrayList<Hop> getMultiOutputHops(ArrayList<Hop> roots, 
Hop [] oldHops) {
-               ArrayList<Hop> ret = new ArrayList<>();
-               for(int i = 0; i < oldHops.length; i++) {
-                       // Create a transient read as FunctionOp will add a 
transient write.
-                       if(HopRewriteUtils.isData(oldHops[i], 
DataOpTypes.PERSISTENTWRITE))
-                               throw new RuntimeException("Persistent write is 
not supported as output for the given rewrite." + oldHops[i]);
-                       // Generate a new name if the old output was not 
transient write.
-                       String name = HopRewriteUtils.isData(oldHops[i], 
DataOpTypes.TRANSIENTWRITE) ? oldHops[i].getName() : "_genGPU" + (_seq++);
-                       DataOp tRead = 
HopRewriteUtils.createTransientRead(name, oldHops[i]);
-                       
HopRewriteUtils.rewireAllParentChildReferences(oldHops[i], tRead);
-                       ret.add(tRead);
-                       // Remove old output from roots to avoid unnecessary 
computation.
-                       if(roots.contains(oldHops[i])) {
-                               roots.remove(oldHops[i]);
-                       }
+               return hi;
+       };
+       
+       // Pattern 8:
+       // 1/sqrt(var+epsilon)
+       private static final HopDagPatternMatcher _invVar = 
+               div(1, 
+                               sqrt(   // var+epsilon
+                                               plus(   leaf("var", MATRIX), 
+                                                               leaf("eps", 
SCALAR) )))
+               .fitsOnGPU(2);
+       private static final Function<Hop, Hop> _invVarReplacer = hi -> {
+               LOG.debug("Applied computeInverseVariance rewrite.");
+               Hop newHop = HopRewriteUtils.createDnnOp(_invVar, 
OpOpDnn.INV_VAR, "var", "eps");
+               return HopRewriteUtils.rewireAllParentChildReferences(hi, 
newHop);
+       };
+       
+       
+       private static ArrayList<HopPatternRewriter> _rewriters = null;
+       public ArrayList<HopPatternRewriter> getPatternRewriter() {
+               if(_rewriters == null) {
+                       ArrayList<HopPatternRewriter> rewriters = new 
ArrayList<>();
+                       rewriters.add(new HopPatternRewriter("batchNormdX", 
_batchNormdX, _batchNormdXReplacer));
+                       rewriters.add(new 
HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, 
_batchNormUpdatedVarReplacer));
+                       rewriters.add(new HopPatternRewriter("batchNormTest", 
_batchNormTest, _batchNormTestReplacer));
+                       rewriters.add(new HopPatternRewriter("channelSums", 
_channelSums, _channelSumsReplacer));
+                       rewriters.add(new HopPatternRewriter("updateNesterovX", 
_updateNesterovX, _updateNesterovXReplacer));
+                       rewriters.add(new HopPatternRewriter("reshapeColMeans", 
_reshapeColMeans, _reshapeColMeansReplacer));
+                       rewriters.add(new HopPatternRewriter("updateEMA", 
_updateEMA, _updateEMAReplacer));
+                       rewriters.add(new HopPatternRewriter("invVar", _invVar, 
_invVarReplacer));
+                       _rewriters = rewriters;
                }
-               return ret;
+               return _rewriters;
        }
-       // ------------------------------------------------------------
        
-       /**
-        * Checks for the nesterov_update_x pattern (X = X - mu*v_prev + 
(1+mu)*v)
-        * and returns a new DnnOp if matched
-        * 
-        * @param parent parent of the input
-        * @param hi input to be matched
-        * @param pos position
-        * @return a new DnnOp or hi
-        */
-       private static Hop updateNesterovX(Hop parent, Hop hi, int pos) {
-               if(fitsOnGPU(hi, 4) && isBinaryMMAdd(hi) && 
isBinaryMMMinus(getFirstInput(hi))
-                       && isBinarySMMult(getSecondInput(getFirstInput(hi))) 
-                       && isBinarySMMult(getSecondInput(hi))) {
-                       Hop onePlusMu = getFirstInput(getSecondInput(hi));
-                       Hop tmp = getSecondInput(getFirstInput(hi));
-                       Hop mu = getFirstInput(tmp);
-                       if(isOnePlusMu(onePlusMu, mu)) {
-                               Hop v_prev = getSecondInput(tmp);
-                               Hop v = getSecondInput(getSecondInput(hi));
-                               Hop X = getFirstInput(getFirstInput(hi));
-                               if(hasSameDimensions(X, v) && 
hasSameDimensions(X, v_prev)) {
-                                       ArrayList<Hop> inHops = new 
ArrayList<Hop>();
-                                       inHops.add(X);
-                                       inHops.add(v);
-                                       inHops.add(v_prev);
-                                       inHops.add(mu);
-                                       LOG.debug("Applied updateNesterovX 
rewrite.");
-                                       Hop newHop = new DnnOp(hi.getName(), 
hi.getDataType(), hi.getValueType(),
-                                                       
OpOpDnn.UPDATE_NESTEROV_X, inHops);
-                                       return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
-                               }
-                       }
-               }
-               return hi;
-       }
+       
+       // 
-------------------------------------------------------------------------------------------
        
        private static boolean hasSameDimensions(Hop x, Hop y) {
                return x.dimsKnown() && y.dimsKnown() && (x.getDim1() == 
y.getDim1()) && (x.getDim2() == y.getDim2());
        }
        
-       private static boolean isOnePlusMu(Hop onePlusMu, Hop mu) {
-               return (isBinarySMMult(onePlusMu, 1.0) && 
getSecondInput(onePlusMu) == mu) ||
-                               getValue(onePlusMu) == getValue(mu) + 1;
-       }
-       
-       /**
-        * Checks for the batch norm (mode="test") pattern using the helper 
isBatchNormTrainMean and isBatchNormTrainVar
-        * and returns a new DnnOp if matched
-        * 
-        * @param parent parent of the input
-        * @param hi input to be matched
-        * @param pos position
-        * @return a new DnnOp or hi
-        */
-       private static Hop batchNormTest(Hop parent, Hop hi, int pos) {
-               // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
-               // hi = bias_add(bias_multiply(norm, gamma), beta)
-               // 2x for input and output and 1x for overhead
-               if(hasFirstInput(hi) && isBiasAdd(hi) && 
isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) {
-                       Hop norm = getFirstInput(getFirstInput(hi));
-                       if(hasSecondInput(norm) && isBiasMultiply(norm) && 
isBiasAdd(getFirstInput(norm)) 
-                                       && 
isUnaryMinus(getSecondInput(getFirstInput(norm)))
-                                       && 
isOneDivideBySqrt(getSecondInput(norm))) {
-                               double eps = 0;
-                               Hop var = 
getFirstInput(getSecondInput(getSecondInput(norm)));
-                               if( HopRewriteUtils.isBinary(var, OpOp2.PLUS) &&
-                                       (getFirstInput(var) instanceof 
LiteralOp || getSecondInput(var) instanceof LiteralOp)) {
-                                       // eps + ema_var
-                                       if(getFirstInput(var) instanceof 
LiteralOp) {
-                                               eps = 
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>());
-                                               var = getSecondInput(var);
-                                       }
-                                       else {
-                                               eps = 
OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new 
HashMap<>());
-                                               var = getFirstInput(var);
-                                       }
-                               }
-                               // Generate batch norm test op
-                               Hop X = getFirstInput(getFirstInput(norm));
-                               Hop mean = 
getSecondInput(getSecondInput(getFirstInput(norm)));
-                               
-                               // This guard disallows eager fusion of train 
batch normalization into test batch normalization
-                               boolean potentialForBatchNormTrain = 
!X.rowsKnown() && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, 
var, X, getFirstInput(mean), true);
-                               if(!potentialForBatchNormTrain) {
-                                       Hop gamma = 
getSecondInput(getFirstInput(hi));
-                                       Hop beta = getSecondInput(hi);
-                                       ArrayList<Hop> inHops = new 
ArrayList<Hop>();
-                                       inHops.add(X);
-                                       inHops.add(gamma);
-                                       inHops.add(beta);
-                                       inHops.add(mean);
-                                       inHops.add(var);
-                                       inHops.add(new LiteralOp(eps));
-                                       if(fitsOnGPU(inHops, true)) {
-                                               LOG.debug("Applied 
batchNormTest rewrite.");
-                                               Hop newHop = new 
DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
-                                                               
OpOpDnn.BATCH_NORM2D_TEST, inHops);
-                                               return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
-                                       }
-                               }
-                               else {
-                                       LOG.debug("Skipping batchNormTest 
rewrite as there is potential for batch normalization train rewrite after 
recompilation.");
-                               }
-                       }
-               }
-               
-               return hi;
+       private static boolean checkDimensions(Hop x, long dim1, long dim2) {
+               return x.dimsKnown() && (x.getDim1() == dim1) && (x.getDim2() 
== dim2);
        }
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/lops/DnnTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java 
b/src/main/java/org/apache/sysml/lops/DnnTransform.java
index 3183b5f..2d2d5f1 100644
--- a/src/main/java/org/apache/sysml/lops/DnnTransform.java
+++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java
@@ -32,7 +32,8 @@ public class DnnTransform extends Lop
                RELU_MAX_POOLING, RELU_MAX_POOLING_BACKWARD, RELU_BACKWARD,
                CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
                BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, 
BATCH_NORM2D_TEST, 
-               UPDATE_NESTEROV_X
+               UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, 
UPDATE_EMA, INV_VAR,
+               BATCH_NORM2D_BACKWARD_DX
        }
        
        private OperationTypes operation;
@@ -167,11 +168,26 @@ public class DnnTransform extends Lop
                case CHANNEL_SUMS:
                        return "channel_sums";
                
+               case INV_VAR:
+                       return "inv_var";
+               
                case UPDATE_NESTEROV_X:
                        return "update_nesterov_x";
                        
                case BATCH_NORM2D_TEST:
                        return "batch_norm2d_test";
+               
+               case BATCH_NORM2D_BACKWARD_DX:
+                       return "batch_norm2d_bwd_dx";
+                       
+               case UPDATE_EMA_VAR:
+                       return "update_ema_var";
+                       
+               case UPDATE_EMA:
+                       return "update_ema";
+                       
+               case RESHAPE_COLMEANS:
+                       return "reshape_colmeans";
                        
                default:
                        throw new 
UnsupportedOperationException(this.printErrorLocation() + "Instruction is not 
defined for Transform operation " + operation);
@@ -181,7 +197,8 @@ public class DnnTransform extends Lop
        
        @Override
        public String getInstructions(String input, String bias, String output) 
{
-               if(operation == OperationTypes.BIAS_ADD || operation == 
OperationTypes.BIAS_MULTIPLY || operation == OperationTypes.RELU_BACKWARD) {
+               if(operation == OperationTypes.BIAS_ADD || operation == 
OperationTypes.BIAS_MULTIPLY || operation == OperationTypes.RELU_BACKWARD
+                               || operation == OperationTypes.INV_VAR) {
                        StringBuilder sb = new StringBuilder();
                        sb.append( getExecType() );
                        
@@ -190,7 +207,7 @@ public class DnnTransform extends Lop
                        sb.append( OPERAND_DELIMITOR );
                        sb.append( getInputs().get(0).prepInputOperand(input));
                        sb.append( OPERAND_DELIMITOR );
-                       sb.append( getInputs().get(0).prepInputOperand(bias));
+                       sb.append( getInputs().get(1).prepInputOperand(bias));
                        //output
                        sb.append( OPERAND_DELIMITOR );
                        sb.append( this.prepOutputOperand(output));
@@ -212,7 +229,7 @@ public class DnnTransform extends Lop
        
        @Override
        public String getInstructions(String input, String C, String HW, String 
output) {
-               if(operation == OperationTypes.CHANNEL_SUMS) {
+               if(operation == OperationTypes.CHANNEL_SUMS || operation == 
OperationTypes.RESHAPE_COLMEANS || operation == OperationTypes.UPDATE_EMA) {
                        StringBuilder sb = new StringBuilder();
                        sb.append( getExecType() );
                        
@@ -306,6 +323,34 @@ public class DnnTransform extends Lop
                        throw new LopsException("The operation is not supported 
with six operands:" + operation.name());
                }
        }
+       
+       public String getInstructions(String input1, String input2, String 
input3, String input4, String input5, String output) {
+               if(operation == OperationTypes.UPDATE_EMA_VAR || operation == 
OperationTypes.BATCH_NORM2D_BACKWARD_DX) {
+                       StringBuilder sb = new StringBuilder();
+                       sb.append( getExecType() );
+                       
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getOpcode() );
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(0).prepInputOperand(input1));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(1).prepInputOperand(input2));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(2).prepInputOperand(input3));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(3).prepInputOperand(input4));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(4).prepInputOperand(input5));
+                       //output
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( this.prepOutputOperand(output));
+                       
+                       return sb.toString();
+               }
+               else {
+                       throw new LopsException("The operation is not supported 
with six operands:" + operation.name());
+               }
+       }
 
        public void appendOpcode(StringBuilder sb) {
                sb.append( getExecType() );

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 9f3a1e2..fe86dc8 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -270,58 +270,6 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        setDimensions(dc0, getFifthExpr());
                        break;
                }
-               case BATCH_NORM2D:
-               {
-                       // Input: image, scale, bias, runningMean, runningVar, 
mode, epsilon, exponentialAverageFactor
-                       checkNumParameters(8);
-                       checkMatrixParam(getFirstExpr());
-                       checkMatrixParam(getSecondExpr());
-                       checkMatrixParam(getThirdExpr());
-                       checkMatrixParam(getFourthExpr());
-                       checkMatrixParam(getFifthExpr());
-                       
-                       // Output: ret, retRunningMean, retRunningVar, 
resultSaveMean, resultSaveInvVariance
-                       // setup output properties
-                       if(getOutputs().length != 5)
-                               raiseValidateError("batch_norm2d has 5 
outputs", false);
-                        
-                       DataIdentifier ret = (DataIdentifier) getOutputs()[0];
-                       DataIdentifier retRunningMean = (DataIdentifier) 
getOutputs()[1];
-                       DataIdentifier retRunningVar = (DataIdentifier) 
getOutputs()[2];
-                       DataIdentifier resultSaveMean = (DataIdentifier) 
getOutputs()[3];
-                       DataIdentifier resultSaveInvVariance = (DataIdentifier) 
getOutputs()[4];
-                       
-                       setDimensions(ret, getFirstExpr());
-                       setDimensions(retRunningMean, getFourthExpr());
-                       setDimensions(retRunningVar, getFourthExpr());
-                       setDimensions(resultSaveMean, getFourthExpr());
-                       setDimensions(resultSaveInvVariance, getFourthExpr());
-                       break;
-               }
-               case BATCH_NORM2D_BACKWARD:
-               {
-                       // Input: image, dout, scale, epsilon, savedMean, 
savedInvVariance
-                       checkNumParameters(6);
-                       checkMatrixParam(getFirstExpr());
-                       checkMatrixParam(getSecondExpr());
-                       checkMatrixParam(getThirdExpr());
-                       checkMatrixParam(getFifthExpr());
-                       checkMatrixParam(getSixthExpr());
-                       
-                       // Output: dX, dScale, dBias 
-                       // setup output properties
-                       if(getOutputs().length != 3)
-                               raiseValidateError("batch_norm2d_backward has 3 
outputs", false);
-                       
-                       DataIdentifier dX = (DataIdentifier) getOutputs()[0];
-                       DataIdentifier dScale = (DataIdentifier) 
getOutputs()[1];
-                       DataIdentifier dBias = (DataIdentifier) getOutputs()[2];
-                       
-                       setDimensions(dX, getFirstExpr());
-                       setDimensions(dScale, getThirdExpr());
-                       setDimensions(dBias, getThirdExpr());
-                       break;
-               }
                case EIGEN:
                        checkNumParameters(1);
                        checkMatrixParam(getFirstExpr());
@@ -1451,8 +1399,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                                // always unconditional (because unsupported 
operation)
                                BuiltinFunctionOp op = getOpCode();
                                if( op==BuiltinFunctionOp.EIGEN || 
op==BuiltinFunctionOp.LU || op==BuiltinFunctionOp.QR || 
op==BuiltinFunctionOp.SVD 
-                                               || op==BuiltinFunctionOp.LSTM 
|| op==BuiltinFunctionOp.LSTM_BACKWARD
-                                               || 
op==BuiltinFunctionOp.BATCH_NORM2D || 
op==BuiltinFunctionOp.BATCH_NORM2D_BACKWARD)
+                                               || op==BuiltinFunctionOp.LSTM 
|| op==BuiltinFunctionOp.LSTM_BACKWARD)
                                        raiseValidateError("Function "+op+" 
needs to be called with multi-return assignment.", false, 
LanguageErrorCodes.INVALID_PARAMETERS);
                                else
                                        raiseValidateError("Unsupported 
function "+op, false, LanguageErrorCodes.INVALID_PARAMETERS);
@@ -1535,8 +1482,6 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                case EIGEN:
                case LSTM:
                case LSTM_BACKWARD:
-               case BATCH_NORM2D:
-               case BATCH_NORM2D_BACKWARD:
                case SVD:
                        return true;
                default:
@@ -1956,10 +1901,6 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        bifop = Expression.BuiltinFunctionOp.LSTM;
                else if (functionName.equals("lstm_backward"))
                        bifop = Expression.BuiltinFunctionOp.LSTM_BACKWARD;
-               else if (functionName.equals("batch_norm2d"))
-                       bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D;
-               else if (functionName.equals("batch_norm2d_backward"))
-                       bifop = 
Expression.BuiltinFunctionOp.BATCH_NORM2D_BACKWARD;
                else if (functionName.equals("conv2d"))
                         bifop = Expression.BuiltinFunctionOp.CONV2D;
                else if (functionName.equals("bias_add"))

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index e9b643e..e3db435 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2271,8 +2271,6 @@ public class DMLTranslator
                        case EIGEN:
                        case LSTM:
                        case LSTM_BACKWARD:
-                       case BATCH_NORM2D:
-                       case BATCH_NORM2D_BACKWARD:
                        case SVD:
                                
                                // Number of outputs = size of targetList = #of 
identifiers in source.getOutputs

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java 
b/src/main/java/org/apache/sysml/parser/Expression.java
index 46e6442..33fca66 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -93,7 +93,7 @@ public abstract class Expression implements ParseInfo
                EXISTS,
                CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIASADD, 
BIASMULT,
                MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD, AVG_POOL_BACKWARD,
-               LSTM, LSTM_BACKWARD, BATCH_NORM2D, BATCH_NORM2D_BACKWARD,
+               LSTM, LSTM_BACKWARD,
                EXP,
                FLOOR,
                IFELSE,

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index f4122d9..3480504 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -59,11 +59,13 @@ public class GPUInstructionParser  extends InstructionParser
                String2GPUInstructionType.put( "channel_sums",          
GPUINSTRUCTION_TYPE.Dnn);
                String2GPUInstructionType.put( "lstm",                  
GPUINSTRUCTION_TYPE.Dnn);
                String2GPUInstructionType.put( "lstm_backward",         
GPUINSTRUCTION_TYPE.Dnn);
-               String2GPUInstructionType.put( "batch_norm2d",           
GPUINSTRUCTION_TYPE.Dnn);
-               String2GPUInstructionType.put( "batch_norm2d_backward",  
GPUINSTRUCTION_TYPE.Dnn);
                String2GPUInstructionType.put( "batch_norm2d_test",      
GPUINSTRUCTION_TYPE.Dnn);
-               String2GPUInstructionType.put( "batch_norm2d_train",      
GPUINSTRUCTION_TYPE.Dnn);
-               String2GPUInstructionType.put( "update_nesterov_x",      
GPUINSTRUCTION_TYPE.Dnn);
+               String2GPUInstructionType.put( "update_nesterov_x",     
GPUINSTRUCTION_TYPE.Dnn);
+               String2GPUInstructionType.put( "update_ema_var",        
GPUINSTRUCTION_TYPE.Dnn);
+               String2GPUInstructionType.put( "update_ema",                    
GPUINSTRUCTION_TYPE.Dnn);
+               String2GPUInstructionType.put( "reshape_colmeans",      
GPUINSTRUCTION_TYPE.Dnn);
+               String2GPUInstructionType.put( "inv_var",                       
GPUINSTRUCTION_TYPE.Dnn);
+               String2GPUInstructionType.put( "batch_norm2d_bwd_dx",   
GPUINSTRUCTION_TYPE.Dnn);
                
                // Matrix Multiply Operators
                String2GPUInstructionType.put( "ba+*",  
GPUINSTRUCTION_TYPE.AggregateBinary);

Reply via email to