Repository: spark
Updated Branches:
  refs/heads/branch-2.2 30149d54c -> fb59a1954


[SPARK-20451] Filter out nested mapType datatypes from sort order in randomSplit

## What changes were proposed in this pull request?

In `randomSplit`, It is possible that the underlying dataset doesn't guarantee 
the ordering of rows in its constituent partitions each time a split is 
materialized which could result in overlapping
splits.

To prevent this, as part of SPARK-12662, we explicitly sort each input 
partition to make the ordering deterministic. Given that `MapTypes` cannot be 
sorted this patch explicitly prunes them out from the sort order. Additionally, 
if the resulting sort order is empty, this patch then materializes the dataset 
to guarantee determinism.

## How was this patch tested?

Extended `randomSplit on reordered partitions` in `DataFrameStatSuite` to also 
test for dataframes with mapTypes nested mapTypes.

Author: Sameer Agarwal <samee...@cs.berkeley.edu>

Closes #17751 from sameeragarwal/randomsplit2.

(cherry picked from commit 31345fde82ada1f8bb12807b250b04726a1f6aa6)
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fb59a195
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fb59a195
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fb59a195

Branch: refs/heads/branch-2.2
Commit: fb59a195428597f50c599fff0c6521604a454400
Parents: 30149d5
Author: Sameer Agarwal <samee...@cs.berkeley.edu>
Authored: Tue Apr 25 13:05:20 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Apr 25 13:05:41 2017 +0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Dataset.scala    | 18 +++++---
 .../apache/spark/sql/DataFrameStatSuite.scala   | 43 +++++++++++++-------
 2 files changed, 41 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fb59a195/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c6dcd93..06dd550 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1726,15 +1726,23 @@ class Dataset[T] private[sql](
     // It is possible that the underlying dataframe doesn't guarantee the 
ordering of rows in its
     // constituent partitions each time a split is materialized which could 
result in
     // overlapping splits. To prevent this, we explicitly sort each input 
partition to make the
-    // ordering deterministic.
-    // MapType cannot be sorted.
-    val sorted = 
Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType])
-      .map(SortOrder(_, Ascending)), global = false, logicalPlan)
+    // ordering deterministic. Note that MapTypes cannot be sorted and are 
explicitly pruned out
+    // from the sort order.
+    val sortOrder = logicalPlan.output
+      .filter(attr => RowOrdering.isOrderable(attr.dataType))
+      .map(SortOrder(_, Ascending))
+    val plan = if (sortOrder.nonEmpty) {
+      Sort(sortOrder, global = false, logicalPlan)
+    } else {
+      // SPARK-12662: If sort order is empty, we materialize the dataset to 
guarantee determinism
+      cache()
+      logicalPlan
+    }
     val sum = weights.sum
     val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
     normalizedCumWeights.sliding(2).map { x =>
       new Dataset[T](
-        sparkSession, Sample(x(0), x(1), withReplacement = false, seed, 
sorted)(), encoder)
+        sparkSession, Sample(x(0), x(1), withReplacement = false, seed, 
plan)(), encoder)
     }.toArray
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fb59a195/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 97890a0..dd118f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with 
SharedSQLContext {
   }
 
   test("randomSplit on reordered partitions") {
-    // This test ensures that randomSplit does not create overlapping splits 
even when the
-    // underlying dataframe (such as the one below) doesn't guarantee a 
deterministic ordering of
-    // rows in each partition.
-    val data =
-      sparkContext.parallelize(1 to 600, 
2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
-    val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
 
-    assert(splits.length == 2, "wrong number of splits")
+    def testNonOverlappingSplits(data: DataFrame): Unit = {
+      val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
+      assert(splits.length == 2, "wrong number of splits")
+
+      // Verify that the splits span the entire dataset
+      assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
 
-    // Verify that the splits span the entire dataset
-    assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
+      // Verify that the splits don't overlap
+      
assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty)
 
-    // Verify that the splits don't overlap
-    assert(splits(0).intersect(splits(1)).collect().isEmpty)
+      // Verify that the results are deterministic across multiple runs
+      val firstRun = splits.toSeq.map(_.collect().toSeq)
+      val secondRun = data.randomSplit(Array[Double](2, 3), seed = 
1).toSeq.map(_.collect().toSeq)
+      assert(firstRun == secondRun)
+    }
 
-    // Verify that the results are deterministic across multiple runs
-    val firstRun = splits.toSeq.map(_.collect().toSeq)
-    val secondRun = data.randomSplit(Array[Double](2, 3), seed = 
1).toSeq.map(_.collect().toSeq)
-    assert(firstRun == secondRun)
+    // This test ensures that randomSplit does not create overlapping splits 
even when the
+    // underlying dataframe (such as the one below) doesn't guarantee a 
deterministic ordering of
+    // rows in each partition.
+    val dataWithInts = sparkContext.parallelize(1 to 600, 2)
+      .mapPartitions(scala.util.Random.shuffle(_)).toDF("int")
+    val dataWithMaps = sparkContext.parallelize(1 to 600, 2)
+      .map(i => (i, Map(i -> i.toString)))
+      .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map")
+    val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2)
+      .map(i => (i, Array(Map(i -> i.toString))))
+      .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps")
+
+    testNonOverlappingSplits(dataWithInts)
+    testNonOverlappingSplits(dataWithMaps)
+    testNonOverlappingSplits(dataWithArrayOfMaps)
   }
 
   test("pearson correlation") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to