Repository: systemml Updated Branches: refs/heads/master c43420855 -> 7fec7fa57
[SYSTEMML-1663] New rewrite for chains of element-wise multiply Closes #540. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7fec7fa5 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7fec7fa5 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7fec7fa5 Branch: refs/heads/master Commit: 7fec7fa571c1722b2c56c0498bdfe6bada10624e Parents: c434208 Author: Dylan Hutchison <dhutc...@cs.washington.edu> Authored: Sun Jun 18 21:47:30 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Sun Jun 18 21:47:30 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/AggUnaryOp.java | 120 ++++-- .../org/apache/sysml/hops/OptimizerUtils.java | 4 +- .../sysml/hops/rewrite/HopDagValidator.java | 5 +- .../sysml/hops/rewrite/HopRewriteUtils.java | 19 +- .../sysml/hops/rewrite/ProgramRewriter.java | 2 + .../RewriteAlgebraicSimplificationDynamic.java | 13 +- ...RewriteElementwiseMultChainOptimization.java | 270 ++++++++++++ .../java/org/apache/sysml/utils/Explain.java | 2 +- .../ElementwiseAdditionMultiplicationTest.java | 2 +- ...iteElementwiseMultChainOptimizationTest.java | 127 ++++++ .../ternary/ABATernaryAggregateTest.java | 415 +++++++++++++++++++ .../functions/misc/RewriteEMultChainOpXYX.R | 33 ++ .../functions/misc/RewriteEMultChainOpXYX.dml | 28 ++ .../functions/ternary/AAATernaryAggregateC.R | 31 ++ .../functions/ternary/AAATernaryAggregateC.dml | 28 ++ .../functions/ternary/AAATernaryAggregateRC.R | 32 ++ .../functions/ternary/AAATernaryAggregateRC.dml | 29 ++ .../functions/ternary/ABATernaryAggregateC.R | 32 ++ .../functions/ternary/ABATernaryAggregateC.dml | 29 ++ .../functions/ternary/ABATernaryAggregateRC.R | 33 ++ .../functions/ternary/ABATernaryAggregateRC.dml | 30 ++ .../functions/misc/ZPackageSuite.java | 1 + .../functions/ternary/ZPackageSuite.java | 3 +- 23 files changed, 1236 insertions(+), 52 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/main/java/org/apache/sysml/hops/AggUnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java index ee4ded2..8e681c1 100644 --- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java @@ -490,33 +490,38 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop (_direction == Direction.RowCol || _direction == Direction.Col) ) { Hop input1 = getInput().get(0); - if( input1.getParent().size() == 1 && //sum single consumer - input1 instanceof BinaryOp && ((BinaryOp)input1).getOp()==OpOp2.MULT - // As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it. - && input1.optFindExecType() != ExecType.MR) - { - Hop input11 = input1.getInput().get(0); - Hop input12 = input1.getInput().get(1); - - if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT ) { - //ternary, arbitrary matrices but no mv/outer operations. - ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) - && HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1) - && HopRewriteUtils.isEqualSize(input12, input1); - } - else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT ) { - //ternary, arbitrary matrices but no mv/outer operations. - ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) - && HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1) - && HopRewriteUtils.isEqualSize(input11, input1); + if (input1.getParent().size() == 1 + && input1 instanceof BinaryOp) { //sum single consumer + BinaryOp binput1 = (BinaryOp)input1; + + if (binput1.getOp() == OpOp2.POW + && binput1.getInput().get(1) instanceof LiteralOp) { + LiteralOp lit = (LiteralOp)binput1.getInput().get(1); + ret = HopRewriteUtils.getIntValueSafe(lit) == 3; } - else { - //binary, arbitrary matrices but no mv/outer operations. - ret = HopRewriteUtils.isEqualSize(input11, input12); + else if (binput1.getOp() == OpOp2.MULT + // As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it. + && input1.optFindExecType() != ExecType.MR) { + Hop input11 = input1.getInput().get(0); + Hop input12 = input1.getInput().get(1); + + if (input11 instanceof BinaryOp && ((BinaryOp) input11).getOp() == OpOp2.MULT) { + //ternary, arbitrary matrices but no mv/outer operations. + ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) && HopRewriteUtils + .isEqualSize(input11.getInput().get(1), input1) && HopRewriteUtils + .isEqualSize(input12, input1); + } else if (input12 instanceof BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT) { + //ternary, arbitrary matrices but no mv/outer operations. + ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) && HopRewriteUtils + .isEqualSize(input12.getInput().get(1), input1) && HopRewriteUtils + .isEqualSize(input11, input1); + } else { + //binary, arbitrary matrices but no mv/outer operations. + ret = HopRewriteUtils.isEqualSize(input11, input12); + } } } } - return ret; } @@ -627,28 +632,63 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop private Lop constructLopsTernaryAggregateRewrite(ExecType et) throws HopsException, LopsException { - Hop input1 = getInput().get(0); + BinaryOp input1 = (BinaryOp)getInput().get(0); Hop input11 = input1.getInput().get(0); Hop input12 = input1.getInput().get(1); - Lop in1 = null; - Lop in2 = null; - Lop in3 = null; - - if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT ) - { - in1 = input11.getInput().get(0).constructLops(); - in2 = input11.getInput().get(1).constructLops(); - in3 = input12.constructLops(); - } - else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT ) - { + Lop in1 = null, in2 = null, in3 = null; + boolean handled = false; + + if (input1.getOp() == OpOp2.POW) { + assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3"; in1 = input11.constructLops(); - in2 = input12.getInput().get(0).constructLops(); - in3 = input12.getInput().get(1).constructLops(); + in2 = in1; + in3 = in1; + handled = true; + } else if (input11 instanceof BinaryOp ) { + BinaryOp b11 = (BinaryOp)input11; + switch( b11.getOp() ) { + case MULT: // A*B*C case + in1 = input11.getInput().get(0).constructLops(); + in2 = input11.getInput().get(1).constructLops(); + in3 = input12.constructLops(); + handled = true; + break; + case POW: // A*A*B case + Hop b112 = b11.getInput().get(1); + if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT) + && HopRewriteUtils.isLiteralOfValue(b112, 2) ) { + in1 = b11.getInput().get(0).constructLops(); + in2 = in1; + in3 = input12.constructLops(); + handled = true; + } + break; + default: break; + } + } else if( input12 instanceof BinaryOp ) { + BinaryOp b12 = (BinaryOp)input12; + switch (b12.getOp()) { + case MULT: // A*B*C case + in1 = input11.constructLops(); + in2 = input12.getInput().get(0).constructLops(); + in3 = input12.getInput().get(1).constructLops(); + handled = true; + break; + case POW: // A*B*B case + Hop b112 = b12.getInput().get(1); + if ( HopRewriteUtils.isLiteralOfValue(b112, 2) ) { + in1 = b12.getInput().get(0).constructLops(); + in2 = in1; + in3 = input11.constructLops(); + handled = true; + } + break; + default: break; + } } - else - { + + if (!handled) { in1 = input11.constructLops(); in2 = input12.constructLops(); in3 = new LiteralOp(1).constructLops(); http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/main/java/org/apache/sysml/hops/OptimizerUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java index a40e36c..79b7ee6 100644 --- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java @@ -110,8 +110,8 @@ public class OptimizerUtils */ public static boolean ALLOW_CONSTANT_FOLDING = true; - public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true; - public static boolean ALLOW_OPERATOR_FUSION = true; + public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true; + public static boolean ALLOW_OPERATOR_FUSION = true; /** * Enables if-else branch removal for constant predicates (original literals or http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java index 8cb5e1e..9ac21fc 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java @@ -35,6 +35,8 @@ import org.apache.sysml.parser.Expression; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.utils.Explain; +import com.google.common.collect.Lists; + import static org.apache.sysml.hops.HopsException.check; /** @@ -89,7 +91,8 @@ public class HopDagValidator { //check visit status final boolean seen = !state.seen.add(id); check(seen == hop.isVisited(), hop, - "seen previously is %b but does not match hop visit status", seen); + "(parents: %s) seen previously is %b but does not match hop visit status", + Lists.transform(hop.getParent(), Hop::getHopID), seen); if (seen) return; // we saw the Hop previously, no need to re-validate //check parent linking http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index cf6081b..b98901a 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -241,11 +241,20 @@ public class HopRewriteUtils parent.getInput().add( pos, child ); child.getParent().add( parent ); } - - public static void rewireAllParentChildReferences( Hop hold, Hop hnew ) { - ArrayList<Hop> parents = new ArrayList<Hop>(hold.getParent()); - for( Hop lparent : parents ) - HopRewriteUtils.replaceChildReference(lparent, hold, hnew); + + /** + * Replace an old Hop with a replacement Hop. + * If the old Hop has no parents, then return the replacement. + * Otherwise rewire each of the Hop's parents into the replacement and return the replacement. + * @param hold To be replaced + * @param hnew The replacement + * @return hnew + */ + public static Hop rewireAllParentChildReferences( Hop hold, Hop hnew ) { + ArrayList<Hop> parents = hold.getParent(); + while (!parents.isEmpty()) + HopRewriteUtils.replaceChildReference(parents.get(0), hold, hnew); + return hnew; } public static void replaceChildReference( Hop parent, Hop inOld, Hop inNew ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 0e65f3f..7ee3ccb 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -96,6 +96,8 @@ public class ProgramRewriter _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); + if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) + _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index ad80c05..91c5972 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -53,6 +53,8 @@ import org.apache.sysml.parser.DataExpression; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; +import static org.apache.sysml.hops.OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; + /** * Rule: Algebraic Simplifications. Simplifies binary expressions * in terms of two major purposes: (1) rewrite binary operations @@ -2050,7 +2052,16 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) //no other consumer than sum && hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) - && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) ) + && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) + && ( !ALLOW_SUM_PRODUCT_REWRITES + || !( HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) // do not rewrite (A^2)*B + && hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp // let tak+* handle it + && ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2 )) + && ( !ALLOW_SUM_PRODUCT_REWRITES + || !( HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) // do not rewrite B*(A^2) + && hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp // let tak+* handle it + && ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2 )) + ) { baLeft = hi2.getInput().get(0); baRight = hi2.getInput().get(1); http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java new file mode 100644 index 0000000..9ca0932 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.rewrite; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; + +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.parser.Expression; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; + +/** + * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule. + * + * Rewrite a chain of element-wise multiply hops that contain identical elements. + * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power. + * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity. + * + * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain, + * since foreign parents may rely on the individual results. + */ +public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { + @Override + public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { + if( roots == null ) + return null; + for( int i=0; i<roots.size(); i++ ) { + Hop h = roots.get(i); + roots.set(i, rule_RewriteEMult(h)); + } + return roots; + } + + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException { + if( root == null ) + return null; + return rule_RewriteEMult(root); + } + + private static boolean isBinaryMult(final Hop hop) { + return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT; + } + + private static Hop rule_RewriteEMult(final Hop root) { + if (root.isVisited()) + return root; + root.setVisited(); + + // 1. Find immediate subtree of EMults. + if (isBinaryMult(root)) { + final Hop left = root.getInput().get(0), right = root.getInput().get(1); + final Set<BinaryOp> emults = new HashSet<>(); + final Multiset<Hop> leaves = HashMultiset.create(); + findEMultsAndLeaves((BinaryOp)root, emults, leaves); + + // 2. Ensure it is profitable to do a rewrite. + if (isOptimizable(emults, leaves)) { + // 3. Check for foreign parents. + // A foreign parent is a parent of some EMult that is not in the set. + // Foreign parents destroy correctness of this rewrite. + final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) && + (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right)); + if (okay) { + // 4. Construct replacement EMults for the leaves + final Hop replacement = constructReplacement(emults, leaves); + if (LOG.isDebugEnabled()) + LOG.debug(String.format( + "Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d", + emults.size(), root.getHopID(), replacement.getHopID())); + + // 5. Replace root with replacement + final Hop newRoot = HopRewriteUtils.rewireAllParentChildReferences(root, replacement); + + // 6. Recurse at leaves (no need to repeat the interior emults) + for (final Hop leaf : leaves.elementSet()) { + recurseInputs(leaf); + } + return newRoot; + } + } + } + // This rewrite is not applicable to the current root. + // Try the root's children. + recurseInputs(root); + return root; + } + + private static void recurseInputs(final Hop parent) { + final ArrayList<Hop> inputs = parent.getInput(); + for (int i = 0; i < inputs.size(); i++) { + final Hop input = inputs.get(i); + final Hop newInput = rule_RewriteEMult(input); + inputs.set(i, newInput); + } + } + + private static Hop constructReplacement(final Set<BinaryOp> emults, final Multiset<Hop> leaves) { + // Sort by data type + final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType); + for (final Multiset.Entry<Hop> entry : leaves.entrySet()) { + final Hop h = entry.getElement(); + // unlink parents that are in the emult set(we are throwing them away) + // keep other parents + h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent)); + sorted.put(h, entry.getCount()); + } + // sorted contains all leaves, sorted by data type, stripped from their parents + + // Construct right-deep EMult tree + final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator(); + Hop first = constructPower(iterator.next()); + + for (int i = 1; i < sorted.size(); i++) { + final Hop second = constructPower(iterator.next()); + first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT); + first.setVisited(); + } + return first; + } + + private static Hop constructPower(final Map.Entry<Hop, Integer> entry) { + final Hop hop = entry.getKey(); + final int cnt = entry.getValue(); + assert(cnt >= 1); + hop.setVisited(); // we will visit the leaves' children next + if (cnt == 1) + return hop; + Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); + pow.setVisited(); + return pow; + } + + /** + * A Comparator that orders Hops by their data type, dimention, and sparsity. + * The order is as follows: + * scalars > row vectors > col vectors > + * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) > + * other data types. + * Disambiguate by Hop ID. + */ + private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() { + private final int[] orderDataType = new int[Expression.DataType.values().length]; + { + for (int i = 0, valuesLength = Expression.DataType.values().length; i < valuesLength; i++) + switch(Expression.DataType.values()[i]) { + case SCALAR: orderDataType[i] = 4; break; + case MATRIX: orderDataType[i] = 3; break; + case FRAME: orderDataType[i] = 2; break; + case OBJECT: orderDataType[i] = 1; break; + case UNKNOWN:orderDataType[i] = 0; break; + } + } + + @Override + public final int compare(Hop o1, Hop o2) { + int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]); + if (c != 0) return c; + + // o1 and o2 have the same data type + switch (o1.getDataType()) { + case MATRIX: + // two matrices; check for vectors + if (o1.getDim1() == 1) { // row vector + if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices + return compareBySparsityThenId(o1, o2); // both row vectors + } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not + return -1; // row vectors are the greatest matrices + } else if (o1.getDim2() == 1) { // col vector + if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors + return compareBySparsityThenId(o1, o2); // both col vectors + } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not + return 1; // col vectors greater than non-vectors + } else { // both non-vectors + return compareBySparsityThenId(o1, o2); + } + default: + return Long.compare(o1.getHopID(), o2.getHopID()); + } + } + private int compareBySparsityThenId(Hop o1, Hop o2) { + // the hop with more nnz is first; unknown nnz (-1) last + int c = Long.compare(o1.getNnz(), o2.getNnz()); + if (c != 0) return c; + return Long.compare(o1.getHopID(), o2.getHopID()); + } + }.reversed(); + + /** + * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults. + * @param emults The set of BinaryOp element-wise multiply hops in the emult chain. + * @param child An interior emult hop in the emult chain dag. + * @return Whether this interior emult or any child emult has a foreign parent. + */ + private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) { + final ArrayList<Hop> parents = child.getParent(); + if (parents.size() > 1) + for (final Hop parent : parents) + if (parent instanceof BinaryOp && !emults.contains(parent)) + return false; + // child does not have foreign parents + + final ArrayList<Hop> inputs = child.getInput(); + final Hop left = inputs.get(0), right = inputs.get(1); + return (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) && + (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right)); + } + + /** + * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively. + * @param root Root of sub-dag + * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root). + * @param leaves Out parameter. The multiset of multiplicands in the emult chain. + */ + private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) { + // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality. + emults.add(root); + + final ArrayList<Hop> inputs = root.getInput(); + final Hop left = inputs.get(0), right = inputs.get(1); + + if (isBinaryMult(left)) + findEMultsAndLeaves((BinaryOp) left, emults, leaves); + else + leaves.add(left); + + if (isBinaryMult(right)) + findEMultsAndLeaves((BinaryOp) right, emults, leaves); + else + leaves.add(right); + } + + /** + * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands). + * @param emults The set of BinaryOp element-wise multiply hops in the emult chain. + * @param leaves The multiset of multiplicands in the emult chain. + * @return If the multiset is worth optimizing. + */ + private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) { + return emults.size() >= 2; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/main/java/org/apache/sysml/utils/Explain.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java index 5cf0548..6451396 100644 --- a/src/main/java/org/apache/sysml/utils/Explain.java +++ b/src/main/java/org/apache/sysml/utils/Explain.java @@ -566,7 +566,7 @@ public class Explain childs.append(" ("); boolean childAdded = false; for( Hop input : hop.getInput() ) - if( !(input instanceof LiteralOp) ){ + if( SHOW_LITERAL_HOPS || !(input instanceof LiteralOp) ){ childs.append(childAdded?",":""); childs.append(input.getHopID()); childAdded = true; http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java index 523a648..f78e598 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java @@ -134,6 +134,6 @@ public class ElementwiseAdditionMultiplicationTest extends AutomatedTestBase runTest(); - compareResults(); + compareResults(1e-10); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java new file mode 100644 index 0000000..91cb4e0 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import java.util.HashMap; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`. + */ +public class RewriteElementwiseMultChainOptimizationTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewriteEMultChainOpXYX"; + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationTest.class.getSimpleName() + "/"; + + private static final int rows = 123; + private static final int cols = 321; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + } + + @Test + public void testMatrixMultChainOptNoRewritesCP() { + testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP); + } + + @Test + public void testMatrixMultChainOptNoRewritesSP() { + testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK); + } + + @Test + public void testMatrixMultChainOptRewritesCP() { + testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP); + } + + @Test + public void testMatrixMultChainOptRewritesSP() { + testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK); + } + + private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et) + { + RUNTIME_PLATFORM platformOld = rtplatform; + switch( et ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") }; + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + double Xsparsity = 0.8, Ysparsity = 0.6; + double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7); + double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3); + writeInputMatrixWithMTD("X", X, true); + writeInputMatrixWithMTD("Y", Y, true); + + //execute tests + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + //check for presence of power operator, if we did a rewrite + if( rewrites ) { + Assert.assertTrue(heavyHittersContainsSubString("^2")); + } + } + finally { + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld; + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java new file mode 100644 index 0000000..12525c9 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.ternary; + +import java.util.HashMap; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.hops.rewrite.RewriteElementwiseMultChainOptimization; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.instructions.Instruction; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; + +/** + * Similar to {@link TernaryAggregateTest} except that it tests `sum(A*B*A)`. + * Checks compatibility with {@link RewriteElementwiseMultChainOptimization}. + */ +public class ABATernaryAggregateTest extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "ABATernaryAggregateRC"; + private final static String TEST_NAME2 = "ABATernaryAggregateC"; + private final static String TEST_NAME3 = "AAATernaryAggregateRC"; + private final static String TEST_NAME4 = "AAATernaryAggregateC"; + + private final static String TEST_DIR = "functions/ternary/"; + private final static String TEST_CLASS_DIR = TEST_DIR + ABATernaryAggregateTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + + private final static int rows = 111; + private final static int cols = 101; + + private final static double sparsity1 = 0.7; + private final static double sparsity2 = 0.3; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) ); + addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) ); + } + + @Test + public void testTernaryAggregateRCDenseVectorCP() { + runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseVectorCP() { + runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCDenseMatrixCP() { + runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseMatrixCP() { + runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCDenseVectorSP() { + runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCSparseVectorSP() { + runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCDenseMatrixSP() { + runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCSparseMatrixSP() { + runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCDenseVectorMR() { + runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateRCSparseVectorMR() { + runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateRCDenseMatrixMR() { + runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateRCSparseMatrixMR() { + runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateCDenseVectorCP() { + runTernaryAggregateTest(TEST_NAME2, false, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseVectorCP() { + runTernaryAggregateTest(TEST_NAME2, true, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseMatrixCP() { + runTernaryAggregateTest(TEST_NAME2, false, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseMatrixCP() { + runTernaryAggregateTest(TEST_NAME2, true, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseVectorSP() { + runTernaryAggregateTest(TEST_NAME2, false, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateCSparseVectorSP() { + runTernaryAggregateTest(TEST_NAME2, true, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateCDenseMatrixSP() { + runTernaryAggregateTest(TEST_NAME2, false, false, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateCSparseMatrixSP() { + runTernaryAggregateTest(TEST_NAME2, true, false, true, ExecType.SPARK); + } + + //additional tests to check default without rewrites + + @Test + public void testTernaryAggregateRCDenseVectorCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME1, false, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseVectorCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME1, true, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCDenseMatrixCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME1, false, false, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseMatrixCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME1, true, false, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseVectorCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME2, false, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseVectorCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME2, true, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseMatrixCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME2, false, false, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseMatrixCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME2, true, false, false, ExecType.CP); + } + + + // another set of tests for the case of sum(A*A*A) + + @Test + public void testTernaryAggregateRCDenseVectorCP_AAA() { + runTernaryAggregateTest(TEST_NAME3, false, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseVectorCP_AAA() { + runTernaryAggregateTest(TEST_NAME3, true, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCDenseMatrixCP_AAA() { + runTernaryAggregateTest(TEST_NAME3, false, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseMatrixCP_AAA() { + runTernaryAggregateTest(TEST_NAME3, true, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCDenseVectorSP_AAA() { + runTernaryAggregateTest(TEST_NAME3, false, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCSparseVectorSP_AAA() { + runTernaryAggregateTest(TEST_NAME3, true, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCDenseMatrixSP_AAA() { + runTernaryAggregateTest(TEST_NAME3, false, false, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCSparseMatrixSP_AAA() { + runTernaryAggregateTest(TEST_NAME3, true, false, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCDenseVectorMR_AAA() { + runTernaryAggregateTest(TEST_NAME3, false, true, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateRCSparseVectorMR_AAA() { + runTernaryAggregateTest(TEST_NAME3, true, true, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateRCDenseMatrixMR_AAA() { + runTernaryAggregateTest(TEST_NAME3, false, false, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateRCSparseMatrixMR_AAA() { + runTernaryAggregateTest(TEST_NAME3, true, false, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateCDenseVectorCP_AAA() { + runTernaryAggregateTest(TEST_NAME4, false, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseVectorCP_AAA() { + runTernaryAggregateTest(TEST_NAME4, true, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseMatrixCP_AAA() { + runTernaryAggregateTest(TEST_NAME4, false, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseMatrixCP_AAA() { + runTernaryAggregateTest(TEST_NAME4, true, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseVectorSP_AAA() { + runTernaryAggregateTest(TEST_NAME4, false, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateCSparseVectorSP_AAA() { + runTernaryAggregateTest(TEST_NAME4, true, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateCDenseMatrixSP_AAA() { + runTernaryAggregateTest(TEST_NAME4, false, false, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateCSparseMatrixSP_AAA() { + runTernaryAggregateTest(TEST_NAME4, true, false, true, ExecType.SPARK); + } + + //additional tests to check default without rewrites + + @Test + public void testTernaryAggregateRCDenseVectorCPNoRewrite_AAA() { + runTernaryAggregateTest(TEST_NAME3, false, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseVectorCPNoRewrite_AAA() { + runTernaryAggregateTest(TEST_NAME3, true, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCDenseMatrixCPNoRewrite_AAA() { + runTernaryAggregateTest(TEST_NAME3, false, false, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseMatrixCPNoRewrite_AAA() { + runTernaryAggregateTest(TEST_NAME3, true, false, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseVectorCPNoRewrite_AAA() { + runTernaryAggregateTest(TEST_NAME4, false, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseVectorCPNoRewrite_AAA() { + runTernaryAggregateTest(TEST_NAME4, true, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseMatrixCPNoRewrite_AAA() { + runTernaryAggregateTest(TEST_NAME4, false, false, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseMatrixCPNoRewrite_AAA() { + runTernaryAggregateTest(TEST_NAME4, true, false, false, ExecType.CP); + } + + + + private void runTernaryAggregateTest(String testname, boolean sparse, boolean vectors, boolean rewrites, ExecType et) + { + //rtplatform for MR + RUNTIME_PLATFORM platformOld = rtplatform; + switch( et ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-explain","-stats","-args", input("A"), output("R")}; + + fullRScriptName = HOME + testname + ".R"; + rCmd = "Rscript" + " " + fullRScriptName + " " + + inputDir() + " " + expectedDir(); + + //generate actual dataset + double sparsity = sparse ? sparsity2 : sparsity1; + double[][] A = getRandomMatrix(vectors ? rows*cols : rows, + vectors ? 1 : cols, 0, 1, sparsity, 17); + writeInputMatrixWithMTD("A", A, true); + + //run test cases + runTest(true, false, null, -1); + runRScript(true); + + //compare output matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + //check for rewritten patterns in statistics output + if( rewrites && et != ExecType.MR ) { + String opcode = ((et == ExecType.SPARK) ? Instruction.SP_INST_PREFIX : "") + + (((testname.equals(TEST_NAME1) || testname.equals(TEST_NAME3) || vectors ) ? "tak+*" : "tack+*")); + Assert.assertTrue(Statistics.getCPHeavyHitterOpCodes().contains(opcode)); + } + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R new file mode 100644 index 0000000..fec61ae --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep=""))) + +R = 2 * X * 3 * Y * 4 * X; + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml new file mode 100644 index 0000000..88f252f --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +X = read($1); +Y = read($2); + +R = 2 * X * 3 * Y * 4 * X; + +write(R, $3); \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/scripts/functions/ternary/AAATernaryAggregateC.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateC.R b/src/test/scripts/functions/ternary/AAATernaryAggregateC.R new file mode 100644 index 0000000..a096c2b --- /dev/null +++ b/src/test/scripts/functions/ternary/AAATernaryAggregateC.R @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) + +R = t(as.matrix(colSums(A * A * A))); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml b/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml new file mode 100644 index 0000000..b576a4d --- /dev/null +++ b/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); + +if(1==1){} + +R = colSums(A * A * A); + +write(R, $2); http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R new file mode 100644 index 0000000..776ddd0 --- /dev/null +++ b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) + +s = sum(A * A * A); +R = as.matrix(s); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml new file mode 100644 index 0000000..7283703 --- /dev/null +++ b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); + +if(1==1){} + +s = sum(A * A * A); +R = as.matrix(s); + +write(R, $2); \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/scripts/functions/ternary/ABATernaryAggregateC.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.R b/src/test/scripts/functions/ternary/ABATernaryAggregateC.R new file mode 100644 index 0000000..9601089 --- /dev/null +++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B = A * 2; + +R = t(as.matrix(colSums(A * B * A))); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml new file mode 100644 index 0000000..737b409 --- /dev/null +++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); +B = A * 2; + +if(1==1){} + +R = colSums(A * B * A); + +write(R, $2); http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R new file mode 100644 index 0000000..6552c7e --- /dev/null +++ b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B = A * 2; + +s = sum(A * B * A); +R = as.matrix(s); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml new file mode 100644 index 0000000..965c8d3 --- /dev/null +++ b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); +B = A * 2; + +if(1==1){} + +s = sum(A * B * A); +R = as.matrix(s); + +write(R, $2); \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index e352e6d..860cdbe 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -50,6 +50,7 @@ import org.junit.runners.Suite; ReadAfterWriteTest.class, RewriteCSETransposeScalarTest.class, RewriteCTableToRExpandTest.class, + RewriteElementwiseMultChainOptimizationTest.class, RewriteEliminateAggregatesTest.class, RewriteFuseBinaryOpChainTest.class, RewriteFusedRandTest.class, http://git-wip-us.apache.org/repos/asf/systemml/blob/7fec7fa5/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java index 784177d..ee14359 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java @@ -22,10 +22,11 @@ package org.apache.sysml.test.integration.functions.ternary; import org.junit.runner.RunWith; import org.junit.runners.Suite; -/** Group together the tests in this package into a single suite so that the Maven build +/* Group together the tests in this package into a single suite so that the Maven build * won't run two of them at once. */ @RunWith(Suite.class) @Suite.SuiteClasses({ + ABATernaryAggregateTest.class, CentralMomentWeightsTest.class, CovarianceWeightsTest.class, CTableMatrixIgnoreZerosTest.class,