Repository: incubator-systemml Updated Branches: refs/heads/master dea42de1f -> cf4e5ab6e
[SYSTEMML-765] New rewrite 'pushdown sum binary mult', tests Closes #173 Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/cf4e5ab6 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/cf4e5ab6 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/cf4e5ab6 Branch: refs/heads/master Commit: cf4e5ab6e11273a8a468b669c22d76f55fc43f12 Parents: dea42de Author: tgamal <[email protected]> Authored: Wed Jun 8 19:22:01 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Jun 8 19:22:01 2016 -0700 ---------------------------------------------------------------------- .../RewriteAlgebraicSimplificationStatic.java | 45 ++++++- .../misc/RewritePushdownSumBinaryMult.java | 126 +++++++++++++++++++ .../misc/RewritePushdownSumBinaryMult.R | 23 ++++ .../misc/RewritePushdownSumBinaryMult.dml | 26 ++++ .../misc/RewritePushdownSumBinaryMult2.R | 24 ++++ .../misc/RewritePushdownSumBinaryMult2.dml | 26 ++++ 6 files changed, 266 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index fff9310..c36c01f 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -145,7 +145,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v) hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X) hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X)) - hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor + hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lamda*X) -> lamda*sum(X) + hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B); hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X) hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y)); @@ -161,8 +162,10 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) - //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) + + //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) + //process childs recursively after rewrites (to investigate pattern newly created by rewrites) if( !descendFirst ) rule_AlgebraicSimplification(hi, descendFirst); @@ -928,7 +931,42 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } - + /** + * + * @param parent + * @param hi + * @param pos + * @return + * @throws HopsException + */ + private Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) throws HopsException { + //pattern: sum(lamda*X) -> lamda*sum(X) + if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol + && ((AggUnaryOp)hi).getOp()==Hop.AggOp.SUM + && ((AggUnaryOp)hi).getInput().get(0) instanceof BinaryOp + && ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.MULT + && hi.getInput().get(0).getParent().size() == 1 // only one parent which is the sum + && ((hi.getInput().get(0).getInput().get(0).getDataType()==DataType.SCALAR && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.MATRIX) + ||(hi.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.SCALAR))) + { + Hop operand1 = hi.getInput().get(0).getInput().get(0); + Hop operand2 = hi.getInput().get(0).getInput().get(1); + + //check which operand is the Scalar and which is the matrix + Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? operand1 : operand2; + Hop matrix = (operand1.getDataType()==DataType.MATRIX) ? operand1 : operand2; + + AggUnaryOp aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.RowCol); + Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT); + + HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); + HopRewriteUtils.addChildReference(parent, bop, pos); + + LOG.debug("Applied pushdownSumBinaryMult."); + return bop; + } + return hi; + } /** * * @param parent @@ -1870,5 +1908,4 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } - } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumBinaryMult.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumBinaryMult.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumBinaryMult.java new file mode 100644 index 0000000..9724d1d --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumBinaryMult.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; + +/** + * Regression test for function recompile-once issue with literal replacement. + * + */ +public class RewritePushdownSumBinaryMult extends AutomatedTestBase +{ + + private static final String TEST_NAME1 = "RewritePushdownSumBinaryMult"; + private static final String TEST_NAME2 = "RewritePushdownSumBinaryMult2"; + + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownSumBinaryMult.class.getSimpleName() + "/"; + + //private static final int rows = 1234; + //private static final int cols = 567; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() + { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + } + + @Test + public void testPushdownSumBinaryMultNoRewrite() + { + testRewritePushdownSumBinaryMult( TEST_NAME1, false ); + } + + + @Test + public void testPushdownSumBinaryMultRewrite() + { + testRewritePushdownSumBinaryMult( TEST_NAME1, true ); + } + + + @Test + public void testPushdownSumBinaryMultNoRewrite2() + { + testRewritePushdownSumBinaryMult( TEST_NAME2, false ); + } + + @Test + public void testPushdownSumBinaryMultRewrite2() + { + testRewritePushdownSumBinaryMult( TEST_NAME2, true ); + } + + + /** + * + * @param condition + * @param branchRemoval + * @param IPA + */ + private void testRewritePushdownSumBinaryMult( String testname, boolean rewrites ) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{ "-stats","-args", output("Scalar") }; + + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + runRScript(true); + + //compare scalars + HashMap<CellIndex, Double> dmlfile = readDMLScalarFromHDFS("Scalar"); + HashMap<CellIndex, Double> rfile = readRScalarFromFS("Scalar"); + TestUtils.compareScalars(dmlfile.toString(), rfile.toString()); + System.out.println("Test case passed"); + + } + finally + { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.R b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.R new file mode 100644 index 0000000..48a000b --- /dev/null +++ b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.R @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X=matrix(1,10,10) +lamda=sum(X) +args<-commandArgs(TRUE) +write(sum(lamda*X),paste(args[2],"Scalar",sep="")) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.dml b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.dml new file mode 100644 index 0000000..9850242 --- /dev/null +++ b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X=matrix(1,10,10) +if(1==1){} +lamda=sum(X) +y=sum(lamda*X) +write(y, $1) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.R b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.R new file mode 100644 index 0000000..09a0910 --- /dev/null +++ b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.R @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +X=matrix(1,10,10) +lamda=sum(X) +args<-commandArgs(TRUE) +write(sum(X*lamda),paste(args[2],"Scalar",sep="")) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.dml b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.dml new file mode 100644 index 0000000..07e0e54 --- /dev/null +++ b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X=matrix(1,10,10) +if(1==1){} +lamda=sum(X) +y=sum(X*lamda) +write(y, $1)
