Repository: spark Updated Branches: refs/heads/branch-1.6 69a885a71 -> 017b73e69
[SPARK-12662][SQL] Fix DataFrame.randomSplit to avoid creating overlapping splits https://issues.apache.org/jira/browse/SPARK-12662 cc yhuai Author: Sameer Agarwal <[email protected]> Closes #10626 from sameeragarwal/randomsplit. (cherry picked from commit f194d9911a93fc3a78be820096d4836f22d09976) Signed-off-by: Reynold Xin <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/017b73e6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/017b73e6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/017b73e6 Branch: refs/heads/branch-1.6 Commit: 017b73e69693cd151516f92640a95a4a66e02dff Parents: 69a885a Author: Sameer Agarwal <[email protected]> Authored: Thu Jan 7 10:37:15 2016 -0800 Committer: Reynold Xin <[email protected]> Committed: Thu Jan 7 10:37:24 2016 -0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/sql/DataFrame.scala | 7 ++++++- .../apache/spark/sql/DataFrameStatSuite.scala | 22 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/017b73e6/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 74f9370..3180049 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1107,10 +1107,15 @@ class DataFrame private[sql]( * @since 1.4.0 */ def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = { + // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its + // constituent partitions each time a split is materialized which could result in + // overlapping splits. To prevent this, we explicitly sort each input partition to make the + // ordering deterministic. + val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan) val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan)) + new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)) }.toArray } http://git-wip-us.apache.org/repos/asf/spark/blob/017b73e6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index b15af42..63ad6c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -62,6 +62,28 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } } + test("randomSplit on reordered partitions") { + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val data = + sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + + assert(splits.length == 2, "wrong number of splits") + + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + + // Verify that the splits don't overalap + assert(splits(0).intersect(splits(1)).collect().isEmpty) + + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } + test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.stat.corr("a", "b", "pearson") --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
