Correct ordering of e-mult chain rewrites.

Sorting scalars, vectors, matrices appropriately and by sparsity (when nnz 
information is available).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8b832f62
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8b832f62
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8b832f62

Branch: refs/heads/master
Commit: 8b832f624dd23ba0006672c444cf6f0649a6e753
Parents: ff8c836
Author: Dylan Hutchison <[email protected]>
Authored: Fri Jun 9 20:48:57 2017 -0700
Committer: Dylan Hutchison <[email protected]>
Committed: Sun Jun 18 17:43:21 2017 -0700

----------------------------------------------------------------------
 .../apache/sysml/hops/rewrite/RewriteEMult.java | 78 ++++++++++++++++++--
 .../functions/misc/RewriteEMultChainTest.java   |  7 +-
 2 files changed, 74 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index 66da6fa..d483a08 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -32,6 +32,7 @@ import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.parser.Expression;
 
 import com.google.common.collect.HashMultiset;
 import com.google.common.collect.Multiset;
@@ -125,13 +126,13 @@ public class RewriteEMult extends HopRewriteRule {
                }
                // sorted contains all leaves, sorted by data type, stripped 
from their parents
 
-               // Construct left-deep EMult tree
-               Iterator<Map.Entry<Hop, Integer>> iterator = 
sorted.entrySet().iterator();
+               // Construct right-deep EMult tree
+               final Iterator<Map.Entry<Hop, Integer>> iterator = 
sorted.entrySet().iterator();
                Hop first = constructPower(iterator.next());
 
                for (int i = 1; i < sorted.size(); i++) {
                        final Hop second = constructPower(iterator.next());
-                       first = HopRewriteUtils.createBinary(first, second, 
Hop.OpOp2.MULT);
+                       first = HopRewriteUtils.createBinary(second, first, 
Hop.OpOp2.MULT);
                }
                return first;
        }
@@ -140,14 +141,75 @@ public class RewriteEMult extends HopRewriteRule {
                final Hop hop = entry.getKey();
                final int cnt = entry.getValue();
                assert(cnt >= 1);
-               if (cnt == 1)
-                       return hop;
+               if (cnt == 1) return hop;
                return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), 
Hop.OpOp2.POW);
        }
 
-       private static Comparator<Hop> compareByDataType = 
Comparator.comparing(Hop::getDataType)
-                       .thenComparing(Hop::getName)
-                       .thenComparingInt(Object::hashCode);
+
+
+       // Order: scalars > row vectors > col vectors >
+       //        non-vector matrices ordered by sparsity (higher nnz first, 
unknown sparsity last) >
+       //        other data types
+       // disambiguate by Hop ID
+       private static final Comparator<Hop> compareByDataType = new 
Comparator<Hop>() {
+               @Override
+               public final int compare(Hop o1, Hop o2) {
+                       int c = 
Integer.compare(orderDataType[o1.getDataType().ordinal()], 
orderDataType[o2.getDataType().ordinal()]);
+                       if (c != 0) return c;
+
+                       // o1 and o2 have the same data type
+                       switch (o1.getDataType()) {
+                       case SCALAR: return Long.compare(o1.getHopID(), 
o2.getHopID());
+                       case MATRIX:
+                               // two matrices; check for vectors
+                               if (o1.getDim1() == 1) { // row vector
+                                               if (o2.getDim1() != 1) return 
1; // row vectors are greatest of matrices
+                                               return 
compareBySparsityThenId(o1, o2); // both row vectors
+                               } else if (o2.getDim1() == 1) { // 2 is row 
vector; 1 is not
+                                               return -1; // row vectors are 
the greatest matrices
+                               } else if (o1.getDim2() == 1) { // col vector
+                                               if (o2.getDim2() != 1) return 
1; // col vectors greater than non-vectors
+                                               return 
compareBySparsityThenId(o1, o2); // both col vectors
+                               } else if (o2.getDim2() == 1) { // 2 is col 
vector; 1 is not
+                                               return 1; // col vectors 
greater than non-vectors
+                               } else { // both non-vectors
+                                               return 
compareBySparsityThenId(o1, o2);
+                               }
+                       default:
+                               return Long.compare(o1.getHopID(), 
o2.getHopID());
+                       }
+               }
+               private int compareBySparsityThenId(Hop o1, Hop o2) {
+                       // the hop with more nnz is first; unknown nnz (-1) last
+                       int c = Long.compare(o1.getNnz(), o2.getNnz());
+                       if (c != 0) return c;
+                       return Long.compare(o1.getHopID(), o2.getHopID());
+               }
+               private final int[] orderDataType;
+               {
+                       Expression.DataType[] dtValues = 
Expression.DataType.values();
+                       orderDataType = new int[dtValues.length];
+                       for (int i = 0, valuesLength = dtValues.length; i < 
valuesLength; i++) {
+                               switch(dtValues[i]) {
+                               case SCALAR:
+                                       orderDataType[i] = 4;
+                                       break;
+                               case MATRIX:
+                                       orderDataType[i] = 3;
+                                       break;
+                               case FRAME:
+                                       orderDataType[i] = 2;
+                                       break;
+                               case OBJECT:
+                                       orderDataType[i] = 1;
+                                       break;
+                               case UNKNOWN:
+                                       orderDataType[i] = 0;
+                                       break;
+                               }
+                       }
+               }
+       };
 
        private static boolean checkForeignParent(final Set<BinaryOp> emults, 
final BinaryOp child) {
                final ArrayList<Hop> parents = child.getParent();

http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
index e076c95..18ed55d 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -99,8 +99,9 @@ public class RewriteEMultChainTest extends AutomatedTestBase
                        fullRScriptName = HOME + testname + ".R";
                        rCmd = getRCmd(inputDir(), expectedDir());              
        
 
-                       double[][] X = getRandomMatrix(rows, cols, -1, 1, 
0.97d, 7);
-                       double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.9d, 
3);
+                       double Xsparsity = 0.8, Ysparsity = 0.6;
+                       double[][] X = getRandomMatrix(rows, cols, -1, 1, 
Xsparsity, 7);
+                       double[][] Y = getRandomMatrix(rows, cols, -1, 1, 
Ysparsity, 3);
                        writeInputMatrixWithMTD("X", X, true);
                        writeInputMatrixWithMTD("Y", Y, true);
                        
@@ -123,5 +124,5 @@ public class RewriteEMultChainTest extends AutomatedTestBase
                        rtplatform = platformOld;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                }
-       }       
+       }
 }

Reply via email to