Repository: incubator-systemml Updated Branches: refs/heads/master 4288da9ee -> 7ed36a98f
[HOTFIX][SYSTEMML-1659] Fix sum-sumSq aggregate elimination rewrite This patch fixes the rewrite 'remove unnecessary aggregate' for the case of sum-sumsq. As it turned out, the newly introduced tests, did not check for it because sum(rowSums(X^2)) is first rewritten (w/ static rewrites) to sum(X^2) and only subsequently rewritten (w/ dynamic rewrites) to sumSq(X). For the NN library, the sum is exposed after dynamic rewrites and hence the issue of the wrong aggregation type assignment showed up. This issue did not show up in local tests, because they were run with a custom test suite, which only included application, conversion, functions, and mlcontext tests (due to issues with the recently introduced gpu tests). Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/7ed36a98 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/7ed36a98 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/7ed36a98 Branch: refs/heads/master Commit: 7ed36a98fedb54403bfb6349417782d9b88361f8 Parents: 4288da9 Author: Matthias Boehm <[email protected]> Authored: Sat Jun 3 15:59:08 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jun 3 15:59:08 2017 -0700 ---------------------------------------------------------------------- .../rewrite/RewriteAlgebraicSimplificationStatic.java | 2 ++ .../functions/misc/RewriteEliminateAggregatesTest.java | 12 ++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ed36a98/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 74f5488..17f1ace 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -837,6 +837,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule Hop input = au2.getInput().get(0); HopRewriteUtils.removeAllChildReferences(au2); HopRewriteUtils.replaceChildReference(au1, au2, input); + if( au2.getOp() == AggOp.SUM_SQ ) + au1.setOp(AggOp.SUM_SQ); LOG.debug("Applied removeUnnecessaryAggregates (line "+hi.getBeginLine()+")."); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ed36a98/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java index 741ef31..3092867 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java @@ -19,8 +19,6 @@ package org.apache.sysml.test.integration.functions.misc; -import java.util.HashMap; - import org.junit.Assert; import org.junit.Test; import org.apache.sysml.hops.OptimizerUtils; @@ -35,6 +33,8 @@ public class RewriteEliminateAggregatesTest extends AutomatedTestBase private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEliminateAggregatesTest.class.getSimpleName() + "/"; + private double tol = Math.pow(10, -10); + @Override public void setUp() { TestUtils.clearAssertionInformation(); @@ -111,7 +111,7 @@ public class RewriteEliminateAggregatesTest extends AutomatedTestBase OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; //generate actual dataset - double[][] A = getRandomMatrix(123, 12, 0, 1, 0.9, 7); + double[][] A = getRandomMatrix(123, 12, -5, 5, 0.9, 7); writeInputMatrixWithMTD("A", A, true); //run test @@ -119,9 +119,9 @@ public class RewriteEliminateAggregatesTest extends AutomatedTestBase runRScript(true); //compare scalars - HashMap<CellIndex, Double> dmlfile = readDMLScalarFromHDFS("Scalar"); - HashMap<CellIndex, Double> rfile = readRScalarFromFS("Scalar"); - TestUtils.compareScalars(dmlfile.toString(), rfile.toString()); + double ret1 = readDMLScalarFromHDFS("Scalar").get(new CellIndex(1,1)); + double ret2 = readRScalarFromFS("Scalar").get(new CellIndex(1,1)); + TestUtils.compareScalars(ret1, ret2, tol); //check for applied rewrites if( rewrites ) {
