Repository: spark Updated Branches: refs/heads/branch-2.0 7e2bfff20 -> 796dd1514
Revert "[SPARK-14851][CORE] Support radix sort with nullable longs" This reverts commit beb75300455a4f92000b69e740256102d9f2d472. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/796dd151 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/796dd151 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/796dd151 Branch: refs/heads/branch-2.0 Commit: 796dd15142c00e96d2d7180f7909055a3eb1dfdf Parents: 7e2bfff Author: Reynold Xin <[email protected]> Authored: Sat Jun 11 15:49:39 2016 -0700 Committer: Reynold Xin <[email protected]> Committed: Sat Jun 11 15:49:39 2016 -0700 ---------------------------------------------------------------------- .../util/collection/unsafe/sort/RadixSort.java | 24 ++++----- .../unsafe/sort/UnsafeExternalSorter.java | 11 ++-- .../unsafe/sort/UnsafeInMemorySorter.java | 56 ++++---------------- .../unsafe/sort/UnsafeExternalSorterSuite.java | 26 ++++----- .../unsafe/sort/UnsafeInMemorySorterSuite.java | 2 +- .../collection/unsafe/sort/RadixSortSuite.scala | 4 +- .../sql/execution/UnsafeExternalRowSorter.java | 20 ++----- .../sql/catalyst/expressions/SortOrder.scala | 40 ++++++-------- .../sql/execution/UnsafeKVExternalSorter.java | 11 ++-- .../apache/spark/sql/execution/SortExec.scala | 12 ++--- .../spark/sql/execution/SortPrefixUtils.scala | 32 ++++------- .../apache/spark/sql/execution/WindowExec.scala | 4 +- .../execution/joins/CartesianProductExec.scala | 2 +- .../apache/spark/sql/execution/SortSuite.scala | 11 ---- .../sql/execution/benchmark/SortBenchmark.scala | 2 +- 15 files changed, 79 insertions(+), 178 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java index 4043617..4f3f0de 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -170,13 +170,9 @@ public class RadixSort { /** * Specialization of sort() for key-prefix arrays. In this type of array, each record consists * of two longs, only the second of which is sorted on. - * - * @param startIndex starting index in the array to sort from. This parameter is not supported - * in the plain sort() implementation. */ public static int sortKeyPrefixArray( LongArray array, - int startIndex, int numRecords, int startByteIndex, int endByteIndex, @@ -186,11 +182,10 @@ public class RadixSort { assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 4 <= array.size(); - int inIndex = startIndex; - int outIndex = startIndex + numRecords * 2; + int inIndex = 0; + int outIndex = numRecords * 2; if (numRecords > 0) { - long[][] counts = getKeyPrefixArrayCounts( - array, startIndex, numRecords, startByteIndex, endByteIndex); + long[][] counts = getKeyPrefixArrayCounts(array, numRecords, startByteIndex, endByteIndex); for (int i = startByteIndex; i <= endByteIndex; i++) { if (counts[i] != null) { sortKeyPrefixArrayAtByte( @@ -210,14 +205,13 @@ public class RadixSort { * getCounts with some added parameters but that seems to hurt in benchmarks. */ private static long[][] getKeyPrefixArrayCounts( - LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, int numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; long bitwiseMax = 0; long bitwiseMin = -1L; - long baseOffset = array.getBaseOffset() + startIndex * 8L; - long limit = baseOffset + numRecords * 16L; + long limit = array.getBaseOffset() + numRecords * 16; Object baseObject = array.getBaseObject(); - for (long offset = baseOffset; offset < limit; offset += 16) { + for (long offset = array.getBaseOffset(); offset < limit; offset += 16) { long value = Platform.getLong(baseObject, offset + 8); bitwiseMax |= value; bitwiseMin &= value; @@ -226,7 +220,7 @@ public class RadixSort { for (int i = startByteIndex; i <= endByteIndex; i++) { if (((bitsChanged >>> (i * 8)) & 0xff) != 0) { counts[i] = new long[256]; - for (long offset = baseOffset; offset < limit; offset += 16) { + for (long offset = array.getBaseOffset(); offset < limit; offset += 16) { counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++; } } @@ -244,8 +238,8 @@ public class RadixSort { long[] offsets = transformCountsToOffsets( counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); Object baseObject = array.getBaseObject(); - long baseOffset = array.getBaseOffset() + inIndex * 8L; - long maxOffset = baseOffset + numRecords * 16L; + long baseOffset = array.getBaseOffset() + inIndex * 8; + long maxOffset = baseOffset + numRecords * 16; for (long offset = baseOffset; offset < maxOffset; offset += 16) { long key = Platform.getLong(baseObject, offset); long prefix = Platform.getLong(baseObject, offset + 8); http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index ec15f0b..e14a23f 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -369,8 +369,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { /** * Write a record to the sorter. */ - public void insertRecord( - Object recordBase, long recordOffset, int length, long prefix, boolean prefixIsNull) + public void insertRecord(Object recordBase, long recordOffset, int length, long prefix) throws IOException { growPointerArrayIfNecessary(); @@ -385,7 +384,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; assert(inMemSorter != null); - inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); + inMemSorter.insertRecord(recordAddress, prefix); } /** @@ -397,7 +396,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { * record length = key length + value length + 4 */ public void insertKVRecord(Object keyBase, long keyOffset, int keyLen, - Object valueBase, long valueOffset, int valueLen, long prefix, boolean prefixIsNull) + Object valueBase, long valueOffset, int valueLen, long prefix) throws IOException { growPointerArrayIfNecessary(); @@ -416,7 +415,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { pageCursor += valueLen; assert(inMemSorter != null); - inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); + inMemSorter.insertRecord(recordAddress, prefix); } /** @@ -466,7 +465,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private boolean loaded = false; private int numRecords = 0; - SpillableIterator(UnsafeSorterIterator inMemIterator) { + SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) { this.upstream = inMemIterator; this.numRecords = inMemIterator.getNumRecords(); } http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 78da389..c7b070f 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -18,7 +18,6 @@ package org.apache.spark.util.collection.unsafe.sort; import java.util.Comparator; -import java.util.LinkedList; import org.apache.avro.reflect.Nullable; @@ -94,14 +93,6 @@ public final class UnsafeInMemorySorter { private int pos = 0; /** - * If sorting with radix sort, specifies the starting position in the sort buffer where records - * with non-null prefixes are kept. Positions [0..nullBoundaryPos) will contain null-prefixed - * records, and positions [nullBoundaryPos..pos) non-null prefixed records. This lets us avoid - * radix sorting over null values. - */ - private int nullBoundaryPos = 0; - - /* * How many records could be inserted, because part of the array should be left for sorting. */ private int usableCapacity = 0; @@ -169,7 +160,6 @@ public final class UnsafeInMemorySorter { usableCapacity = getUsableCapacity(); } pos = 0; - nullBoundaryPos = 0; } /** @@ -216,27 +206,14 @@ public final class UnsafeInMemorySorter { * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. * @param keyPrefix a user-defined key prefix */ - public void insertRecord(long recordPointer, long keyPrefix, boolean prefixIsNull) { + public void insertRecord(long recordPointer, long keyPrefix) { if (!hasSpaceForAnotherRecord()) { throw new IllegalStateException("There is no space for new record"); } - if (prefixIsNull && radixSortSupport != null) { - // Swap forward a non-null record to make room for this one at the beginning of the array. - array.set(pos, array.get(nullBoundaryPos)); - pos++; - array.set(pos, array.get(nullBoundaryPos + 1)); - pos++; - // Place this record in the vacated position. - array.set(nullBoundaryPos, recordPointer); - nullBoundaryPos++; - array.set(nullBoundaryPos, keyPrefix); - nullBoundaryPos++; - } else { - array.set(pos, recordPointer); - pos++; - array.set(pos, keyPrefix); - pos++; - } + array.set(pos, recordPointer); + pos++; + array.set(pos, keyPrefix); + pos++; } public final class SortedIterator extends UnsafeSorterIterator implements Cloneable { @@ -303,14 +280,15 @@ public final class UnsafeInMemorySorter { * Return an iterator over record pointers in sorted order. For efficiency, all calls to * {@code next()} will return the same mutable object. */ - public UnsafeSorterIterator getSortedIterator() { + public SortedIterator getSortedIterator() { int offset = 0; long start = System.nanoTime(); if (sortComparator != null) { if (this.radixSortSupport != null) { + // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they + // force a full-width sort (and we cannot radix-sort nullable long fields at all). offset = RadixSort.sortKeyPrefixArray( - array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7, - radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); + array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { MemoryBlock unused = new MemoryBlock( array.getBaseObject(), @@ -323,20 +301,6 @@ public final class UnsafeInMemorySorter { } } totalSortTimeNanos += System.nanoTime() - start; - if (nullBoundaryPos > 0) { - assert radixSortSupport != null : "Nulls are only stored separately with radix sort"; - LinkedList<UnsafeSorterIterator> queue = new LinkedList<>(); - if (radixSortSupport.sortDescending()) { - // Nulls are smaller than non-nulls - queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset)); - queue.add(new SortedIterator(nullBoundaryPos / 2, 0)); - } else { - queue.add(new SortedIterator(nullBoundaryPos / 2, 0)); - queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset)); - } - return new UnsafeExternalSorter.ChainedIterator(queue); - } else { - return new SortedIterator(pos / 2, offset); - } + return new SortedIterator(pos / 2, offset); } } http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index bce958c..2cae4be 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -156,14 +156,14 @@ public class UnsafeExternalSorterSuite { private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { final int[] arr = new int[]{ value }; - sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value, false); + sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value); } private static void insertRecord( UnsafeExternalSorter sorter, int[] record, long prefix) throws IOException { - sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix, false); + sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix); } private UnsafeExternalSorter newSorter() throws IOException { @@ -206,13 +206,13 @@ public class UnsafeExternalSorterSuite { @Test public void testSortingEmptyArrays() throws Exception { final UnsafeExternalSorter sorter = newSorter(); - sorter.insertRecord(null, 0, 0, 0, false); - sorter.insertRecord(null, 0, 0, 0, false); + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); sorter.spill(); - sorter.insertRecord(null, 0, 0, 0, false); + sorter.insertRecord(null, 0, 0, 0); sorter.spill(); - sorter.insertRecord(null, 0, 0, 0, false); - sorter.insertRecord(null, 0, 0, 0, false); + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); UnsafeSorterIterator iter = sorter.getSortedIterator(); @@ -232,7 +232,7 @@ public class UnsafeExternalSorterSuite { long prevSortTime = sorter.getSortTimeNanos(); assertEquals(prevSortTime, 0); - sorter.insertRecord(null, 0, 0, 0, false); + sorter.insertRecord(null, 0, 0, 0); sorter.spill(); assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime)); prevSortTime = sorter.getSortTimeNanos(); @@ -240,7 +240,7 @@ public class UnsafeExternalSorterSuite { sorter.spill(); // no sort needed assertEquals(sorter.getSortTimeNanos(), prevSortTime); - sorter.insertRecord(null, 0, 0, 0, false); + sorter.insertRecord(null, 0, 0, 0); UnsafeSorterIterator iter = sorter.getSortedIterator(); assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime)); } @@ -280,7 +280,7 @@ public class UnsafeExternalSorterSuite { final UnsafeExternalSorter sorter = newSorter(); byte[] record = new byte[16]; while (sorter.getNumberOfAllocatedPages() < 2) { - sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0, false); + sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0); } sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); @@ -340,7 +340,7 @@ public class UnsafeExternalSorterSuite { int n = (int) pageSizeBytes / recordSize * 3; for (int i = 0; i < n; i++) { record[0] = (long) i; - sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false); + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); } assertTrue(sorter.getNumberOfAllocatedPages() >= 2); UnsafeExternalSorter.SpillableIterator iter = @@ -372,7 +372,7 @@ public class UnsafeExternalSorterSuite { int n = (int) pageSizeBytes / recordSize * 3; for (int i = 0; i < n; i++) { record[0] = (long) i; - sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false); + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); } assertTrue(sorter.getNumberOfAllocatedPages() >= 2); UnsafeExternalSorter.SpillableIterator iter = @@ -406,7 +406,7 @@ public class UnsafeExternalSorterSuite { int batch = n / 4; for (int i = 0; i < n; i++) { record[0] = (long) i; - sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false); + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); if (i % batch == batch - 1) { sorter.spill(); } http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index bd89085..383c5b3 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -120,7 +120,7 @@ public class UnsafeInMemorySorterSuite { final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); final String str = getStringFromDataPage(baseObject, position + 4, recordLength); final int partitionId = hashPartitioner.getPartition(str); - sorter.insertRecord(address, partitionId, false); + sorter.insertRecord(address, partitionId); position += 4 + recordLength; } final UnsafeSorterIterator iter = sorter.getSortedIterator(); http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index 2c13806..1d26d4a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -152,7 +152,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & 0xff) referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator) val outOffset = RadixSort.sortKeyPrefixArray( - buf2, 0, N, sortType.startByteIdx, sortType.endByteIdx, + buf2, N, sortType.startByteIdx, sortType.endByteIdx, sortType.descending, sortType.signed) val res1 = collectToArray(buf1, 0, N * 2) val res2 = collectToArray(buf2, outOffset, N * 2) @@ -177,7 +177,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & mask) referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator) val outOffset = RadixSort.sortKeyPrefixArray( - buf2, 0, N, sortType.startByteIdx, sortType.endByteIdx, + buf2, N, sortType.startByteIdx, sortType.endByteIdx, sortType.descending, sortType.signed) val res1 = collectToArray(buf1, 0, N * 2) val res2 = collectToArray(buf2, outOffset, N * 2) http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index ad76bf5..37fbad4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -51,20 +51,7 @@ public final class UnsafeExternalRowSorter { private final UnsafeExternalSorter sorter; public abstract static class PrefixComputer { - - public static class Prefix { - /** Key prefix value, or the null prefix value if isNull = true. **/ - long value; - - /** Whether the key is null. */ - boolean isNull; - } - - /** - * Computes prefix for the given row. For efficiency, the returned object may be reused in - * further calls to a given PrefixComputer. - */ - abstract Prefix computePrefix(InternalRow row); + abstract long computePrefix(InternalRow row); } public UnsafeExternalRowSorter( @@ -101,13 +88,12 @@ public final class UnsafeExternalRowSorter { } public void insertRow(UnsafeRow row) throws IOException { - final PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row); + final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes(), - prefix.value, - prefix.isNull + prefix ); numRowsInserted++; if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) { http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index de779ed..42a8be6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -64,21 +64,10 @@ case class SortOrder(child: Expression, direction: SortDirection) } /** - * An expression to generate a 64-bit long prefix used in sorting. If the sort must operate over - * null keys as well, this.nullValue can be used in place of emitted null prefixes in the sort. + * An expression to generate a 64-bit long prefix used in sorting. */ case class SortPrefix(child: SortOrder) extends UnaryExpression { - val nullValue = child.child.dataType match { - case BooleanType | DateType | TimestampType | _: IntegralType => - Long.MinValue - case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => - Long.MinValue - case _: DecimalType => - DoublePrefixComparator.computePrefix(Double.NegativeInfinity) - case _ => 0L - } - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -86,19 +75,20 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { val input = childCode.value val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName val DoublePrefixCmp = classOf[DoublePrefixComparator].getName - val prefixCode = child.child.dataType match { + + val (nullValue: Long, prefixCode: String) = child.child.dataType match { case BooleanType => - s"$input ? 1L : 0L" + (Long.MinValue, s"$input ? 1L : 0L") case _: IntegralType => - s"(long) $input" + (Long.MinValue, s"(long) $input") case DateType | TimestampType => - s"(long) $input" + (Long.MinValue, s"(long) $input") case FloatType | DoubleType => - s"$DoublePrefixCmp.computePrefix((double)$input)" - case StringType => s"$input.getPrefix()" - case BinaryType => s"$BinaryPrefixCmp.computePrefix($input)" + (0L, s"$DoublePrefixCmp.computePrefix((double)$input)") + case StringType => (0L, s"$input.getPrefix()") + case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)") case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => - if (dt.precision <= Decimal.MAX_LONG_DIGITS) { + val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) { s"$input.toUnscaledLong()" } else { // reduce the scale to fit in a long @@ -106,15 +96,17 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { val s = p - (dt.precision - dt.scale) s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L" } + (Long.MinValue, prefix) case dt: DecimalType => - s"$DoublePrefixCmp.computePrefix($input.toDouble())" - case _ => "0L" + (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), + s"$DoublePrefixCmp.computePrefix($input.toDouble())") + case _ => (0L, "0L") } ev.copy(code = childCode.code + s""" - |long ${ev.value} = 0L; - |boolean ${ev.isNull} = ${childCode.isNull}; + |long ${ev.value} = ${nullValue}L; + |boolean ${ev.isNull} = false; |if (!${childCode.isNull}) { | ${ev.value} = $prefixCode; |} http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 99fe51d..bb823cd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -118,10 +118,9 @@ public final class UnsafeKVExternalSorter { // Compute prefix row.pointTo(baseObject, baseOffset, loc.getKeyLength()); - final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix = - prefixComputer.computePrefix(row); + final long prefix = prefixComputer.computePrefix(row); - inMemSorter.insertRecord(address, prefix.value, prefix.isNull); + inMemSorter.insertRecord(address, prefix); } sorter = UnsafeExternalSorter.createWithExistingInMemorySorter( @@ -147,12 +146,10 @@ public final class UnsafeKVExternalSorter { * sorted runs, and then reallocates memory to hold the new record. */ public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException { - final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix = - prefixComputer.computePrefix(key); + final long prefix = prefixComputer.computePrefix(key); sorter.insertKVRecord( key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(), - value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), - prefix.value, prefix.isNull); + value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix); } /** http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 6db7f45..66a16ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -68,16 +68,10 @@ case class SortExec( SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression) // The generator for prefix - val prefixExpr = SortPrefix(boundSortExpression) - val prefixProjection = UnsafeProjection.create(Seq(prefixExpr)) + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix - override def computePrefix(row: InternalRow): - UnsafeExternalRowSorter.PrefixComputer.Prefix = { - val prefix = prefixProjection.apply(row) - result.isNull = prefix.isNullAt(0) - result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0) - result + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) } } http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 940467e..1a5ff5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -33,11 +33,6 @@ object SortPrefixUtils { override def compare(prefix1: Long, prefix2: Long): Int = 0 } - /** - * Dummy sort prefix result to use for empty rows. - */ - private val emptyPrefix = new UnsafeExternalRowSorter.PrefixComputer.Prefix - def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { case StringType => @@ -75,6 +70,10 @@ object SortPrefixUtils { */ def canSortFullyWithPrefix(sortOrder: SortOrder): Boolean = { sortOrder.dataType match { + // TODO(ekl) long-type is problematic because it's null prefix representation collides with + // the lowest possible long value. Handle this special case outside radix sort. + case LongType if sortOrder.nullable => + false case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType | FloatType | DoubleType => true @@ -98,29 +97,16 @@ object SortPrefixUtils { def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = { if (schema.nonEmpty) { val boundReference = BoundReference(0, schema.head.dataType, nullable = true) - val prefixExpr = SortPrefix(SortOrder(boundReference, Ascending)) - val prefixProjection = UnsafeProjection.create(prefixExpr) + val prefixProjection = UnsafeProjection.create( + SortPrefix(SortOrder(boundReference, Ascending))) new UnsafeExternalRowSorter.PrefixComputer { - private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix - override def computePrefix(row: InternalRow): - UnsafeExternalRowSorter.PrefixComputer.Prefix = { - val prefix = prefixProjection.apply(row) - if (prefix.isNullAt(0)) { - result.isNull = true - result.value = prefixExpr.nullValue - } else { - result.isNull = false - result.value = prefix.getLong(0) - } - result + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) } } } else { new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): - UnsafeExternalRowSorter.PrefixComputer.Prefix = { - emptyPrefix - } + override def computePrefix(row: InternalRow): Long = 0 } } } http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala index 1b9634c..97bbab6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala @@ -347,13 +347,13 @@ case class WindowExec( SparkEnv.get.memoryManager.pageSizeBytes, false) rows.foreach { r => - sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false) + sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0) } rows.clear() } } else { sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, - nextRow.getSizeInBytes, 0, false) + nextRow.getSizeInBytes, 0) } fetchNextRow() } http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index d870d91..88f78a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -53,7 +53,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField val partition = split.asInstanceOf[CartesianPartition] for (y <- rdd2.iterator(partition.s2, context)) { - sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false) + sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) } // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index ba3fa37..c3acf29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -54,17 +54,6 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = false) } - test("sorting all nulls") { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF().selectExpr("NULL as a"), - (child: SparkPlan) => - GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)), - (child: SparkPlan) => - GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)), - sortAnswers = false - ) - } - test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), http://git-wip-us.apache.org/repos/asf/spark/blob/796dd151/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 50ae26a..9964b73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -110,7 +110,7 @@ class SortBenchmark extends BenchmarkBase { benchmark.addTimerCase("radix sort key prefix array") { timer => val (_, buf2) = generateKeyPrefixTestData(size, rand.nextLong) timer.startTiming() - RadixSort.sortKeyPrefixArray(buf2, 0, size, 0, 7, false, false) + RadixSort.sortKeyPrefixArray(buf2, size, 0, 7, false, false) timer.stopTiming() } benchmark.run() --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
