Review comments, part 1
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b94557fd Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b94557fd Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b94557fd Branch: refs/heads/master Commit: b94557fd2c90c591179cdbf05a32242fadc36448 Parents: d88f867 Author: Dylan Hutchison <dhutc...@cs.washington.edu> Authored: Sun Jun 11 00:35:52 2017 -0700 Committer: Dylan Hutchison <dhutc...@cs.washington.edu> Committed: Sun Jun 18 17:43:37 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/AggUnaryOp.java | 21 +- .../org/apache/sysml/hops/OptimizerUtils.java | 5 - .../sysml/hops/rewrite/HopRewriteUtils.java | 3 +- .../sysml/hops/rewrite/ProgramRewriter.java | 4 +- .../RewriteAlgebraicSimplificationDynamic.java | 16 +- .../apache/sysml/hops/rewrite/RewriteEMult.java | 284 ------------------- ...RewriteElementwiseMultChainOptimization.java | 281 ++++++++++++++++++ .../functions/misc/RewriteEMultChainTest.java | 127 --------- ...ementwiseMultChainOptimizationChainTest.java | 127 +++++++++ .../ternary/ABATernaryAggregateTest.java | 9 +- .../functions/misc/ZPackageSuite.java | 1 + .../functions/ternary/ZPackageSuite.java | 3 +- 12 files changed, 436 insertions(+), 445 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/AggUnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java index 300a20c..a207831 100644 --- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java @@ -497,7 +497,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop if (binput1.getOp() == OpOp2.POW && binput1.getInput().get(1) instanceof LiteralOp) { LiteralOp lit = (LiteralOp)binput1.getInput().get(1); - ret = lit.getLongValue() == 3; + ret = HopRewriteUtils.getIntValueSafe(lit) == 3; } else if (binput1.getOp() == OpOp2.MULT // As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it. @@ -640,15 +640,10 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop boolean handled = false; if (input1.getOp() == OpOp2.POW) { - switch ((int)((LiteralOp)input12).getLongValue()) { - case 3: - in1 = input11.constructLops(); - in2 = in1; - in3 = in1; - break; - default: - throw new AssertionError("unreachable; only applies to power 3"); - } + assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3"; + in1 = input11.constructLops(); + in2 = in1; + in3 = in1; handled = true; } else if (input11 instanceof BinaryOp ) { BinaryOp b11 = (BinaryOp)input11; @@ -662,8 +657,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop case POW: // A*A*B case Hop b112 = b11.getInput().get(1); if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT) - && b112 instanceof LiteralOp - && ((LiteralOp)b112).getLongValue() == 2) { + && HopRewriteUtils.isLiteralOfValue(b112, 2) ) { in1 = b11.getInput().get(0).constructLops(); in2 = in1; in3 = input12.constructLops(); @@ -682,8 +676,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop break; case POW: // A*B*B case Hop b112 = b12.getInput().get(1); - if ( b112 instanceof LiteralOp - && ((LiteralOp)b112).getLongValue() == 2) { + if ( HopRewriteUtils.isLiteralOfValue(b112, 2) ) { in1 = b12.getInput().get(0).constructLops(); in2 = in1; in3 = input11.constructLops(); http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/OptimizerUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java index 2a76d07..79b7ee6 100644 --- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java @@ -111,11 +111,6 @@ public class OptimizerUtils public static boolean ALLOW_CONSTANT_FOLDING = true; public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true; - /** - * Enables rewriting chains of element-wise multiplies that contain the same multiplicand more than once, as in - * `A*B*A ==> (A^2)*B`. - */ - public static boolean ALLOW_EMULT_CHAIN_REWRITE = true; public static boolean ALLOW_OPERATOR_FUSION = true; /** http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 17ac4ec..8f71359 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -251,8 +251,7 @@ public class HopRewriteUtils * @return replacement */ public static Hop replaceHop(final Hop old, final Hop replacement) { - final ArrayList<Hop> rootParents = old.getParent(); - if (rootParents.isEmpty()) + if (old.getParent().isEmpty()) return replacement; // new old! HopRewriteUtils.rewireAllParentChildReferences(old, replacement); return replacement; http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index b6aab38..1053850 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -96,8 +96,8 @@ public class ProgramRewriter _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); - if ( OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE ) - _dagRuleSet.add( new RewriteEMult() ); //dependency: cse + if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) + _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 166af2f..91c5972 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -53,6 +53,8 @@ import org.apache.sysml.parser.DataExpression; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; +import static org.apache.sysml.hops.OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; + /** * Rule: Algebraic Simplifications. Simplifies binary expressions * in terms of two major purposes: (1) rewrite binary operations @@ -2051,12 +2053,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule && hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) - && !(HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) // do not rewrite (A^2)*B - && hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp // let tak+* handle it - && ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2) - && !(HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) // do not rewrite B*(A^2) - && hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp // let tak+* handle it - && ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2) + && ( !ALLOW_SUM_PRODUCT_REWRITES + || !( HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) // do not rewrite (A^2)*B + && hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp // let tak+* handle it + && ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2 )) + && ( !ALLOW_SUM_PRODUCT_REWRITES + || !( HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) // do not rewrite B*(A^2) + && hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp // let tak+* handle it + && ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2 )) ) { baLeft = hi2.getInput().get(0); http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java deleted file mode 100644 index 5cd1471..0000000 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java +++ /dev/null @@ -1,284 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.hops.rewrite; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashSet; -import java.util.Iterator; -import java.util.Map; -import java.util.Set; -import java.util.SortedMap; -import java.util.TreeMap; - -import org.apache.sysml.hops.BinaryOp; -import org.apache.sysml.hops.Hop; -import org.apache.sysml.hops.HopsException; -import org.apache.sysml.hops.LiteralOp; -import org.apache.sysml.parser.Expression; - -import com.google.common.collect.HashMultiset; -import com.google.common.collect.Multiset; - -/** - * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule. - * - * Rewrite a chain of element-wise multiply hops that contain identical elements. - * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power. - * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity. - * - * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain, - * since foreign parents may rely on the individual results. - */ -public class RewriteEMult extends HopRewriteRule { - @Override - public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { - if( roots == null ) - return null; - for( int i=0; i<roots.size(); i++ ) { - Hop h = roots.get(i); - roots.set(i, rule_RewriteEMult(h)); - } - return roots; - } - - @Override - public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException { - if( root == null ) - return null; - return rule_RewriteEMult(root); - } - - private static boolean isBinaryMult(final Hop hop) { - return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT; - } - - private static Hop rule_RewriteEMult(final Hop root) { - if (root.isVisited()) - return root; - root.setVisited(); - - // 1. Find immediate subtree of EMults. - if (isBinaryMult(root)) { - final Hop left = root.getInput().get(0), right = root.getInput().get(1); - final Set<BinaryOp> emults = new HashSet<>(); - final Multiset<Hop> leaves = HashMultiset.create(); - findEMultsAndLeaves((BinaryOp)root, emults, leaves); - - // 2. Ensure it is profitable to do a rewrite. - if (isOptimizable(emults, leaves)) { - // 3. Check for foreign parents. - // A foreign parent is a parent of some EMult that is not in the set. - // Foreign parents destroy correctness of this rewrite. - final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) && - (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right)); - if (okay) { - // 4. Construct replacement EMults for the leaves - final Hop replacement = constructReplacement(leaves); - if (LOG.isDebugEnabled()) - LOG.debug(String.format( - "Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d", - emults.size(), root.getHopID(), replacement.getHopID())); - - // 5. Replace root with replacement - final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement); - - // 6. Recurse at leaves (no need to repeat the interior emults) - for (final Hop leaf : leaves.elementSet()) { - recurseInputs(leaf); - } - return newRoot; - } - } - } - // This rewrite is not applicable to the current root. - // Try the root's children. - recurseInputs(root); - return root; - } - - private static void recurseInputs(final Hop parent) { - final ArrayList<Hop> inputs = parent.getInput(); - for (int i = 0; i < inputs.size(); i++) { - final Hop input = inputs.get(i); - final Hop newInput = rule_RewriteEMult(input); - inputs.set(i, newInput); - } - } - - private static Hop constructReplacement(final Multiset<Hop> leaves) { - // Sort by data type - final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType); - for (final Multiset.Entry<Hop> entry : leaves.entrySet()) { - final Hop h = entry.getElement(); - // unlink parents (the EMults, which we are throwing away) - h.getParent().clear(); - sorted.put(h, entry.getCount()); - } - // sorted contains all leaves, sorted by data type, stripped from their parents - - // Construct right-deep EMult tree - final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator(); - Hop first = constructPower(iterator.next()); - - for (int i = 1; i < sorted.size(); i++) { - final Hop second = constructPower(iterator.next()); - first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT); - first.setVisited(); - } - return first; - } - - private static Hop constructPower(Map.Entry<Hop, Integer> entry) { - final Hop hop = entry.getKey(); - final int cnt = entry.getValue(); - assert(cnt >= 1); - if (cnt == 1) - return hop; // don't set this visited... we will visit this next - Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); - pow.setVisited(); - return pow; - } - - /** - * A Comparator that orders Hops by their data type, dimention, and sparsity. - * The order is as follows: - * scalars > row vectors > col vectors > - * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) > - * other data types. - * Disambiguate by Hop ID. - */ - private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() { - @Override - public final int compare(Hop o1, Hop o2) { - int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]); - if (c != 0) return c; - - // o1 and o2 have the same data type - switch (o1.getDataType()) { - case SCALAR: return Long.compare(o1.getHopID(), o2.getHopID()); - case MATRIX: - // two matrices; check for vectors - if (o1.getDim1() == 1) { // row vector - if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices - return compareBySparsityThenId(o1, o2); // both row vectors - } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not - return -1; // row vectors are the greatest matrices - } else if (o1.getDim2() == 1) { // col vector - if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors - return compareBySparsityThenId(o1, o2); // both col vectors - } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not - return 1; // col vectors greater than non-vectors - } else { // both non-vectors - return compareBySparsityThenId(o1, o2); - } - default: - return Long.compare(o1.getHopID(), o2.getHopID()); - } - } - private int compareBySparsityThenId(Hop o1, Hop o2) { - // the hop with more nnz is first; unknown nnz (-1) last - int c = Long.compare(o1.getNnz(), o2.getNnz()); - if (c != 0) return c; - return Long.compare(o1.getHopID(), o2.getHopID()); - } - private final int[] orderDataType; - { - Expression.DataType[] dtValues = Expression.DataType.values(); - orderDataType = new int[dtValues.length]; - for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) { - switch(dtValues[i]) { - case SCALAR: - orderDataType[i] = 4; - break; - case MATRIX: - orderDataType[i] = 3; - break; - case FRAME: - orderDataType[i] = 2; - break; - case OBJECT: - orderDataType[i] = 1; - break; - case UNKNOWN: - orderDataType[i] = 0; - break; - } - } - } - }; - - /** - * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults. - * @param emults The set of BinaryOp element-wise multiply hops in the emult chain. - * @param child An interior emult hop in the emult chain dag. - * @return Whether this interior emult or any child emult has a foreign parent. - */ - private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) { - final ArrayList<Hop> parents = child.getParent(); - if (parents.size() > 1) - for (final Hop parent : parents) - //noinspection SuspiciousMethodCalls - if (!emults.contains(parent)) - return false; - // child does not have foreign parents - - final ArrayList<Hop> inputs = child.getInput(); - final Hop left = inputs.get(0), right = inputs.get(1); - return (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) && - (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right)); - } - - /** - * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively. - * @param root Root of sub-dag - * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root). - * @param leaves Out parameter. The multiset of multiplicands in the emult chain. - */ - private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) { - // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality. - emults.add(root); - - final ArrayList<Hop> inputs = root.getInput(); - final Hop left = inputs.get(0), right = inputs.get(1); - - if (isBinaryMult(left)) findEMultsAndLeaves((BinaryOp) left, emults, leaves); - else leaves.add(left); - - if (isBinaryMult(right)) findEMultsAndLeaves((BinaryOp) right, emults, leaves); - else leaves.add(right); - } - - /** - * Only optimize a subtree of emults if there are at least two emults. - * @param emults The set of BinaryOp element-wise multiply hops in the emult chain. - * @param leaves The multiset of multiplicands in the emult chain. - * @return If the multiset is worth optimizing. - */ - private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) { - // Old criterion: there should be at least one repeated leaf -// for (Multiset.Entry<Hop> hopEntry : leaves.entrySet()) { -// if (hopEntry.getCount() > 1) -// return true; -// } -// return false; - return emults.size() >= 2; - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java new file mode 100644 index 0000000..bd873ff --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.rewrite; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; + +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.parser.Expression; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; + +/** + * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule. + * + * Rewrite a chain of element-wise multiply hops that contain identical elements. + * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power. + * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity. + * + * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain, + * since foreign parents may rely on the individual results. + */ +public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { + @Override + public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { + if( roots == null ) + return null; + for( int i=0; i<roots.size(); i++ ) { + Hop h = roots.get(i); + roots.set(i, rule_RewriteEMult(h)); + } + return roots; + } + + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException { + if( root == null ) + return null; + return rule_RewriteEMult(root); + } + + private static boolean isBinaryMult(final Hop hop) { + return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT; + } + + private static Hop rule_RewriteEMult(final Hop root) { + if (root.isVisited()) + return root; + root.setVisited(); + + // 1. Find immediate subtree of EMults. + if (isBinaryMult(root)) { + final Hop left = root.getInput().get(0), right = root.getInput().get(1); + final Set<BinaryOp> emults = new HashSet<>(); + final Multiset<Hop> leaves = HashMultiset.create(); + findEMultsAndLeaves((BinaryOp)root, emults, leaves); + + // 2. Ensure it is profitable to do a rewrite. + if (isOptimizable(emults, leaves)) { + // 3. Check for foreign parents. + // A foreign parent is a parent of some EMult that is not in the set. + // Foreign parents destroy correctness of this rewrite. + final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) && + (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right)); + if (okay) { + // 4. Construct replacement EMults for the leaves + final Hop replacement = constructReplacement(leaves); + if (LOG.isDebugEnabled()) + LOG.debug(String.format( + "Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d", + emults.size(), root.getHopID(), replacement.getHopID())); + + // 5. Replace root with replacement + final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement); + + // 6. Recurse at leaves (no need to repeat the interior emults) + for (final Hop leaf : leaves.elementSet()) { + recurseInputs(leaf); + } + return newRoot; + } + } + } + // This rewrite is not applicable to the current root. + // Try the root's children. + recurseInputs(root); + return root; + } + + private static void recurseInputs(final Hop parent) { + final ArrayList<Hop> inputs = parent.getInput(); + for (int i = 0; i < inputs.size(); i++) { + final Hop input = inputs.get(i); + final Hop newInput = rule_RewriteEMult(input); + inputs.set(i, newInput); + } + } + + private static Hop constructReplacement(final Multiset<Hop> leaves) { + // Sort by data type + final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType); + for (final Multiset.Entry<Hop> entry : leaves.entrySet()) { + final Hop h = entry.getElement(); + // unlink parents (the EMults, which we are throwing away) + h.getParent().clear(); + sorted.put(h, entry.getCount()); + } + // sorted contains all leaves, sorted by data type, stripped from their parents + + // Construct right-deep EMult tree + final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator(); + Hop first = constructPower(iterator.next()); + + for (int i = 1; i < sorted.size(); i++) { + final Hop second = constructPower(iterator.next()); + first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT); + first.setVisited(); + } + return first; + } + + private static Hop constructPower(Map.Entry<Hop, Integer> entry) { + final Hop hop = entry.getKey(); + final int cnt = entry.getValue(); + assert(cnt >= 1); + if (cnt == 1) + return hop; // don't set this visited... we will visit this next + Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); + pow.setVisited(); + return pow; + } + + /** + * A Comparator that orders Hops by their data type, dimention, and sparsity. + * The order is as follows: + * scalars > row vectors > col vectors > + * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) > + * other data types. + * Disambiguate by Hop ID. + */ + private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() { + @Override + public final int compare(Hop o1, Hop o2) { + int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]); + if (c != 0) return c; + + // o1 and o2 have the same data type + switch (o1.getDataType()) { + case MATRIX: + // two matrices; check for vectors + if (o1.getDim1() == 1) { // row vector + if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices + return compareBySparsityThenId(o1, o2); // both row vectors + } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not + return -1; // row vectors are the greatest matrices + } else if (o1.getDim2() == 1) { // col vector + if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors + return compareBySparsityThenId(o1, o2); // both col vectors + } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not + return 1; // col vectors greater than non-vectors + } else { // both non-vectors + return compareBySparsityThenId(o1, o2); + } + default: + return Long.compare(o1.getHopID(), o2.getHopID()); + } + } + private int compareBySparsityThenId(Hop o1, Hop o2) { + // the hop with more nnz is first; unknown nnz (-1) last + int c = Long.compare(o1.getNnz(), o2.getNnz()); + if (c != 0) return c; + return Long.compare(o1.getHopID(), o2.getHopID()); + } + private final int[] orderDataType; + { + Expression.DataType[] dtValues = Expression.DataType.values(); + orderDataType = new int[dtValues.length]; + for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) { + switch(dtValues[i]) { + case SCALAR: + orderDataType[i] = 4; + break; + case MATRIX: + orderDataType[i] = 3; + break; + case FRAME: + orderDataType[i] = 2; + break; + case OBJECT: + orderDataType[i] = 1; + break; + case UNKNOWN: + orderDataType[i] = 0; + break; + } + } + } + }; + + /** + * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults. + * @param emults The set of BinaryOp element-wise multiply hops in the emult chain. + * @param child An interior emult hop in the emult chain dag. + * @return Whether this interior emult or any child emult has a foreign parent. + */ + private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) { + final ArrayList<Hop> parents = child.getParent(); + if (parents.size() > 1) + for (final Hop parent : parents) + //noinspection SuspiciousMethodCalls (for Intellij, which checks when + if (!emults.contains(parent)) + return false; + // child does not have foreign parents + + final ArrayList<Hop> inputs = child.getInput(); + final Hop left = inputs.get(0), right = inputs.get(1); + return (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) && + (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right)); + } + + /** + * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively. + * @param root Root of sub-dag + * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root). + * @param leaves Out parameter. The multiset of multiplicands in the emult chain. + */ + private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) { + // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality. + emults.add(root); + + final ArrayList<Hop> inputs = root.getInput(); + final Hop left = inputs.get(0), right = inputs.get(1); + + if (isBinaryMult(left)) + findEMultsAndLeaves((BinaryOp) left, emults, leaves); + else + leaves.add(left); + + if (isBinaryMult(right)) + findEMultsAndLeaves((BinaryOp) right, emults, leaves); + else + leaves.add(right); + } + + /** + * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands). + * @param emults The set of BinaryOp element-wise multiply hops in the emult chain. + * @param leaves The multiset of multiplicands in the emult chain. + * @return If the multiset is worth optimizing. + */ + private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) { + return emults.size() >= 2; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java deleted file mode 100644 index 85dbea4..0000000 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.test.integration.functions.misc; - -import java.util.HashMap; - -import org.apache.sysml.api.DMLScript; -import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; -import org.apache.sysml.hops.OptimizerUtils; -import org.apache.sysml.lops.LopProperties.ExecType; -import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; -import org.apache.sysml.test.integration.AutomatedTestBase; -import org.apache.sysml.test.integration.TestConfiguration; -import org.apache.sysml.test.utils.TestUtils; -import org.junit.Assert; -import org.junit.Test; - -/** - * Test whether `A*B*A` successfully rewrites to `(A^2)*B`. - */ -public class RewriteEMultChainTest extends AutomatedTestBase -{ - private static final String TEST_NAME1 = "RewriteEMultChainOpXYX"; - private static final String TEST_DIR = "functions/misc/"; - private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/"; - - private static final int rows = 123; - private static final int cols = 321; - private static final double eps = Math.pow(10, -10); - - @Override - public void setUp() { - TestUtils.clearAssertionInformation(); - addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); - } - - @Test - public void testMatrixMultChainOptNoRewritesCP() { - testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP); - } - - @Test - public void testMatrixMultChainOptNoRewritesSP() { - testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK); - } - - @Test - public void testMatrixMultChainOptRewritesCP() { - testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP); - } - - @Test - public void testMatrixMultChainOptRewritesSP() { - testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK); - } - - private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et) - { - RUNTIME_PLATFORM platformOld = rtplatform; - switch( et ){ - case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; - case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; - default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; - } - - boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - if( rtplatform == RUNTIME_PLATFORM.SPARK ) - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - - boolean rewritesOld = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE; - OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites; - - try - { - TestConfiguration config = getTestConfiguration(testname); - loadTestConfiguration(config); - - String HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = HOME + testname + ".dml"; - programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") }; - fullRScriptName = HOME + testname + ".R"; - rCmd = getRCmd(inputDir(), expectedDir()); - - double Xsparsity = 0.8, Ysparsity = 0.6; - double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7); - double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3); - writeInputMatrixWithMTD("X", X, true); - writeInputMatrixWithMTD("Y", Y, true); - - //execute tests - runTest(true, false, null, -1); - runRScript(true); - - //compare matrices - HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); - HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); - TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); - - //check for presence of power operator, if we did a rewrite - if( rewrites ) { - Assert.assertTrue(heavyHittersContainsSubString("^2")); - } - } - finally { - OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOld; - rtplatform = platformOld; - DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; - } - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java new file mode 100644 index 0000000..47b2f0e --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import java.util.HashMap; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test whether `A*B*A` successfully rewrites to `(A^2)*B`. + */ +public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewriteEMultChainOpXYX"; + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationChainTest.class.getSimpleName() + "/"; + + private static final int rows = 123; + private static final int cols = 321; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + } + + @Test + public void testMatrixMultChainOptNoRewritesCP() { + testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP); + } + + @Test + public void testMatrixMultChainOptNoRewritesSP() { + testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK); + } + + @Test + public void testMatrixMultChainOptRewritesCP() { + testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP); + } + + @Test + public void testMatrixMultChainOptRewritesSP() { + testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK); + } + + private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et) + { + RUNTIME_PLATFORM platformOld = rtplatform; + switch( et ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") }; + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + double Xsparsity = 0.8, Ysparsity = 0.6; + double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7); + double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3); + writeInputMatrixWithMTD("X", X, true); + writeInputMatrixWithMTD("Y", Y, true); + + //execute tests + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + //check for presence of power operator, if we did a rewrite + if( rewrites ) { + Assert.assertTrue(heavyHittersContainsSubString("^2")); + } + } + finally { + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld; + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java index 1829bf0..460829d 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java @@ -24,6 +24,7 @@ import java.util.HashMap; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.hops.rewrite.RewriteElementwiseMultChainOptimization; import org.apache.sysml.lops.LopProperties.ExecType; import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; @@ -36,7 +37,7 @@ import org.junit.Test; /** * Similar to {@link TernaryAggregateTest} except that it tests `sum(A*B*A)`. - * Checks compatibility with {@link org.apache.sysml.hops.rewrite.RewriteEMult}. + * Checks compatibility with {@link RewriteElementwiseMultChainOptimization}. */ public class ABATernaryAggregateTest extends AutomatedTestBase { @@ -368,14 +369,14 @@ public class ABATernaryAggregateTest extends AutomatedTestBase DMLScript.USE_LOCAL_SPARK_CONFIG = true; boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES, - rewritesOldEmult = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE; + rewritesOldEmult = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; try { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; - OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites; + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; @@ -411,7 +412,7 @@ public class ABATernaryAggregateTest extends AutomatedTestBase rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld; - OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOldEmult; + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOldEmult; } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index e352e6d..deea784 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -50,6 +50,7 @@ import org.junit.runners.Suite; ReadAfterWriteTest.class, RewriteCSETransposeScalarTest.class, RewriteCTableToRExpandTest.class, + RewriteElementwiseMultChainOptimizationChainTest.class, RewriteEliminateAggregatesTest.class, RewriteFuseBinaryOpChainTest.class, RewriteFusedRandTest.class, http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java index 784177d..ee14359 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java @@ -22,10 +22,11 @@ package org.apache.sysml.test.integration.functions.ternary; import org.junit.runner.RunWith; import org.junit.runners.Suite; -/** Group together the tests in this package into a single suite so that the Maven build +/* Group together the tests in this package into a single suite so that the Maven build * won't run two of them at once. */ @RunWith(Suite.class) @Suite.SuiteClasses({ + ABATernaryAggregateTest.class, CentralMomentWeightsTest.class, CovarianceWeightsTest.class, CTableMatrixIgnoreZerosTest.class,