[SYSTEMML-1659] New simplification rewrite 'aggregate elimination' This new static algebraic simplification rewrite removes unnecessary row- or column-wise aggregates which are directly fed into a full row/column aggregate. For example, we now rewrite sum(rowSums(X)), as it appears in nn-cross_entropy_loss::forward, to sum(X).
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/50d211ba Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/50d211ba Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/50d211ba Branch: refs/heads/master Commit: 50d211baa91e6a74b32cd8c1780758608d33c7c8 Parents: a68648d Author: Matthias Boehm <[email protected]> Authored: Fri Jun 2 21:47:59 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jun 3 10:48:31 2017 -0700 ---------------------------------------------------------------------- .../RewriteAlgebraicSimplificationStatic.java | 28 ++++ .../test/integration/AutomatedTestBase.java | 2 +- .../misc/RewriteEliminateAggregatesTest.java | 136 +++++++++++++++++++ .../functions/misc/RewriteEliminateAggregate.R | 41 ++++++ .../misc/RewriteEliminateAggregate.dml | 41 ++++++ .../functions/misc/ZPackageSuite.java | 1 + 6 files changed, 248 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index a3db317..74f5488 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -148,6 +148,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v) hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X) + hi = removeUnnecessaryAggregates(hi); //e.g., sum(rowSums(X)) -> sum(X) hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s; hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X)) hi = pushdownCSETransposeScalarOperation(hop, hi, i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) @@ -817,6 +818,33 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } + private Hop removeUnnecessaryAggregates(Hop hi) + { + //sum(rowSums(X)) -> sum(X), sum(colSums(X)) -> sum(X) + //min(rowMins(X)) -> min(X), min(colMins(X)) -> min(X) + //max(rowMaxs(X)) -> max(X), max(colMaxs(X)) -> max(X) + //sum(rowSums(X^2)) -> sum(X), sum(colSums(X^2)) -> sum(X) + if( hi instanceof AggUnaryOp && hi.getInput().get(0) instanceof AggUnaryOp + && ((AggUnaryOp)hi).getDirection()==Direction.RowCol + && hi.getInput().get(0).getParent().size()==1 ) + { + AggUnaryOp au1 = (AggUnaryOp) hi; + AggUnaryOp au2 = (AggUnaryOp) hi.getInput().get(0); + if( (au1.getOp()==AggOp.SUM && (au2.getOp()==AggOp.SUM || au2.getOp()==AggOp.SUM_SQ)) + || (au1.getOp()==AggOp.MIN && au2.getOp()==AggOp.MIN) + || (au1.getOp()==AggOp.MAX && au2.getOp()==AggOp.MAX) ) + { + Hop input = au2.getInput().get(0); + HopRewriteUtils.removeAllChildReferences(au2); + HopRewriteUtils.replaceChildReference(au1, au2, input); + + LOG.debug("Applied removeUnnecessaryAggregates (line "+hi.getBeginLine()+")."); + } + } + + return hi; + } + private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos ) throws HopsException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java index 0e56655..7b93211 100644 --- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java @@ -1818,7 +1818,7 @@ public abstract class AutomatedTestBase for( String opcode : Statistics.getCPHeavyHitterOpCodes()) for( String s : str ) if(opcode.contains(s)) - return true; + return true; return false; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java new file mode 100644 index 0000000..741ef31 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +public class RewriteEliminateAggregatesTest extends AutomatedTestBase +{ + private static final String TEST_NAME = "RewriteEliminateAggregate"; + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEliminateAggregatesTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "R" }) ); + } + + @Test + public void testEliminateSumSumNoRewrite() { + testRewriteEliminateAggregate(1, false); + } + + @Test + public void testEliminateMinMinNoRewrite() { + testRewriteEliminateAggregate(2, false); + } + + @Test + public void testEliminateMaxMaxNoRewrite() { + testRewriteEliminateAggregate(3, false); + } + + @Test + public void testEliminateSumSqSumNoRewrite() { + testRewriteEliminateAggregate(4, false); + } + + @Test + public void testEliminateMinSumNoRewrite() { + testRewriteEliminateAggregate(5, false); + } + + @Test + public void testEliminateSumSumRewrite() { + testRewriteEliminateAggregate(1, true); + } + + @Test + public void testEliminateMinMinRewrite() { + testRewriteEliminateAggregate(2, true); + } + + @Test + public void testEliminateMaxMaxRewrite() { + testRewriteEliminateAggregate(3, true); + } + + @Test + public void testEliminateSumSqSumRewrite() { + testRewriteEliminateAggregate(4, true); + } + + @Test + public void testEliminateMinSumRewrite() { + testRewriteEliminateAggregate(5, true); + } + + private void testRewriteEliminateAggregate(int type, boolean rewrites) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try + { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{ "-stats","-args", + input("A"), String.valueOf(type), output("Scalar") }; + + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(type), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + //generate actual dataset + double[][] A = getRandomMatrix(123, 12, 0, 1, 0.9, 7); + writeInputMatrixWithMTD("A", A, true); + + //run test + runTest(true, false, null, -1); + runRScript(true); + + //compare scalars + HashMap<CellIndex, Double> dmlfile = readDMLScalarFromHDFS("Scalar"); + HashMap<CellIndex, Double> rfile = readRScalarFromFS("Scalar"); + TestUtils.compareScalars(dmlfile.toString(), rfile.toString()); + + //check for applied rewrites + if( rewrites ) { + Assert.assertEquals(type==5, + heavyHittersContainsSubString("uar", "uac")); + } + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/test/scripts/functions/misc/RewriteEliminateAggregate.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEliminateAggregate.R b/src/test/scripts/functions/misc/RewriteEliminateAggregate.R new file mode 100644 index 0000000..6848443 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteEliminateAggregate.R @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +type = args[2] + +if( type==1 ) { + agg = sum(rowSums(A)); +} else if( type==2 ) { + agg = min(rowMins(A)); +} else if( type==3 ) { + agg = max(rowMaxs(A)); +} else if( type==4 ) { + agg = sum(rowSums(A^2)); +} else if( type==5 ) { + agg = sum(rowMins(A)); +} + +write(agg, paste(args[3], "Scalar",sep="")) + http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/test/scripts/functions/misc/RewriteEliminateAggregate.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEliminateAggregate.dml b/src/test/scripts/functions/misc/RewriteEliminateAggregate.dml new file mode 100644 index 0000000..e00199d --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteEliminateAggregate.dml @@ -0,0 +1,41 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); +type = $2; + +if( type==1 ) { + agg = sum(rowSums(A)); +} +else if( type==2 ) { + agg = min(rowMins(A)); +} +else if( type==3 ) { + agg = max(rowMaxs(A)); +} +else if( type==4 ) { + agg = sum(rowSums(A^2)); +} +else if( type==5 ) { + agg = sum(rowMins(A)); +} + +write(agg, $3); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index 8a06322..7da786d 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -49,6 +49,7 @@ import org.junit.runners.Suite; ReadAfterWriteTest.class, RewriteCSETransposeScalarTest.class, RewriteCTableToRExpandTest.class, + RewriteEliminateAggregatesTest.class, RewriteFusedRandTest.class, RewriteLoopVectorization.class, RewriteMatrixMultChainOptTest.class,
