[SYSTEMML-2032] Fix rewrite fuse-datagen-binary-op (pdf awareness) There was a missing check for uniform pdf functions in the rewrite for fusing rand with min 0, max 1 and scalar variable multiplications. This patch fixes the rewrite and adds related negative tests.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d47414ed Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d47414ed Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d47414ed Branch: refs/heads/master Commit: d47414ed0728700776c19585533b5dfc0eb835e1 Parents: b2387b7 Author: Matthias Boehm <[email protected]> Authored: Thu Nov 30 19:41:47 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Thu Nov 30 19:41:47 2017 -0800 ---------------------------------------------------------------------- .../RewriteAlgebraicSimplificationStatic.java | 29 +++++++++++++------- .../functions/misc/RewriteFusedRandTest.java | 25 +++++++++++++---- .../functions/misc/RewriteFusedRandVar3.dml | 28 +++++++++++++++++++ 3 files changed, 66 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/d47414ed/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index cc2fe88..963e578 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -371,10 +371,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule Hop min = inputGen.getInput(DataExpression.RAND_MIN); Hop max = inputGen.getInput(DataExpression.RAND_MAX); double sval = ((LiteralOp)right).getDoubleValue(); + boolean pdfUniform = pdf instanceof LiteralOp + && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); if( HopRewriteUtils.isBinary(bop, OpOp2.MULT, OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV) - && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp - && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) ) + && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform ) { //create fused data gen operator DataGenOp gen = null; @@ -392,7 +393,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule HopRewriteUtils.replaceChildReference(p, bop, gen); hi = gen; - LOG.debug("Applied fuseDatagenAndBinaryOperation1 (line "+bop.getBeginLine()+")."); + LOG.debug("Applied fuseDatagenAndBinaryOperation1 " + + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); } } //right input rand and hence output matrix double, left scalar literal @@ -404,10 +406,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule Hop min = inputGen.getInput(DataExpression.RAND_MIN); Hop max = inputGen.getInput(DataExpression.RAND_MAX); double sval = ((LiteralOp)left).getDoubleValue(); + boolean pdfUniform = pdf instanceof LiteralOp + && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS) - && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp - && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) ) + && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform ) { //create fused data gen operator DataGenOp gen = null; @@ -423,7 +426,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule HopRewriteUtils.replaceChildReference(p, bop, gen); hi = gen; - LOG.debug("Applied fuseDatagenAndBinaryOperation2 (line "+bop.getBeginLine()+")."); + LOG.debug("Applied fuseDatagenAndBinaryOperation2 " + + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); } } //left input rand and hence output matrix double, right scalar variable @@ -433,6 +437,10 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule DataGenOp gen = (DataGenOp)left; Hop min = gen.getInput(DataExpression.RAND_MIN); Hop max = gen.getInput(DataExpression.RAND_MAX); + Hop pdf = gen.getInput(DataExpression.RAND_PDF); + boolean pdfUniform = pdf instanceof LiteralOp + && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); + if( HopRewriteUtils.isBinary(bop, OpOp2.PLUS) && HopRewriteUtils.isLiteralOfValue(min, 0) @@ -445,10 +453,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule for( Hop p : parents ) HopRewriteUtils.replaceChildReference(p, bop, gen); hi = gen; - LOG.debug("Applied fuseDatagenAndBinaryOperation3a (line "+bop.getBeginLine()+")."); + LOG.debug("Applied fuseDatagenAndBinaryOperation3a " + + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); } else if( HopRewriteUtils.isBinary(bop, OpOp2.MULT) - && (HopRewriteUtils.isLiteralOfValue(min, 0) + && ((HopRewriteUtils.isLiteralOfValue(min, 0) && pdfUniform) || HopRewriteUtils.isLiteralOfValue(min, 1)) && HopRewriteUtils.isLiteralOfValue(max, 1) ) { @@ -460,10 +469,10 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule for( Hop p : parents ) HopRewriteUtils.replaceChildReference(p, bop, gen); hi = gen; - LOG.debug("Applied fuseDatagenAndBinaryOperation3b (line "+bop.getBeginLine()+")."); + LOG.debug("Applied fuseDatagenAndBinaryOperation3b " + + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); } } - } return hi; http://git-wip-us.apache.org/repos/asf/systemml/blob/d47414ed/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java index d7fe902..9257538 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java @@ -32,6 +32,7 @@ public class RewriteFusedRandTest extends AutomatedTestBase private static final String TEST_NAME1 = "RewriteFusedRandLit"; private static final String TEST_NAME2 = "RewriteFusedRandVar1"; private static final String TEST_NAME3 = "RewriteFusedRandVar2"; + private static final String TEST_NAME4 = "RewriteFusedRandVar3"; private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFusedRandTest.class.getSimpleName() + "/"; @@ -46,6 +47,7 @@ public class RewriteFusedRandTest extends AutomatedTestBase addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) ); } @Test @@ -79,15 +81,25 @@ public class RewriteFusedRandTest extends AutomatedTestBase } @Test - public void testRewriteFusedZerosPlusVar() { + public void testRewriteFusedZerosPlusVarUniform() { testRewriteFusedRand( TEST_NAME2, "uniform", true ); } @Test - public void testRewriteFusedOnesMultVar() { + public void testRewriteFusedOnesMultVarUniform() { testRewriteFusedRand( TEST_NAME3, "uniform", true ); } + @Test + public void testRewriteFusedOnesMult2VarUniform() { + testRewriteFusedRand( TEST_NAME4, "uniform", true ); + } + + @Test + public void testRewriteFusedOnesMult2VarNormal() { + testRewriteFusedRand( TEST_NAME4, "normal", true ); + } + private void testRewriteFusedRand( String testname, String pdf, boolean rewrites ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; @@ -115,13 +127,14 @@ public class RewriteFusedRandTest extends AutomatedTestBase Assert.assertEquals("Wrong result", new Double(Math.pow(rows*cols, 2)), ret); //check for applied rewrites - if( rewrites && pdf.equals("uniform") ) { - Assert.assertTrue(!heavyHittersContainsString("+") - && !heavyHittersContainsString("*")); + if( rewrites ) { + boolean expected = testname.equals(TEST_NAME2) || pdf.equals("uniform"); + Assert.assertTrue(expected == (!heavyHittersContainsString("+") + && !heavyHittersContainsString("*"))); } } finally { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; } - } + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/d47414ed/src/test/scripts/functions/misc/RewriteFusedRandVar3.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFusedRandVar3.dml b/src/test/scripts/functions/misc/RewriteFusedRandVar3.dml new file mode 100644 index 0000000..0eb1125 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFusedRandVar3.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, $1, $2) +while(FALSE){} +Y = rand(rows=$1, cols=$2, min=0, max=1, pdf=$3) * sum(X); +while(FALSE){} +R = as.matrix(sum(Y)) + +write(R, $5);
