Repository: incubator-systemml Updated Branches: refs/heads/master 1385cf1ca -> 201238fd3
[SYSTEMML-1254] New rewrite 'pushdown CSE transpose-scalar', incl tests This new rewrite allows to pushdown a transpose below a matrix-scalar binary operation (except quantile and centralMoment) in order to reuse an existing transpose common subexpression. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/0e6411da Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/0e6411da Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/0e6411da Branch: refs/heads/master Commit: 0e6411dada77870ae29049288b1789313a35a9f6 Parents: 1385cf1 Author: Matthias Boehm <[email protected]> Authored: Tue Feb 14 17:11:59 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Wed Feb 15 10:49:19 2017 -0800 ---------------------------------------------------------------------- .../sysml/hops/rewrite/HopRewriteUtils.java | 19 +++- .../RewriteAlgebraicSimplificationStatic.java | 36 +++++++ .../java/org/apache/sysml/utils/Statistics.java | 7 +- .../misc/RewriteCSETransposeScalarTest.java | 104 +++++++++++++++++++ .../misc/RewriteCSETransposeScalarMult.dml | 31 ++++++ .../misc/RewriteCSETransposeScalarPow.dml | 31 ++++++ .../functions/misc/ZPackageSuite.java | 1 + 7 files changed, 223 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 50501dc..d3be09d 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -711,17 +711,28 @@ public class HopRewriteUtils return ret; } - public static boolean isTransposeOperation(Hop hop) - { + public static boolean isTransposeOperation(Hop hop) { return (hop instanceof ReorgOp && ((ReorgOp)hop).getOp()==ReOrgOp.TRANSPOSE); } - public static boolean isTransposeOfItself(Hop hop1, Hop hop2) - { + public static boolean containsTransposeOperation(ArrayList<Hop> hops) { + boolean ret = false; + for( Hop hop : hops ) + ret |= isTransposeOperation(hop); + return ret; + } + + public static boolean isTransposeOfItself(Hop hop1, Hop hop2) { return hop1 instanceof ReorgOp && ((ReorgOp)hop1).getOp()==ReOrgOp.TRANSPOSE && hop1.getInput().get(0) == hop2 || hop2 instanceof ReorgOp && ((ReorgOp)hop2).getOp()==ReOrgOp.TRANSPOSE && hop2.getInput().get(0) == hop1; } + public static boolean isBinaryMatrixScalarOperation(Hop hop) { + return hop instanceof BinaryOp && + ((hop.getInput().get(0).getDataType().isMatrix() && hop.getInput().get(1).getDataType().isScalar()) + ||(hop.getInput().get(1).getDataType().isMatrix() && hop.getInput().get(0).getDataType().isScalar())); + } + public static boolean isNonZeroIndicator(Hop pred, Hop hop ) { if( pred instanceof BinaryOp && ((BinaryOp)pred).getOp()==OpOp2.NOTEQUAL http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index ad6a4da..41459b4 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -148,6 +148,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X) hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s; hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X)) + hi = pushdownCSETransposeScalarOperation(hop, hi, i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lamda*X) -> lamda*sum(X) hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B); @@ -942,6 +943,41 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } + + private Hop pushdownCSETransposeScalarOperation( Hop parent, Hop hi, int pos ) + { + // a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) + // probed at root node of b in above example + // (with support for left or right scalar operations) + if( HopRewriteUtils.isTransposeOperation(hi) && hi.getParent().size()==1 + && HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput().get(0)) + && hi.getInput().get(0).getParent().size()==1) + { + int Xpos = hi.getInput().get(0).getInput().get(0).getDataType().isMatrix() ? 0 : 1; + Hop X = hi.getInput().get(0).getInput().get(Xpos); + BinaryOp binary = (BinaryOp) hi.getInput().get(0); + + if( HopRewriteUtils.containsTransposeOperation(X.getParent()) + && !HopRewriteUtils.isValidOp(binary.getOp(), new OpOp2[]{OpOp2.CENTRALMOMENT, OpOp2.QUANTILE})) + { + //clear existing wiring + HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); + HopRewriteUtils.removeChildReference(hi, binary); + HopRewriteUtils.removeChildReference(binary, X); + + //re-wire operators + HopRewriteUtils.addChildReference(parent, binary, pos); + HopRewriteUtils.addChildReference(binary, hi, Xpos); + HopRewriteUtils.addChildReference(hi, X); + //note: common subexpression later eliminated by dedicated rewrite + + hi = binary; + LOG.debug("Applied pushdownCSETransposeScalarOperation (line "+hi.getBeginLine()+")."); + } + } + + return hi; + } private Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) throws HopsException { //pattern: sum(lamda*X) -> lamda*sum(X) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index cf9b5fb..e371d9c 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -464,10 +464,13 @@ public class Statistics _cpInstCounts.put(key, newCnt); } - public static Set<String> getCPHeavyHitterOpCodes() - { + public static Set<String> getCPHeavyHitterOpCodes() { return _cpInstTime.keySet(); } + + public static long getCPHeavyHitterCount(String opcode) { + return _cpInstCounts.get(opcode); + } @SuppressWarnings("unchecked") public static String getHeavyHitters( int num ) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCSETransposeScalarTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCSETransposeScalarTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCSETransposeScalarTest.java new file mode 100644 index 0000000..61daf38 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCSETransposeScalarTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; + +/** + * + * + */ +public class RewriteCSETransposeScalarTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewriteCSETransposeScalarPow"; //right scalar + private static final String TEST_NAME2 = "RewriteCSETransposeScalarMult"; //left scalar + + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteCSETransposeScalarTest.class.getSimpleName() + "/"; + + private static final int rows = 1932; + private static final int cols = 14; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + } + + @Test + public void testRewriteCSETransposePow() { + testRewriteCSETransposeScalar( TEST_NAME1, true ); + } + + @Test + public void testRewriteCSETransposePowNoRewrite() { + testRewriteCSETransposeScalar( TEST_NAME1, false ); + } + + @Test + public void testRewriteCSETransposeMult() { + testRewriteCSETransposeScalar( TEST_NAME2, true ); + } + + @Test + public void testRewriteCSETransposeMultNoRewrite() { + testRewriteCSETransposeScalar( TEST_NAME2, false ); + } + + /** + * + * @param testname + * @param et + */ + private void testRewriteCSETransposeScalar( String testname, boolean rewrites ) + { + boolean rewritesOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{ "-stats", "-args", String.valueOf(rows), + String.valueOf(cols), output("R") }; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + //run performance tests + runTest(true, false, null, -1); + + //compare output + double ret = TestUtils.readDMLScalar(output("R")); + Assert.assertEquals("Wrong result, expected: "+(rows*cols), new Double(rows*cols), new Double(ret)); + Assert.assertEquals(new Long(rewrites?1:2), new Long(Statistics.getCPHeavyHitterCount("r'"))); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewritesOld; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/test/scripts/functions/misc/RewriteCSETransposeScalarMult.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteCSETransposeScalarMult.dml b/src/test/scripts/functions/misc/RewriteCSETransposeScalarMult.dml new file mode 100644 index 0000000..07e67dc --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteCSETransposeScalarMult.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rand(rows=$1, cols=$2, min=1, max=10); +if(1==1){} + +a = t(X); +b = t(2*X); + +if(1==1){} + +R = sum(2*a == b); +write(R, $3); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/test/scripts/functions/misc/RewriteCSETransposeScalarPow.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteCSETransposeScalarPow.dml b/src/test/scripts/functions/misc/RewriteCSETransposeScalarPow.dml new file mode 100644 index 0000000..f47c227 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteCSETransposeScalarPow.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rand(rows=$1, cols=$2, min=1, max=10); +if(1==1){} + +a = t(X); +b = t(X^2); + +if(1==1){} + +R = sum(a^2 == b); +write(R, $3); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index 32b5f7b..1b3478d 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -46,6 +46,7 @@ import org.junit.runners.Suite; PrintExpressionTest.class, PrintMatrixTest.class, ReadAfterWriteTest.class, + RewriteCSETransposeScalarTest.class, RewriteFusedRandTest.class, RewriteLoopVectorization.class, RewritePushdownSumBinaryMult.class,
