Add scalars to Rewrite Emult test

Not sure how to check this in an assert statement


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/737f93b1
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/737f93b1
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/737f93b1

Branch: refs/heads/master
Commit: 737f93b15a96aba31bc6c6da3651be309e3b8b0c
Parents: b94557f
Author: Dylan Hutchison <dhutc...@cs.washington.edu>
Authored: Sun Jun 11 01:56:07 2017 -0700
Committer: Dylan Hutchison <dhutc...@cs.washington.edu>
Committed: Sun Jun 18 17:43:41 2017 -0700

----------------------------------------------------------------------
 .../hops/rewrite/RewriteElementwiseMultChainOptimization.java   | 2 +-
 src/main/java/org/apache/sysml/utils/Explain.java               | 4 ++--
 .../misc/RewriteElementwiseMultChainOptimizationChainTest.java  | 4 ++--
 .../integration/functions/ternary/ABATernaryAggregateTest.java  | 5 +----
 src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R        | 4 ++--
 src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml      | 2 +-
 6 files changed, 9 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index bd873ff..1dd5813 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -222,7 +222,7 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                                }
                        }
                }
-       };
+       }.reversed();
 
        /**
         * Check if a node has a parent that is not in the set of emults. 
Recursively check children who are also emults.

http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/main/java/org/apache/sysml/utils/Explain.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Explain.java 
b/src/main/java/org/apache/sysml/utils/Explain.java
index 5cf0548..450c6e5 100644
--- a/src/main/java/org/apache/sysml/utils/Explain.java
+++ b/src/main/java/org/apache/sysml/utils/Explain.java
@@ -76,7 +76,7 @@ public class Explain
        //internal configuration parameters
        private static final boolean REPLACE_SPECIAL_CHARACTERS = true; 
        private static final boolean SHOW_MEM_ABOVE_BUDGET      = true;
-       private static final boolean SHOW_LITERAL_HOPS          = false;
+       private static final boolean SHOW_LITERAL_HOPS          = true;
        private static final boolean SHOW_DATA_DEPENDENCIES     = true;
        private static final boolean SHOW_DATA_FLOW_PROPERTIES  = true;
 
@@ -566,7 +566,7 @@ public class Explain
                        childs.append(" (");
                        boolean childAdded = false;
                        for( Hop input : hop.getInput() )
-                               if( !(input instanceof LiteralOp) ){
+                               if( SHOW_LITERAL_HOPS || !(input instanceof 
LiteralOp) ){
                                        childs.append(childAdded?",":"");
                                        childs.append(input.getHopID());
                                        childAdded = true;

http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
index 47b2f0e..e490750 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
@@ -33,7 +33,7 @@ import org.junit.Assert;
 import org.junit.Test;
 
 /**
- * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
+ * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`.
  */
 public class RewriteElementwiseMultChainOptimizationChainTest extends 
AutomatedTestBase
 {
@@ -96,7 +96,7 @@ public class RewriteElementwiseMultChainOptimizationChainTest 
extends AutomatedT
                        fullDMLScriptName = HOME + testname + ".dml";
                        programArgs = new String[] { "-explain", "hops", 
"-stats", "-args", input("X"), input("Y"), output("R") };
                        fullRScriptName = HOME + testname + ".R";
-                       rCmd = getRCmd(inputDir(), expectedDir());              
        
+                       rCmd = getRCmd(inputDir(), expectedDir());
 
                        double Xsparsity = 0.8, Ysparsity = 0.6;
                        double[][] X = getRandomMatrix(rows, cols, -1, 1, 
Xsparsity, 7);

http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
index 460829d..12525c9 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -368,15 +368,13 @@ public class ABATernaryAggregateTest extends 
AutomatedTestBase
                if( rtplatform == RUNTIME_PLATFORM.SPARK )
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
        
-               boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES,
-                               rewritesOldEmult = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+               boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
                
                try {
                        TestConfiguration config = 
getTestConfiguration(testname);
                        loadTestConfiguration(config);
                        
                        OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
-                       OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
 
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
@@ -412,7 +410,6 @@ public class ABATernaryAggregateTest extends 
AutomatedTestBase
                        rtplatform = platformOld;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                        OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
-                       OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = 
rewritesOldEmult;
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R 
b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
index 6d94cc8..fec61ae 100644
--- a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
@@ -28,6 +28,6 @@ library("matrixStats")
 X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
 Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
 
-R = X * Y * X;
+R = 2 * X * 3 * Y * 4 * X;
 
-writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));

http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml 
b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
index 3992403..88f252f 100644
--- a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
@@ -23,6 +23,6 @@
 X = read($1);
 Y = read($2);
 
-R = X * Y * X;
+R = 2 * X * 3 * Y * 4 * X;
 
 write(R, $3);
\ No newline at end of file

Reply via email to