Repository: systemml
Updated Branches:
  refs/heads/master a03065299 -> a66126d49


[SYSTEMML-1990] New rewrite for order operation chains

This patch introduces a new rewrite for merging subsequent order
operations (data, scalar order-by column, and with consistent descending
configuration and single consumers) into a single order operation with
multiple order-by columns.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f366c469
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f366c469
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f366c469

Branch: refs/heads/master
Commit: f366c46960aac412a862c20e07e5f844b58b05a7
Parents: a030652
Author: Matthias Boehm <[email protected]>
Authored: Wed Nov 8 17:41:35 2017 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Wed Nov 8 17:41:35 2017 -0800

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 72 +++++++++++++++++---
 .../RewriteAlgebraicSimplificationStatic.java   | 60 +++++++++++++++-
 .../cp/StringInitCPInstruction.java             |  2 +-
 .../reorg/MultipleOrderByColsTest.java          | 30 +++++++-
 .../scripts/functions/reorg/OrderMultiBy.dml    |  5 --
 .../scripts/functions/reorg/OrderMultiBy2.R     | 42 ++++++++++++
 .../scripts/functions/reorg/OrderMultiBy2.dml   | 29 ++++++++
 7 files changed, 223 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f366c469/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 15cc2cb..28b2189 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -67,6 +67,7 @@ import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.instructions.cp.ScalarObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;
+import org.apache.sysml.runtime.instructions.cp.StringInitCPInstruction;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
@@ -199,6 +200,17 @@ public class HopRewriteUtils
                        && getDoubleValueSafe((LiteralOp)hop)==val);
        }
        
+       public static boolean isLiteralOfValue( Hop hop, boolean val ) {
+               try {
+                       return (hop instanceof LiteralOp 
+                               && (hop.getValueType()==ValueType.BOOLEAN)
+                               && ((LiteralOp)hop).getBooleanValue()==val);
+               }
+               catch(HopsException ex) {
+                       throw new RuntimeException(ex);
+               }
+       }
+       
        public static ScalarObject getScalarObject( LiteralOp op )
        {
                try {
@@ -481,6 +493,32 @@ public class HopRewriteUtils
                return datagen;
        }
        
+       public static Hop createDataGenOpByVal( ArrayList<LiteralOp> values, 
long rows, long cols ) 
+               throws HopsException
+       {
+               StringBuilder sb = new StringBuilder();
+               for(LiteralOp lit : values) {
+                       if(sb.length()>0)
+                               sb.append(StringInitCPInstruction.DELIM);
+                       sb.append(lit.getStringValue());
+               }
+               LiteralOp str = new LiteralOp(sb.toString());
+               
+               HashMap<String, Hop> params = new HashMap<>();
+               params.put(DataExpression.RAND_ROWS, new LiteralOp(rows));
+               params.put(DataExpression.RAND_COLS, new LiteralOp(cols));
+               params.put(DataExpression.RAND_MIN, str);
+               params.put(DataExpression.RAND_MAX, str);
+               params.put(DataExpression.RAND_SEED, new 
LiteralOp(DataGenOp.UNSPECIFIED_SEED));
+               
+               Hop datagen = new DataGenOp(DataGenMethod.SINIT, new 
DataIdentifier("tmp"), params);
+               int blksz = ConfigurationManager.getBlocksize();
+               datagen.setOutputBlocksizes(blksz, blksz);
+               copyLineNumbers(values.get(0), datagen);
+               
+               return datagen;
+       }
+       
        public static boolean isDataGenOp(Hop hop, DataGenMethod... ops) {
                return (hop instanceof DataGenOp 
                        && ArrayUtils.contains(ops, ((DataGenOp)hop).getOp()));
@@ -506,14 +544,21 @@ public class HopRewriteUtils
                return createReorg(input, ReOrgOp.TRANSPOSE);
        }
        
-       public static ReorgOp createReorg(Hop input, ReOrgOp rop)
-       {
-               ReorgOp transpose = new ReorgOp(input.getName(), 
input.getDataType(), input.getValueType(), rop, input);
-               transpose.setOutputBlocksizes(input.getRowsInBlock(), 
input.getColsInBlock());
-               copyLineNumbers(input, transpose);
-               transpose.refreshSizeInformation();     
-               
-               return transpose;
+       public static ReorgOp createReorg(Hop input, ReOrgOp rop) {
+               ReorgOp reorg = new ReorgOp(input.getName(), 
input.getDataType(), input.getValueType(), rop, input);
+               reorg.setOutputBlocksizes(input.getRowsInBlock(), 
input.getColsInBlock());
+               copyLineNumbers(input, reorg);
+               reorg.refreshSizeInformation();
+               return reorg;
+       }
+       
+       public static ReorgOp createReorg(ArrayList<Hop> inputs, ReOrgOp rop) {
+               Hop main = inputs.get(0);
+               ReorgOp reorg = new ReorgOp(main.getName(), main.getDataType(), 
main.getValueType(), rop, inputs);
+               reorg.setOutputBlocksizes(main.getRowsInBlock(), 
main.getColsInBlock());
+               copyLineNumbers(main, reorg);
+               reorg.refreshSizeInformation();
+               return reorg;
        }
        
        public static UnaryOp createUnary(Hop input, OpOp1 type) 
@@ -831,8 +876,17 @@ public class HopRewriteUtils
                return ret;
        }
 
+       public static boolean isReorg(Hop hop, ReOrgOp type) {
+               return hop instanceof ReorgOp && ((ReorgOp)hop).getOp()==type;
+       }
+       
+       public static boolean isReorg(Hop hop, ReOrgOp... types) {
+               return ( hop instanceof ReorgOp 
+                       && ArrayUtils.contains(types, ((ReorgOp) hop).getOp()));
+       }
+       
        public static boolean isTransposeOperation(Hop hop) {
-               return (hop instanceof ReorgOp && 
((ReorgOp)hop).getOp()==ReOrgOp.TRANSPOSE);
+               return isReorg(hop, ReOrgOp.TRANSPOSE);
        }
        
        public static boolean isTransposeOperation(Hop hop, int maxParents) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/f366c469/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 2d5d881..4c68fe2 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -171,6 +171,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifySlicedMatrixMult(hop, hi, i);           
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
                        hi = simplifyConstantSort(hop, hi, i);               
//e.g., order(matrix())->matrix/seq; 
                        hi = simplifyOrderedSort(hop, hi, i);                
//e.g., order(matrix())->seq; 
+                       hi = fuseOrderOperationChain(hi);                    
//e.g., order(order(X,2),1) -> order(X,(12))
                        hi = removeUnnecessaryReorgOperation(hop, hi, i);    
//e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
                        hi = simplifyTransposeAggBinBinaryChains(hop, hi, 
i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C)
                        hi = removeUnnecessaryMinus(hop, hi, i);             
//e.g., -(-X)->X; potentially introduced by simplify binary or dyn rewrites
@@ -1475,12 +1476,69 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                                                LOG.debug("Applied 
simplifyOrderedSort2.");
                                        }
                                }
-                       }          
+                       }
                }
                
                return hi;
        }
 
+       private static Hop fuseOrderOperationChain(Hop hi) 
+               throws HopsException
+       {
+               //order(order(X,2),1) -> order(X, (12)), 
+               if( HopRewriteUtils.isReorg(hi, ReOrgOp.SORT)
+                       && hi.getInput().get(1) instanceof LiteralOp //scalar by
+                       && hi.getInput().get(2) instanceof LiteralOp //scalar 
desc
+                       && 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) ) //not ixret 
+               { 
+                       LiteralOp by = (LiteralOp) hi.getInput().get(1);
+                       boolean desc = 
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2));
+                       
+                       //find chain of order operations with same desc/ixret 
configuration and single consumers
+                       ArrayList<LiteralOp> byList = new 
ArrayList<LiteralOp>();
+                       byList.add(by);
+                       Hop input = hi.getInput().get(0);
+                       while( HopRewriteUtils.isReorg(input, ReOrgOp.SORT)
+                               && input.getInput().get(1) instanceof LiteralOp 
//scalar by
+                               && 
HopRewriteUtils.isLiteralOfValue(input.getInput().get(2), desc)
+                               && 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false)
+                               && input.getParent().size() == 1 ) 
+                       {
+                               byList.add((LiteralOp)input.getInput().get(1));
+                               input = input.getInput().get(0);
+                       }
+                       
+                       //merge order chain if at least two instances
+                       if( byList.size() >= 2 ) {
+                               //create new order operations
+                               ArrayList<Hop> inputs = new ArrayList<>();
+                               inputs.add(input);
+                               
inputs.add(HopRewriteUtils.createDataGenOpByVal(byList, 1, byList.size()));
+                               inputs.add(new LiteralOp(desc));
+                               inputs.add(new LiteralOp(false));
+                               Hop hnew = HopRewriteUtils.createReorg(inputs, 
ReOrgOp.SORT);
+                               
+                               //cleanup references recursively
+                               Hop current = hi;
+                               while(current != input ) {
+                                       Hop tmp = current.getInput().get(0);
+                                       
HopRewriteUtils.removeAllChildReferences(current);
+                                       current = tmp;
+                               }
+                               
+                               //rewire all parents (avoid anomalies with 
replicated datagen)
+                               List<Hop> parents = new 
ArrayList<>(hi.getParent());
+                               for( Hop p : parents )
+                                       
HopRewriteUtils.replaceChildReference(p, hi, hnew);
+                               
+                               hi = hnew;
+                               LOG.debug("Applied fuseOrderOperationChain 
(line "+hi.getBeginLine()+").");
+                       }
+               }
+               
+               return hi;
+       }
+       
        /**
         * Patterns: t(t(A)%*%t(B)+C) -> B%*%A+t(C)
         * 

http://git-wip-us.apache.org/repos/asf/systemml/blob/f366c469/src/main/java/org/apache/sysml/runtime/instructions/cp/StringInitCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/StringInitCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/StringInitCPInstruction.java
index 4b89573..93e02b9 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/StringInitCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/StringInitCPInstruction.java
@@ -30,7 +30,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.operators.Operator;
 
 public class StringInitCPInstruction extends UnaryCPInstruction {
-       private static final String DELIM = " ";
+       public static final String DELIM = " ";
 
        private final long _rlen;
        private final long _clen;

http://git-wip-us.apache.org/repos/asf/systemml/blob/f366c469/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java
index 67c6487..10dc1a4 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java
@@ -21,6 +21,7 @@ package org.apache.sysml.test.integration.functions.reorg;
 
 import java.util.HashMap;
 
+import org.junit.Assert;
 import org.junit.Test;
 
 import org.apache.sysml.api.DMLScript;
@@ -30,10 +31,12 @@ import 
org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
 
 public class MultipleOrderByColsTest extends AutomatedTestBase 
 {
        private final static String TEST_NAME1 = "OrderMultiBy";
+       private final static String TEST_NAME2 = "OrderMultiBy2";
        
        private final static String TEST_DIR = "functions/reorg/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
MultipleOrderByColsTest.class.getSimpleName() + "/";
@@ -48,6 +51,7 @@ public class MultipleOrderByColsTest extends AutomatedTestBase
        public void setUp() {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,new String[]{"B"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[]{"B"}));
        }
        
        @Test
@@ -90,6 +94,26 @@ public class MultipleOrderByColsTest extends 
AutomatedTestBase
                runOrderTest(TEST_NAME1, true, true, true, ExecType.CP);
        }
 
+       @Test
+       public void testOrder2DenseAscDataCP() {
+               runOrderTest(TEST_NAME2, false, false, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testOrder2DenseDescDataCP() {
+               runOrderTest(TEST_NAME2, false, true, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testOrder2SparseAscDataCP() {
+               runOrderTest(TEST_NAME2, true, false, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testOrder2SparseDescDataCP() {
+               runOrderTest(TEST_NAME2, true, true, false, ExecType.CP);
+       }
+       
 //TODO enable together with additional spark sort runtime
 //     @Test
 //     public void testOrderDenseAscDataSP() {
@@ -152,7 +176,7 @@ public class MultipleOrderByColsTest extends 
AutomatedTestBase
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[]{"-explain","-args", 
input("A"), 
+                       programArgs = new String[]{"-stats","-args", 
input("A"), 
                                String.valueOf(desc).toUpperCase(), 
String.valueOf(ixret).toUpperCase(), output("B") };
                        
                        fullRScriptName = HOME + TEST_NAME + ".R";
@@ -170,6 +194,10 @@ public class MultipleOrderByColsTest extends 
AutomatedTestBase
                        HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("B");
                        HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS("B");
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
+                       
+                       //check for applied rewrite
+                       if( testname.equals(TEST_NAME2) && !ixret )
+                               
Assert.assertTrue(Statistics.getCPHeavyHitterCount("rsort")==1);
                }
                finally {
                        rtplatform = platformOld;

http://git-wip-us.apache.org/repos/asf/systemml/blob/f366c469/src/test/scripts/functions/reorg/OrderMultiBy.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/reorg/OrderMultiBy.dml 
b/src/test/scripts/functions/reorg/OrderMultiBy.dml
index f6d2246..78cf84e 100644
--- a/src/test/scripts/functions/reorg/OrderMultiBy.dml
+++ b/src/test/scripts/functions/reorg/OrderMultiBy.dml
@@ -23,11 +23,6 @@
 A = read($1);
 
 ix = matrix("3 7 14", rows=1, cols=3)
-
-#B = order(target=A, by=14, decreasing=$2, index.return=$3);
-#B = order(target=B, by=7, decreasing=$2, index.return=$3);
-#B = order(target=B, by=3, decreasing=$2, index.return=$3);
-
 B = order(target=A, by=ix, decreasing=$2, index.return=$3);
 
 write(B, $4, format="text");  
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/f366c469/src/test/scripts/functions/reorg/OrderMultiBy2.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/reorg/OrderMultiBy2.R 
b/src/test/scripts/functions/reorg/OrderMultiBy2.R
new file mode 100644
index 0000000..374dad0
--- /dev/null
+++ b/src/test/scripts/functions/reorg/OrderMultiBy2.R
@@ -0,0 +1,42 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A = readMM(paste(args[1], "A.mtx", sep=""))
+desc = as.logical(args[2]);
+ixret = as.logical(args[3]);
+
+col1 = A[,3];
+col2 = A[,7];
+col3 = A[,14];
+
+
+if( ixret ) {
+  B = order(col1, col2, col3, decreasing=desc);
+} else {
+  B = A[order(col1, col2, col3, decreasing=desc),];
+}
+
+writeMM(as(B,"CsparseMatrix"), paste(args[4], "B", sep=""))

http://git-wip-us.apache.org/repos/asf/systemml/blob/f366c469/src/test/scripts/functions/reorg/OrderMultiBy2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/reorg/OrderMultiBy2.dml 
b/src/test/scripts/functions/reorg/OrderMultiBy2.dml
new file mode 100644
index 0000000..0c301ae
--- /dev/null
+++ b/src/test/scripts/functions/reorg/OrderMultiBy2.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = read($1);
+
+B = order(target=A, by=14, decreasing=$2, index.return=$3);
+B = order(target=B, by=7, decreasing=$2, index.return=$3);
+B = order(target=B, by=3, decreasing=$2, index.return=$3);
+
+write(B, $4, format="text");

Reply via email to