[SYSTEMML-766] Improved 'fuse axpy' rewrite (more patterns, no overlap) Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/973b8635 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/973b8635 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/973b8635
Branch: refs/heads/master Commit: 973b863579d7bf82505933d3d67fef4517c53eb3 Parents: b233b59 Author: Matthias Boehm <[email protected]> Authored: Wed Jul 20 22:34:46 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Jul 21 12:54:15 2016 -0700 ---------------------------------------------------------------------- .../sysml/hops/rewrite/HopRewriteUtils.java | 29 +++++++++ .../RewriteAlgebraicSimplificationDynamic.java | 65 ++++++++++++++++++++ .../RewriteAlgebraicSimplificationStatic.java | 41 ------------ .../misc/RewriteFuseBinaryOpChainTest.java | 40 ++++++++++-- .../misc/RewriteFuseBinaryOpChainTest3.R | 28 +++++++++ .../misc/RewriteFuseBinaryOpChainTest3.dml | 27 ++++++++ 6 files changed, 184 insertions(+), 46 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index a5432f1..385a888 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -36,6 +36,7 @@ import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.FileFormatTypes; import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.Hop.OpOp3; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.Hop.VisitStatus; @@ -45,6 +46,7 @@ import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.MemoTable; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.ReorgOp; +import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.parser.DataExpression; @@ -644,6 +646,22 @@ public class HopRewriteUtils return datagen; } + /** + * + * @param mleft + * @param smid + * @param mright + * @param op + * @return + */ + public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop mright, OpOp3 op) { + TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, mleft, smid, mright); + ternOp.setRowsInBlock(mleft.getRowsInBlock()); + ternOp.setColsInBlock(mleft.getColsInBlock()); + ternOp.refreshSizeInformation(); + return ternOp; + } + public static void setOutputBlocksizes( Hop hop, long brlen, long bclen ) { hop.setRowsInBlock( brlen ); @@ -878,6 +896,17 @@ public class HopRewriteUtils * @param hop * @return */ + public static boolean isScalarMatrixBinaryMult( Hop hop ) { + return hop instanceof BinaryOp && ((BinaryOp)hop).getOp()==OpOp2.MULT + && ((hop.getInput().get(0).getDataType()==DataType.SCALAR && hop.getInput().get(1).getDataType()==DataType.MATRIX) + || (hop.getInput().get(0).getDataType()==DataType.MATRIX && hop.getInput().get(1).getDataType()==DataType.SCALAR)); + } + + /** + * + * @param hop + * @return + */ public static boolean isBasic1NSequence(Hop hop) { boolean ret = false; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 8205e83..dbde506 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -35,6 +35,7 @@ import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.DataGenMethod; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp1; +import org.apache.sysml.hops.Hop.OpOp3; import org.apache.sysml.hops.Hop.OpOp4; import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.HopsException; @@ -174,6 +175,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule hi = simplifyWeightedUnaryMM(hop, hi, i); //e.g., X*exp(U%*%t(V)) -> wumm(X, U, t(V), exp) hi = simplifyDotProductSum(hop, hi, i); //e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1 hi = fuseSumSquared(hop, hi, i); //e.g., sum(X^2) -> sumSq(X), if ncol(X)>1 + hi = fuseAxpyBinaryOperationChain(hop, hi, i); //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y) hi = reorderMinusMatrixMult(hop, hi, i); //e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size hi = simplifySumMatrixMult(hop, hi, i); //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss hi = simplifyEmptyBinaryOperation(hop, hi, i); //e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X @@ -2458,6 +2460,69 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule return hi; } + + /** + * + * @param parent + * @param hi + * @param pos + * @return + * @throws HopsException + */ + private Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) + { + //patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY + if( hi instanceof BinaryOp + && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) ) + { + BinaryOp bop = (BinaryOp) hi; + Hop left = bop.getInput().get(0); + Hop right = bop.getInput().get(1); + Hop ternop = null; + + //pattern (a) X + s*Y -> X +* sY + if( bop.getOp() == OpOp2.PLUS && left.getDataType()==DataType.MATRIX + && HopRewriteUtils.isScalarMatrixBinaryMult(right) + && right.getParent().size() == 1 ) //single consumer s*Y + { + Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); + Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); + ternop = HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT); + LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " +hi.getBeginLine()+")"); + } + //pattern (b) s*Y + X -> X +* sY + else if( bop.getOp() == OpOp2.PLUS && right.getDataType()==DataType.MATRIX + && HopRewriteUtils.isScalarMatrixBinaryMult(left) + && left.getParent().size() == 1 //single consumer s*Y + && HopRewriteUtils.isEqualSize(left, right)) //correctness matrix-vector + { + Hop smid = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); + Hop mright = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); + ternop = HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT); + LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")"); + } + //pattern (c) X - s*Y -> X -* sY + else if( bop.getOp() == OpOp2.MINUS && left.getDataType()==DataType.MATRIX + && HopRewriteUtils.isScalarMatrixBinaryMult(right) + && right.getParent().size() == 1 ) //single consumer s*Y + { + Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); + Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); + ternop = HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT); + LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " +hi.getBeginLine()+")"); + } + + //rewire parent-child operators if rewrite applied + if( ternop != null ) { + HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); + HopRewriteUtils.addChildReference(parent, ternop, pos); + hi = ternop; + } + } + + return hi; + } + /** * * @param parent http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 9ef2c05..ae9c073 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -162,7 +162,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) - hi = fuseBinaryOperationChain(hop, hi, i); //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y) //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) //process childs recursively after rewrites (to investigate pattern newly created by rewrites) @@ -1906,44 +1905,4 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } - - /** - * - * @param parent - * @param hi - * @param pos - * @return - * @throws HopsException - */ - private Hop fuseBinaryOperationChain(Hop parent, Hop hi, int pos) { - //pattern: X + lamda*Y -> X +* lambda Y - if( hi instanceof BinaryOp - && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) - && hi.getInput().get(0).getDataType()==DataType.MATRIX - && hi.getInput().get(1) instanceof BinaryOp - && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.MULT ) - { - //Check that the inner binary Op is a product of Scalar times Matrix or viceversa - Hop innerBinaryOp = hi.getInput().get(1); - if ( (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR && innerBinaryOp.getInput().get(1).getDataType()==DataType.MATRIX) - || (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX && innerBinaryOp.getInput().get(1).getDataType()==DataType.SCALAR)) - { - //check which operand is the Scalar and which is the matrix - Hop lamda = (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1); - Hop matrix = (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1); - - OpOp3 op = (((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT; - TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, hi.getInput().get(0), lamda, matrix); - HopRewriteUtils.refreshOutputParameters(ternOp, hi.getInput().get(0)); - - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, ternOp, pos); - - LOG.debug("Applied fuseBinaryOperationChain. (line " +hi.getBeginLine()+")"); - return ternOp; - } - } - - return hi; - } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java index 890a3b2..ff85ebc 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java @@ -40,8 +40,9 @@ import org.apache.sysml.utils.Statistics; */ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase { - private static final String TEST_NAME1 = "RewriteFuseBinaryOpChainTest1"; - private static final String TEST_NAME2 = "RewriteFuseBinaryOpChainTest2"; + private static final String TEST_NAME1 = "RewriteFuseBinaryOpChainTest1"; //+* (X+s*Y) + private static final String TEST_NAME2 = "RewriteFuseBinaryOpChainTest2"; //-* (X-s*Y) + private static final String TEST_NAME3 = "RewriteFuseBinaryOpChainTest3"; //+* (s*Y+X) private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFuseBinaryOpChainTest.class.getSimpleName() + "/"; @@ -53,6 +54,7 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase TestUtils.clearAssertionInformation(); addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) ); } @Test @@ -60,7 +62,6 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase testFuseBinaryChain( TEST_NAME1, false, ExecType.CP ); } - @Test public void testFuseBinaryPlusRewriteCP() { testFuseBinaryChain( TEST_NAME1, true, ExecType.CP); @@ -77,6 +78,16 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase } @Test + public void testFuseBinaryPlus2NoRewriteCP() { + testFuseBinaryChain( TEST_NAME3, false, ExecType.CP ); + } + + @Test + public void testFuseBinaryPlus2RewriteCP() { + testFuseBinaryChain( TEST_NAME3, true, ExecType.CP ); + } + + @Test public void testFuseBinaryPlusNoRewriteSP() { testFuseBinaryChain( TEST_NAME1, false, ExecType.SPARK ); } @@ -97,6 +108,16 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase } @Test + public void testFuseBinaryPlus2NoRewriteSP() { + testFuseBinaryChain( TEST_NAME3, false, ExecType.SPARK ); + } + + @Test + public void testFuseBinaryPlus2RewriteSP() { + testFuseBinaryChain( TEST_NAME3, true, ExecType.SPARK ); + } + + @Test public void testFuseBinaryPlusNoRewriteMR() { testFuseBinaryChain( TEST_NAME1, false, ExecType.MR ); } @@ -116,6 +137,15 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase testFuseBinaryChain( TEST_NAME2, true, ExecType.MR ); } + @Test + public void testFuseBinaryPlus2NoRewriteMR() { + testFuseBinaryChain( TEST_NAME3, false, ExecType.MR ); + } + + @Test + public void testFuseBinaryPlus2RewriteMR() { + testFuseBinaryChain( TEST_NAME3, true, ExecType.MR ); + } /** * @@ -162,8 +192,8 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase //check for applies rewrites if( rewrites && instType!=ExecType.MR ) { String prefix = (instType==ExecType.SPARK) ? Instruction.SP_INST_PREFIX : ""; - Assert.assertTrue("Rewrite not applied.",Statistics.getCPHeavyHitterOpCodes() - .contains(testname.equals(TEST_NAME1) ? prefix+"+*" : prefix+"-*" )); + String opcode = (testname.equals(TEST_NAME1)||testname.equals(TEST_NAME3)) ? prefix+"+*" : prefix+"-*"; + Assert.assertTrue("Rewrite not applied.",Statistics.getCPHeavyHitterOpCodes().contains(opcode)); } } finally http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R new file mode 100644 index 0000000..5ae1642 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X=matrix(1,10,10) +Y=matrix(1,10,10) +lamda=7 +S=lamda*Y+X +writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml new file mode 100644 index 0000000..af84884 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X=matrix(1,rows=10,cols=10) +Y=matrix(1,rows=10,cols=10) +if(1==1){} +lamda=7 +S=lamda*Y+X +write(S,$1) \ No newline at end of file
