Repository: incubator-systemml Updated Branches: refs/heads/master 513fde4c7 -> 6d95c9f5e
[SYSTEMML-714] Fix rewrite 'pushdown sum on binary+' (m. parents), tests Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/6d95c9f5 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/6d95c9f5 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/6d95c9f5 Branch: refs/heads/master Commit: 6d95c9f5e9df281b6903aaffcde451b25f4bce98 Parents: 513fde4 Author: Matthias Boehm <[email protected]> Authored: Thu May 26 23:58:38 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu May 26 23:58:38 2016 -0700 ---------------------------------------------------------------------- .../RewriteAlgebraicSimplificationDynamic.java | 7 +- .../misc/RewritePushdownSumOnBinaryTest.java | 96 ++++++++++++++++++++ .../misc/RewritePushdownSumOnBinary.dml | 33 +++++++ .../functions/misc/ZPackageSuite.java | 2 + 4 files changed, 136 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6d95c9f5/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 5375539..817c839 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -1443,9 +1443,12 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule //rewire new subdag HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.removeAllChildReferences(hi); - HopRewriteUtils.removeAllChildReferences(bop); HopRewriteUtils.addChildReference(parent, newBin, pos); + if( hi.getParent().isEmpty() ) + HopRewriteUtils.removeAllChildReferences(hi); + if( bop.getParent().isEmpty() ) + HopRewriteUtils.removeAllChildReferences(bop); + hi = newBin; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6d95c9f5/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumOnBinaryTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumOnBinaryTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumOnBinaryTest.java new file mode 100644 index 0000000..dec4670 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumOnBinaryTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +/** + * + * + */ +public class RewritePushdownSumOnBinaryTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewritePushdownSumOnBinary"; + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownSumOnBinaryTest.class.getSimpleName() + "/"; + + private static final int rows = 1000; + private static final int cols = 1; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R1", "R2" }) ); + } + + @Test + public void testRewritePushdownSumOnBinaryNoRewrite() { + testRewritePushdownSumOnBinary( TEST_NAME1, false ); + } + + @Test + public void testRewritePushdownSumOnBinary() { + testRewritePushdownSumOnBinary( TEST_NAME1, true ); + } + + + /** + * + * @param condition + * @param branchRemoval + * @param IPA + */ + private void testRewritePushdownSumOnBinary( String testname, boolean rewrites ) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{ "-args", String.valueOf(rows), + String.valueOf(cols), output("R1"), output("R2") }; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + //run performance tests + runTest(true, false, null, -1); + + //compare matrices + long expect = Math.round(0.5*rows); + HashMap<CellIndex, Double> dmlfile1 = readDMLScalarFromHDFS("R1"); + Assert.assertEquals("Wrong result R1, expected: "+expect, expect, Math.round(dmlfile1.get(new CellIndex(1,1)))); + HashMap<CellIndex, Double> dmlfile2 = readDMLScalarFromHDFS("R2"); + Assert.assertEquals("Wrong result R2, expected: "+expect, expect, Math.round(dmlfile2.get(new CellIndex(1,1)))); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6d95c9f5/src/test/scripts/functions/misc/RewritePushdownSumOnBinary.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewritePushdownSumOnBinary.dml b/src/test/scripts/functions/misc/RewritePushdownSumOnBinary.dml new file mode 100644 index 0000000..d48ac0a --- /dev/null +++ b/src/test/scripts/functions/misc/RewritePushdownSumOnBinary.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = rand(rows=$1, cols=$2, seed=1); +B = rand(rows=$1, cols=$2, seed=2); +C = rand(rows=$1, cols=$2, seed=3); +D = rand(rows=$1, cols=$2, seed=4); + +r1 = sum(A*B + C*D); +r2 = r1; + +print("r1="+r1+", r2="+r2); +write(r1, $3); +write(r2, $4); + http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6d95c9f5/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index f24660f..4720595 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -47,6 +47,8 @@ import org.junit.runners.Suite; PrintMatrixTest.class, ReadAfterWriteTest.class, RewriteFusedRandTest.class, + RewritePushdownSumOnBinaryTest.class, + RewritePushdownUaggTest.class, RewriteSimplifyRowColSumMVMultTest.class, RewriteSlicedMatrixMultTest.class, ScalarAssignmentTest.class,
