[SYSTEMML-2032] Fix rewrite fuse-datagen-binary-op (pdf awareness)

There was a missing check for uniform pdf functions in the rewrite for
fusing rand with min 0, max 1 and scalar variable multiplications. This
patch fixes the rewrite and adds related negative tests.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d47414ed
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d47414ed
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d47414ed

Branch: refs/heads/master
Commit: d47414ed0728700776c19585533b5dfc0eb835e1
Parents: b2387b7
Author: Matthias Boehm <[email protected]>
Authored: Thu Nov 30 19:41:47 2017 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Thu Nov 30 19:41:47 2017 -0800

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationStatic.java   | 29 +++++++++++++-------
 .../functions/misc/RewriteFusedRandTest.java    | 25 +++++++++++++----
 .../functions/misc/RewriteFusedRandVar3.dml     | 28 +++++++++++++++++++
 3 files changed, 66 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d47414ed/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index cc2fe88..963e578 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -371,10 +371,11 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                Hop min = 
inputGen.getInput(DataExpression.RAND_MIN);
                                Hop max = 
inputGen.getInput(DataExpression.RAND_MAX);
                                double sval = 
((LiteralOp)right).getDoubleValue();
+                               boolean pdfUniform = pdf instanceof LiteralOp 
+                                       && 
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue());
                                
                                if( HopRewriteUtils.isBinary(bop, OpOp2.MULT, 
OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV)
-                                       && min instanceof LiteralOp && max 
instanceof LiteralOp && pdf instanceof LiteralOp 
-                                       && 
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) )
+                                       && min instanceof LiteralOp && max 
instanceof LiteralOp && pdfUniform )
                                {
                                        //create fused data gen operator
                                        DataGenOp gen = null;
@@ -392,7 +393,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                                
HopRewriteUtils.replaceChildReference(p, bop, gen);
                                        
                                        hi = gen;
-                                       LOG.debug("Applied 
fuseDatagenAndBinaryOperation1 (line "+bop.getBeginLine()+").");
+                                       LOG.debug("Applied 
fuseDatagenAndBinaryOperation1 "
+                                               + "("+bop.getFilename()+", line 
"+bop.getBeginLine()+").");
                                }
                        }
                        //right input rand and hence output matrix double, left 
scalar literal
@@ -404,10 +406,11 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                Hop min = 
inputGen.getInput(DataExpression.RAND_MIN);
                                Hop max = 
inputGen.getInput(DataExpression.RAND_MAX);
                                double sval = 
((LiteralOp)left).getDoubleValue();
+                               boolean pdfUniform = pdf instanceof LiteralOp 
+                                       && 
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue());
                                
                                if( (bop.getOp()==OpOp2.MULT || 
bop.getOp()==OpOp2.PLUS)
-                                       && min instanceof LiteralOp && max 
instanceof LiteralOp && pdf instanceof LiteralOp 
-                                       && 
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) )
+                                       && min instanceof LiteralOp && max 
instanceof LiteralOp && pdfUniform )
                                {
                                        //create fused data gen operator
                                        DataGenOp gen = null;
@@ -423,7 +426,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                                
HopRewriteUtils.replaceChildReference(p, bop, gen);
                                        
                                        hi = gen;
-                                       LOG.debug("Applied 
fuseDatagenAndBinaryOperation2 (line "+bop.getBeginLine()+").");
+                                       LOG.debug("Applied 
fuseDatagenAndBinaryOperation2 "
+                                               + "("+bop.getFilename()+", line 
"+bop.getBeginLine()+").");
                                }
                        }
                        //left input rand and hence output matrix double, right 
scalar variable
@@ -433,6 +437,10 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                DataGenOp gen = (DataGenOp)left;
                                Hop min = gen.getInput(DataExpression.RAND_MIN);
                                Hop max = gen.getInput(DataExpression.RAND_MAX);
+                               Hop pdf = gen.getInput(DataExpression.RAND_PDF);
+                               boolean pdfUniform = pdf instanceof LiteralOp 
+                                       && 
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue());
+                                       
                                
                                if( HopRewriteUtils.isBinary(bop, OpOp2.PLUS)
                                        && 
HopRewriteUtils.isLiteralOfValue(min, 0)
@@ -445,10 +453,11 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                        for( Hop p : parents )
                                                
HopRewriteUtils.replaceChildReference(p, bop, gen);
                                        hi = gen;
-                                       LOG.debug("Applied 
fuseDatagenAndBinaryOperation3a (line "+bop.getBeginLine()+").");
+                                       LOG.debug("Applied 
fuseDatagenAndBinaryOperation3a "
+                                               + "("+bop.getFilename()+", line 
"+bop.getBeginLine()+").");
                                }
                                else if( HopRewriteUtils.isBinary(bop, 
OpOp2.MULT)
-                                       && 
(HopRewriteUtils.isLiteralOfValue(min, 0)
+                                       && 
((HopRewriteUtils.isLiteralOfValue(min, 0) && pdfUniform)
                                                || 
HopRewriteUtils.isLiteralOfValue(min, 1))
                                        && 
HopRewriteUtils.isLiteralOfValue(max, 1) )
                                {
@@ -460,10 +469,10 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                        for( Hop p : parents )
                                                
HopRewriteUtils.replaceChildReference(p, bop, gen);
                                        hi = gen;
-                                       LOG.debug("Applied 
fuseDatagenAndBinaryOperation3b (line "+bop.getBeginLine()+").");
+                                       LOG.debug("Applied 
fuseDatagenAndBinaryOperation3b "
+                                               + "("+bop.getFilename()+", line 
"+bop.getBeginLine()+").");
                                }
                        }
-                       
                }
                
                return hi;

http://git-wip-us.apache.org/repos/asf/systemml/blob/d47414ed/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
index d7fe902..9257538 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
@@ -32,6 +32,7 @@ public class RewriteFusedRandTest extends AutomatedTestBase
        private static final String TEST_NAME1 = "RewriteFusedRandLit";
        private static final String TEST_NAME2 = "RewriteFusedRandVar1";
        private static final String TEST_NAME3 = "RewriteFusedRandVar2";
+       private static final String TEST_NAME4 = "RewriteFusedRandVar3";
        
        private static final String TEST_DIR = "functions/misc/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteFusedRandTest.class.getSimpleName() + "/";
@@ -46,6 +47,7 @@ public class RewriteFusedRandTest extends AutomatedTestBase
                addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
        }
 
        @Test
@@ -79,15 +81,25 @@ public class RewriteFusedRandTest extends AutomatedTestBase
        }
        
        @Test
-       public void testRewriteFusedZerosPlusVar() {
+       public void testRewriteFusedZerosPlusVarUniform() {
                testRewriteFusedRand( TEST_NAME2, "uniform", true );
        }
        
        @Test
-       public void testRewriteFusedOnesMultVar() {
+       public void testRewriteFusedOnesMultVarUniform() {
                testRewriteFusedRand( TEST_NAME3, "uniform", true );
        }
        
+       @Test
+       public void testRewriteFusedOnesMult2VarUniform() {
+               testRewriteFusedRand( TEST_NAME4, "uniform", true );
+       }
+       
+       @Test
+       public void testRewriteFusedOnesMult2VarNormal() {
+               testRewriteFusedRand( TEST_NAME4, "normal", true );
+       }
+       
        private void testRewriteFusedRand( String testname, String pdf, boolean 
rewrites )
        {       
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -115,13 +127,14 @@ public class RewriteFusedRandTest extends 
AutomatedTestBase
                                Assert.assertEquals("Wrong result", new 
Double(Math.pow(rows*cols, 2)), ret);
                        
                        //check for applied rewrites
-                       if( rewrites && pdf.equals("uniform") ) {
-                               
Assert.assertTrue(!heavyHittersContainsString("+")
-                                       && !heavyHittersContainsString("*"));
+                       if( rewrites ) {
+                               boolean expected = testname.equals(TEST_NAME2) 
|| pdf.equals("uniform");
+                               Assert.assertTrue(expected == 
(!heavyHittersContainsString("+")
+                                       && !heavyHittersContainsString("*")));
                        }
                }
                finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
                }
-       }       
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/d47414ed/src/test/scripts/functions/misc/RewriteFusedRandVar3.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFusedRandVar3.dml 
b/src/test/scripts/functions/misc/RewriteFusedRandVar3.dml
new file mode 100644
index 0000000..0eb1125
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFusedRandVar3.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(1, $1, $2)
+while(FALSE){}
+Y = rand(rows=$1, cols=$2, min=0, max=1, pdf=$3) * sum(X);
+while(FALSE){}
+R = as.matrix(sum(Y))
+
+write(R, $5);

Reply via email to