Repository: systemml
Updated Branches:
  refs/heads/master 4bc1fea87 -> 5df6ab6dd


[SYSTEMML-2007] New spark order operations w/ multiple order-by cols

This patch adds runtime support for distributed spark operations
regarding the recently added order w/ multiple order-by columns. We now
also enable the related automatic rewrite of consecutive order calls for
CP and Spark execution types.
 

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/27cabbc4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/27cabbc4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/27cabbc4

Branch: refs/heads/master
Commit: 27cabbc4730377d9e8e34d06855106687123c240
Parents: 4bc1fea
Author: Matthias Boehm <[email protected]>
Authored: Tue Nov 14 17:32:38 2017 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Tue Nov 14 18:45:00 2017 -0800

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/ReorgOp.java     |   3 +-
 .../RewriteAlgebraicSimplificationStatic.java   |   3 +-
 .../instructions/spark/ReorgSPInstruction.java  |  90 +++--
 .../spark/functions/IsBlockInList.java          |  53 +++
 .../spark/functions/IsBlockInRange.java         |   1 -
 .../instructions/spark/utils/RDDSortUtils.java  | 354 ++++++++++++++++---
 .../apache/sysml/runtime/util/SortUtils.java    |  16 +-
 .../reorg/MultipleOrderByColsTest.java          |  83 +++--
 8 files changed, 485 insertions(+), 118 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/hops/ReorgOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ReorgOp.java 
b/src/main/java/org/apache/sysml/hops/ReorgOp.java
index 4b55c9b..4d29336 100644
--- a/src/main/java/org/apache/sysml/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysml/hops/ReorgOp.java
@@ -372,7 +372,8 @@ public class ReorgOp extends Hop implements MultiThreadedHop
                                        }
                                }
                                else if( et==ExecType.SPARK ) {
-                                       boolean sortRewrite = 
!FORCE_DIST_SORT_INDEXES && isSortSPRewriteApplicable();
+                                       boolean sortRewrite = 
!FORCE_DIST_SORT_INDEXES 
+                                               && isSortSPRewriteApplicable() 
&& by.getDataType().isScalar();
                                        Lop transform1 = 
constructCPOrSparkSortLop(input, by, desc, ixret, et, sortRewrite);
                                        setOutputDimensions(transform1);
                                        setLineNumbers(transform1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index cbfb527..d71c4e0 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1509,7 +1509,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                if( HopRewriteUtils.isReorg(hi, ReOrgOp.SORT)
                        && hi.getInput().get(1) instanceof LiteralOp //scalar by
                        && hi.getInput().get(2) instanceof LiteralOp //scalar 
desc
-                       && 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) ) //not ixret 
+                       && 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) //not ixret
+                       && !OptimizerUtils.isHadoopExecutionMode() )
                { 
                        LiteralOp by = (LiteralOp) hi.getInput().get(1);
                        boolean desc = 
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2));

http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java
index c742a0b..8e11a55 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java
@@ -25,12 +25,14 @@ import java.util.Iterator;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.spark.api.java.function.PairFunction;
 
 import scala.Tuple2;
 
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.functionobjects.DiagIndex;
@@ -40,6 +42,7 @@ import org.apache.sysml.runtime.functionobjects.SwapIndex;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import 
org.apache.sysml.runtime.instructions.spark.functions.FilterDiagBlocksFunction;
+import org.apache.sysml.runtime.instructions.spark.functions.IsBlockInList;
 import org.apache.sysml.runtime.instructions.spark.functions.IsBlockInRange;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDSortUtils;
@@ -53,6 +56,7 @@ import 
org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
 import org.apache.sysml.runtime.matrix.operators.Operator;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.DataConverter;
+import org.apache.sysml.runtime.util.IndexRange;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
 public class ReorgSPInstruction extends UnarySPInstruction {
@@ -162,33 +166,46 @@ public class ReorgSPInstruction extends 
UnarySPInstruction {
                        boolean desc = ec.getScalarInput(_desc.getName(), 
_desc.getValueType(), _desc.isLiteral()).getBooleanValue();
                        boolean ixret = ec.getScalarInput(_ixret.getName(), 
_ixret.getValueType(), _ixret.isLiteral()).getBooleanValue();
                        boolean singleCol = (mcIn.getCols() == 1);
-                       
-                       //error handling unsupported operations
-                       //TODO additional spark sort runtime with multiple 
order columns
-                       if( cols.length > 1 ) 
-                               LOG.warn("Unsupported sort with multiple 
order-by columns. Falling back first sort column.");
-                       long col = cols[0];
-                       
-                       // extract column (if necessary) and sort 
                        out = in1;
-                       if( !singleCol ){
-                               out = out.filter(new IsBlockInRange(1, 
mcIn.getRows(), col, col, mcIn))
-                                       .mapValues(new 
ExtractColumn((int)UtilFunctions.computeCellInBlock(col, 
mcIn.getColsPerBlock())));
-                       }
                        
-                       //actual index/data sort operation
-                       if( ixret ) { //sort indexes 
-                               out = RDDSortUtils.sortIndexesByVal(out, !desc, 
mcIn.getRows(), mcIn.getRowsPerBlock());
-                       }       
-                       else if( singleCol && !desc) { //sort single-column 
matrix
-                               out = RDDSortUtils.sortByVal(out, 
mcIn.getRows(), mcIn.getRowsPerBlock());
-                       }
-                       else { //sort multi-column matrix
-                               if (! _bSortIndInMem)
+                       if( cols.length > mcIn.getColsPerBlock() ) 
+                               LOG.warn("Unsupported sort with number of 
order-by columns large than blocksize: "+cols.length);
+                       
+                       if( singleCol || cols.length==1 ) {
+                               // extract column (if necessary) and sort 
+                               if( !singleCol )
+                                       out = out.filter(new IsBlockInRange(1, 
mcIn.getRows(), cols[0], cols[0], mcIn))
+                                               .mapValues(new 
ExtractColumn((int)UtilFunctions.computeCellInBlock(cols[0], 
mcIn.getColsPerBlock())));
+                               
+                               //actual index/data sort operation
+                               if( ixret ) //sort indexes 
+                                       out = 
RDDSortUtils.sortIndexesByVal(out, !desc, mcIn.getRows(), 
mcIn.getRowsPerBlock());
+                               else if( singleCol && !desc) //sort 
single-column matrix
+                                       out = RDDSortUtils.sortByVal(out, 
mcIn.getRows(), mcIn.getRowsPerBlock());
+                               else if( !_bSortIndInMem ) //sort multi-column 
matrix w/ rewrite
                                        out = RDDSortUtils.sortDataByVal(out, 
in1, !desc, mcIn.getRows(), mcIn.getCols(), mcIn.getRowsPerBlock(), 
mcIn.getColsPerBlock());
-                               else
+                               else //sort multi-column matrix
                                        out = 
RDDSortUtils.sortDataByValMemSort(out, in1, !desc, mcIn.getRows(), 
mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock(), sec, 
(ReorgOperator) _optr);
                        }
+                       else { //general case: multi-column sort
+                               // extract columns (if necessary)
+                               if( cols.length < mcIn.getCols() )
+                                       out = out.filter(new 
IsBlockInList(cols, mcIn))
+                                               .mapToPair(new 
ExtractColumns(cols, mcIn));
+                               
+                               // append extracted columns (if necessary)
+                               if( mcIn.getCols() > mcIn.getColsPerBlock() )
+                                       out = RDDAggregateUtils.mergeByKey(out);
+                               
+                               //actual index/data sort operation
+                               if( ixret ) //sort indexes 
+                                       out = 
RDDSortUtils.sortIndexesByVals(out, !desc, mcIn.getRows(), (long)cols.length, 
mcIn.getRowsPerBlock());
+                               else if( cols.length==mcIn.getCols() && !desc) 
//sort single-column matrix
+                                       out = RDDSortUtils.sortByVals(out, 
mcIn.getRows(), cols.length, mcIn.getRowsPerBlock());
+                               else //sort multi-column matrix
+                                       out = RDDSortUtils.sortDataByVals(out, 
in1, !desc, mcIn.getRows(), mcIn.getCols(),
+                                               cols.length, 
mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
+                       }
                }
                else {
                        throw new DMLRuntimeException("Error: Incorrect opcode 
in ReorgSPInstruction:" + opcode);
@@ -323,5 +340,34 @@ public class ReorgSPInstruction extends UnarySPInstruction 
{
                        return arg0.sliceOperations(0, arg0.getNumRows()-1, 
_col, _col, new MatrixBlock());
                }
        }
+       
+       private static class ExtractColumns implements 
PairFunction<Tuple2<MatrixIndexes, MatrixBlock>,MatrixIndexes,MatrixBlock>
+       {
+               private static final long serialVersionUID = 
2902729186431711506L;
+               
+               private final long[] _cols;
+               private final int _brlen, _bclen;
+               
+               public ExtractColumns(long[] cols, MatrixCharacteristics mc) {
+                       _cols = cols;
+                       _brlen = mc.getRowsPerBlock();
+                       _bclen = mc.getColsPerBlock();
+               }
+               
+               public Tuple2<MatrixIndexes, MatrixBlock> 
call(Tuple2<MatrixIndexes, MatrixBlock> arg0)
+                       throws Exception 
+               {
+                       MatrixIndexes ix = arg0._1();
+                       MatrixBlock in = arg0._2();
+                       MatrixBlock out = new MatrixBlock(in.getNumRows(), 
_cols.length, true);
+                       for(int i=0; i<_cols.length; i++)
+                               if( UtilFunctions.isInBlockRange(ix, _brlen, 
_bclen, new IndexRange(1, Long.MAX_VALUE, _cols[i], _cols[i])) ) {
+                                       int index = 
UtilFunctions.computeCellInBlock(_cols[i], _bclen);
+                                       
out.leftIndexingOperations(in.sliceOperations(0, in.getNumRows()-1, index, 
index, new MatrixBlock()),
+                                               0, in.getNumRows()-1, i, i, 
out, UpdateType.INPLACE);
+                               }
+                       return new Tuple2<>(new MatrixIndexes(ix.getRowIndex(), 
1), out);
+               }
+       }
 }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInList.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInList.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInList.java
new file mode 100644
index 0000000..d0f879e
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInList.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.spark.functions;
+
+import org.apache.spark.api.java.function.Function;
+
+import scala.Tuple2;
+
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.util.UtilFunctions;
+
+public class IsBlockInList implements 
Function<Tuple2<MatrixIndexes,MatrixBlock>, Boolean> 
+{
+       private static final long serialVersionUID = -1956151588590369875L;
+       
+       private final long[] _cols;
+       private final int _brlen, _bclen;
+       
+       public IsBlockInList(long[] cols, MatrixCharacteristics mc) {
+               _cols = cols;
+               _brlen = mc.getRowsPerBlock();
+               _bclen = mc.getColsPerBlock();
+       }
+
+       @Override
+       public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> kv) 
+               throws Exception 
+       {
+               for( int i=0; i<_cols.length; i++ )
+                       if( UtilFunctions.isInBlockRange(kv._1(), _brlen, 
_bclen, 1, Long.MAX_VALUE, _cols[i], _cols[i]) )
+                               return true;
+               return false;
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInRange.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInRange.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInRange.java
index 2ad5fdf..8a0fdd2 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInRange.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInRange.java
@@ -30,7 +30,6 @@ import org.apache.sysml.runtime.util.UtilFunctions;
 
 public class IsBlockInRange implements 
Function<Tuple2<MatrixIndexes,MatrixBlock>, Boolean> 
 {
-       
        private static final long serialVersionUID = 5849687296021280540L;
        
        private long _rl; long _ru; long _cl; long _cu;

http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDSortUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDSortUtils.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDSortUtils.java
index bf63f0d..9232374 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDSortUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDSortUtils.java
@@ -35,16 +35,19 @@ import org.apache.spark.broadcast.Broadcast;
 import scala.Tuple2;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.functionobjects.SortIndex;
 import org.apache.sysml.runtime.instructions.spark.data.PartitionedBlock;
 import org.apache.sysml.runtime.instructions.spark.data.RowMatrixBlock;
 import 
org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.DataConverter;
+import org.apache.sysml.runtime.util.SortUtils;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
 public class RDDSortUtils 
@@ -65,88 +68,165 @@ public class RDDSortUtils
                //create binary block output
                JavaPairRDD<MatrixIndexes, MatrixBlock> ret = sdvals
                                .zipWithIndex()
-                       .mapPartitionsToPair(new 
ConvertToBinaryBlockFunction(rlen, brlen));
+                               .mapPartitionsToPair(new 
ConvertToBinaryBlockFunction(rlen, brlen));
                ret = RDDAggregateUtils.mergeByKey(ret, false);
                
                return ret;
        }
-
+       
        public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortByVal( 
JavaPairRDD<MatrixIndexes, MatrixBlock> in, 
                        JavaPairRDD<MatrixIndexes, MatrixBlock> in2, long rlen, 
int brlen )
        {
                //create value-index rdd from inputs
                JavaRDD<DoublePair> dvals = in.join(in2).values()
-                               .flatMap(new ExtractDoubleValuesFunction2());
+                       .flatMap(new ExtractDoubleValuesFunction2());
        
                //sort (creates sorted range per partition)
                long hdfsBlocksize = InfrastructureAnalyzer.getHDFSBlockSize();
                int numPartitions = 
(int)Math.ceil(((double)rlen*8)/hdfsBlocksize);
                JavaRDD<DoublePair> sdvals = dvals
-                               .sortBy(new CreateDoubleKeyFunction2(), true, 
numPartitions);
+                       .sortBy(new CreateDoubleKeyFunction2(), true, 
numPartitions);
 
                //create binary block output
                JavaPairRDD<MatrixIndexes, MatrixBlock> ret = sdvals
-                               .zipWithIndex()
-                       .mapPartitionsToPair(new 
ConvertToBinaryBlockFunction2(rlen, brlen));
+                       .zipWithIndex()
+                       .mapPartitionsToPair(new 
ConvertToBinaryBlockFunction2(rlen, brlen));
                ret = RDDAggregateUtils.mergeByKey(ret, false);         
                
                return ret;
        }
 
+       public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortByVals(
+               JavaPairRDD<MatrixIndexes, MatrixBlock> in, long rlen, long 
clen, int brlen )
+       {
+               //create value-index rdd from inputs
+               JavaRDD<MatrixBlock> dvals = in.values()
+                       .flatMap(new ExtractRowsFunction());
+               
+               //sort (creates sorted range per partition)
+               int numPartitions = SparkUtils.getNumPreferredPartitions(
+                       new MatrixCharacteristics(rlen, clen, brlen, brlen), 
in);
+               JavaRDD<MatrixBlock> sdvals = dvals
+                       .sortBy(new CreateDoubleKeysFunction(), true, 
numPartitions);
+               
+               //create binary block output
+               JavaPairRDD<MatrixIndexes, MatrixBlock> ret = sdvals
+                       .zipWithIndex()
+                       .mapPartitionsToPair(new 
ConvertToBinaryBlockFunction5(rlen, brlen));
+               ret = RDDAggregateUtils.mergeByKey(ret, false);
+               
+               return ret;
+       }
+       
        public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortIndexesByVal( 
JavaPairRDD<MatrixIndexes, MatrixBlock> val, 
                        boolean asc, long rlen, int brlen )
        {
                //create value-index rdd from inputs
                JavaPairRDD<ValueIndexPair, Double> dvals = val
-                               .flatMapToPair(new 
ExtractDoubleValuesWithIndexFunction(brlen));
-       
+                       .flatMapToPair(new 
ExtractDoubleValuesWithIndexFunction(brlen));
+               
                //sort (creates sorted range per partition)
                long hdfsBlocksize = InfrastructureAnalyzer.getHDFSBlockSize();
                int numPartitions = 
(int)Math.ceil(((double)rlen*16)/hdfsBlocksize);
                JavaRDD<ValueIndexPair> sdvals = dvals
-                               .sortByKey(new IndexComparator(asc), true, 
numPartitions)
-                               .keys(); //workaround for index comparator
-        
+                       .sortByKey(new IndexComparator(asc), true, 
numPartitions)
+                       .keys(); //workaround for index comparator
+               
                //create binary block output
                JavaPairRDD<MatrixIndexes, MatrixBlock> ret = sdvals
-                               .zipWithIndex()
-                       .mapPartitionsToPair(new 
ConvertToBinaryBlockFunction3(rlen, brlen));
-               ret = RDDAggregateUtils.mergeByKey(ret, false);         
+                       .zipWithIndex()
+                       .mapPartitionsToPair(new 
ConvertToBinaryBlockFunction3(rlen, brlen));
+               ret = RDDAggregateUtils.mergeByKey(ret, false);
                
-               return ret;     
+               return ret;
+       }
+       
+       public static JavaPairRDD<MatrixIndexes, MatrixBlock> 
sortIndexesByVals( JavaPairRDD<MatrixIndexes, MatrixBlock> in,
+                       boolean asc, long rlen, long clen, int brlen )
+       {
+               //create value-index rdd from inputs
+               JavaPairRDD<ValuesIndexPair, double[]> dvals = in
+                       .flatMapToPair(new 
ExtractDoubleValuesWithIndexFunction2(brlen));
+               
+               //sort (creates sorted range per partition)
+               int numPartitions = SparkUtils.getNumPreferredPartitions(
+                       new MatrixCharacteristics(rlen, clen+1, brlen, brlen));
+               JavaRDD<ValuesIndexPair> sdvals = dvals
+                       .sortByKey(new IndexComparator2(asc), true, 
numPartitions)
+                       .keys(); //workaround for index comparator
+               
+               //create binary block output
+               JavaPairRDD<MatrixIndexes, MatrixBlock> ret = sdvals
+                       .zipWithIndex()
+                       .mapPartitionsToPair(new 
ConvertToBinaryBlockFunction6(rlen, brlen));
+               ret = RDDAggregateUtils.mergeByKey(ret, false);
+               
+               return ret;
        }
 
        public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortDataByVal( 
JavaPairRDD<MatrixIndexes, MatrixBlock> val, 
-                       JavaPairRDD<MatrixIndexes, MatrixBlock> data, boolean 
asc, long rlen, long clen, int brlen, int bclen )
+               JavaPairRDD<MatrixIndexes, MatrixBlock> data, boolean asc, long 
rlen, long clen, int brlen, int bclen )
        {
                //create value-index rdd from inputs
                JavaPairRDD<ValueIndexPair, Double> dvals = val
-                               .flatMapToPair(new 
ExtractDoubleValuesWithIndexFunction(brlen));
-       
+                       .flatMapToPair(new 
ExtractDoubleValuesWithIndexFunction(brlen));
+               
                //sort (creates sorted range per partition)
                long hdfsBlocksize = InfrastructureAnalyzer.getHDFSBlockSize();
                int numPartitions = 
(int)Math.ceil(((double)rlen*16)/hdfsBlocksize);
                JavaRDD<ValueIndexPair> sdvals = dvals
-                               .sortByKey(new IndexComparator(asc), true, 
numPartitions)
-                               .keys(); //workaround for index comparator
-        
+                       .sortByKey(new IndexComparator(asc), true, 
numPartitions)
+                       .keys(); //workaround for index comparator
+               
                //create target indexes by original index
-               long numRep = (long)Math.ceil((double)clen/bclen);
                JavaPairRDD<MatrixIndexes, MatrixBlock> ixmap = sdvals
-                               .zipWithIndex()
-                               .mapToPair(new ExtractIndexFunction())
-                               .sortByKey()
-                       .mapPartitionsToPair(new 
ConvertToBinaryBlockFunction4(rlen, brlen));
-               ixmap = RDDAggregateUtils.mergeByKey(ixmap, false);             
+                       .zipWithIndex()
+                       .mapToPair(new ExtractIndexFunction())
+                       .sortByKey()
+                       .mapPartitionsToPair(new 
ConvertToBinaryBlockFunction4(rlen, brlen));
+               ixmap = RDDAggregateUtils.mergeByKey(ixmap, false);
+               
+               //actual data sort
+               return sortDataByIx(data, ixmap, rlen, clen, brlen, bclen);
+       }
+       
+       public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortDataByVals( 
JavaPairRDD<MatrixIndexes, MatrixBlock> val, 
+               JavaPairRDD<MatrixIndexes, MatrixBlock> data, boolean asc, long 
rlen, long clen, long clen2, int brlen, int bclen )
+       {
+               //create value-index rdd from inputs
+               JavaPairRDD<ValuesIndexPair, double[]> dvals = val
+                       .flatMapToPair(new 
ExtractDoubleValuesWithIndexFunction2(brlen));
                
+               //sort (creates sorted range per partition)
+               int numPartitions = SparkUtils.getNumPreferredPartitions(
+                       new MatrixCharacteristics(rlen, clen2+1, brlen, brlen));
+               JavaRDD<ValuesIndexPair> sdvals = dvals
+                       .sortByKey(new IndexComparator2(asc), true, 
numPartitions)
+                       .keys(); //workaround for index comparator
+               
+               //create target indexes by original index
+               JavaPairRDD<MatrixIndexes, MatrixBlock> ixmap = sdvals
+                       .zipWithIndex()
+                       .mapToPair(new ExtractIndexFunction2())
+                       .sortByKey()
+                       .mapPartitionsToPair(new 
ConvertToBinaryBlockFunction4(rlen, brlen));
+               ixmap = RDDAggregateUtils.mergeByKey(ixmap, false);
+               
+               //actual data sort
+               return sortDataByIx(data, ixmap, rlen, clen, brlen, bclen);
+       }
+       
+       public static JavaPairRDD<MatrixIndexes, MatrixBlock> 
sortDataByIx(JavaPairRDD<MatrixIndexes,MatrixBlock> data,
+               JavaPairRDD<MatrixIndexes,MatrixBlock> ixmap, long rlen, long 
clen, int brlen, int bclen) {
                //replicate indexes for all column blocks
+               long numRep = (long)Math.ceil((double)clen/bclen);
                JavaPairRDD<MatrixIndexes, MatrixBlock> rixmap = ixmap
-                               .flatMapToPair(new 
ReplicateVectorFunction(false, numRep));      
+                       .flatMapToPair(new ReplicateVectorFunction(false, 
numRep));
                
                //create binary block output
                JavaPairRDD<MatrixIndexes, RowMatrixBlock> ret = data
-                               .join(rixmap)
-                               .mapPartitionsToPair(new 
ShuffleMatrixBlockRowsFunction(rlen, brlen));
+                       .join(rixmap)
+                       .mapPartitionsToPair(new 
ShuffleMatrixBlockRowsFunction(rlen, brlen));
                return RDDAggregateUtils.mergeRowsByKey(ret);
        }
        
@@ -200,10 +280,23 @@ public class RDDSortUtils
 
                @Override
                public Iterator<Double> call(MatrixBlock arg0) 
-                       throws Exception 
-               {
+                       throws Exception {
                        return 
DataConverter.convertToDoubleList(arg0).iterator();
-               }               
+               }
+       }
+       
+       private static class ExtractRowsFunction implements 
FlatMapFunction<MatrixBlock,MatrixBlock> 
+       {
+               private static final long serialVersionUID = 
-2786968469468554974L;
+
+               @Override
+               public Iterator<MatrixBlock> call(MatrixBlock arg0) 
+                       throws Exception {
+                       ArrayList<MatrixBlock> rows = new ArrayList<>();
+                       for(int i=0; i<arg0.getNumRows(); i++)
+                               rows.add(arg0.sliceOperations(i, i, 0, 
arg0.getNumColumns()-1, new MatrixBlock()));
+                       return rows.iterator();
+               }
        }
 
        private static class ExtractDoubleValuesFunction2 implements 
FlatMapFunction<Tuple2<MatrixBlock,MatrixBlock>,DoublePair> 
@@ -256,6 +349,35 @@ public class RDDSortUtils
                        return ret.iterator();
                }
        }
+       
+       private static class ExtractDoubleValuesWithIndexFunction2 implements 
PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>,ValuesIndexPair,double[]> 
+       {
+               private static final long serialVersionUID = 
8358254634903633283L;
+               
+               private final int _brlen;
+               
+               public ExtractDoubleValuesWithIndexFunction2(int brlen) {
+                       _brlen = brlen;
+               }
+               
+               @Override
+               public Iterator<Tuple2<ValuesIndexPair,double[]>> 
call(Tuple2<MatrixIndexes,MatrixBlock> arg0) 
+                       throws Exception 
+               {
+                       ArrayList<Tuple2<ValuesIndexPair,double[]>> ret = new 
ArrayList<>(); 
+                       MatrixIndexes ix = arg0._1();
+                       MatrixBlock mb = arg0._2();
+                       
+                       long ixoffset = (ix.getRowIndex()-1)*_brlen;
+                       for( int i=0; i<mb.getNumRows(); i++) {
+                               double[] vals = 
DataConverter.convertToDoubleVector(
+                                       mb.sliceOperations(i, i, 0, 
mb.getNumColumns()-1, new MatrixBlock()));
+                               ret.add(new Tuple2<>(new 
ValuesIndexPair(vals,ixoffset+i+1), vals));
+                       }
+                       
+                       return ret.iterator();
+               }
+       }
 
        private static class CreateDoubleKeyFunction implements 
Function<Double,Double> 
        {
@@ -266,7 +388,7 @@ public class RDDSortUtils
                        throws Exception 
                {
                        return arg0;
-               }               
+               }
        }
 
        private static class CreateDoubleKeyFunction2 implements 
Function<DoublePair,Double> 
@@ -278,20 +400,35 @@ public class RDDSortUtils
                        throws Exception 
                {
                        return arg0.val1;
-               }               
+               }
        }
 
-       private static class ExtractIndexFunction implements 
PairFunction<Tuple2<ValueIndexPair,Long>,Long,Long> 
+       private static class CreateDoubleKeysFunction implements 
Function<MatrixBlock,double[]> 
        {
+               private static final long serialVersionUID = 
4316858496746520340L;
+
+               @Override
+               public double[] call(MatrixBlock row) throws Exception {
+                       return DataConverter.convertToDoubleVector(row);
+               }
+       }
+       
+       private static class ExtractIndexFunction implements 
PairFunction<Tuple2<ValueIndexPair,Long>,Long,Long> {
                private static final long serialVersionUID = 
-4553468724131249535L;
 
                @Override
-               public Tuple2<Long, Long> call(Tuple2<ValueIndexPair,Long> arg0)
-                               throws Exception 
-               {
+               public Tuple2<Long, Long> call(Tuple2<ValueIndexPair,Long> 
arg0) throws Exception {
                        return new Tuple2<>(arg0._1().ix, arg0._2());
                }
+       }
+       
+       private static class ExtractIndexFunction2 implements 
PairFunction<Tuple2<ValuesIndexPair,Long>,Long,Long> {
+               private static final long serialVersionUID = 
-1366455446597907270L;
 
+               @Override
+               public Tuple2<Long, Long> call(Tuple2<ValuesIndexPair,Long> 
arg0) throws Exception {
+                       return new Tuple2<>(arg0._1().ix, arg0._2());
+               }
        }
 
        private static class ConvertToBinaryBlockFunction implements 
PairFlatMapFunction<Iterator<Tuple2<Double,Long>>,MatrixIndexes,MatrixBlock> 
@@ -485,6 +622,98 @@ public class RDDSortUtils
                }
        }
        
+       private static class ConvertToBinaryBlockFunction5 implements 
PairFlatMapFunction<Iterator<Tuple2<MatrixBlock,Long>>,MatrixIndexes,MatrixBlock>
 
+       {
+               private static final long serialVersionUID = 
6357994683868091724L;
+               
+               private long _rlen = -1;
+               private int _brlen = -1;
+               
+               public ConvertToBinaryBlockFunction5(long rlen, int brlen)
+               {
+                       _rlen = rlen;
+                       _brlen = brlen;
+               }
+               
+               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> 
call(Iterator<Tuple2<MatrixBlock,Long>> arg0) 
+                       throws Exception 
+               {
+                       ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new 
ArrayList<>();
+                       MatrixIndexes ix = null;
+                       MatrixBlock mb = null;
+                       
+                       while( arg0.hasNext() ) 
+                       {
+                               Tuple2<MatrixBlock,Long> val = arg0.next();
+                               long valix = val._2 + 1;
+                               long rix = 
UtilFunctions.computeBlockIndex(valix, _brlen);
+                               int pos = 
UtilFunctions.computeCellInBlock(valix, _brlen);
+                               
+                               if( ix == null || ix.getRowIndex() != rix ) {
+                                       if( ix !=null )
+                                               ret.add(new Tuple2<>(ix,mb));
+                                       long len = 
UtilFunctions.computeBlockSize(_rlen, rix, _brlen);
+                                       ix = new MatrixIndexes(rix,1);
+                                       mb = new MatrixBlock((int)len, 
val._1.getNumColumns(), false);
+                               }
+                               
+                               mb.leftIndexingOperations(val._1, pos, pos, 0, 
val._1.getNumColumns()-1, mb, UpdateType.INPLACE);
+                       }
+                       
+                       //flush last block
+                       if( mb!=null && mb.getNonZeros() != 0 )
+                               ret.add(new Tuple2<>(ix,mb));
+                       return ret.iterator();
+               }
+       }
+       
+       private static class ConvertToBinaryBlockFunction6 implements 
PairFlatMapFunction<Iterator<Tuple2<ValuesIndexPair,Long>>,MatrixIndexes,MatrixBlock>
 
+       {
+               private static final long serialVersionUID = 
5351649694631911694L;
+               
+               private long _rlen = -1;
+               private int _brlen = -1;
+               
+               public ConvertToBinaryBlockFunction6(long rlen, int brlen)
+               {
+                       _rlen = rlen;
+                       _brlen = brlen;
+               }
+               
+               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> 
call(Iterator<Tuple2<ValuesIndexPair,Long>> arg0) 
+                       throws Exception
+               {
+                       ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new 
ArrayList<>();
+                       
+                       MatrixIndexes ix = null;
+                       MatrixBlock mb = null;
+                       
+                       while( arg0.hasNext() ) 
+                       {
+                               Tuple2<ValuesIndexPair,Long> val = arg0.next();
+                               long valix = val._2 + 1;
+                               long rix = 
UtilFunctions.computeBlockIndex(valix, _brlen);
+                               int pos = 
UtilFunctions.computeCellInBlock(valix, _brlen);
+                               
+                               if( ix == null || ix.getRowIndex() != rix ) {
+                                       if( ix !=null )
+                                               ret.add(new Tuple2<>(ix,mb));
+                                       long len = 
UtilFunctions.computeBlockSize(_rlen, rix, _brlen);
+                                       ix = new MatrixIndexes(rix,1);
+                                       mb = new MatrixBlock((int)len, 1, 
false);
+                               }
+                               
+                               mb.quickSetValue(pos, 0, val._1.ix);
+                       }
+                       
+                       //flush last block
+                       if( mb!=null && mb.getNonZeros() != 0 )
+                               ret.add(new Tuple2<>(ix,mb));
+                       
+                       return ret.iterator();
+               }
+       }
+       
        private static class ShuffleMatrixBlockRowsFunction implements 
PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes,Tuple2<MatrixBlock,MatrixBlock>>>,MatrixIndexes,RowMatrixBlock>
 
        {       
                private static final long serialVersionUID = 
6885207719329119646L;
@@ -690,6 +919,19 @@ public class RDDSortUtils
                }
        }
        
+       private static class ValuesIndexPair implements Serializable 
+       {
+               private static final long serialVersionUID = 
4297433409147784971L;
+               
+               public double[] vals;
+               public long ix; 
+
+               public ValuesIndexPair(double[] dvals, long lix) {
+                       vals = dvals;
+                       ix = lix;
+               }
+       }
+       
        public static class IndexComparator implements 
Comparator<ValueIndexPair>, Serializable 
        {
                private static final long serialVersionUID = 
5154839870549241343L;
@@ -700,18 +942,32 @@ public class RDDSortUtils
                }
                        
                @Override
-               public int compare(ValueIndexPair o1, ValueIndexPair o2) 
+               public int compare(ValueIndexPair o1, ValueIndexPair o2) {
+                       int retVal = Double.compare(o1.val, o2.val);
+                       if(retVal != 0)
+                               return (_asc ? retVal : -1*retVal);
+                       else //for stable sort
+                               return Long.compare(o1.ix, o2.ix);
+               }
+       }
+       
+       public static class IndexComparator2 implements 
Comparator<ValuesIndexPair>, Serializable 
+       {
+               private static final long serialVersionUID = 
5531987863790922691L;
+               
+               private boolean _asc;
+               public IndexComparator2(boolean asc) {
+                       _asc = asc;
+               }
+               
+               @Override
+               public int compare(ValuesIndexPair o1, ValuesIndexPair o2) 
                {
-                       //note: use conversion to Double and Long instead of 
native
-                       //compare for compatibility with jdk 6
-                       int retVal = Double.valueOf(o1.val).compareTo(o2.val);
-                       if(retVal != 0) {
+                       int retVal = SortUtils.compare(o1.vals, o2.vals);
+                       if(retVal != 0)
                                return (_asc ? retVal : -1*retVal);
-                       }
-                       else {
-                               //for stable sort
-                               return Long.valueOf(o1.ix).compareTo(o2.ix);
-                       }
+                       else //for stable sort
+                               return Long.compare(o1.ix, o2.ix);
                }
                
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/runtime/util/SortUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/SortUtils.java 
b/src/main/java/org/apache/sysml/runtime/util/SortUtils.java
index c41f3ac..ff90784 100644
--- a/src/main/java/org/apache/sysml/runtime/util/SortUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/util/SortUtils.java
@@ -35,14 +35,14 @@ public class SortUtils
        public static boolean isSorted(int start, int end, int[] indexes) {
                boolean ret = true;
                for( int i=start+1; i<end && ret; i++ )
-               ret &= (indexes[i]<indexes[i-1]);
+                       ret &= (indexes[i]<indexes[i-1]);
                return ret;
        }
 
        public static boolean isSorted(int start, int end, double[] values) {
                boolean ret = true;
                for( int i=start+1; i<end && ret; i++ )
-               ret &= (values[i]<values[i-1]);
+                       ret &= (values[i]<values[i-1]);
                return ret;
        }
        
@@ -51,6 +51,18 @@ public class SortUtils
                        isSorted(0, in.getNumRows()*in.getNumColumns(), 
in.getDenseBlock());
        }
        
+       public static int compare(double[] d1, double[] d2) {
+               if( d1 == null || d2 == null )
+                       throw new RuntimeException("Invalid invocation w/ null 
parameter.");
+               int ret = Long.compare(d1.length, d2.length);
+               if( ret != 0 ) return ret;
+               for(int i=0; i<d1.length; i++) {
+                       ret = Double.compare(d1[i], d2[i]);
+                       if( ret != 0 ) return ret;
+               }
+               return 0;
+       }
+       
        /**
         * In-place sort of two arrays, only indexes is used for comparison and 
values
         * of same position are sorted accordingly. 

http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java
index 10dc1a4..5f11038 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java
@@ -114,46 +114,45 @@ public class MultipleOrderByColsTest extends 
AutomatedTestBase
                runOrderTest(TEST_NAME2, true, true, false, ExecType.CP);
        }
        
-//TODO enable together with additional spark sort runtime
-//     @Test
-//     public void testOrderDenseAscDataSP() {
-//             runOrderTest(TEST_NAME1, false, false, false, ExecType.SPARK);
-//     }
-//     
-//     @Test
-//     public void testOrderDenseAscIxSP() {
-//             runOrderTest(TEST_NAME1, false, false, true, ExecType.SPARK);
-//     }
-//     
-//     @Test
-//     public void testOrderDenseDescDataSP() {
-//             runOrderTest(TEST_NAME1, false, true, false, ExecType.SPARK);
-//     }
-//     
-//     @Test
-//     public void testOrderDenseDescIxSP() {
-//             runOrderTest(TEST_NAME1, false, true, true, ExecType.SPARK);
-//     }
-//     
-//     @Test
-//     public void testOrderSparseAscDataSP() {
-//             runOrderTest(TEST_NAME1, true, false, false, ExecType.SPARK);
-//     }
-//     
-//     @Test
-//     public void testOrderSparseAscIxSP() {
-//             runOrderTest(TEST_NAME1, true, false, true, ExecType.SPARK);
-//     }
-//     
-//     @Test
-//     public void testOrderSparseDescDataSP() {
-//             runOrderTest(TEST_NAME1, true, true, false, ExecType.SPARK);
-//     }
-//     
-//     @Test
-//     public void testOrderSparseDescIxSP() {
-//             runOrderTest(TEST_NAME1, true, true, true, ExecType.SPARK);
-//     }
+       @Test
+       public void testOrderDenseAscDataSP() {
+               runOrderTest(TEST_NAME1, false, false, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testOrderDenseAscIxSP() {
+               runOrderTest(TEST_NAME1, false, false, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testOrderDenseDescDataSP() {
+               runOrderTest(TEST_NAME1, false, true, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testOrderDenseDescIxSP() {
+               runOrderTest(TEST_NAME1, false, true, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testOrderSparseAscDataSP() {
+               runOrderTest(TEST_NAME1, true, false, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testOrderSparseAscIxSP() {
+               runOrderTest(TEST_NAME1, true, false, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testOrderSparseDescDataSP() {
+               runOrderTest(TEST_NAME1, true, true, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testOrderSparseDescIxSP() {
+               runOrderTest(TEST_NAME1, true, true, true, ExecType.SPARK);
+       }
        
        private void runOrderTest( String testname, boolean sparse, boolean 
desc, boolean ixret, ExecType et)
        {
@@ -161,11 +160,11 @@ public class MultipleOrderByColsTest extends 
AutomatedTestBase
                switch( et ){
                        case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
                        case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
-                       default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+                       default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; 
break;
                }
        
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               if( rtplatform == RUNTIME_PLATFORM.SPARK )
+               if( rtplatform == RUNTIME_PLATFORM.SPARK || rtplatform == 
RUNTIME_PLATFORM.HYBRID_SPARK )
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
                
                try

Reply via email to