Repository: systemml Updated Branches: refs/heads/master 7d936cf0c -> 2fed357e2
[HOTFIX] Fix additional size propagation issues and batch_norm2d runtime This patch fixes remaining issues introduces by overlapping previous commits. Due to same other side effects, the NN tests require constant propagation into rand, and the batch_norm2d runtime did not correctly handle empty inputs. In a subsequent patch (after we stabilized the build), we need to revisit both aspects in principled way. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/2fed357e Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/2fed357e Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/2fed357e Branch: refs/heads/master Commit: 2fed357e285c50c5c9403686e37a5d43ce7fc05b Parents: 7d936cf Author: Matthias Boehm <[email protected]> Authored: Sat Jun 2 00:26:48 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jun 2 01:08:02 2018 -0700 ---------------------------------------------------------------------- .../apache/sysml/hops/rewrite/RewriteConstantFolding.java | 8 -------- src/main/java/org/apache/sysml/parser/DataExpression.java | 3 +-- .../org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java | 4 ++-- 3 files changed, 3 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/2fed357e/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java index a153fb1..0099029 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java @@ -105,14 +105,6 @@ public class RewriteConstantFolding extends HopRewriteRule LOG.error("Failed to execute constant folding instructions. No abort.", ex); } } - //fold nrow as precondition for further constant folding - else if( HopRewriteUtils.isUnary(root, OpOp1.NROW) && root.getInput().get(0).rowsKnown() ) { - literal = new LiteralOp(root.getInput().get(0).getDim1()); - } - //fold ncol as precondition for further constant folding - else if( HopRewriteUtils.isUnary(root, OpOp1.NCOL) && root.getInput().get(0).colsKnown() ) { - literal = new LiteralOp(root.getInput().get(0).getDim2()); - } //fold conjunctive predicate if at least one input is literal 'false' else if( isApplicableFalseConjunctivePredicate(root) ) { literal = new LiteralOp(false); http://git-wip-us.apache.org/repos/asf/systemml/blob/2fed357e/src/main/java/org/apache/sysml/parser/DataExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DataExpression.java b/src/main/java/org/apache/sysml/parser/DataExpression.java index 93379fa..eef9c08 100644 --- a/src/main/java/org/apache/sysml/parser/DataExpression.java +++ b/src/main/java/org/apache/sysml/parser/DataExpression.java @@ -535,8 +535,7 @@ public class DataExpression extends DataIdentifier } //general data expression constant propagation - if( !conditional ) - performConstantPropagationRand( currConstVars ); + performConstantPropagationRand( currConstVars ); performConstantPropagationReadWrite( currConstVars ); // check if data parameter of matrix is scalar or matrix -- if scalar, use Rand instead http://git-wip-us.apache.org/repos/asf/systemml/blob/2fed357e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index f8318a5..8a59414 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -509,8 +509,8 @@ public class LibMatrixDNN { double var = resultSaveInvVarianceArr[k]/NPQ - Math.pow(mean, 2.0); resultSaveMeanArr[k] = mean; resultSaveInvVarianceArr[k] = Math.pow(Math.sqrt(var + epsilon), -1.0); - retRunningMeanArr[k] = mu*runningMeanArr[k] + (1-mu)*mean; - retRunningVarArr[k] = mu*runningVarArr[k] + (1-mu)*mean; + retRunningMeanArr[k] = mu*((runningMeanArr!=null)?runningMeanArr[k]:0) + (1-mu)*mean; + retRunningVarArr[k] = mu*((runningVarArr!=null)?runningVarArr[k]:0) + (1-mu)*mean; } } else if(phase.equalsIgnoreCase("test")) {
