Repository: spark
Updated Branches:
  refs/heads/branch-2.1 09f70f5fd -> 4d2d3d47e


[SPARK-23207][SPARK-22905][SPARK-24564][SPARK-25114][SQL][BACKPORT-2.1] 
Shuffle+Repartition on a DataFrame could lead to incorrect answers

## What changes were proposed in this pull request?

    Back port of #20393 and #22079.

    Currently shuffle repartition uses RoundRobinPartitioning, the generated 
result is nondeterministic since the sequence of input rows are not determined.

    The bug can be triggered when there is a repartition call following a 
shuffle (which would lead to non-deterministic row ordering), as the pattern 
shows below:
    upstream stage -> repartition stage -> result stage
    (-> indicate a shuffle)
    When one of the executors process goes down, some tasks on the repartition 
stage will be retried and generate inconsistent ordering, and some tasks of the 
result stage will be retried generating different data.

    The following code returns 931532, instead of 1000000:
    ```
    import scala.sys.process._

    import org.apache.spark.TaskContext
    val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x =>
      x
    }.repartition(200).map { x =>
      if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 
2) {
        throw new Exception("pkill -f java".!!)
      }
      x
    }
    res.distinct().count()
    ```

    In this PR, we propose a most straight-forward way to fix this problem by 
performing a local sort before partitioning, after we make the input row 
ordering deterministic, the function from rows to partitions is fully 
deterministic too.

    The downside of the approach is that with extra local sort inserted, the 
performance of repartition() will go down, so we add a new config named 
`spark.sql.execution.sortBeforeRepartition` to control whether this patch is 
applied. The patch is default enabled to be safe-by-default, but user may 
choose to manually turn it off to avoid performance regression.

    This patch also changes the output rows ordering of repartition(), that 
leads to a bunch of test cases failure because they are comparing the results 
directly.

    Add unit test in ExchangeSuite.

    With this patch(and `spark.sql.execution.sortBeforeRepartition` set to 
true), the following query returns 1000000:
    ```
    import scala.sys.process._

    import org.apache.spark.TaskContext

    spark.conf.set("spark.sql.execution.sortBeforeRepartition", "true")

    val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x =>
      x
    }.repartition(200).map { x =>
      if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 
2) {
        throw new Exception("pkill -f java".!!)
      }
      x
    }
    res.distinct().count()

    res7: Long = 1000000
    ```

    Author: Xingbo Jiang <xingbo.jiangdatabricks.com>

Author: Xingbo Jiang <xingbo.ji...@databricks.com>
Author: Henry Robinson <he...@apache.org>

Closes #22211 from henryr/spark-23207-branch-2.1.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4d2d3d47
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4d2d3d47
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4d2d3d47

Branch: refs/heads/branch-2.1
Commit: 4d2d3d47e00e78893b1ecd5a9a9070adc5243ac9
Parents: 09f70f5
Author: Xingbo Jiang <xingbo.ji...@databricks.com>
Authored: Mon Aug 27 16:20:19 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Mon Aug 27 16:20:19 2018 -0700

----------------------------------------------------------------------
 .../unsafe/sort/RecordComparator.java           |   4 +-
 .../unsafe/sort/UnsafeInMemorySorter.java       |   7 +-
 .../unsafe/sort/UnsafeSorterSpillMerger.java    |   4 +-
 .../main/scala/org/apache/spark/rdd/RDD.scala   |   2 +
 .../apache/spark/memory/TestMemoryConsumer.java |  10 +
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |   4 +-
 .../unsafe/sort/UnsafeInMemorySorterSuite.java  |   4 +-
 .../mllib/clustering/GaussianMixtureModel.scala |   2 +-
 .../spark/mllib/feature/ChiSqSelector.scala     |   2 +-
 .../apache/spark/ml/feature/Word2VecSuite.scala |   3 +-
 .../sql/execution/RecordBinaryComparator.java   |  74 +++++
 .../sql/execution/UnsafeExternalRowSorter.java  |  52 ++-
 .../org/apache/spark/sql/internal/SQLConf.scala |  14 +
 .../sql/execution/UnsafeKVExternalSorter.java   |  10 +-
 .../apache/spark/sql/execution/SortExec.scala   |   2 +-
 .../execution/exchange/ShuffleExchange.scala    |  54 +++-
 .../sort/RecordBinaryComparatorSuite.java       | 322 +++++++++++++++++++
 .../spark/sql/execution/ExchangeSuite.scala     |  26 +-
 .../datasources/parquet/ParquetIOSuite.scala    |   6 +-
 .../execution/streaming/ForeachSinkSuite.scala  |   4 +-
 20 files changed, 576 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
index 09e4258..02b5de8 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
@@ -32,6 +32,8 @@ public abstract class RecordComparator {
   public abstract int compare(
     Object leftBaseObject,
     long leftBaseOffset,
+    int leftBaseLength,
     Object rightBaseObject,
-    long rightBaseOffset);
+    long rightBaseOffset,
+    int rightBaseLength);
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 5b42843..87bd186 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -62,12 +62,13 @@ public final class UnsafeInMemorySorter {
       int uaoSize = UnsafeAlignedOffset.getUaoSize();
       if (prefixComparisonResult == 0) {
         final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
-        // skip length
         final long baseOffset1 = 
memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize;
+        final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, 
baseOffset1 - uaoSize);
         final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
-        // skip length
         final long baseOffset2 = 
memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize;
-        return recordComparator.compare(baseObject1, baseOffset1, baseObject2, 
baseOffset2);
+        final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, 
baseOffset2 - uaoSize);
+        return recordComparator.compare(baseObject1, baseOffset1, baseLength1, 
baseObject2,
+          baseOffset2, baseLength2);
       } else {
         return prefixComparisonResult;
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index 01aed95..f17d79e 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -38,8 +38,8 @@ final class UnsafeSorterSpillMerger {
           prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
         if (prefixComparisonResult == 0) {
           return recordComparator.compare(
-            left.getBaseObject(), left.getBaseOffset(),
-            right.getBaseObject(), right.getBaseOffset());
+            left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(),
+            right.getBaseObject(), right.getBaseOffset(), 
right.getRecordLength());
         } else {
           return prefixComparisonResult;
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 199a377..283545b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -413,6 +413,8 @@ abstract class RDD[T: ClassTag](
    *
    * If you are decreasing the number of partitions in this RDD, consider 
using `coalesce`,
    * which can avoid performing a shuffle.
+   *
+   * TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207.
    */
   def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): 
RDD[T] = withScope {
     coalesce(numPartitions, shuffle = true)

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java 
b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
index db91329..0bbaea6 100644
--- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
+++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
@@ -17,6 +17,10 @@
 
 package org.apache.spark.memory;
 
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
 import java.io.IOException;
 
 public class TestMemoryConsumer extends MemoryConsumer {
@@ -43,6 +47,12 @@ public class TestMemoryConsumer extends MemoryConsumer {
     used -= size;
     taskMemoryManager.releaseExecutionMemory(size, this);
   }
+
+  @VisibleForTesting
+  public void freePage(MemoryBlock page) {
+    used -= page.size();
+    taskMemoryManager.freePage(page, this);
+  }
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index fbbe530..ec890d1 100644
--- 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -73,8 +73,10 @@ public class UnsafeExternalSorterSuite {
     public int compare(
       Object leftBaseObject,
       long leftBaseOffset,
+      int leftBaseLength,
       Object rightBaseObject,
-      long rightBaseOffset) {
+      long rightBaseOffset,
+      int rightBaseLength) {
       return 0;
     }
   };

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index bd89085..377d4a3 100644
--- 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -97,8 +97,10 @@ public class UnsafeInMemorySorterSuite {
       public int compare(
         Object leftBaseObject,
         long leftBaseOffset,
+        int leftBaseLength,
         Object rightBaseObject,
-        long rightBaseOffset) {
+        long rightBaseOffset,
+        int rightBaseLength) {
         return 0;
       }
     };

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index afbe4f9..1933d54 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -154,7 +154,7 @@ object GaussianMixtureModel extends 
Loader[GaussianMixtureModel] {
       val dataArray = Array.tabulate(weights.length) { i =>
         Data(weights(i), gaussians(i).mu, gaussians(i).sigma)
       }
-      
spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path))
+      spark.createDataFrame(sc.makeRDD(dataArray, 
1)).write.parquet(Loader.dataPath(path))
     }
 
     def load(sc: SparkContext, path: String): GaussianMixtureModel = {

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 7ef2a95..95ae154 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -144,7 +144,7 @@ object ChiSqSelectorModel extends 
Loader[ChiSqSelectorModel] {
       val dataArray = Array.tabulate(model.selectedFeatures.length) { i =>
         Data(model.selectedFeatures(i))
       }
-      
spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path))
+      spark.createDataFrame(sc.makeRDD(dataArray, 
1)).write.parquet(Loader.dataPath(path))
     }
 
     def load(sc: SparkContext, path: String): ChiSqSelectorModel = {

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 613cc3d..78efa0d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -205,7 +205,8 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
     val oldModel = new OldWord2VecModel(word2VecMap)
     val instance = new Word2VecModel("myWord2VecModel", oldModel)
     val newInstance = testDefaultReadWrite(instance)
-    assert(newInstance.getVectors.collect() === instance.getVectors.collect())
+    assert(newInstance.getVectors.collect().sortBy(_.getString(0)) ===
+      instance.getVectors.collect().sortBy(_.getString(0)))
   }
 
   test("Word2Vec works with input that is non-nullable (NGram)") {

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
new file mode 100644
index 0000000..40c2cc8
--- /dev/null
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
+
+public final class RecordBinaryComparator extends RecordComparator {
+
+  @Override
+  public int compare(
+      Object leftObj, long leftOff, int leftLen, Object rightObj, long 
rightOff, int rightLen) {
+    int i = 0;
+
+    // If the arrays have different length, the longer one is larger.
+    if (leftLen != rightLen) {
+      return leftLen - rightLen;
+    }
+
+    // The following logic uses `leftLen` as the length for both `leftObj` and 
`rightObj`, since
+    // we have guaranteed `leftLen` == `rightLen`.
+
+    // check if stars align and we can get both offsets to be aligned
+    if ((leftOff % 8) == (rightOff % 8)) {
+      while ((leftOff + i) % 8 != 0 && i < leftLen) {
+        final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff;
+        final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff;
+        if (v1 != v2) {
+          return v1 > v2 ? 1 : -1;
+        }
+        i += 1;
+      }
+    }
+    // for architectures that support unaligned accesses, chew it up 8 bytes 
at a time
+    if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 
8 == 0))) {
+      while (i <= leftLen - 8) {
+        final long v1 = Platform.getLong(leftObj, leftOff + i);
+        final long v2 = Platform.getLong(rightObj, rightOff + i);
+        if (v1 != v2) {
+          return v1 > v2 ? 1 : -1;
+        }
+        i += 8;
+      }
+    }
+    // this will finish off the unaligned comparisons, or do the entire 
aligned comparison
+    // whichever is needed.
+    while (i < leftLen) {
+      final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff;
+      final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff;
+      if (v1 != v2) {
+        return v1 > v2 ? 1 : -1;
+      }
+      i += 1;
+    }
+
+    // The two arrays are equal.
+    return 0;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index c29b002..402f7a7 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -51,26 +51,55 @@ public final class UnsafeExternalRowSorter {
   private final PrefixComputer prefixComputer;
   private final UnsafeExternalSorter sorter;
 
+  public static interface RecordComparatorSupplier {
+    public RecordComparator get();
+  }
+
   public abstract static class PrefixComputer {
 
     public static class Prefix {
       /** Key prefix value, or the null prefix value if isNull = true. **/
-      long value;
+      public long value;
 
       /** Whether the key is null. */
-      boolean isNull;
+      public boolean isNull;
     }
 
     /**
      * Computes prefix for the given row. For efficiency, the returned object 
may be reused in
      * further calls to a given PrefixComputer.
      */
-    abstract Prefix computePrefix(InternalRow row);
+    public abstract Prefix computePrefix(InternalRow row);
+  }
+
+  public static UnsafeExternalRowSorter createWithRecordComparator(
+      StructType schema,
+      RecordComparatorSupplier recordComparatorSupplier,
+      PrefixComparator prefixComparator,
+      PrefixComputer prefixComputer,
+      long pageSizeBytes,
+      boolean canUseRadixSort) throws IOException {
+    return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, 
prefixComparator,
+      prefixComputer, pageSizeBytes, canUseRadixSort);
+  }
+
+  public static UnsafeExternalRowSorter create(
+      final StructType schema,
+      final Ordering<InternalRow> ordering,
+      PrefixComparator prefixComparator,
+      PrefixComputer prefixComputer,
+      long pageSizeBytes,
+      boolean canUseRadixSort) throws IOException {
+    RecordComparatorSupplier recordComparatorSupplier = new 
RecordComparatorSupplier() {
+        public RecordComparator get() { return new RowComparator(ordering, 
schema.length()); }
+      };
+    return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, 
prefixComparator,
+      prefixComputer, pageSizeBytes, canUseRadixSort);
   }
 
-  public UnsafeExternalRowSorter(
+  private UnsafeExternalRowSorter(
       StructType schema,
-      Ordering<InternalRow> ordering,
+      RecordComparatorSupplier recordComparatorSupplier,
       PrefixComparator prefixComparator,
       PrefixComputer prefixComputer,
       long pageSizeBytes,
@@ -84,7 +113,7 @@ public final class UnsafeExternalRowSorter {
       sparkEnv.blockManager(),
       sparkEnv.serializerManager(),
       taskContext,
-      new RowComparator(ordering, schema.length()),
+      recordComparatorSupplier.get(),
       prefixComparator,
       sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize",
                              DEFAULT_INITIAL_SORT_BUFFER_SIZE),
@@ -207,8 +236,15 @@ public final class UnsafeExternalRowSorter {
     }
 
     @Override
-    public int compare(Object baseObj1, long baseOff1, Object baseObj2, long 
baseOff2) {
-      // TODO: Why are the sizes -1?
+    public int compare(
+        Object baseObj1,
+        long baseOff1,
+        int baseLen1,
+        Object baseObj2,
+        long baseOff2,
+        int baseLen2) {
+      // Note that since ordering doesn't need the total length of the record, 
we just pass -1
+      // into the row.
       row1.pointTo(baseObj1, baseOff1, -1);
       row2.pointTo(baseObj2, baseOff2, -1);
       return ordering.compare(row1, row2);

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index ed61ff0..b8a32f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -641,6 +641,18 @@ object SQLConf {
     .booleanConf
     .createWithDefault(false)
 
+  val SORT_BEFORE_REPARTITION =
+    SQLConfigBuilder("spark.sql.execution.sortBeforeRepartition")
+      .internal()
+      .doc("When perform a repartition following a shuffle, the output row 
ordering would be " +
+        "nondeterministic. If some downstream stages fail and some tasks of 
the repartition " +
+        "stage retry, these tasks may generate different data, and that can 
lead to correctness " +
+        "issues. Turn on this config to insert a local sort before actually 
doing repartition " +
+        "to generate consistent repartition results. The performance of 
repartition() may go " +
+        "down since we insert extra local sort before it.")
+      .booleanConf
+      .createWithDefault(true)
+
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
   }
@@ -758,6 +770,8 @@ class SQLConf extends Serializable with Logging {
 
   def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)
 
+  def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION)
+
   /**
    * Returns the [[Resolver]] for the current configuration, which can be used 
to determine if two
    * identifiers are equal.

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 0d51dc9..2d97a2a 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -235,8 +235,14 @@ public final class UnsafeKVExternalSorter {
     }
 
     @Override
-    public int compare(Object baseObj1, long baseOff1, Object baseObj2, long 
baseOff2) {
-      // Note that since ordering doesn't need the total length of the record, 
we just pass -1
+    public int compare(
+        Object baseObj1,
+        long baseOff1,
+        int baseLen1,
+        Object baseObj2,
+        long baseOff2,
+        int baseLen2) {
+      // Note that since ordering doesn't need the total length of the record, 
we just pass -1 
       // into the row.
       row1.pointTo(baseObj1, baseOff1 + 4, -1);
       row2.pointTo(baseObj2, baseOff2 + 4, -1);

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index cc576bb..197ab24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -84,7 +84,7 @@ case class SortExec(
     }
 
     val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
-    val sorter = new UnsafeExternalRowSorter(
+    val sorter = UnsafeExternalRowSorter.create(
       schema, ordering, prefixComparator, prefixComputer, pageSize, 
canUseRadixSort)
 
     if (testSpillFrequency > 0) {

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index 125a493..5d3fdf2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -30,7 +30,10 @@ import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.MutablePair
+import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, 
RecordComparator}
 
 /**
  * Performs a shuffle that will result in the desired `newPartitioning`.
@@ -239,14 +242,61 @@ object ShuffleExchange {
       case RangePartitioning(_, _) | SinglePartition => identity
       case _ => sys.error(s"Exchange not implemented for $newPartitioning")
     }
+
     val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
-      if (needToCopyObjectsBeforeShuffle(part, serializer)) {
+      // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning 
is deterministic,
+      // otherwise a retry task may output different rows and thus lead to 
data loss.
+      //
+      // Currently we following the most straight-forward way that perform a 
local sort before
+      // partitioning.
+      //
+      // Note that we don't perform local sort if the new partitioning has 
only 1 partition, under
+      // that case all output rows go to the same partition.
+      val newRdd = if (SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION) 
&&
+          newPartitioning.numPartitions > 1 &&
+          newPartitioning.isInstanceOf[RoundRobinPartitioning]) {
         rdd.mapPartitionsInternal { iter =>
+          val recordComparatorSupplier = new 
UnsafeExternalRowSorter.RecordComparatorSupplier {
+            override def get: RecordComparator = new RecordBinaryComparator()
+          }
+          // The comparator for comparing row hashcode, which should always be 
Integer.
+          val prefixComparator = PrefixComparators.LONG
+          val canUseRadixSort = 
SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED)
+          // The prefix computer generates row hashcode as the prefix, so we 
may decrease the
+          // probability that the prefixes are equal when input rows choose 
column values from a
+          // limited range.
+          val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+            private val result = new 
UnsafeExternalRowSorter.PrefixComputer.Prefix
+            override def computePrefix(row: InternalRow):
+            UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+              // The hashcode generated from the binary form of a 
[[UnsafeRow]] should not be null.
+              result.isNull = false
+              result.value = row.hashCode()
+              result
+            }
+          }
+          val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
+
+          val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
+            StructType.fromAttributes(outputAttributes),
+            recordComparatorSupplier,
+            prefixComparator,
+            prefixComputer,
+            pageSize,
+            canUseRadixSort)
+          sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+        }
+      } else {
+        rdd
+      }
+
+      if (needToCopyObjectsBeforeShuffle(part, serializer)) {
+        newRdd.mapPartitionsInternal { iter =>
           val getPartitionKey = getPartitionKeyExtractor()
           iter.map { row => (part.getPartition(getPartitionKey(row)), 
row.copy()) }
         }
       } else {
-        rdd.mapPartitionsInternal { iter =>
+        newRdd.mapPartitionsInternal { iter =>
           val getPartitionKey = getPartitionKeyExtractor()
           val mutablePair = new MutablePair[Int, InternalRow]()
           iter.map { row => 
mutablePair.update(part.getPartition(getPartitionKey(row)), row) }

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
 
b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
new file mode 100644
index 0000000..97f3dc5
--- /dev/null
+++ 
b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
@@ -0,0 +1,322 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.execution.sort;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.TestMemoryConsumer;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.execution.RecordBinaryComparator;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.UnsafeAlignedOffset;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.util.collection.unsafe.sort.*;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Test the RecordBinaryComparator, which compares two UnsafeRows by their 
binary form.
+ */
+public class RecordBinaryComparatorSuite {
+
+  private final TaskMemoryManager memoryManager = new TaskMemoryManager(
+      new TestMemoryManager(new 
SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+  private final TestMemoryConsumer consumer = new 
TestMemoryConsumer(memoryManager);
+
+  private final int uaoSize = UnsafeAlignedOffset.getUaoSize();
+
+  private MemoryBlock dataPage;
+  private long pageCursor;
+
+  private LongArray array;
+  private int pos;
+
+  @Before
+  public void beforeEach() {
+    // Only compare between two input rows.
+    array = consumer.allocateArray(2);
+    pos = 0;
+
+    dataPage = memoryManager.allocatePage(4096, consumer);
+    pageCursor = dataPage.getBaseOffset();
+  }
+
+  @After
+  public void afterEach() {
+    consumer.freePage(dataPage);
+    dataPage = null;
+    pageCursor = 0;
+
+    consumer.freeArray(array);
+    array = null;
+    pos = 0;
+  }
+
+  private void insertRow(UnsafeRow row) {
+    Object recordBase = row.getBaseObject();
+    long recordOffset = row.getBaseOffset();
+    int recordLength = row.getSizeInBytes();
+
+    Object baseObject = dataPage.getBaseObject();
+    assert(pageCursor + recordLength <= dataPage.getBaseOffset() + 
dataPage.size());
+    long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, 
pageCursor);
+    UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength);
+    pageCursor += uaoSize;
+    Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, 
recordLength);
+    pageCursor += recordLength;
+
+    assert(pos < 2);
+    array.set(pos, recordAddress);
+    pos++;
+  }
+
+  private int compare(int index1, int index2) {
+    Object baseObject = dataPage.getBaseObject();
+
+    long recordAddress1 = array.get(index1);
+    long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize;
+    int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - 
uaoSize);
+
+    long recordAddress2 = array.get(index2);
+    long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize;
+    int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - 
uaoSize);
+
+    return binaryComparator.compare(baseObject, baseOffset1, recordLength1, 
baseObject,
+        baseOffset2, recordLength2);
+  }
+
+  private final RecordComparator binaryComparator = new 
RecordBinaryComparator();
+
+  // Compute the most compact size for UnsafeRow's backing data.
+  private int computeSizeInBytes(int originalSize) {
+    // All the UnsafeRows in this suite contains less than 64 columns, so the 
bitSetSize shall
+    // always be 8.
+    return 8 + (originalSize + 7) / 8 * 8;
+  }
+
+  // Compute the relative offset of variable-length values.
+  private long relativeOffset(int numFields) {
+    // All the UnsafeRows in this suite contains less than 64 columns, so the 
bitSetSize shall
+    // always be 8.
+    return 8 + numFields * 8L;
+  }
+
+  @Test
+  public void testBinaryComparatorForSingleColumnRow() throws Exception {
+    int numFields = 1;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    row1.setInt(0, 11);
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    row2.setInt(0, 42);
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 0) == 0);
+    assert(compare(0, 1) < 0);
+  }
+
+  @Test
+  public void testBinaryComparatorForMultipleColumnRow() throws Exception {
+    int numFields = 5;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    for (int i = 0; i < numFields; i++) {
+      row1.setDouble(i, i * 3.14);
+    }
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    for (int i = 0; i < numFields; i++) {
+      row2.setDouble(i, 198.7 / (i + 1));
+    }
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 0) == 0);
+    assert(compare(0, 1) < 0);
+  }
+
+  @Test
+  public void testBinaryComparatorForArrayColumn() throws Exception {
+    int numFields = 1;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new 
int[]{11, 42, -1});
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8 + 
arrayData1.getSizeInBytes()));
+    row1.setLong(0, (relativeOffset(numFields) << 32) | (long) 
arrayData1.getSizeInBytes());
+    Platform.copyMemory(arrayData1.getBaseObject(), 
arrayData1.getBaseOffset(), data1,
+        row1.getBaseOffset() + relativeOffset(numFields), 
arrayData1.getSizeInBytes());
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new 
int[]{22});
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8 + 
arrayData2.getSizeInBytes()));
+    row2.setLong(0, (relativeOffset(numFields) << 32) | (long) 
arrayData2.getSizeInBytes());
+    Platform.copyMemory(arrayData2.getBaseObject(), 
arrayData2.getBaseOffset(), data2,
+        row2.getBaseOffset() + relativeOffset(numFields), 
arrayData2.getSizeInBytes());
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 0) == 0);
+    assert(compare(0, 1) > 0);
+  }
+
+  @Test
+  public void testBinaryComparatorForMixedColumns() throws Exception {
+    int numFields = 4;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    UTF8String str1 = UTF8String.fromString("Milk tea");
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes()));
+    row1.setInt(0, 11);
+    row1.setDouble(1, 3.14);
+    row1.setInt(2, -1);
+    row1.setLong(3, (relativeOffset(numFields) << 32) | (long) 
str1.numBytes());
+    Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1,
+        row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes());
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    UTF8String str2 = UTF8String.fromString("Java");
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes()));
+    row2.setInt(0, 11);
+    row2.setDouble(1, 3.14);
+    row2.setInt(2, -1);
+    row2.setLong(3, (relativeOffset(numFields) << 32) | (long) 
str2.numBytes());
+    Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2,
+        row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes());
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 0) == 0);
+    assert(compare(0, 1) > 0);
+  }
+
+  @Test
+  public void testBinaryComparatorForNullColumns() throws Exception {
+    int numFields = 3;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    for (int i = 0; i < numFields; i++) {
+      row1.setNullAt(i);
+    }
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    for (int i = 0; i < numFields - 1; i++) {
+      row2.setNullAt(i);
+    }
+    row2.setDouble(numFields - 1, 3.14);
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 0) == 0);
+    assert(compare(0, 1) > 0);
+  }
+
+  @Test
+  public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() 
throws Exception {
+    int numFields = 1;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    row1.setLong(0, 11);
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    row2.setLong(0, 11L + Integer.MAX_VALUE);
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 1) < 0);
+  }
+
+  @Test
+  public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws 
Exception {
+    int numFields = 1;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    row1.setLong(0, Long.MIN_VALUE);
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    row2.setLong(0, 1);
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 1) < 0);
+  }
+
+  @Test
+  public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws 
Exception {
+    int numFields = 4;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    row1.setInt(0, 11);
+    row1.setDouble(1, 3.14);
+    row1.setInt(2, -1);
+    row1.setLong(3, 0);
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    row2.setInt(0, 11);
+    row2.setDouble(1, 3.14);
+    row2.setInt(2, -1);
+    row2.setLong(3, 1);
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 1) < 0);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 36cde32..2fdd802 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -17,11 +17,14 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.sql.Row
+import scala.util.Random
+
+import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
IdentityBroadcastMode, SinglePartition}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchange}
 import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 
 class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
@@ -101,4 +104,25 @@ class ExchangeSuite extends SparkPlanTest with 
SharedSQLContext {
     assert(!exchange4.sameResult(exchange5))
     assert(exchange5 sameResult exchange4)
   }
+
+  test("SPARK-23207: Make repartition() generate consistent output") {
+    def assertConsistency(ds: Dataset[java.lang.Long]): Unit = {
+      ds.persist()
+
+      val exchange = ds.mapPartitions { iter =>
+        Random.shuffle(iter)
+      }.repartition(111)
+      val exchange2 = ds.repartition(111)
+
+      assert(exchange.rdd.collectPartitions() === 
exchange2.rdd.collectPartitions())
+    }
+
+    withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "true") {
+      // repartition() should generate consistent output.
+      assertConsistency(spark.range(10000))
+
+      // case when input contains duplicated rows.
+      assertConsistency(spark.range(10000).map(i => 
Random.nextInt(1000).toLong))
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index acdadb3..44139f6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -635,7 +635,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest 
with SharedSQLContext {
             val v = (row.getInt(0), row.getString(1))
             result += v
           }
-          assert(data == result)
+          assert(data.toSet == result.toSet)
         } finally {
           reader.close()
         }
@@ -651,7 +651,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest 
with SharedSQLContext {
             val row = reader.getCurrentValue.asInstanceOf[InternalRow]
             result += row.getString(0)
           }
-          assert(data.map(_._2) == result)
+          assert(data.map(_._2).toSet == result.toSet)
         } finally {
           reader.close()
         }
@@ -668,7 +668,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest 
with SharedSQLContext {
             val v = (row.getString(0), row.getInt(1))
             result += v
           }
-          assert(data.map { x => (x._2, x._1) } == result)
+          assert(data.map { x => (x._2, x._1) }.toSet == result.toSet)
         } finally {
           reader.close()
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/4d2d3d47/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
index 9137d65..41434e6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
@@ -52,13 +52,13 @@ class ForeachSinkSuite extends StreamTest with 
SharedSQLContext with BeforeAndAf
 
       var expectedEventsForPartition0 = Seq(
         ForeachSinkSuite.Open(partition = 0, version = 0),
-        ForeachSinkSuite.Process(value = 1),
+        ForeachSinkSuite.Process(value = 2),
         ForeachSinkSuite.Process(value = 3),
         ForeachSinkSuite.Close(None)
       )
       var expectedEventsForPartition1 = Seq(
         ForeachSinkSuite.Open(partition = 1, version = 0),
-        ForeachSinkSuite.Process(value = 2),
+        ForeachSinkSuite.Process(value = 1),
         ForeachSinkSuite.Process(value = 4),
         ForeachSinkSuite.Close(None)
       )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to