Add scalars to Rewrite Emult test Not sure how to check this in an assert statement
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/737f93b1 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/737f93b1 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/737f93b1 Branch: refs/heads/master Commit: 737f93b15a96aba31bc6c6da3651be309e3b8b0c Parents: b94557f Author: Dylan Hutchison <dhutc...@cs.washington.edu> Authored: Sun Jun 11 01:56:07 2017 -0700 Committer: Dylan Hutchison <dhutc...@cs.washington.edu> Committed: Sun Jun 18 17:43:41 2017 -0700 ---------------------------------------------------------------------- .../hops/rewrite/RewriteElementwiseMultChainOptimization.java | 2 +- src/main/java/org/apache/sysml/utils/Explain.java | 4 ++-- .../misc/RewriteElementwiseMultChainOptimizationChainTest.java | 4 ++-- .../integration/functions/ternary/ABATernaryAggregateTest.java | 5 +---- src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R | 4 ++-- src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml | 2 +- 6 files changed, 9 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java index bd873ff..1dd5813 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java @@ -222,7 +222,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { } } } - }; + }.reversed(); /** * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults. http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/main/java/org/apache/sysml/utils/Explain.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java index 5cf0548..450c6e5 100644 --- a/src/main/java/org/apache/sysml/utils/Explain.java +++ b/src/main/java/org/apache/sysml/utils/Explain.java @@ -76,7 +76,7 @@ public class Explain //internal configuration parameters private static final boolean REPLACE_SPECIAL_CHARACTERS = true; private static final boolean SHOW_MEM_ABOVE_BUDGET = true; - private static final boolean SHOW_LITERAL_HOPS = false; + private static final boolean SHOW_LITERAL_HOPS = true; private static final boolean SHOW_DATA_DEPENDENCIES = true; private static final boolean SHOW_DATA_FLOW_PROPERTIES = true; @@ -566,7 +566,7 @@ public class Explain childs.append(" ("); boolean childAdded = false; for( Hop input : hop.getInput() ) - if( !(input instanceof LiteralOp) ){ + if( SHOW_LITERAL_HOPS || !(input instanceof LiteralOp) ){ childs.append(childAdded?",":""); childs.append(input.getHopID()); childAdded = true; http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java index 47b2f0e..e490750 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java @@ -33,7 +33,7 @@ import org.junit.Assert; import org.junit.Test; /** - * Test whether `A*B*A` successfully rewrites to `(A^2)*B`. + * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`. */ public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase { @@ -96,7 +96,7 @@ public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedT fullDMLScriptName = HOME + testname + ".dml"; programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") }; fullRScriptName = HOME + testname + ".R"; - rCmd = getRCmd(inputDir(), expectedDir()); + rCmd = getRCmd(inputDir(), expectedDir()); double Xsparsity = 0.8, Ysparsity = 0.6; double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7); http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java index 460829d..12525c9 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java @@ -368,15 +368,13 @@ public class ABATernaryAggregateTest extends AutomatedTestBase if( rtplatform == RUNTIME_PLATFORM.SPARK ) DMLScript.USE_LOCAL_SPARK_CONFIG = true; - boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES, - rewritesOldEmult = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; + boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; try { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; - OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; @@ -412,7 +410,6 @@ public class ABATernaryAggregateTest extends AutomatedTestBase rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld; - OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOldEmult; } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R index 6d94cc8..fec61ae 100644 --- a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R +++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R @@ -28,6 +28,6 @@ library("matrixStats") X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep=""))) -R = X * Y * X; +R = 2 * X * 3 * Y * 4 * X; -writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml index 3992403..88f252f 100644 --- a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml +++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml @@ -23,6 +23,6 @@ X = read($1); Y = read($2); -R = X * Y * X; +R = 2 * X * 3 * Y * 4 * X; write(R, $3); \ No newline at end of file