Correct ordering of e-mult chain rewrites. Sorting scalars, vectors, matrices appropriately and by sparsity (when nnz information is available).
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8b832f62 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8b832f62 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8b832f62 Branch: refs/heads/master Commit: 8b832f624dd23ba0006672c444cf6f0649a6e753 Parents: ff8c836 Author: Dylan Hutchison <[email protected]> Authored: Fri Jun 9 20:48:57 2017 -0700 Committer: Dylan Hutchison <[email protected]> Committed: Sun Jun 18 17:43:21 2017 -0700 ---------------------------------------------------------------------- .../apache/sysml/hops/rewrite/RewriteEMult.java | 78 ++++++++++++++++++-- .../functions/misc/RewriteEMultChainTest.java | 7 +- 2 files changed, 74 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java index 66da6fa..d483a08 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java @@ -32,6 +32,7 @@ import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.parser.Expression; import com.google.common.collect.HashMultiset; import com.google.common.collect.Multiset; @@ -125,13 +126,13 @@ public class RewriteEMult extends HopRewriteRule { } // sorted contains all leaves, sorted by data type, stripped from their parents - // Construct left-deep EMult tree - Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator(); + // Construct right-deep EMult tree + final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator(); Hop first = constructPower(iterator.next()); for (int i = 1; i < sorted.size(); i++) { final Hop second = constructPower(iterator.next()); - first = HopRewriteUtils.createBinary(first, second, Hop.OpOp2.MULT); + first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT); } return first; } @@ -140,14 +141,75 @@ public class RewriteEMult extends HopRewriteRule { final Hop hop = entry.getKey(); final int cnt = entry.getValue(); assert(cnt >= 1); - if (cnt == 1) - return hop; + if (cnt == 1) return hop; return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); } - private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType) - .thenComparing(Hop::getName) - .thenComparingInt(Object::hashCode); + + + // Order: scalars > row vectors > col vectors > + // non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) > + // other data types + // disambiguate by Hop ID + private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() { + @Override + public final int compare(Hop o1, Hop o2) { + int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]); + if (c != 0) return c; + + // o1 and o2 have the same data type + switch (o1.getDataType()) { + case SCALAR: return Long.compare(o1.getHopID(), o2.getHopID()); + case MATRIX: + // two matrices; check for vectors + if (o1.getDim1() == 1) { // row vector + if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices + return compareBySparsityThenId(o1, o2); // both row vectors + } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not + return -1; // row vectors are the greatest matrices + } else if (o1.getDim2() == 1) { // col vector + if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors + return compareBySparsityThenId(o1, o2); // both col vectors + } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not + return 1; // col vectors greater than non-vectors + } else { // both non-vectors + return compareBySparsityThenId(o1, o2); + } + default: + return Long.compare(o1.getHopID(), o2.getHopID()); + } + } + private int compareBySparsityThenId(Hop o1, Hop o2) { + // the hop with more nnz is first; unknown nnz (-1) last + int c = Long.compare(o1.getNnz(), o2.getNnz()); + if (c != 0) return c; + return Long.compare(o1.getHopID(), o2.getHopID()); + } + private final int[] orderDataType; + { + Expression.DataType[] dtValues = Expression.DataType.values(); + orderDataType = new int[dtValues.length]; + for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) { + switch(dtValues[i]) { + case SCALAR: + orderDataType[i] = 4; + break; + case MATRIX: + orderDataType[i] = 3; + break; + case FRAME: + orderDataType[i] = 2; + break; + case OBJECT: + orderDataType[i] = 1; + break; + case UNKNOWN: + orderDataType[i] = 0; + break; + } + } + } + }; private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) { final ArrayList<Hop> parents = child.getParent(); http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java index e076c95..18ed55d 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java @@ -99,8 +99,9 @@ public class RewriteEMultChainTest extends AutomatedTestBase fullRScriptName = HOME + testname + ".R"; rCmd = getRCmd(inputDir(), expectedDir()); - double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.97d, 7); - double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.9d, 3); + double Xsparsity = 0.8, Ysparsity = 0.6; + double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7); + double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3); writeInputMatrixWithMTD("X", X, true); writeInputMatrixWithMTD("Y", Y, true); @@ -123,5 +124,5 @@ public class RewriteEMultChainTest extends AutomatedTestBase rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } - } + } }
