Change order of row and col vectors, so as to create inner products rather than outer products.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6c3e1c5b Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6c3e1c5b Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6c3e1c5b Branch: refs/heads/master Commit: 6c3e1c5bad30dc8f11ff9d3f412ce68873c37202 Parents: 04f692d Author: Dylan Hutchison <dhutc...@cs.washington.edu> Authored: Tue Jul 11 20:08:22 2017 -0700 Committer: Dylan Hutchison <dhutc...@cs.washington.edu> Committed: Tue Jul 11 20:08:22 2017 -0700 ---------------------------------------------------------------------- ...RewriteElementwiseMultChainOptimization.java | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/6c3e1c5b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java index 9ca0932..9cc8fcd 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java @@ -162,7 +162,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { /** * A Comparator that orders Hops by their data type, dimention, and sparsity. * The order is as follows: - * scalars > row vectors > col vectors > + * scalars > col vectors > row vectors > * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) > * other data types. * Disambiguate by Hop ID. @@ -181,23 +181,23 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { } @Override - public final int compare(Hop o1, Hop o2) { - int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]); + public final int compare(final Hop o1, final Hop o2) { + final int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]); if (c != 0) return c; // o1 and o2 have the same data type switch (o1.getDataType()) { case MATRIX: // two matrices; check for vectors - if (o1.getDim1() == 1) { // row vector - if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices - return compareBySparsityThenId(o1, o2); // both row vectors - } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not - return -1; // row vectors are the greatest matrices - } else if (o1.getDim2() == 1) { // col vector - if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors + if (o1.getDim2() == 1) { // col vector + if (o2.getDim2() != 1) return 1; // col vectors are greatest of matrices return compareBySparsityThenId(o1, o2); // both col vectors } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not + return -1; // col vectors are the greatest matrices + } else if (o1.getDim1() == 1) { // row vector + if (o2.getDim1() != 1) return 1; // row vectors greater than non-vectors + return compareBySparsityThenId(o1, o2); // both row vectors + } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not return 1; // col vectors greater than non-vectors } else { // both non-vectors return compareBySparsityThenId(o1, o2); @@ -206,9 +206,9 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { return Long.compare(o1.getHopID(), o2.getHopID()); } } - private int compareBySparsityThenId(Hop o1, Hop o2) { + private int compareBySparsityThenId(final Hop o1, final Hop o2) { // the hop with more nnz is first; unknown nnz (-1) last - int c = Long.compare(o1.getNnz(), o2.getNnz()); + final int c = Long.compare(o1.getNnz(), o2.getNnz()); if (c != 0) return c; return Long.compare(o1.getHopID(), o2.getHopID()); }