Repository: systemml Updated Branches: refs/heads/master fab31fd1f -> 512fb9e11
[SYSTEMML-445] Extend coverage for GPU batchnorm test rewrite - If inv_var rewrite has already been applied, the application of GPU batchnorm test rewrite (and CuDNN batchnorm kernel) is skipped. This commit fixes this performance regression. - Also, this commit allows for forcing of GPU rewrites in case of forced GPU mode. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/512fb9e1 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/512fb9e1 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/512fb9e1 Branch: refs/heads/master Commit: 512fb9e119541ae9d7dae58c0812a89d569d1ca0 Parents: fab31fd Author: Niketan Pansare <npan...@us.ibm.com> Authored: Tue Oct 9 13:56:47 2018 -0700 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Tue Oct 9 13:56:47 2018 -0700 ---------------------------------------------------------------------- .../sysml/hops/rewrite/HopDagPatternMatcher.java | 15 +++++++++++++++ .../sysml/hops/rewrite/RewriteGPUSpecificOps.java | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/512fb9e1/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java index 51b1812..33cd5ed 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java @@ -294,6 +294,19 @@ public class HopDagPatternMatcher { return new HopDagPatternMatcher().addPredicate("sqrt", h -> HopRewriteUtils.isUnary(h, OpOp1.SQRT)) .addChildMatcher(child); } + public static HopDagPatternMatcher inv_var(HopDagPatternMatcher var, HopDagPatternMatcher eps) { + return new HopDagPatternMatcher().addPredicate("sqrt", h -> { + if(HopRewriteUtils.isDnn(h, OpOpDnn.INV_VAR)) { + return true; + } + else { + return HopRewriteUtils.isBinary(h, OpOp2.DIV) && HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 1.0) && + HopRewriteUtils.isUnary(h.getInput().get(1), OpOp1.SQRT) && + HopRewriteUtils.isBinary(h.getInput().get(1).getInput().get(0), OpOp2.PLUS); + } + }) + .addChildMatcher(var, eps); + } public static HopDagPatternMatcher div(HopDagPatternMatcher child1, HopDagPatternMatcher child2) { return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV)) .addChildMatcher(child1, child2); @@ -370,6 +383,8 @@ public class HopDagPatternMatcher { .addChildMatcher(child1, dummy); } private static boolean _fitsOnGPU(Hop h, double multiplier) { + if(ConfigurationManager.isForcedGPU()) + return true; double memEst = multiplier*h.getMemEstimate(); return ConfigurationManager.isGPU() && h.dimsKnown() && OptimizerUtils.isMemoryBasedOptLevel() && memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget(); http://git-wip-us.apache.org/repos/asf/systemml/blob/512fb9e1/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java index 53d368b..577adc3 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java @@ -178,7 +178,7 @@ public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher { HopDagPatternMatcher norm = bias_multiply( bias_add(leaf("X", MATRIX), unaryMinus(leaf("mean", MATRIX))), // bias_add(X, -mean) - div(1, sqrt(plus(leaf("var", MATRIX), leaf("eps", SCALAR))))); // 1/sqrt(var+eps) + inv_var(leaf("var", MATRIX), leaf("eps", SCALAR))); // 1/sqrt(var+eps) // hi = bias_add(bias_multiply(norm, gamma), beta) _batchNormTest = bias_add(