Repository: systemml Updated Branches: refs/heads/master 5f580f02e -> 4d370a8a6
[SYSTEMML-2374] New simplification rewrite 'fold nary min/max ops' This patch adds a new dynamic rewrite for folding nested binary or nary min/max operations into a single nary min/max operation. Due to limited support for broadcasting this is a dynamic rewrite that is only applied if the dimensions of all involved matrix inputs match. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8d320791 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8d320791 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8d320791 Branch: refs/heads/master Commit: 8d320791265321de38050e741308e3243ce89a7b Parents: 5f580f0 Author: Matthias Boehm <[email protected]> Authored: Thu Jun 7 22:23:05 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Jun 7 22:23:05 2018 -0700 ---------------------------------------------------------------------- .../RewriteAlgebraicSimplificationDynamic.java | 57 ++++++++- .../functions/misc/RewriteFoldMinMaxTest.java | 118 +++++++++++++++++++ .../scripts/functions/misc/RewriteFoldMax.dml | 28 +++++ .../scripts/functions/misc/RewriteFoldMin.dml | 28 +++++ .../functions/misc/ZPackageSuite.java | 1 + 5 files changed, 231 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/8d320791/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 81c20e0..062da2f 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -22,6 +22,7 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -37,11 +38,13 @@ import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.Hop.OpOp3; import org.apache.sysml.hops.Hop.OpOp4; +import org.apache.sysml.hops.Hop.OpOpN; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LeftIndexingOp; import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.NaryOp; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.QuaternaryOp; @@ -182,7 +185,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule hi = simplifyWeightedUnaryMM(hop, hi, i); //e.g., X*exp(U%*%t(V)) -> wumm(X, U, t(V), exp) hi = simplifyDotProductSum(hop, hi, i); //e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1 hi = fuseSumSquared(hop, hi, i); //e.g., sum(X^2) -> sumSq(X), if ncol(X)>1 - hi = fuseAxpyBinaryOperationChain(hop, hi, i); //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y) + hi = fuseAxpyBinaryOperationChain(hop, hi, i); //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y) } hi = reorderMinusMatrixMult(hop, hi, i); //e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size hi = simplifySumMatrixMult(hop, hi, i); //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss @@ -191,6 +194,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule hi = simplifyNnzComputation(hop, hi, i); //e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known hi = simplifyNrowNcolComputation(hop, hi, i); //e.g., nrow(X) -> literal(nrow(X)), if nrow known to remove data dependency hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) + if( OptimizerUtils.ALLOW_OPERATOR_FUSION ) + foldMultipleMinMaxOperations(hi); //e.g., min(X,min(min(3,7),Y)) -> min(X,3,7,Y) //process childs recursively after rewrites (to investigate pattern newly created by rewrites) if( !descendFirst ) @@ -2584,4 +2589,54 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule return hi; } + + private static Hop foldMultipleMinMaxOperations(Hop hi) + { + if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX) + || HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX)) + && !OptimizerUtils.isHadoopExecutionMode() ) + { + OpOp2 bop = (hi instanceof BinaryOp) ? ((BinaryOp)hi).getOp() : + OpOp2.valueOf(((NaryOp)hi).getOp().name()); + OpOpN nop = (hi instanceof NaryOp) ? ((NaryOp)hi).getOp() : + OpOpN.valueOf(((BinaryOp)hi).getOp().name()); + + boolean converged = false; + while( !converged ) { + //get first matching min/max + Hop first = hi.getInput().stream() + .filter(h -> HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop)) + .findFirst().orElse(null); + + //replace current op with new nary min/max + final Hop lhi = hi; + if( first != null && first.getParent().size()==1 + && first.getInput().stream().allMatch(c -> c.getDataType()==DataType.SCALAR + || HopRewriteUtils.isEqualSize(lhi, c))) { + //construct new list of inputs (in original order) + ArrayList<Hop> linputs = new ArrayList<>(); + for(Hop in : hi.getInput()) + if( in == first ) + linputs.addAll(first.getInput()); + else + linputs.add(in); + Hop hnew = HopRewriteUtils.createNary(nop, linputs.toArray(new Hop[0])); + //clear dangling references + HopRewriteUtils.removeAllChildReferences(hi); + HopRewriteUtils.removeAllChildReferences(first); + //rewire all parents (avoid anomalies with refs to hi) + List<Hop> parents = new ArrayList<>(hi.getParent()); + for( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, hnew); + hi = hnew; + LOG.debug("Applied foldMultipleMinMaxOperations (line "+hi.getBeginLine()+")."); + } + else { + converged = true; + } + } + } + + return hi; + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/8d320791/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldMinMaxTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldMinMaxTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldMinMaxTest.java new file mode 100644 index 0000000..65c2a3e --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldMinMaxTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; + +public class RewriteFoldMinMaxTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewriteFoldMin"; + private static final String TEST_NAME2 = "RewriteFoldMax"; + + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFoldMinMaxTest.class.getSimpleName() + "/"; + + private static final int rows = 1932; + private static final int cols = 14; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + } + + @Test + public void testRewriteFoldMinNoRewrite() { + testRewriteFoldMinMax( TEST_NAME1, false, ExecType.CP ); + } + + @Test + public void testRewriteFoldMinRewrite() { + testRewriteFoldMinMax( TEST_NAME1, true, ExecType.CP ); + } + + @Test + public void testRewriteFoldMaxNoRewrite() { + testRewriteFoldMinMax( TEST_NAME2, false, ExecType.CP ); + } + + @Test + public void testRewriteFoldMaxRewrite() { + testRewriteFoldMinMax( TEST_NAME2, true, ExecType.CP ); + } + + private void testRewriteFoldMinMax( String testname, boolean rewrites, ExecType et ) + { + RUNTIME_PLATFORM platformOld = rtplatform; + switch( et ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK || rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{ "-stats", "-args", String.valueOf(rows), + String.valueOf(cols), output("R") }; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + //run performance tests + runTest(true, false, null, -1); + + //compare matrices + Double ret = readDMLMatrixFromHDFS("R").get(new CellIndex(1,1)); + Assert.assertEquals("Wrong result", new Double(5*rows*cols), ret); + + //check for applied rewrites + if( rewrites ) { + Assert.assertTrue(!heavyHittersContainsString("min") && !heavyHittersContainsString("max") + && (!testname.equals(TEST_NAME1) || Statistics.getCPHeavyHitterCount("nmin") == 1) + && (!testname.equals(TEST_NAME2) || Statistics.getCPHeavyHitterCount("nmax") == 1)); + } + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + rtplatform = platformOld; + } + } +} + http://git-wip-us.apache.org/repos/asf/systemml/blob/8d320791/src/test/scripts/functions/misc/RewriteFoldMax.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFoldMax.dml b/src/test/scripts/functions/misc/RewriteFoldMax.dml new file mode 100644 index 0000000..c5117c8 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFoldMax.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, $1, $2) +while(FALSE){} +Y = max(X-7,max(max(X-5,-7),5)) +while(FALSE){} +R = as.matrix(sum(Y)) + +write(R, $3); http://git-wip-us.apache.org/repos/asf/systemml/blob/8d320791/src/test/scripts/functions/misc/RewriteFoldMin.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFoldMin.dml b/src/test/scripts/functions/misc/RewriteFoldMin.dml new file mode 100644 index 0000000..7919d9b --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFoldMin.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, $1, $2) +while(FALSE){} +Y = min(X+7,min(min(X+5,7),5)) +while(FALSE){} +R = as.matrix(sum(Y)) + +write(R, $3); http://git-wip-us.apache.org/repos/asf/systemml/blob/8d320791/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index e2c7bf1..75e9970 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -61,6 +61,7 @@ import org.junit.runners.Suite; RewriteCTableToRExpandTest.class, RewriteElementwiseMultChainOptimizationTest.class, RewriteEliminateAggregatesTest.class, + RewriteFoldMinMaxTest.class, RewriteFoldRCBindTest.class, RewriteFuseBinaryOpChainTest.class, RewriteFusedRandTest.class,
