[HOTFIX][SYSTEMML-1663] Fix and disable element-wise mult chain rewrite This patch fixes the custom hop comparator to find an ordering of element-wise multiplication chains (scalars, vectors, matrices), which fixes the test issue of PR549. Due to additional issues that could cause result incorrectness or runtime errors, I'm temporarily disabling this rewrite and related tests.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a5c834b2 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a5c834b2 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a5c834b2 Branch: refs/heads/master Commit: a5c834b27da9cfeffe0ad6e606c43fe3246831d2 Parents: 9e7ce7b Author: Matthias Boehm <[email protected]> Authored: Wed Jun 21 00:05:32 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Jun 21 00:05:32 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/rewrite/ProgramRewriter.java | 6 ++--- ...RewriteElementwiseMultChainOptimization.java | 27 ++++++++++++-------- ...iteElementwiseMultChainOptimizationTest.java | 4 ++- 3 files changed, 23 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/a5c834b2/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 7ee3ccb..92d31c2 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -96,8 +96,8 @@ public class ProgramRewriter _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); - if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) - _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse + //if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) + // _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) @@ -108,7 +108,7 @@ public class ProgramRewriter _dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications _dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock - //add statment block rewrite rules + //add statement block rewrite rules if( OptimizerUtils.ALLOW_BRANCH_REMOVAL ) _sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS ) http://git-wip-us.apache.org/repos/asf/systemml/blob/a5c834b2/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java index 9ca0932..fe2a5d0 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java @@ -42,7 +42,7 @@ import com.google.common.collect.Multiset; * * Rewrite a chain of element-wise multiply hops that contain identical elements. * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power. - * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity. + * The order of the multiplicands depends on their data types, dimensions (matrix or vector), and sparsity. * * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain, * since foreign parents may rely on the individual results. @@ -136,6 +136,8 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { // sorted contains all leaves, sorted by data type, stripped from their parents // Construct right-deep EMult tree + // TODO compile binary outer mult for transition from row and column vectors to matrices + // TODO compile subtree for column vectors to avoid blow-up of intermediates on row-col vector transition final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator(); Hop first = constructPower(iterator.next()); @@ -160,13 +162,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { } /** - * A Comparator that orders Hops by their data type, dimention, and sparsity. + * A Comparator that orders Hops by their data type, dimension, and sparsity. * The order is as follows: * scalars > row vectors > col vectors > * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) > * other data types. * Disambiguate by Hop ID. */ + //TODO replace by ComparableHop wrapper around hop that implements equals and compareTo + //in order to ensure comparisons that are 'consistent with equals' private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() { private final int[] orderDataType = new int[Expression.DataType.values().length]; { @@ -190,17 +194,17 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { case MATRIX: // two matrices; check for vectors if (o1.getDim1() == 1) { // row vector - if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices - return compareBySparsityThenId(o1, o2); // both row vectors + if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices + return compareBySparsityThenId(o1, o2); // both row vectors } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not - return -1; // row vectors are the greatest matrices + return -1; // row vectors are the greatest matrices } else if (o1.getDim2() == 1) { // col vector - if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors - return compareBySparsityThenId(o1, o2); // both col vectors + if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors + return compareBySparsityThenId(o1, o2); // both col vectors } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not - return 1; // col vectors greater than non-vectors + return -1; // col vectors greater than non-vectors } else { // both non-vectors - return compareBySparsityThenId(o1, o2); + return compareBySparsityThenId(o1, o2); } default: return Long.compare(o1.getHopID(), o2.getHopID()); @@ -243,7 +247,10 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) { // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality. emults.add(root); - + + // TODO proper handling of DAGs (avoid collecting the same leaf multiple times) + // TODO exclude hops with unknown dimensions and move rewrites to dynamic rewrites + final ArrayList<Hop> inputs = root.getInput(); final Hop left = inputs.get(0), right = inputs.get(1); http://git-wip-us.apache.org/repos/asf/systemml/blob/a5c834b2/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java index 91cb4e0..b16fa3e 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java @@ -50,7 +50,7 @@ public class RewriteElementwiseMultChainOptimizationTest extends AutomatedTestBa TestUtils.clearAssertionInformation(); addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); } - + @Test public void testMatrixMultChainOptNoRewritesCP() { testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP); @@ -61,6 +61,7 @@ public class RewriteElementwiseMultChainOptimizationTest extends AutomatedTestBa testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK); } + /* TODO enable together with RewriteElementwiseMultChainOptimization @Test public void testMatrixMultChainOptRewritesCP() { testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP); @@ -70,6 +71,7 @@ public class RewriteElementwiseMultChainOptimizationTest extends AutomatedTestBa public void testMatrixMultChainOptRewritesSP() { testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK); } + */ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et) {
