Repository: systemml
Updated Branches:
  refs/heads/master fab31fd1f -> 512fb9e11


[SYSTEMML-445] Extend coverage for GPU batchnorm test rewrite

- If inv_var rewrite has already been applied, the application of GPU
batchnorm test rewrite (and CuDNN batchnorm kernel) is skipped. This
commit fixes this performance regression.
- Also, this commit allows for forcing of GPU rewrites in case of forced
GPU mode.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/512fb9e1
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/512fb9e1
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/512fb9e1

Branch: refs/heads/master
Commit: 512fb9e119541ae9d7dae58c0812a89d569d1ca0
Parents: fab31fd
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Tue Oct 9 13:56:47 2018 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Tue Oct 9 13:56:47 2018 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopDagPatternMatcher.java     | 15 +++++++++++++++
 .../sysml/hops/rewrite/RewriteGPUSpecificOps.java    |  2 +-
 2 files changed, 16 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/512fb9e1/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
index 51b1812..33cd5ed 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
@@ -294,6 +294,19 @@ public class HopDagPatternMatcher {
                return new HopDagPatternMatcher().addPredicate("sqrt", h -> 
HopRewriteUtils.isUnary(h, OpOp1.SQRT))
                                .addChildMatcher(child);
        }
+       public static HopDagPatternMatcher inv_var(HopDagPatternMatcher var, 
HopDagPatternMatcher eps) {
+               return new HopDagPatternMatcher().addPredicate("sqrt", h ->  {
+                       if(HopRewriteUtils.isDnn(h, OpOpDnn.INV_VAR)) {
+                               return true;
+                       }
+                       else {
+                               return HopRewriteUtils.isBinary(h, OpOp2.DIV) 
&& HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 1.0) &&
+                               HopRewriteUtils.isUnary(h.getInput().get(1), 
OpOp1.SQRT) && 
+                               
HopRewriteUtils.isBinary(h.getInput().get(1).getInput().get(0), OpOp2.PLUS);
+                       }
+               })
+                               .addChildMatcher(var, eps);
+       }
        public static HopDagPatternMatcher div(HopDagPatternMatcher child1, 
HopDagPatternMatcher child2) {
                return new HopDagPatternMatcher().addPredicate("div", h -> 
HopRewriteUtils.isBinary(h, OpOp2.DIV))
                                .addChildMatcher(child1, child2);
@@ -370,6 +383,8 @@ public class HopDagPatternMatcher {
                                .addChildMatcher(child1, dummy);
        }
        private static boolean _fitsOnGPU(Hop h, double multiplier) {
+               if(ConfigurationManager.isForcedGPU())
+                       return true;
                double memEst = multiplier*h.getMemEstimate();
                return ConfigurationManager.isGPU() && h.dimsKnown() && 
OptimizerUtils.isMemoryBasedOptLevel() &&
                                memEst < OptimizerUtils.getLocalMemBudget() && 
memEst < GPUContextPool.initialGPUMemBudget();

http://git-wip-us.apache.org/repos/asf/systemml/blob/512fb9e1/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index 53d368b..577adc3 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -178,7 +178,7 @@ public class RewriteGPUSpecificOps extends 
HopRewriteRuleWithPatternMatcher {
                HopDagPatternMatcher norm = 
                        bias_multiply(
                                        bias_add(leaf("X", MATRIX), 
unaryMinus(leaf("mean", MATRIX))), // bias_add(X, -mean)
-                                       div(1, sqrt(plus(leaf("var", MATRIX), 
leaf("eps", SCALAR))))); // 1/sqrt(var+eps)
+                                       inv_var(leaf("var", MATRIX), 
leaf("eps", SCALAR))); // 1/sqrt(var+eps)
                // hi = bias_add(bias_multiply(norm, gamma), beta)
                _batchNormTest = 
                        bias_add(

Reply via email to