This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 18e83f9ab23 [SPARK-40407][SQL] Fix the potential data skew caused by 
df.repartition
18e83f9ab23 is described below

commit 18e83f9ab234e2e475842c70f191d90e4dd6e00d
Author: Bobby Wang <[email protected]>
AuthorDate: Thu Sep 22 20:59:00 2022 +0800

    [SPARK-40407][SQL] Fix the potential data skew caused by df.repartition
    
    ### What changes were proposed in this pull request?
    
    ``` scala
    val df = spark.range(0, 100, 1, 50).repartition(4)
    val v = df.rdd.mapPartitions { iter => {
            Iterator.single(iter.length)
    }.collect()
    println(v.mkString(","))
    ```
    
    The above simple code outputs `50,0,0,50`, which means there is no data in 
partition 1 and partition 2.
    
    The RoundRobin seems to ensure to distribute the records evenly *in the 
same partition*, and not guarantee it between partitions.
    
    Below is the code to generate the key
    
    ``` scala
          case RoundRobinPartitioning(numPartitions) =>
            // Distributes elements evenly across output partitions, starting 
from a random partition.
            var position = new 
Random(TaskContext.get().partitionId()).nextInt(numPartitions)
            (row: InternalRow) =>
    {         // The HashPartitioner will handle the `mod` by the number of 
partitions
             position += 1
             position
     }
    ```
    
    In this case, There are 50 partitions, each partition will only compute 2 
elements. The issue for RoundRobin here is it always starts with position=2 to 
do the Roundrobin.
    
    See the output of Random
    ``` scala
    scala> (1 to 200).foreach(partitionId => print(new 
Random(partitionId).nextInt(4) + " "))  // the position is always 2.
    2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 
2 2
    ```
    
    Similarly, the below Random code also outputs the same value,
    
    ``` scala
    (1 to 200).foreach(partitionId => print(new Random(partitionId).nextInt(2) 
+ " "))
    (1 to 200).foreach(partitionId => print(new Random(partitionId).nextInt(4) 
+ " "))
    (1 to 200).foreach(partitionId => print(new Random(partitionId).nextInt(8) 
+ " "))
    (1 to 200).foreach(partitionId => print(new Random(partitionId).nextInt(16) 
+ " "))
    (1 to 200).foreach(partitionId => print(new Random(partitionId).nextInt(32) 
+ " "))
    ```
    
    Consider partition 0, the total elements are [0, 1], so when shuffle 
writes, for element 0, the key will be (position + 1) = 2 + 1 = 3%4=3, the 
element 1, the key will be (position + 1)=(3+1)=4%4 = 0
    consider partition 1, the total elements are [2, 3], so when shuffle 
writes, for element 2, the key will be (position + 1) = 2 + 1 = 3%4=3, the 
element 3, the key will be (position + 1)=(3+1)=4%4 = 0
    
    The calculation is also applied for other left partitions since the 
starting position is always 2 for this case.
    
    So, as you can see, each partition will write its elements to Partition [0, 
3], which results in Partition [1, 2] without any data.
    
    This PR changes the starting position of RoundRobin. The default position 
calculated by `new Random(partitionId).nextInt(numPartitions)` may always be 
the same for different partitions, which means each partition will output the 
data into the same keys when shuffle writes, and some keys may not have any 
data in some special cases.
    
    ### Why are the changes needed?
    
    The PR can fix the data skew issue for the special cases.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Will add some tests and watch CI pass
    
    Closes #37855 from wbo4958/roundrobin-data-skew.
    
    Authored-by: Bobby Wang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit f6c4e58b85d7486c70cd6d58aae208f037e657fa)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/execution/exchange/ShuffleExchangeExec.scala     | 10 +++++++++-
 .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala     |  6 ++++++
 .../spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala  |  4 ++--
 3 files changed, 17 insertions(+), 3 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index c033aedc778..bc8416c5b2e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -21,6 +21,7 @@ import java.util.Random
 import java.util.function.Supplier
 
 import scala.concurrent.Future
+import scala.util.hashing
 
 import org.apache.spark._
 import org.apache.spark.internal.config
@@ -306,7 +307,14 @@ object ShuffleExchangeExec {
     def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match 
{
       case RoundRobinPartitioning(numPartitions) =>
         // Distributes elements evenly across output partitions, starting from 
a random partition.
-        var position = new 
Random(TaskContext.get().partitionId()).nextInt(numPartitions)
+        // nextInt(numPartitions) implementation has a special case when bound 
is a power of 2,
+        // which is basically taking several highest bits from the initial 
seed, with only a
+        // minimal scrambling. Due to deterministic seed, using the generator 
only once,
+        // and lack of scrambling, the position values for power-of-two 
numPartitions always
+        // end up being almost the same regardless of the index. substantially 
scrambling the
+        // seed by hashing will help. Refer to SPARK-21782 for more details.
+        val partitionId = TaskContext.get().partitionId()
+        var position = new 
Random(hashing.byteswap32(partitionId)).nextInt(numPartitions)
         (row: InternalRow) => {
           // The HashPartitioner will handle the `mod` by the number of 
partitions
           position += 1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 347e9fc08af..617314eff4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -2060,6 +2060,12 @@ class DatasetSuite extends QueryTest
       (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12),
       (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13))
   }
+
+  test("SPARK-40407: repartition should not result in severe data skew") {
+    val df = spark.range(0, 100, 1, 50).repartition(4)
+    val result = df.mapPartitions(iter => 
Iterator.single(iter.length)).collect()
+    assert(result.sorted.toSeq === Seq(19, 25, 25, 31))
+  }
 }
 
 case class Bar(a: Int)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 2635c55dedf..f711aac0253 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -1895,8 +1895,8 @@ class AdaptiveQueryExecSuite
         withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "150") {
           // partition size [0,258,72,72,72]
           checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 2, 4)
-          // partition size [72,216,216,144,72]
-          checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 4, 7)
+          // partition size [144,72,144,216,144]
+          checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 2, 6)
         }
 
         // no skewed partition should be optimized


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to