Repository: systemml Updated Branches: refs/heads/master 6a11413b1 -> 223066eeb
[SYSTEMML-2010] Improved rewrite for merging statement blocks So far we did not merging basic blocks that contain functions with any subsequent blocks because function directly bind to output variables. However, this is too conservative because blocks with non-conflicting inputs and outputs can still be merged, which creates more opportunities for rewrites and common subexpression elimination. This patch generalizes the existing rewrite for merging sequences of statement blocks accordingly. Furthermore, this patch also includes additional tests and a fix for a potential result correctness issue due to the invalid merge of statement blocks, where the second contains a function and there is an output conflict. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6b4eaa6b Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6b4eaa6b Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6b4eaa6b Branch: refs/heads/master Commit: 6b4eaa6bd8ba77a18de1c9aebcb9ef3047ada89e Parents: 6a11413 Author: Matthias Boehm <[email protected]> Authored: Fri Nov 10 15:01:24 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Fri Nov 10 15:01:24 2017 -0800 ---------------------------------------------------------------------- .../hops/rewrite/RewriteMergeBlockSequence.java | 23 ++++++++++- .../functions/misc/RewriteMergeBlocksTest.java | 34 +++++++++++++--- .../functions/misc/RewriteMergeFunctionCut.dml | 2 +- .../functions/misc/RewriteMergeFunctionCut2.R | 36 ++++++++++++++++ .../functions/misc/RewriteMergeFunctionCut2.dml | 38 +++++++++++++++++ .../functions/misc/RewriteMergeFunctionCut3.R | 36 ++++++++++++++++ .../functions/misc/RewriteMergeFunctionCut3.dml | 39 ++++++++++++++++++ .../functions/misc/RewriteMergeFunctionCut4.R | 36 ++++++++++++++++ .../functions/misc/RewriteMergeFunctionCut4.dml | 43 ++++++++++++++++++++ 9 files changed, 280 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/6b4eaa6b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java index e1c2630..9593f5f 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java @@ -22,6 +22,7 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import org.apache.sysml.hops.FunctionOp; @@ -64,7 +65,9 @@ public class RewriteMergeBlockSequence extends StatementBlockRewriteRule StatementBlock sb2 = tmpList.get(i+1); if( HopRewriteUtils.isLastLevelStatementBlock(sb1) && HopRewriteUtils.isLastLevelStatementBlock(sb2) - && !hasFunctionOpRoot(sb1) && !sb1.isSplitDag() && !sb2.isSplitDag() ) + && !sb1.isSplitDag() && !sb2.isSplitDag() + && (!hasFunctionOpRoot(sb1) || !hasFunctionIOConflict(sb1,sb2)) + && (!hasFunctionOpRoot(sb2) || !hasFunctionIOConflict(sb2,sb1)) ) { ArrayList<Hop> sb1Hops = sb1.get_hops(); ArrayList<Hop> sb2Hops = sb2.get_hops(); @@ -163,4 +166,22 @@ public class RewriteMergeBlockSequence extends StatementBlockRewriteRule ret |= (root instanceof FunctionOp); return ret; } + + private static boolean hasFunctionIOConflict(StatementBlock sb1, StatementBlock sb2) + throws HopsException + { + //semantics: a function op root in sb1 conflicts with sb2 if this function op writes + //to a variable that is read or written by sb2, where the write might be either + //a traditional transient write or another function op. + + //collect all function output variables of sb1 + HashSet<String> outSb1 = new HashSet<>(); + for( Hop root : sb1.get_hops() ) + if( root instanceof FunctionOp ) + outSb1.addAll(Arrays.asList(((FunctionOp)root).getOutputVariableNames())); + + //check all output variables against read/updated sets + return sb2.variablesRead().containsAnyName(outSb1) + || sb2.variablesUpdated().containsAnyName(outSb1); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/6b4eaa6b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteMergeBlocksTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteMergeBlocksTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteMergeBlocksTest.java index 4a14ba6..6f84ea7 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteMergeBlocksTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteMergeBlocksTest.java @@ -30,7 +30,11 @@ import org.apache.sysml.test.utils.TestUtils; public class RewriteMergeBlocksTest extends AutomatedTestBase { private static final String TEST_NAME1 = "RewriteMergeIfCut"; //full merge - private static final String TEST_NAME2 = "RewriteMergeFunctionCut"; //only input merge + private static final String TEST_NAME2 = "RewriteMergeFunctionCut"; //full merge + private static final String TEST_NAME3 = "RewriteMergeFunctionCut2"; //only input merge + private static final String TEST_NAME4 = "RewriteMergeFunctionCut3"; //only input merge + private static final String TEST_NAME5 = "RewriteMergeFunctionCut4"; //only input merge + private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewriteMergeBlocksTest.class.getSimpleName() + "/"; @@ -42,19 +46,39 @@ public class RewriteMergeBlocksTest extends AutomatedTestBase TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"})); addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[]{"R"})); + addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[]{"R"})); + addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[]{"R"})); } @Test public void testIfCutMerge() { - testRewriteMerge(TEST_NAME1); + testRewriteMerge(TEST_NAME1, true); } @Test public void testFunctionCutMerge() { - testRewriteMerge(TEST_NAME2); + testRewriteMerge(TEST_NAME2, true); + } + + @Test + public void testFunctionCutMerge2() { + testRewriteMerge(TEST_NAME3, false); + } + + @Test + public void testFunctionCutMerge3() { + testRewriteMerge(TEST_NAME4, false); + } + + @Test + public void testFunctionCutMerge4() { + //note: this test primarily checks for result correctness + //(prevent too eager merge of functions) + testRewriteMerge(TEST_NAME5, true); } - private void testRewriteMerge(String testname) + private void testRewriteMerge(String testname, boolean expectedMerge) { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); @@ -73,7 +97,7 @@ public class RewriteMergeBlocksTest extends AutomatedTestBase HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); - Assert.assertTrue(testname.equals(TEST_NAME1) == + Assert.assertTrue(expectedMerge == heavyHittersContainsSubString("mmchain")); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/6b4eaa6b/src/test/scripts/functions/misc/RewriteMergeFunctionCut.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteMergeFunctionCut.dml b/src/test/scripts/functions/misc/RewriteMergeFunctionCut.dml index 9e247a9..372896e 100644 --- a/src/test/scripts/functions/misc/RewriteMergeFunctionCut.dml +++ b/src/test/scripts/functions/misc/RewriteMergeFunctionCut.dml @@ -32,7 +32,7 @@ P = matrix(0.7, 600, 2); Q = P[,1:1] * (X %*% ssX_V); Y = X + 2; -Y = printAndAssign(X); +Y2 = printAndAssign(X); R = t(X) %*% (Q - P[,1:1] * rowSums(Q)); write(R, $1); http://git-wip-us.apache.org/repos/asf/systemml/blob/6b4eaa6b/src/test/scripts/functions/misc/RewriteMergeFunctionCut2.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteMergeFunctionCut2.R b/src/test/scripts/functions/misc/RewriteMergeFunctionCut2.R new file mode 100644 index 0000000..cd91748 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteMergeFunctionCut2.R @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +X = matrix(0.5, 600, 10); +ssX_V = matrix(0.9, 10, 1); +P = matrix(0.7, 600, 2); + +Q = P[,1:1] * (X %*% ssX_V); +X = X; +R = t(X) %*% (Q - P[,1:1] * rowSums(Q)); + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/6b4eaa6b/src/test/scripts/functions/misc/RewriteMergeFunctionCut2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteMergeFunctionCut2.dml b/src/test/scripts/functions/misc/RewriteMergeFunctionCut2.dml new file mode 100644 index 0000000..db12015 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteMergeFunctionCut2.dml @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +printAndAssign = function(Matrix[Double] X) return (Matrix[Double] Y) { + if( sum(X) > 0 ) + print("sum(X) = " + sum(X)); + Y = X; +} + + +X = matrix(0.5, 600, 10); +ssX_V = matrix(0.9, 10, 1); +P = matrix(0.7, 600, 2); + +Q = P[,1:1] * (X %*% ssX_V); +Y = X + 2; +X = printAndAssign(X); +R = t(X) %*% (Q - P[,1:1] * rowSums(Q)); + +write(R, $1); http://git-wip-us.apache.org/repos/asf/systemml/blob/6b4eaa6b/src/test/scripts/functions/misc/RewriteMergeFunctionCut3.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteMergeFunctionCut3.R b/src/test/scripts/functions/misc/RewriteMergeFunctionCut3.R new file mode 100644 index 0000000..cd91748 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteMergeFunctionCut3.R @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +X = matrix(0.5, 600, 10); +ssX_V = matrix(0.9, 10, 1); +P = matrix(0.7, 600, 2); + +Q = P[,1:1] * (X %*% ssX_V); +X = X; +R = t(X) %*% (Q - P[,1:1] * rowSums(Q)); + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/6b4eaa6b/src/test/scripts/functions/misc/RewriteMergeFunctionCut3.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteMergeFunctionCut3.dml b/src/test/scripts/functions/misc/RewriteMergeFunctionCut3.dml new file mode 100644 index 0000000..a48c23a --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteMergeFunctionCut3.dml @@ -0,0 +1,39 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +printAndAssign = function(Matrix[Double] X) return (Matrix[Double] Y) { + if( sum(X) > 0 ) + print("sum(X) = " + sum(X)); + Y = X; +} + + +X = matrix(0.5, 600, 10); +ssX_V = matrix(0.9, 10, 1); +P = matrix(0.7, 600, 2); + +Q = P[,1:1] * (X %*% ssX_V); +Y = X + 2; +Y = printAndAssign(X); +Y = printAndAssign(X); +R = t(X) %*% (Q - P[,1:1] * rowSums(Q)); + +write(R, $1); http://git-wip-us.apache.org/repos/asf/systemml/blob/6b4eaa6b/src/test/scripts/functions/misc/RewriteMergeFunctionCut4.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteMergeFunctionCut4.R b/src/test/scripts/functions/misc/RewriteMergeFunctionCut4.R new file mode 100644 index 0000000..cd91748 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteMergeFunctionCut4.R @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +X = matrix(0.5, 600, 10); +ssX_V = matrix(0.9, 10, 1); +P = matrix(0.7, 600, 2); + +Q = P[,1:1] * (X %*% ssX_V); +X = X; +R = t(X) %*% (Q - P[,1:1] * rowSums(Q)); + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/6b4eaa6b/src/test/scripts/functions/misc/RewriteMergeFunctionCut4.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteMergeFunctionCut4.dml b/src/test/scripts/functions/misc/RewriteMergeFunctionCut4.dml new file mode 100644 index 0000000..99d6f39 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteMergeFunctionCut4.dml @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +printAndAssign = function(Matrix[Double] X) return (Matrix[Double] Y) { + if( sum(X) > 0 ) + print("sum(X) = " + sum(X)); + Y = X/2; +} + + +X = matrix(0.5, 600, 10); +while(FALSE){} + +Y = X; +# if the following function is mistakenly merged with +# the previous block, this would create incorrect results +X = printAndAssign(X); +X = Y; + +ssX_V = matrix(0.9, 10, 1); +P = matrix(0.7, 600, 2); +Q = P[,1:1] * (X %*% ssX_V); +R = t(X) %*% (Q - P[,1:1] * rowSums(Q)); + +write(R, $1);
