[HOTFIX][SYSTEMML-1663] Fix and disable element-wise mult chain rewrite

This patch fixes the custom hop comparator to find an ordering of
element-wise multiplication chains (scalars, vectors, matrices), which
fixes the test issue of PR549. Due to additional issues that could cause
result incorrectness or runtime errors, I'm temporarily disabling this
rewrite and related tests. 

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a5c834b2
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a5c834b2
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a5c834b2

Branch: refs/heads/master
Commit: a5c834b27da9cfeffe0ad6e606c43fe3246831d2
Parents: 9e7ce7b
Author: Matthias Boehm <[email protected]>
Authored: Wed Jun 21 00:05:32 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed Jun 21 00:05:32 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/ProgramRewriter.java     |  6 ++---
 ...RewriteElementwiseMultChainOptimization.java | 27 ++++++++++++--------
 ...iteElementwiseMultChainOptimizationTest.java |  4 ++-
 3 files changed, 23 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/a5c834b2/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 7ee3ccb..92d31c2 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,8 +96,8 @@ public class ProgramRewriter
                        _dagRuleSet.add(     new 
RewriteRemoveUnnecessaryCasts()             );         
                        if( 
OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
                                _dagRuleSet.add( new 
RewriteCommonSubexpressionElimination()     );
-                       if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
-                               _dagRuleSet.add( new 
RewriteElementwiseMultChainOptimization()   ); //dependency: cse
+                       //if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
+                       //      _dagRuleSet.add( new 
RewriteElementwiseMultChainOptimization()   ); //dependency: cse
                        if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
                                _dagRuleSet.add( new RewriteConstantFolding()   
                 ); //dependency: cse
                        if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
@@ -108,7 +108,7 @@ public class ProgramRewriter
                                _dagRuleSet.add( new 
RewriteIndexingVectorization()              ); //dependency: cse, 
simplifications
                        _dagRuleSet.add( new 
RewriteInjectSparkPReadCheckpointing()          ); //dependency: reblock
                        
-                       //add statment block rewrite rules
+                       //add statement block rewrite rules
                        if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )               
        
                                _sbRuleSet.add(  new 
RewriteRemoveUnnecessaryBranches()          ); //dependency: constant folding   
           
                        if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS )

http://git-wip-us.apache.org/repos/asf/systemml/blob/a5c834b2/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 9ca0932..fe2a5d0 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -42,7 +42,7 @@ import com.google.common.collect.Multiset;
  *
  * Rewrite a chain of element-wise multiply hops that contain identical 
elements.
  * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), 
where `^` is element-wise power.
- * The order of the multiplicands depends on their data types, dimentions 
(matrix or vector), and sparsity.
+ * The order of the multiplicands depends on their data types, dimensions 
(matrix or vector), and sparsity.
  *
  * Does not rewrite in the presence of foreign parents in the middle of the 
e-wise multiply chain,
  * since foreign parents may rely on the individual results.
@@ -136,6 +136,8 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                // sorted contains all leaves, sorted by data type, stripped 
from their parents
 
                // Construct right-deep EMult tree
+               // TODO compile binary outer mult for transition from row and 
column vectors to matrices
+               // TODO compile subtree for column vectors to avoid blow-up of 
intermediates on row-col vector transition
                final Iterator<Map.Entry<Hop, Integer>> iterator = 
sorted.entrySet().iterator();
                Hop first = constructPower(iterator.next());
 
@@ -160,13 +162,15 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
        }
 
        /**
-        * A Comparator that orders Hops by their data type, dimention, and 
sparsity.
+        * A Comparator that orders Hops by their data type, dimension, and 
sparsity.
         * The order is as follows:
         *              scalars > row vectors > col vectors >
         *      non-vector matrices ordered by sparsity (higher nnz first, 
unknown sparsity last) >
         *      other data types.
         * Disambiguate by Hop ID.
         */
+       //TODO replace by ComparableHop wrapper around hop that implements 
equals and compareTo
+       //in order to ensure comparisons that are 'consistent with equals'
        private static final Comparator<Hop> compareByDataType = new 
Comparator<Hop>() {
                private final int[] orderDataType = new 
int[Expression.DataType.values().length];
                {
@@ -190,17 +194,17 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                        case MATRIX:
                                // two matrices; check for vectors
                                if (o1.getDim1() == 1) { // row vector
-                                               if (o2.getDim1() != 1) return 
1; // row vectors are greatest of matrices
-                                               return 
compareBySparsityThenId(o1, o2); // both row vectors
+                                       if (o2.getDim1() != 1) return 1; // row 
vectors are greatest of matrices
+                                       return compareBySparsityThenId(o1, o2); 
// both row vectors
                                } else if (o2.getDim1() == 1) { // 2 is row 
vector; 1 is not
-                                               return -1; // row vectors are 
the greatest matrices
+                                       return -1; // row vectors are the 
greatest matrices
                                } else if (o1.getDim2() == 1) { // col vector
-                                               if (o2.getDim2() != 1) return 
1; // col vectors greater than non-vectors
-                                               return 
compareBySparsityThenId(o1, o2); // both col vectors
+                                       if (o2.getDim2() != 1) return 1; // col 
vectors greater than non-vectors
+                                       return compareBySparsityThenId(o1, o2); 
// both col vectors
                                } else if (o2.getDim2() == 1) { // 2 is col 
vector; 1 is not
-                                               return 1; // col vectors 
greater than non-vectors
+                                       return -1; // col vectors greater than 
non-vectors
                                } else { // both non-vectors
-                                               return 
compareBySparsityThenId(o1, o2);
+                                       return compareBySparsityThenId(o1, o2);
                                }
                        default:
                                return Long.compare(o1.getHopID(), 
o2.getHopID());
@@ -243,7 +247,10 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
        private static void findEMultsAndLeaves(final BinaryOp root, final 
Set<BinaryOp> emults, final Multiset<Hop> leaves) {
                // Because RewriteCommonSubexpressionElimination already ran, 
it is safe to compare by equality.
                emults.add(root);
-
+               
+               // TODO proper handling of DAGs (avoid collecting the same leaf 
multiple times)
+               // TODO exclude hops with unknown dimensions and move rewrites 
to dynamic rewrites 
+               
                final ArrayList<Hop> inputs = root.getInput();
                final Hop left = inputs.get(0), right = inputs.get(1);
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/a5c834b2/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
index 91cb4e0..b16fa3e 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
@@ -50,7 +50,7 @@ public class RewriteElementwiseMultChainOptimizationTest 
extends AutomatedTestBa
                TestUtils.clearAssertionInformation();
                addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
        }
-
+       
        @Test
        public void testMatrixMultChainOptNoRewritesCP() {
                testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
@@ -61,6 +61,7 @@ public class RewriteElementwiseMultChainOptimizationTest 
extends AutomatedTestBa
                testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
        }
        
+       /* TODO enable together with RewriteElementwiseMultChainOptimization
        @Test
        public void testMatrixMultChainOptRewritesCP() {
                testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
@@ -70,6 +71,7 @@ public class RewriteElementwiseMultChainOptimizationTest 
extends AutomatedTestBa
        public void testMatrixMultChainOptRewritesSP() {
                testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
        }
+       */
 
        private void testRewriteMatrixMultChainOp(String testname, boolean 
rewrites, ExecType et)
        {       

Reply via email to