Repository: systemml Updated Branches: refs/heads/master 85e3a9631 -> eca9dbbb8 (forced update)
[SYSTEMML-1663] Fix and enable rewrite element-wise multiply chains Groups together types of e-wise multiply inputs. Comprehensive test on all different kinds of objects multiplied together. The new order of element-wise multiply chains is as follows: <pre> (((unknown * object * frame) * ([least-nnz-matrix * matrix] * most-nnz-matrix)) * ([least-nnz-row-vector * row-vector] * most-nnz-row-vector)) * ([[scalars * least-nnz-col-vector] * col-vector] * most-nnz-col-vector) </pre> Moves to dynamic rewrites. Do not rewrite if top-level dims unknown. Adds new 'wumm' pattern to pick up element-wise multiply rewrite. The new pattern recognizes when there is a '*2' or '2*' outside 'W*(U%*%t(V))'. Closes #567. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/eca9dbbb Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/eca9dbbb Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/eca9dbbb Branch: refs/heads/master Commit: eca9dbbb85971af688e81c9254538c53fc429b30 Parents: 1b3dff0 Author: Dylan Hutchison <dhutc...@cs.washington.edu> Authored: Fri Jul 14 23:08:46 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Fri Jul 14 23:08:46 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/rewrite/ProgramRewriter.java | 8 +- .../RewriteAlgebraicSimplificationDynamic.java | 45 ++++- ...RewriteElementwiseMultChainOptimization.java | 180 +++++++++++++------ ...ElementwiseMultChainOptimizationAllTest.java | 134 ++++++++++++++ .../functions/misc/RewriteEMultChainOpAll.R | 37 ++++ .../functions/misc/RewriteEMultChainOpAll.dml | 31 ++++ 6 files changed, 376 insertions(+), 59 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 92d31c2..7c4f861 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -54,7 +54,7 @@ public class ProgramRewriter private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName()); //internal local debug level - private static final boolean LDEBUG = false; + private static final boolean LDEBUG = false; private static final boolean CHECK = false; private ArrayList<HopRewriteRule> _dagRuleSet = null; @@ -96,8 +96,6 @@ public class ProgramRewriter _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); - //if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) - // _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) @@ -125,7 +123,9 @@ public class ProgramRewriter // DYNAMIC REWRITES (which do require size information) if( dynamicRewrites ) { - _dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse + _dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse + if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) + _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 8cd71f4..09b66de 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -29,11 +29,11 @@ import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.DataGenOp; import org.apache.sysml.hops.Hop; -import org.apache.sysml.hops.QuaternaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.DataGenMethod; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp1; +import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.Hop.OpOp3; import org.apache.sysml.hops.Hop.OpOp4; import org.apache.sysml.hops.Hop.ParamBuiltinOp; @@ -44,7 +44,7 @@ import org.apache.sysml.hops.LeftIndexingOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.ParameterizedBuiltinOp; -import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.QuaternaryOp; import org.apache.sysml.hops.ReorgOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.UnaryOp; @@ -1959,6 +1959,47 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule appliedPattern = true; LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")"); } + + //Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V)) + if( !appliedPattern + && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT) + && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2) + || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2))) + { + final Hop nl; // non-literal + if( hi.getInput().get(0) instanceof LiteralOp ) { + nl = hi.getInput().get(1); + } else { + nl = hi.getInput().get(0); + } + + if ( HopRewriteUtils.isBinary(nl, OpOp2.MULT) + && nl.getParent().size()==1 // ensure no foreign parents + && HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) //prevent mv + && nl.getDim2() > 1 //not applied for vector-vector mult + && nl.getInput().get(0).getDataType() == DataType.MATRIX + && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock() + && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1)) + && (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain + && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0),true) ) + { + final Hop W = nl.getInput().get(0); + final Hop U = nl.getInput().get(1).getInput().get(0); + Hop V = nl.getInput().get(1).getInput().get(1); + if( !HopRewriteUtils.isTransposeOperation(V) ) + V = HopRewriteUtils.createTranspose(V); + else + V = V.getInput().get(0); + + hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, + OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT); + hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); + hnew.refreshSizeInformation(); + + appliedPattern = true; + LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line "+hi.getBeginLine()+")"); + } + } //Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops if( !appliedPattern http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java index c2c3b11..2e411f6 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java @@ -26,8 +26,8 @@ import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; -import java.util.SortedMap; -import java.util.TreeMap; +import java.util.SortedSet; +import java.util.TreeSet; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; @@ -44,6 +44,15 @@ import org.apache.sysml.parser.Expression; * * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain, * since foreign parents may rely on the individual results. + * Does not perform rewrites on an element-wise multiply if its dimensions are unknown. + * + * The new order of element-wise multiply chains is as follows: + * <pre> + * (((unknown * object * frame) * ([least-nnz-matrix * matrix] * most-nnz-matrix)) + * * ([least-nnz-row-vector * row-vector] * most-nnz-row-vector)) + * * ([[scalars * least-nnz-col-vector] * col-vector] * most-nnz-col-vector) + * </pre> + * Identical elements are replaced with powers. */ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { @Override @@ -73,15 +82,18 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { return root; root.setVisited(); - // 1. Find immediate subtree of EMults. - if (isBinaryMult(root)) { + // 1. Find immediate subtree of EMults. Check dimsKnown. + if (isBinaryMult(root) && root.dimsKnown()) { final Hop left = root.getInput().get(0), right = root.getInput().get(1); + // The set of BinaryOp element-wise multiply hops in the emult chain. final Set<BinaryOp> emults = new HashSet<>(); + // The multiset of multiplicands in the emult chain. final Map<Hop, Integer> leaves = new HashMap<>(); // poor man's HashMultiset findEMultsAndLeaves((BinaryOp)root, emults, leaves); // 2. Ensure it is profitable to do a rewrite. - if (isOptimizable(emults, leaves)) { + // Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands). + if (emults.size() >= 2) { // 3. Check for foreign parents. // A foreign parent is a parent of some EMult that is not in the set. // Foreign parents destroy correctness of this rewrite. @@ -123,38 +135,110 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { private static Hop constructReplacement(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) { // Sort by data type - final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType); + final SortedSet<Hop> sorted = new TreeSet<>(compareByDataType); for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) { final Hop h = entry.getKey(); // unlink parents that are in the emult set(we are throwing them away) // keep other parents h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent)); - sorted.put(h, entry.getValue()); + sorted.add(constructPower(h, entry.getValue())); } // sorted contains all leaves, sorted by data type, stripped from their parents // Construct right-deep EMult tree - // TODO compile binary outer mult for transition from row and column vectors to matrices - // TODO compile subtree for column vectors to avoid blow-up of intermediates on row-col vector transition - final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator(); - Hop first = constructPower(iterator.next()); - - for (int i = 1; i < sorted.size(); i++) { - final Hop second = constructPower(iterator.next()); - first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT); - first.setVisited(); + final Iterator<Hop> iterator = sorted.iterator(); + + Hop next = iterator.hasNext() ? iterator.next() : null; + Hop colVectorsScalars = null; + while(next != null && + (next.getDataType() == Expression.DataType.SCALAR + || next.getDataType() == Expression.DataType.MATRIX && next.getDim2() == 1)) + { + if( colVectorsScalars == null ) + colVectorsScalars = next; + else { + colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT); + colVectorsScalars.setVisited(); + } + next = iterator.hasNext() ? iterator.next() : null; + } + // next is not processed and is either null or past col vectors + + Hop rowVectors = null; + while(next != null && + (next.getDataType() == Expression.DataType.MATRIX && next.getDim1() == 1)) + { + if( rowVectors == null ) + rowVectors = next; + else { + rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT); + rowVectors.setVisited(); + } + next = iterator.hasNext() ? iterator.next() : null; + } + // next is not processed and is either null or past row vectors + + Hop matrices = null; + while(next != null && + (next.getDataType() == Expression.DataType.MATRIX)) + { + if( matrices == null ) + matrices = next; + else { + matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT); + matrices.setVisited(); + } + next = iterator.hasNext() ? iterator.next() : null; + } + // next is not processed and is either null or past matrices + + Hop other = null; + while(next != null) + { + if( other == null ) + other = next; + else { + other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT); + other.setVisited(); + } + next = iterator.hasNext() ? iterator.next() : null; + } + // finished + + // ((other * matrices) * rowVectors) * colVectorsScalars + Hop top = null; + if( other == null && matrices != null ) + top = matrices; + else if( other != null && matrices == null ) + top = other; + else if( other != null ) { //matrices != null + top = HopRewriteUtils.createBinary(other, matrices, Hop.OpOp2.MULT); + top.setVisited(); + } + + if( top == null && rowVectors != null ) + top = rowVectors; + else if( rowVectors != null ) { //top != null + top = HopRewriteUtils.createBinary(top, rowVectors, Hop.OpOp2.MULT); + top.setVisited(); + } + + if( top == null && colVectorsScalars != null ) + top = colVectorsScalars; + else if( colVectorsScalars != null ) { //top != null + top = HopRewriteUtils.createBinary(top, colVectorsScalars, Hop.OpOp2.MULT); + top.setVisited(); } - return first; + + return top; } - private static Hop constructPower(final Map.Entry<Hop, Integer> entry) { - final Hop hop = entry.getKey(); - final int cnt = entry.getValue(); + private static Hop constructPower(final Hop hop, final int cnt) { assert(cnt >= 1); hop.setVisited(); // we will visit the leaves' children next if (cnt == 1) return hop; - Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); + final Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); pow.setVisited(); return pow; } @@ -162,8 +246,8 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { /** * A Comparator that orders Hops by their data type, dimension, and sparsity. * The order is as follows: - * scalars > row vectors > col vectors > - * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) > + * scalars < col vectors < row vectors < + * non-vector matrices ordered by sparsity (higher nnz last, unknown sparsity last) > * other data types. * Disambiguate by Hop ID. */ @@ -174,33 +258,33 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { { for (int i = 0, valuesLength = Expression.DataType.values().length; i < valuesLength; i++) switch(Expression.DataType.values()[i]) { - case SCALAR: orderDataType[i] = 4; break; - case MATRIX: orderDataType[i] = 3; break; + case SCALAR: orderDataType[i] = 0; break; + case MATRIX: orderDataType[i] = 1; break; case FRAME: orderDataType[i] = 2; break; - case OBJECT: orderDataType[i] = 1; break; - case UNKNOWN:orderDataType[i] = 0; break; + case OBJECT: orderDataType[i] = 3; break; + case UNKNOWN:orderDataType[i] = 4; break; } } @Override - public final int compare(Hop o1, Hop o2) { - int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]); + public final int compare(final Hop o1, final Hop o2) { + final int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]); if (c != 0) return c; // o1 and o2 have the same data type switch (o1.getDataType()) { case MATRIX: // two matrices; check for vectors - if (o1.getDim1() == 1) { // row vector - if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices - return compareBySparsityThenId(o1, o2); // both row vectors - } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not - return -1; // row vectors are the greatest matrices - } else if (o1.getDim2() == 1) { // col vector - if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors - return compareBySparsityThenId(o1, o2); // both col vectors + if (o1.getDim2() == 1) { // col vector + if (o2.getDim2() != 1) return -1; // col vectors are greatest of matrices + return compareBySparsityThenId(o1, o2); // both col vectors } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not - return -1; // col vectors greater than non-vectors + return 1; // col vectors are the greatest matrices + } else if (o1.getDim1() == 1) { // row vector + if (o2.getDim1() != 1) return -1; // row vectors greater than non-vectors + return compareBySparsityThenId(o1, o2); // both row vectors + } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not + return 1; // row vectors greater than non-vectors } else { // both non-vectors return compareBySparsityThenId(o1, o2); } @@ -208,13 +292,13 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { return Long.compare(o1.getHopID(), o2.getHopID()); } } - private int compareBySparsityThenId(Hop o1, Hop o2) { + private int compareBySparsityThenId(final Hop o1, final Hop o2) { // the hop with more nnz is first; unknown nnz (-1) last - int c = Long.compare(o1.getNnz(), o2.getNnz()); - if (c != 0) return c; + final int c = Long.compare(o1.getNnz(), o2.getNnz()); + if (c != 0) return -c; return Long.compare(o1.getHopID(), o2.getHopID()); } - }.reversed(); + }; /** * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults. @@ -242,8 +326,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root). * @param leaves Out parameter. The multiset of multiplicands in the emult chain. */ - private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, - final Map<Hop, Integer> leaves) { + private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) { // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality. emults.add(root); @@ -268,13 +351,4 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { map.put(k, map.getOrDefault(k, 0) + 1); } - /** - * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands). - * @param emults The set of BinaryOp element-wise multiply hops in the emult chain. - * @param leaves The multiset of multiplicands in the emult chain. - * @return If the multiset is worth optimizing. - */ - private static boolean isOptimizable(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) { - return emults.size() >= 2; - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java new file mode 100644 index 0000000..ba5c78d --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import java.util.HashMap; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test rewriting `2*X*3*v*5*w*4*z*5*Y*2*v*2*X`, where `v` and `z` are row vectors and `w` is a column vector, + * successfully rewrites to `Y*(X^2)*z*(v^2)*w*2400`. + */ +public class RewriteElementwiseMultChainOptimizationAllTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewriteEMultChainOpAll"; + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationAllTest.class.getSimpleName() + "/"; + + private static final int rows = 123; + private static final int cols = 321; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + } + + @Test + public void testMatrixMultChainOptNoRewritesCP() { + testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP); + } + + @Test + public void testMatrixMultChainOptNoRewritesSP() { + testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK); + } + + @Test + public void testMatrixMultChainOptRewritesCP() { + testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP); + } + + @Test + public void testMatrixMultChainOptRewritesSP() { + testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK); + } + + private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et) + { + RUNTIME_PLATFORM platformOld = rtplatform; + switch( et ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), input("v"), input("z"), input("w"), output("R") }; + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + double Xsparsity = 0.8, Ysparsity = 0.6; + double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7); + double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3); + double[][] z = getRandomMatrix(1, cols, -1, 1, Ysparsity, 5); + double[][] v = getRandomMatrix(1, cols, -1, 1, Xsparsity, 4); + double[][] w = getRandomMatrix(rows, 1, -1, 1, Ysparsity, 6); + writeInputMatrixWithMTD("X", X, true); + writeInputMatrixWithMTD("Y", Y, true); + writeInputMatrixWithMTD("z", z, true); + writeInputMatrixWithMTD("v", v, true); + writeInputMatrixWithMTD("w", w, true); + + //execute tests + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + //check for presence of power operator, if we did a rewrite + if( rewrites ) { + Assert.assertTrue(heavyHittersContainsSubString("^2")); + } + } + finally { + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld; + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R new file mode 100644 index 0000000..20f76c2 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +# args[1]="" +options(digits=22) +library("Matrix") +library("matrixStats") + +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep=""))) +v = as.matrix(readMM(paste(args[1], "v.mtx", sep=""))) +z = as.matrix(readMM(paste(args[1], "z.mtx", sep=""))) +w = as.matrix(readMM(paste(args[1], "w.mtx", sep=""))) + +R = 2* X *3* X *5* Y *4*5*2*2* (matrix(1,length(w),1)%*%z) * (matrix(1,length(w),1)%*%v)^2 * (w%*%matrix(1,1,length(v))) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml new file mode 100644 index 0000000..90f9242 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +X = read($1); +Y = read($2); +v = read($3); +z = read($4); +w = read($5); + +R = 2* X *3* v *5* w *4* z *5* Y *2* v *2* X + +write(R, $6); \ No newline at end of file