Repository: systemml Updated Branches: refs/heads/master 4bc1fea87 -> 5df6ab6dd
[SYSTEMML-2007] New spark order operations w/ multiple order-by cols This patch adds runtime support for distributed spark operations regarding the recently added order w/ multiple order-by columns. We now also enable the related automatic rewrite of consecutive order calls for CP and Spark execution types. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/27cabbc4 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/27cabbc4 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/27cabbc4 Branch: refs/heads/master Commit: 27cabbc4730377d9e8e34d06855106687123c240 Parents: 4bc1fea Author: Matthias Boehm <[email protected]> Authored: Tue Nov 14 17:32:38 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Tue Nov 14 18:45:00 2017 -0800 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/ReorgOp.java | 3 +- .../RewriteAlgebraicSimplificationStatic.java | 3 +- .../instructions/spark/ReorgSPInstruction.java | 90 +++-- .../spark/functions/IsBlockInList.java | 53 +++ .../spark/functions/IsBlockInRange.java | 1 - .../instructions/spark/utils/RDDSortUtils.java | 354 ++++++++++++++++--- .../apache/sysml/runtime/util/SortUtils.java | 16 +- .../reorg/MultipleOrderByColsTest.java | 83 +++-- 8 files changed, 485 insertions(+), 118 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/hops/ReorgOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ReorgOp.java b/src/main/java/org/apache/sysml/hops/ReorgOp.java index 4b55c9b..4d29336 100644 --- a/src/main/java/org/apache/sysml/hops/ReorgOp.java +++ b/src/main/java/org/apache/sysml/hops/ReorgOp.java @@ -372,7 +372,8 @@ public class ReorgOp extends Hop implements MultiThreadedHop } } else if( et==ExecType.SPARK ) { - boolean sortRewrite = !FORCE_DIST_SORT_INDEXES && isSortSPRewriteApplicable(); + boolean sortRewrite = !FORCE_DIST_SORT_INDEXES + && isSortSPRewriteApplicable() && by.getDataType().isScalar(); Lop transform1 = constructCPOrSparkSortLop(input, by, desc, ixret, et, sortRewrite); setOutputDimensions(transform1); setLineNumbers(transform1); http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index cbfb527..d71c4e0 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -1509,7 +1509,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule if( HopRewriteUtils.isReorg(hi, ReOrgOp.SORT) && hi.getInput().get(1) instanceof LiteralOp //scalar by && hi.getInput().get(2) instanceof LiteralOp //scalar desc - && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) ) //not ixret + && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) //not ixret + && !OptimizerUtils.isHadoopExecutionMode() ) { LiteralOp by = (LiteralOp) hi.getInput().get(1); boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)); http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java index c742a0b..8e11a55 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ReorgSPInstruction.java @@ -25,12 +25,14 @@ import java.util.Iterator; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.spark.api.java.function.PairFunction; import scala.Tuple2; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.functionobjects.DiagIndex; @@ -40,6 +42,7 @@ import org.apache.sysml.runtime.functionobjects.SwapIndex; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.spark.functions.FilterDiagBlocksFunction; +import org.apache.sysml.runtime.instructions.spark.functions.IsBlockInList; import org.apache.sysml.runtime.instructions.spark.functions.IsBlockInRange; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.instructions.spark.utils.RDDSortUtils; @@ -53,6 +56,7 @@ import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; import org.apache.sysml.runtime.matrix.operators.Operator; import org.apache.sysml.runtime.matrix.operators.ReorgOperator; import org.apache.sysml.runtime.util.DataConverter; +import org.apache.sysml.runtime.util.IndexRange; import org.apache.sysml.runtime.util.UtilFunctions; public class ReorgSPInstruction extends UnarySPInstruction { @@ -162,33 +166,46 @@ public class ReorgSPInstruction extends UnarySPInstruction { boolean desc = ec.getScalarInput(_desc.getName(), _desc.getValueType(), _desc.isLiteral()).getBooleanValue(); boolean ixret = ec.getScalarInput(_ixret.getName(), _ixret.getValueType(), _ixret.isLiteral()).getBooleanValue(); boolean singleCol = (mcIn.getCols() == 1); - - //error handling unsupported operations - //TODO additional spark sort runtime with multiple order columns - if( cols.length > 1 ) - LOG.warn("Unsupported sort with multiple order-by columns. Falling back first sort column."); - long col = cols[0]; - - // extract column (if necessary) and sort out = in1; - if( !singleCol ){ - out = out.filter(new IsBlockInRange(1, mcIn.getRows(), col, col, mcIn)) - .mapValues(new ExtractColumn((int)UtilFunctions.computeCellInBlock(col, mcIn.getColsPerBlock()))); - } - //actual index/data sort operation - if( ixret ) { //sort indexes - out = RDDSortUtils.sortIndexesByVal(out, !desc, mcIn.getRows(), mcIn.getRowsPerBlock()); - } - else if( singleCol && !desc) { //sort single-column matrix - out = RDDSortUtils.sortByVal(out, mcIn.getRows(), mcIn.getRowsPerBlock()); - } - else { //sort multi-column matrix - if (! _bSortIndInMem) + if( cols.length > mcIn.getColsPerBlock() ) + LOG.warn("Unsupported sort with number of order-by columns large than blocksize: "+cols.length); + + if( singleCol || cols.length==1 ) { + // extract column (if necessary) and sort + if( !singleCol ) + out = out.filter(new IsBlockInRange(1, mcIn.getRows(), cols[0], cols[0], mcIn)) + .mapValues(new ExtractColumn((int)UtilFunctions.computeCellInBlock(cols[0], mcIn.getColsPerBlock()))); + + //actual index/data sort operation + if( ixret ) //sort indexes + out = RDDSortUtils.sortIndexesByVal(out, !desc, mcIn.getRows(), mcIn.getRowsPerBlock()); + else if( singleCol && !desc) //sort single-column matrix + out = RDDSortUtils.sortByVal(out, mcIn.getRows(), mcIn.getRowsPerBlock()); + else if( !_bSortIndInMem ) //sort multi-column matrix w/ rewrite out = RDDSortUtils.sortDataByVal(out, in1, !desc, mcIn.getRows(), mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); - else + else //sort multi-column matrix out = RDDSortUtils.sortDataByValMemSort(out, in1, !desc, mcIn.getRows(), mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock(), sec, (ReorgOperator) _optr); } + else { //general case: multi-column sort + // extract columns (if necessary) + if( cols.length < mcIn.getCols() ) + out = out.filter(new IsBlockInList(cols, mcIn)) + .mapToPair(new ExtractColumns(cols, mcIn)); + + // append extracted columns (if necessary) + if( mcIn.getCols() > mcIn.getColsPerBlock() ) + out = RDDAggregateUtils.mergeByKey(out); + + //actual index/data sort operation + if( ixret ) //sort indexes + out = RDDSortUtils.sortIndexesByVals(out, !desc, mcIn.getRows(), (long)cols.length, mcIn.getRowsPerBlock()); + else if( cols.length==mcIn.getCols() && !desc) //sort single-column matrix + out = RDDSortUtils.sortByVals(out, mcIn.getRows(), cols.length, mcIn.getRowsPerBlock()); + else //sort multi-column matrix + out = RDDSortUtils.sortDataByVals(out, in1, !desc, mcIn.getRows(), mcIn.getCols(), + cols.length, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); + } } else { throw new DMLRuntimeException("Error: Incorrect opcode in ReorgSPInstruction:" + opcode); @@ -323,5 +340,34 @@ public class ReorgSPInstruction extends UnarySPInstruction { return arg0.sliceOperations(0, arg0.getNumRows()-1, _col, _col, new MatrixBlock()); } } + + private static class ExtractColumns implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>,MatrixIndexes,MatrixBlock> + { + private static final long serialVersionUID = 2902729186431711506L; + + private final long[] _cols; + private final int _brlen, _bclen; + + public ExtractColumns(long[] cols, MatrixCharacteristics mc) { + _cols = cols; + _brlen = mc.getRowsPerBlock(); + _bclen = mc.getColsPerBlock(); + } + + public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) + throws Exception + { + MatrixIndexes ix = arg0._1(); + MatrixBlock in = arg0._2(); + MatrixBlock out = new MatrixBlock(in.getNumRows(), _cols.length, true); + for(int i=0; i<_cols.length; i++) + if( UtilFunctions.isInBlockRange(ix, _brlen, _bclen, new IndexRange(1, Long.MAX_VALUE, _cols[i], _cols[i])) ) { + int index = UtilFunctions.computeCellInBlock(_cols[i], _bclen); + out.leftIndexingOperations(in.sliceOperations(0, in.getNumRows()-1, index, index, new MatrixBlock()), + 0, in.getNumRows()-1, i, i, out, UpdateType.INPLACE); + } + return new Tuple2<>(new MatrixIndexes(ix.getRowIndex(), 1), out); + } + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInList.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInList.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInList.java new file mode 100644 index 0000000..d0f879e --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInList.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.spark.functions; + +import org.apache.spark.api.java.function.Function; + +import scala.Tuple2; + +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.util.UtilFunctions; + +public class IsBlockInList implements Function<Tuple2<MatrixIndexes,MatrixBlock>, Boolean> +{ + private static final long serialVersionUID = -1956151588590369875L; + + private final long[] _cols; + private final int _brlen, _bclen; + + public IsBlockInList(long[] cols, MatrixCharacteristics mc) { + _cols = cols; + _brlen = mc.getRowsPerBlock(); + _bclen = mc.getColsPerBlock(); + } + + @Override + public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> kv) + throws Exception + { + for( int i=0; i<_cols.length; i++ ) + if( UtilFunctions.isInBlockRange(kv._1(), _brlen, _bclen, 1, Long.MAX_VALUE, _cols[i], _cols[i]) ) + return true; + return false; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInRange.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInRange.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInRange.java index 2ad5fdf..8a0fdd2 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInRange.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/IsBlockInRange.java @@ -30,7 +30,6 @@ import org.apache.sysml.runtime.util.UtilFunctions; public class IsBlockInRange implements Function<Tuple2<MatrixIndexes,MatrixBlock>, Boolean> { - private static final long serialVersionUID = 5849687296021280540L; private long _rl; long _ru; long _cl; long _cu; http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDSortUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDSortUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDSortUtils.java index bf63f0d..9232374 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDSortUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDSortUtils.java @@ -35,16 +35,19 @@ import org.apache.spark.broadcast.Broadcast; import scala.Tuple2; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.functionobjects.SortIndex; import org.apache.sysml.runtime.instructions.spark.data.PartitionedBlock; import org.apache.sysml.runtime.instructions.spark.data.RowMatrixBlock; import org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.operators.ReorgOperator; import org.apache.sysml.runtime.util.DataConverter; +import org.apache.sysml.runtime.util.SortUtils; import org.apache.sysml.runtime.util.UtilFunctions; public class RDDSortUtils @@ -65,88 +68,165 @@ public class RDDSortUtils //create binary block output JavaPairRDD<MatrixIndexes, MatrixBlock> ret = sdvals .zipWithIndex() - .mapPartitionsToPair(new ConvertToBinaryBlockFunction(rlen, brlen)); + .mapPartitionsToPair(new ConvertToBinaryBlockFunction(rlen, brlen)); ret = RDDAggregateUtils.mergeByKey(ret, false); return ret; } - + public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortByVal( JavaPairRDD<MatrixIndexes, MatrixBlock> in, JavaPairRDD<MatrixIndexes, MatrixBlock> in2, long rlen, int brlen ) { //create value-index rdd from inputs JavaRDD<DoublePair> dvals = in.join(in2).values() - .flatMap(new ExtractDoubleValuesFunction2()); + .flatMap(new ExtractDoubleValuesFunction2()); //sort (creates sorted range per partition) long hdfsBlocksize = InfrastructureAnalyzer.getHDFSBlockSize(); int numPartitions = (int)Math.ceil(((double)rlen*8)/hdfsBlocksize); JavaRDD<DoublePair> sdvals = dvals - .sortBy(new CreateDoubleKeyFunction2(), true, numPartitions); + .sortBy(new CreateDoubleKeyFunction2(), true, numPartitions); //create binary block output JavaPairRDD<MatrixIndexes, MatrixBlock> ret = sdvals - .zipWithIndex() - .mapPartitionsToPair(new ConvertToBinaryBlockFunction2(rlen, brlen)); + .zipWithIndex() + .mapPartitionsToPair(new ConvertToBinaryBlockFunction2(rlen, brlen)); ret = RDDAggregateUtils.mergeByKey(ret, false); return ret; } + public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortByVals( + JavaPairRDD<MatrixIndexes, MatrixBlock> in, long rlen, long clen, int brlen ) + { + //create value-index rdd from inputs + JavaRDD<MatrixBlock> dvals = in.values() + .flatMap(new ExtractRowsFunction()); + + //sort (creates sorted range per partition) + int numPartitions = SparkUtils.getNumPreferredPartitions( + new MatrixCharacteristics(rlen, clen, brlen, brlen), in); + JavaRDD<MatrixBlock> sdvals = dvals + .sortBy(new CreateDoubleKeysFunction(), true, numPartitions); + + //create binary block output + JavaPairRDD<MatrixIndexes, MatrixBlock> ret = sdvals + .zipWithIndex() + .mapPartitionsToPair(new ConvertToBinaryBlockFunction5(rlen, brlen)); + ret = RDDAggregateUtils.mergeByKey(ret, false); + + return ret; + } + public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortIndexesByVal( JavaPairRDD<MatrixIndexes, MatrixBlock> val, boolean asc, long rlen, int brlen ) { //create value-index rdd from inputs JavaPairRDD<ValueIndexPair, Double> dvals = val - .flatMapToPair(new ExtractDoubleValuesWithIndexFunction(brlen)); - + .flatMapToPair(new ExtractDoubleValuesWithIndexFunction(brlen)); + //sort (creates sorted range per partition) long hdfsBlocksize = InfrastructureAnalyzer.getHDFSBlockSize(); int numPartitions = (int)Math.ceil(((double)rlen*16)/hdfsBlocksize); JavaRDD<ValueIndexPair> sdvals = dvals - .sortByKey(new IndexComparator(asc), true, numPartitions) - .keys(); //workaround for index comparator - + .sortByKey(new IndexComparator(asc), true, numPartitions) + .keys(); //workaround for index comparator + //create binary block output JavaPairRDD<MatrixIndexes, MatrixBlock> ret = sdvals - .zipWithIndex() - .mapPartitionsToPair(new ConvertToBinaryBlockFunction3(rlen, brlen)); - ret = RDDAggregateUtils.mergeByKey(ret, false); + .zipWithIndex() + .mapPartitionsToPair(new ConvertToBinaryBlockFunction3(rlen, brlen)); + ret = RDDAggregateUtils.mergeByKey(ret, false); - return ret; + return ret; + } + + public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortIndexesByVals( JavaPairRDD<MatrixIndexes, MatrixBlock> in, + boolean asc, long rlen, long clen, int brlen ) + { + //create value-index rdd from inputs + JavaPairRDD<ValuesIndexPair, double[]> dvals = in + .flatMapToPair(new ExtractDoubleValuesWithIndexFunction2(brlen)); + + //sort (creates sorted range per partition) + int numPartitions = SparkUtils.getNumPreferredPartitions( + new MatrixCharacteristics(rlen, clen+1, brlen, brlen)); + JavaRDD<ValuesIndexPair> sdvals = dvals + .sortByKey(new IndexComparator2(asc), true, numPartitions) + .keys(); //workaround for index comparator + + //create binary block output + JavaPairRDD<MatrixIndexes, MatrixBlock> ret = sdvals + .zipWithIndex() + .mapPartitionsToPair(new ConvertToBinaryBlockFunction6(rlen, brlen)); + ret = RDDAggregateUtils.mergeByKey(ret, false); + + return ret; } public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortDataByVal( JavaPairRDD<MatrixIndexes, MatrixBlock> val, - JavaPairRDD<MatrixIndexes, MatrixBlock> data, boolean asc, long rlen, long clen, int brlen, int bclen ) + JavaPairRDD<MatrixIndexes, MatrixBlock> data, boolean asc, long rlen, long clen, int brlen, int bclen ) { //create value-index rdd from inputs JavaPairRDD<ValueIndexPair, Double> dvals = val - .flatMapToPair(new ExtractDoubleValuesWithIndexFunction(brlen)); - + .flatMapToPair(new ExtractDoubleValuesWithIndexFunction(brlen)); + //sort (creates sorted range per partition) long hdfsBlocksize = InfrastructureAnalyzer.getHDFSBlockSize(); int numPartitions = (int)Math.ceil(((double)rlen*16)/hdfsBlocksize); JavaRDD<ValueIndexPair> sdvals = dvals - .sortByKey(new IndexComparator(asc), true, numPartitions) - .keys(); //workaround for index comparator - + .sortByKey(new IndexComparator(asc), true, numPartitions) + .keys(); //workaround for index comparator + //create target indexes by original index - long numRep = (long)Math.ceil((double)clen/bclen); JavaPairRDD<MatrixIndexes, MatrixBlock> ixmap = sdvals - .zipWithIndex() - .mapToPair(new ExtractIndexFunction()) - .sortByKey() - .mapPartitionsToPair(new ConvertToBinaryBlockFunction4(rlen, brlen)); - ixmap = RDDAggregateUtils.mergeByKey(ixmap, false); + .zipWithIndex() + .mapToPair(new ExtractIndexFunction()) + .sortByKey() + .mapPartitionsToPair(new ConvertToBinaryBlockFunction4(rlen, brlen)); + ixmap = RDDAggregateUtils.mergeByKey(ixmap, false); + + //actual data sort + return sortDataByIx(data, ixmap, rlen, clen, brlen, bclen); + } + + public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortDataByVals( JavaPairRDD<MatrixIndexes, MatrixBlock> val, + JavaPairRDD<MatrixIndexes, MatrixBlock> data, boolean asc, long rlen, long clen, long clen2, int brlen, int bclen ) + { + //create value-index rdd from inputs + JavaPairRDD<ValuesIndexPair, double[]> dvals = val + .flatMapToPair(new ExtractDoubleValuesWithIndexFunction2(brlen)); + //sort (creates sorted range per partition) + int numPartitions = SparkUtils.getNumPreferredPartitions( + new MatrixCharacteristics(rlen, clen2+1, brlen, brlen)); + JavaRDD<ValuesIndexPair> sdvals = dvals + .sortByKey(new IndexComparator2(asc), true, numPartitions) + .keys(); //workaround for index comparator + + //create target indexes by original index + JavaPairRDD<MatrixIndexes, MatrixBlock> ixmap = sdvals + .zipWithIndex() + .mapToPair(new ExtractIndexFunction2()) + .sortByKey() + .mapPartitionsToPair(new ConvertToBinaryBlockFunction4(rlen, brlen)); + ixmap = RDDAggregateUtils.mergeByKey(ixmap, false); + + //actual data sort + return sortDataByIx(data, ixmap, rlen, clen, brlen, bclen); + } + + public static JavaPairRDD<MatrixIndexes, MatrixBlock> sortDataByIx(JavaPairRDD<MatrixIndexes,MatrixBlock> data, + JavaPairRDD<MatrixIndexes,MatrixBlock> ixmap, long rlen, long clen, int brlen, int bclen) { //replicate indexes for all column blocks + long numRep = (long)Math.ceil((double)clen/bclen); JavaPairRDD<MatrixIndexes, MatrixBlock> rixmap = ixmap - .flatMapToPair(new ReplicateVectorFunction(false, numRep)); + .flatMapToPair(new ReplicateVectorFunction(false, numRep)); //create binary block output JavaPairRDD<MatrixIndexes, RowMatrixBlock> ret = data - .join(rixmap) - .mapPartitionsToPair(new ShuffleMatrixBlockRowsFunction(rlen, brlen)); + .join(rixmap) + .mapPartitionsToPair(new ShuffleMatrixBlockRowsFunction(rlen, brlen)); return RDDAggregateUtils.mergeRowsByKey(ret); } @@ -200,10 +280,23 @@ public class RDDSortUtils @Override public Iterator<Double> call(MatrixBlock arg0) - throws Exception - { + throws Exception { return DataConverter.convertToDoubleList(arg0).iterator(); - } + } + } + + private static class ExtractRowsFunction implements FlatMapFunction<MatrixBlock,MatrixBlock> + { + private static final long serialVersionUID = -2786968469468554974L; + + @Override + public Iterator<MatrixBlock> call(MatrixBlock arg0) + throws Exception { + ArrayList<MatrixBlock> rows = new ArrayList<>(); + for(int i=0; i<arg0.getNumRows(); i++) + rows.add(arg0.sliceOperations(i, i, 0, arg0.getNumColumns()-1, new MatrixBlock())); + return rows.iterator(); + } } private static class ExtractDoubleValuesFunction2 implements FlatMapFunction<Tuple2<MatrixBlock,MatrixBlock>,DoublePair> @@ -256,6 +349,35 @@ public class RDDSortUtils return ret.iterator(); } } + + private static class ExtractDoubleValuesWithIndexFunction2 implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>,ValuesIndexPair,double[]> + { + private static final long serialVersionUID = 8358254634903633283L; + + private final int _brlen; + + public ExtractDoubleValuesWithIndexFunction2(int brlen) { + _brlen = brlen; + } + + @Override + public Iterator<Tuple2<ValuesIndexPair,double[]>> call(Tuple2<MatrixIndexes,MatrixBlock> arg0) + throws Exception + { + ArrayList<Tuple2<ValuesIndexPair,double[]>> ret = new ArrayList<>(); + MatrixIndexes ix = arg0._1(); + MatrixBlock mb = arg0._2(); + + long ixoffset = (ix.getRowIndex()-1)*_brlen; + for( int i=0; i<mb.getNumRows(); i++) { + double[] vals = DataConverter.convertToDoubleVector( + mb.sliceOperations(i, i, 0, mb.getNumColumns()-1, new MatrixBlock())); + ret.add(new Tuple2<>(new ValuesIndexPair(vals,ixoffset+i+1), vals)); + } + + return ret.iterator(); + } + } private static class CreateDoubleKeyFunction implements Function<Double,Double> { @@ -266,7 +388,7 @@ public class RDDSortUtils throws Exception { return arg0; - } + } } private static class CreateDoubleKeyFunction2 implements Function<DoublePair,Double> @@ -278,20 +400,35 @@ public class RDDSortUtils throws Exception { return arg0.val1; - } + } } - private static class ExtractIndexFunction implements PairFunction<Tuple2<ValueIndexPair,Long>,Long,Long> + private static class CreateDoubleKeysFunction implements Function<MatrixBlock,double[]> { + private static final long serialVersionUID = 4316858496746520340L; + + @Override + public double[] call(MatrixBlock row) throws Exception { + return DataConverter.convertToDoubleVector(row); + } + } + + private static class ExtractIndexFunction implements PairFunction<Tuple2<ValueIndexPair,Long>,Long,Long> { private static final long serialVersionUID = -4553468724131249535L; @Override - public Tuple2<Long, Long> call(Tuple2<ValueIndexPair,Long> arg0) - throws Exception - { + public Tuple2<Long, Long> call(Tuple2<ValueIndexPair,Long> arg0) throws Exception { return new Tuple2<>(arg0._1().ix, arg0._2()); } + } + + private static class ExtractIndexFunction2 implements PairFunction<Tuple2<ValuesIndexPair,Long>,Long,Long> { + private static final long serialVersionUID = -1366455446597907270L; + @Override + public Tuple2<Long, Long> call(Tuple2<ValuesIndexPair,Long> arg0) throws Exception { + return new Tuple2<>(arg0._1().ix, arg0._2()); + } } private static class ConvertToBinaryBlockFunction implements PairFlatMapFunction<Iterator<Tuple2<Double,Long>>,MatrixIndexes,MatrixBlock> @@ -485,6 +622,98 @@ public class RDDSortUtils } } + private static class ConvertToBinaryBlockFunction5 implements PairFlatMapFunction<Iterator<Tuple2<MatrixBlock,Long>>,MatrixIndexes,MatrixBlock> + { + private static final long serialVersionUID = 6357994683868091724L; + + private long _rlen = -1; + private int _brlen = -1; + + public ConvertToBinaryBlockFunction5(long rlen, int brlen) + { + _rlen = rlen; + _brlen = brlen; + } + + public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixBlock,Long>> arg0) + throws Exception + { + ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new ArrayList<>(); + MatrixIndexes ix = null; + MatrixBlock mb = null; + + while( arg0.hasNext() ) + { + Tuple2<MatrixBlock,Long> val = arg0.next(); + long valix = val._2 + 1; + long rix = UtilFunctions.computeBlockIndex(valix, _brlen); + int pos = UtilFunctions.computeCellInBlock(valix, _brlen); + + if( ix == null || ix.getRowIndex() != rix ) { + if( ix !=null ) + ret.add(new Tuple2<>(ix,mb)); + long len = UtilFunctions.computeBlockSize(_rlen, rix, _brlen); + ix = new MatrixIndexes(rix,1); + mb = new MatrixBlock((int)len, val._1.getNumColumns(), false); + } + + mb.leftIndexingOperations(val._1, pos, pos, 0, val._1.getNumColumns()-1, mb, UpdateType.INPLACE); + } + + //flush last block + if( mb!=null && mb.getNonZeros() != 0 ) + ret.add(new Tuple2<>(ix,mb)); + return ret.iterator(); + } + } + + private static class ConvertToBinaryBlockFunction6 implements PairFlatMapFunction<Iterator<Tuple2<ValuesIndexPair,Long>>,MatrixIndexes,MatrixBlock> + { + private static final long serialVersionUID = 5351649694631911694L; + + private long _rlen = -1; + private int _brlen = -1; + + public ConvertToBinaryBlockFunction6(long rlen, int brlen) + { + _rlen = rlen; + _brlen = brlen; + } + + public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<ValuesIndexPair,Long>> arg0) + throws Exception + { + ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new ArrayList<>(); + + MatrixIndexes ix = null; + MatrixBlock mb = null; + + while( arg0.hasNext() ) + { + Tuple2<ValuesIndexPair,Long> val = arg0.next(); + long valix = val._2 + 1; + long rix = UtilFunctions.computeBlockIndex(valix, _brlen); + int pos = UtilFunctions.computeCellInBlock(valix, _brlen); + + if( ix == null || ix.getRowIndex() != rix ) { + if( ix !=null ) + ret.add(new Tuple2<>(ix,mb)); + long len = UtilFunctions.computeBlockSize(_rlen, rix, _brlen); + ix = new MatrixIndexes(rix,1); + mb = new MatrixBlock((int)len, 1, false); + } + + mb.quickSetValue(pos, 0, val._1.ix); + } + + //flush last block + if( mb!=null && mb.getNonZeros() != 0 ) + ret.add(new Tuple2<>(ix,mb)); + + return ret.iterator(); + } + } + private static class ShuffleMatrixBlockRowsFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes,Tuple2<MatrixBlock,MatrixBlock>>>,MatrixIndexes,RowMatrixBlock> { private static final long serialVersionUID = 6885207719329119646L; @@ -690,6 +919,19 @@ public class RDDSortUtils } } + private static class ValuesIndexPair implements Serializable + { + private static final long serialVersionUID = 4297433409147784971L; + + public double[] vals; + public long ix; + + public ValuesIndexPair(double[] dvals, long lix) { + vals = dvals; + ix = lix; + } + } + public static class IndexComparator implements Comparator<ValueIndexPair>, Serializable { private static final long serialVersionUID = 5154839870549241343L; @@ -700,18 +942,32 @@ public class RDDSortUtils } @Override - public int compare(ValueIndexPair o1, ValueIndexPair o2) + public int compare(ValueIndexPair o1, ValueIndexPair o2) { + int retVal = Double.compare(o1.val, o2.val); + if(retVal != 0) + return (_asc ? retVal : -1*retVal); + else //for stable sort + return Long.compare(o1.ix, o2.ix); + } + } + + public static class IndexComparator2 implements Comparator<ValuesIndexPair>, Serializable + { + private static final long serialVersionUID = 5531987863790922691L; + + private boolean _asc; + public IndexComparator2(boolean asc) { + _asc = asc; + } + + @Override + public int compare(ValuesIndexPair o1, ValuesIndexPair o2) { - //note: use conversion to Double and Long instead of native - //compare for compatibility with jdk 6 - int retVal = Double.valueOf(o1.val).compareTo(o2.val); - if(retVal != 0) { + int retVal = SortUtils.compare(o1.vals, o2.vals); + if(retVal != 0) return (_asc ? retVal : -1*retVal); - } - else { - //for stable sort - return Long.valueOf(o1.ix).compareTo(o2.ix); - } + else //for stable sort + return Long.compare(o1.ix, o2.ix); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/main/java/org/apache/sysml/runtime/util/SortUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/SortUtils.java b/src/main/java/org/apache/sysml/runtime/util/SortUtils.java index c41f3ac..ff90784 100644 --- a/src/main/java/org/apache/sysml/runtime/util/SortUtils.java +++ b/src/main/java/org/apache/sysml/runtime/util/SortUtils.java @@ -35,14 +35,14 @@ public class SortUtils public static boolean isSorted(int start, int end, int[] indexes) { boolean ret = true; for( int i=start+1; i<end && ret; i++ ) - ret &= (indexes[i]<indexes[i-1]); + ret &= (indexes[i]<indexes[i-1]); return ret; } public static boolean isSorted(int start, int end, double[] values) { boolean ret = true; for( int i=start+1; i<end && ret; i++ ) - ret &= (values[i]<values[i-1]); + ret &= (values[i]<values[i-1]); return ret; } @@ -51,6 +51,18 @@ public class SortUtils isSorted(0, in.getNumRows()*in.getNumColumns(), in.getDenseBlock()); } + public static int compare(double[] d1, double[] d2) { + if( d1 == null || d2 == null ) + throw new RuntimeException("Invalid invocation w/ null parameter."); + int ret = Long.compare(d1.length, d2.length); + if( ret != 0 ) return ret; + for(int i=0; i<d1.length; i++) { + ret = Double.compare(d1[i], d2[i]); + if( ret != 0 ) return ret; + } + return 0; + } + /** * In-place sort of two arrays, only indexes is used for comparison and values * of same position are sorted accordingly. http://git-wip-us.apache.org/repos/asf/systemml/blob/27cabbc4/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java b/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java index 10dc1a4..5f11038 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/reorg/MultipleOrderByColsTest.java @@ -114,46 +114,45 @@ public class MultipleOrderByColsTest extends AutomatedTestBase runOrderTest(TEST_NAME2, true, true, false, ExecType.CP); } -//TODO enable together with additional spark sort runtime -// @Test -// public void testOrderDenseAscDataSP() { -// runOrderTest(TEST_NAME1, false, false, false, ExecType.SPARK); -// } -// -// @Test -// public void testOrderDenseAscIxSP() { -// runOrderTest(TEST_NAME1, false, false, true, ExecType.SPARK); -// } -// -// @Test -// public void testOrderDenseDescDataSP() { -// runOrderTest(TEST_NAME1, false, true, false, ExecType.SPARK); -// } -// -// @Test -// public void testOrderDenseDescIxSP() { -// runOrderTest(TEST_NAME1, false, true, true, ExecType.SPARK); -// } -// -// @Test -// public void testOrderSparseAscDataSP() { -// runOrderTest(TEST_NAME1, true, false, false, ExecType.SPARK); -// } -// -// @Test -// public void testOrderSparseAscIxSP() { -// runOrderTest(TEST_NAME1, true, false, true, ExecType.SPARK); -// } -// -// @Test -// public void testOrderSparseDescDataSP() { -// runOrderTest(TEST_NAME1, true, true, false, ExecType.SPARK); -// } -// -// @Test -// public void testOrderSparseDescIxSP() { -// runOrderTest(TEST_NAME1, true, true, true, ExecType.SPARK); -// } + @Test + public void testOrderDenseAscDataSP() { + runOrderTest(TEST_NAME1, false, false, false, ExecType.SPARK); + } + + @Test + public void testOrderDenseAscIxSP() { + runOrderTest(TEST_NAME1, false, false, true, ExecType.SPARK); + } + + @Test + public void testOrderDenseDescDataSP() { + runOrderTest(TEST_NAME1, false, true, false, ExecType.SPARK); + } + + @Test + public void testOrderDenseDescIxSP() { + runOrderTest(TEST_NAME1, false, true, true, ExecType.SPARK); + } + + @Test + public void testOrderSparseAscDataSP() { + runOrderTest(TEST_NAME1, true, false, false, ExecType.SPARK); + } + + @Test + public void testOrderSparseAscIxSP() { + runOrderTest(TEST_NAME1, true, false, true, ExecType.SPARK); + } + + @Test + public void testOrderSparseDescDataSP() { + runOrderTest(TEST_NAME1, true, true, false, ExecType.SPARK); + } + + @Test + public void testOrderSparseDescIxSP() { + runOrderTest(TEST_NAME1, true, true, true, ExecType.SPARK); + } private void runOrderTest( String testname, boolean sparse, boolean desc, boolean ixret, ExecType et) { @@ -161,11 +160,11 @@ public class MultipleOrderByColsTest extends AutomatedTestBase switch( et ){ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; - default: rtplatform = RUNTIME_PLATFORM.HYBRID; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; } boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - if( rtplatform == RUNTIME_PLATFORM.SPARK ) + if( rtplatform == RUNTIME_PLATFORM.SPARK || rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK ) DMLScript.USE_LOCAL_SPARK_CONFIG = true; try
