[SYSTEMML-2507] New rewrites for cumulative aggregate patterns This patch adds the following simplification rewrites as well as related tests: (a) X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri, if X squared (b) colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1)) (c) rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9a1f64b4 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9a1f64b4 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9a1f64b4 Branch: refs/heads/master Commit: 9a1f64b42c177a82a98716ad9ef34d4d266178d2 Parents: b96807b Author: Matthias Boehm <[email protected]> Authored: Tue Dec 11 20:10:23 2018 +0100 Committer: Matthias Boehm <[email protected]> Committed: Tue Dec 11 20:10:46 2018 +0100 ---------------------------------------------------------------------- .../RewriteAlgebraicSimplificationDynamic.java | 33 ++++- .../RewriteAlgebraicSimplificationStatic.java | 45 +++++++ .../hops/rewrite/RewriteGPUSpecificOps.java | 26 ++-- .../misc/RewriteCumulativeAggregatesTest.java | 126 +++++++++++++++++++ .../misc/RewriteCumulativeAggregates.R | 43 +++++++ .../misc/RewriteCumulativeAggregates.dml | 49 ++++++++ 6 files changed, 306 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 36864aa..9556181 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -175,6 +175,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule hi = simplifyMatrixMultDiag(hop, hi, i); //e.g., diag(X)%*%Y -> X*Y, if ncol(Y)==1 / -> Y*X if ncol(Y)>1 hi = simplifyDiagMatrixMult(hop, hi, i); //e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector hi = simplifySumDiagToTrace(hi); //e.g., sum(diag(X)) -> trace(X); if col vector + hi = simplifyLowerTriExtraction(hop, hi, i); //e.g., X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri hi = pushdownBinaryOperationOnDiag(hop, hi, i); //e.g., diag(X)*7 -> diag(X*7); if col vector hi = pushdownSumOnAdditiveBinary(hop, hi, i); //e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B) if(OptimizerUtils.ALLOW_OPERATOR_FUSION) { @@ -1046,7 +1047,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule if( hi instanceof AggUnaryOp ) { AggUnaryOp au = (AggUnaryOp) hi; - if( au.getOp()==AggOp.SUM && au.getDirection()==Direction.RowCol ) //sum + if( au.getOp()==AggOp.SUM && au.getDirection()==Direction.RowCol ) //sum { Hop hi2 = au.getInput().get(0); if( hi2 instanceof ReorgOp && ((ReorgOp)hi2).getOp()==ReOrgOp.DIAG && hi2.getDim2()==1 ) //diagM2V @@ -1054,7 +1055,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule Hop hi3 = hi2.getInput().get(0); //remove diag operator - HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0); + HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0); HopRewriteUtils.cleanupUnreferenced(hi2); //change sum to trace @@ -1063,12 +1064,38 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule LOG.debug("Applied simplifySumDiagToTrace"); } } - } return hi; } + private static Hop simplifyLowerTriExtraction(Hop parent, Hop hi, int pos) { + //pattern: X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri (only right) + if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) + && hi.getDim1() == hi.getDim2() && hi.getDim1() > 1 ) { + Hop left = hi.getInput().get(0); + Hop right = hi.getInput().get(1); + + if( HopRewriteUtils.isUnary(right, OpOp1.CUMSUM) && right.getParent().size()==1 + && HopRewriteUtils.isReorg(right.getInput().get(0), ReOrgOp.DIAG) + && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0).getInput().get(0), 1d)) + { + LinkedHashMap<String,Hop> args = new LinkedHashMap<>(); + args.put("target", left); + args.put("diag", new LiteralOp(true)); + args.put("values", new LiteralOp(true)); + Hop hnew = HopRewriteUtils.createParameterizedBuiltinOp( + left, args, ParamBuiltinOp.LOWER_TRI); + HopRewriteUtils.replaceChildReference(parent, hi, hnew); + HopRewriteUtils.removeAllChildReferences(right); + + hi = hnew; + LOG.debug("Applied simplifyLowerTriExtraction"); + } + } + return hi; + } + @SuppressWarnings("unchecked") private static Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi, int pos) { http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 62a5d4f..9a3956c 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -183,6 +183,9 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule } hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) hi = simplifyBinaryComparisonChain(hop, hi, i); //e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> outer(v1,v2,"!="), + hi = simplifyCumsumColOrFullAggregates(hi); //e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1)) + hi = simplifyCumsumReverse(hop, hi, i); //e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X) + //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) @@ -1844,6 +1847,48 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } + private static Hop simplifyCumsumColOrFullAggregates(Hop hi) { + //pattern: colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1)) + if( (HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.Col) + || HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol)) + && HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM) + && hi.getInput().get(0).getParent().size()==1) + { + Hop cumsumX = hi.getInput().get(0); + Hop X = cumsumX.getInput().get(0); + Hop mult = HopRewriteUtils.createBinary(X, + HopRewriteUtils.createSeqDataGenOp(X, false), OpOp2.MULT); + HopRewriteUtils.replaceChildReference(hi, cumsumX, mult); + HopRewriteUtils.removeAllChildReferences(cumsumX); + LOG.debug("Applied simplifyCumsumColOrFullAggregates (line "+hi.getBeginLine()+")"); + } + return hi; + } + + private static Hop simplifyCumsumReverse(Hop parent, Hop hi, int pos) { + //pattern: rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X) + if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV) + && HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM) + && hi.getInput().get(0).getParent().size()==1 + && HopRewriteUtils.isReorg(hi.getInput().get(0).getInput().get(0), ReOrgOp.REV) + && hi.getInput().get(0).getInput().get(0).getParent().size()==1) + { + Hop cumsumX = hi.getInput().get(0); + Hop revX = cumsumX.getInput().get(0); + Hop X = revX.getInput().get(0); + Hop plus = HopRewriteUtils.createBinary(X, HopRewriteUtils + .createAggUnaryOp(X, AggOp.SUM, Direction.Col), OpOp2.PLUS); + Hop minus = HopRewriteUtils.createBinary(plus, + HopRewriteUtils.createUnary(X, OpOp1.CUMSUM), OpOp2.MINUS); + HopRewriteUtils.replaceChildReference(parent, hi, minus, pos); + HopRewriteUtils.cleanupUnreferenced(hi, cumsumX, revX); + + hi = minus; + LOG.debug("Applied simplifyCumsumReverse (line "+hi.getBeginLine()+")"); + } + return hi; + } + /** * NOTE: currently disabled since this rewrite is INVALID in the * presence of NaNs (because (NaN!=NaN) is true). http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java index ab40d7b..1d87c09 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java @@ -176,19 +176,19 @@ public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher { // norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win) // # Compute gradients during training // dgamma = util::channel_sums(dout*norm, C, Hin, Win) - private static final HopDagPatternMatcher _batchNormDGamma; - static { - _batchNormDGamma = util_channel_sums( - mult( leaf("dout", MATRIX).fitsOnGPU(3), - bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", MATRIX))), - leaf("ema_var", MATRIX))), leaf("C", SCALAR), leaf("HW", SCALAR)); - } - private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi -> { - LOG.debug("Applied batchNormDGamma rewrite."); - Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, - "ema_mean", "dout", "X", "ema_var"); - return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); - }; +// private static final HopDagPatternMatcher _batchNormDGamma; +// static { +// _batchNormDGamma = util_channel_sums( +// mult( leaf("dout", MATRIX).fitsOnGPU(3), +// bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", MATRIX))), +// leaf("ema_var", MATRIX))), leaf("C", SCALAR), leaf("HW", SCALAR)); +// } +// private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi -> { +// LOG.debug("Applied batchNormDGamma rewrite."); +// Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, +// "ema_mean", "dout", "X", "ema_var"); +// return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); +// }; // Pattern 3: private static final HopDagPatternMatcher _batchNormTest; http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java new file mode 100644 index 0000000..da13502 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +public class RewriteCumulativeAggregatesTest extends AutomatedTestBase +{ + private static final String TEST_NAME = "RewriteCumulativeAggregates"; + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteCumulativeAggregatesTest.class.getSimpleName() + "/"; + + private static final int rows = 1234; + private static final int cols = 7; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "R" }) ); + } + + @Test + public void testCumAggRewrite1False() { + testCumAggRewrite(1, false); + } + + @Test + public void testCumAggRewrite1True() { + testCumAggRewrite(1, true); + } + + @Test + public void testCumAggRewrite2False() { + testCumAggRewrite(2, false); + } + + @Test + public void testCumAggRewrite2True() { + testCumAggRewrite(2, true); + } + + @Test + public void testCumAggRewrite3False() { + testCumAggRewrite(3, false); + } + + @Test + public void testCumAggRewrite3True() { + testCumAggRewrite(3, true); + } + + @Test + public void testCumAggRewrite4False() { + testCumAggRewrite(4, false); + } + + @Test + public void testCumAggRewrite4True() { + testCumAggRewrite(4, true); + } + + private void testCumAggRewrite(int num, boolean rewrites) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{ "-stats", "-args", + input("A"), String.valueOf(num), output("R") }; + rCmd = getRCmd(inputDir(), String.valueOf(num), expectedDir()); + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + //generate input data + double[][] A = getRandomMatrix((num==4)?1:rows, + (num==1)?rows:cols, -1, 1, 0.9, 7); + writeInputMatrixWithMTD("A", A, true); + + //run performance tests + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, 1e-7, "Stat-DML", "Stat-R"); + + //check applied rewrites + if( rewrites ) + Assert.assertTrue(!heavyHittersContainsString((num==2) ? "rev" : "ucumk+")); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R new file mode 100644 index 0000000..f8a8576 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +num = as.integer(args[2]); + +#note: cumsum and rev only over vectors +if( num == 1 ) { + R = lower.tri(X,diag=TRUE) * X; +} else if( num == 2 ) { + A = X[seq(nrow(X),1),] + R = apply(A, 2, cumsum); + R = R[seq(nrow(X),1),] +} else if( num == 3 ) { + R = t(as.matrix(colSums(apply(X, 2, cumsum)))); +} else if( num == 4 ) { + R = X; +} + +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml new file mode 100644 index 0000000..f4c3486 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml @@ -0,0 +1,49 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +foo = function( Matrix[Double] A ) return( Matrix[Double] B ) +{ + for( i in 1:1 ) { + continue = TRUE; + if( sum(A)<0 ) { + continue = FALSE; + } + iter = 0; + if( continue ) { + iter = iter+1; + } + B = A+iter; + } +} + +X = read($1); + +if( $2 == 1 ) + R = X * cumsum(diag(matrix(1,nrow(X),1))); +else if( $2 == 2 ) + R = rev(cumsum(rev(X))); +else if( $2 == 3 ) + R = colSums(cumsum(X)); +else if( $2 == 4 ) + R = cumsum(X); + +write(R, $3); \ No newline at end of file
