This is an automated email from the ASF dual-hosted git repository.

stevenwu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new 87cade9cfc Flink: backport PR #13900 for adding unit test of skewness 
for range partitioner (#13943)
87cade9cfc is described below

commit 87cade9cfcf5256248e525256118a865bb1923c5
Author: Steven Zhen Wu <stevenz...@gmail.com>
AuthorDate: Thu Aug 28 21:14:04 2025 -0700

    Flink: backport PR #13900 for adding unit test of skewness for range 
partitioner (#13943)
---
 .../sink/shuffle/MapRangePartitionerBenchmark.java | 125 +++-----------
 .../shuffle/SketchRangePartitionerBenchmark.java   | 114 +++++++++++++
 .../flink/sink/shuffle/DataDistributionUtil.java   | 178 ++++++++++++++++++++
 .../sink/shuffle/TestDataDistributionUtil.java     |  49 ++++++
 .../sink/shuffle/TestRangePartitionerSkew.java     | 183 +++++++++++++++++++++
 .../sink/shuffle/MapRangePartitionerBenchmark.java | 125 +++-----------
 .../shuffle/SketchRangePartitionerBenchmark.java   | 114 +++++++++++++
 .../flink/sink/shuffle/DataDistributionUtil.java   | 178 ++++++++++++++++++++
 .../sink/shuffle/TestDataDistributionUtil.java     |  49 ++++++
 .../sink/shuffle/TestRangePartitionerSkew.java     | 183 +++++++++++++++++++++
 10 files changed, 1088 insertions(+), 210 deletions(-)

diff --git 
a/flink/v1.19/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
 
b/flink/v1.19/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
index 24cad2669d..80a46ac530 100644
--- 
a/flink/v1.19/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
+++ 
b/flink/v1.19/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
@@ -18,7 +18,6 @@
  */
 package org.apache.iceberg.flink.sink.shuffle;
 
-import java.nio.charset.StandardCharsets;
 import java.util.Comparator;
 import java.util.List;
 import java.util.Map;
@@ -31,9 +30,7 @@ import org.apache.iceberg.SortKey;
 import org.apache.iceberg.SortOrder;
 import org.apache.iceberg.SortOrderComparators;
 import org.apache.iceberg.StructLike;
-import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.Lists;
-import org.apache.iceberg.relocated.com.google.common.collect.Maps;
 import org.apache.iceberg.types.Types;
 import org.openjdk.jmh.annotations.Benchmark;
 import org.openjdk.jmh.annotations.BenchmarkMode;
@@ -54,8 +51,7 @@ import org.openjdk.jmh.infra.Blackhole;
 @Measurement(iterations = 5)
 @BenchmarkMode(Mode.SingleShotTime)
 public class MapRangePartitionerBenchmark {
-  private static final String CHARS =
-      "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.!?";
+
   private static final int SAMPLE_SIZE = 100_000;
   private static final Schema SCHEMA =
       new Schema(
@@ -73,51 +69,42 @@ public class MapRangePartitionerBenchmark {
   private static final Comparator<StructLike> SORT_ORDER_COMPARTOR =
       SortOrderComparators.forSchema(SCHEMA, SORT_ORDER);
   private static final SortKey SORT_KEY = new SortKey(SCHEMA, SORT_ORDER);
+  private static final int PARALLELISM = 100;
 
   private MapRangePartitioner partitioner;
   private RowData[] rows;
 
   @Setup
   public void setupBenchmark() {
-    NavigableMap<Integer, Long> weights = longTailDistribution(100_000, 24, 
240, 100, 2.0);
-    Map<SortKey, Long> mapStatistics = 
Maps.newHashMapWithExpectedSize(weights.size());
-    weights.forEach(
-        (id, weight) -> {
-          SortKey sortKey = SORT_KEY.copy();
-          sortKey.set(0, id);
-          mapStatistics.put(sortKey, weight);
-        });
+    NavigableMap<Integer, Long> weights =
+        DataDistributionUtil.longTailDistribution(100_000, 24, 240, 100, 2.0, 
0.7);
+    Map<SortKey, Long> mapStatistics =
+        DataDistributionUtil.mapStatisticsWithLongTailDistribution(weights, 
SORT_KEY);
 
     MapAssignment mapAssignment =
-        MapAssignment.fromKeyFrequency(2, mapStatistics, 0.0, 
SORT_ORDER_COMPARTOR);
-    this.partitioner =
-        new MapRangePartitioner(
-            SCHEMA, SortOrder.builderFor(SCHEMA).asc("id").build(), 
mapAssignment);
+        MapAssignment.fromKeyFrequency(PARALLELISM, mapStatistics, 0.0, 
SORT_ORDER_COMPARTOR);
+    this.partitioner = new MapRangePartitioner(SCHEMA, SORT_ORDER, 
mapAssignment);
 
     List<Integer> keys = Lists.newArrayList(weights.keySet().iterator());
-    long[] weightsCDF = new long[keys.size()];
-    long totalWeight = 0;
-    for (int i = 0; i < keys.size(); ++i) {
-      totalWeight += weights.get(keys.get(i));
-      weightsCDF[i] = totalWeight;
-    }
+    long[] weightsCDF = DataDistributionUtil.computeCumulativeWeights(keys, 
weights);
+    long totalWeight = weightsCDF[weightsCDF.length - 1];
 
     // pre-calculate the samples for benchmark run
     this.rows = new GenericRowData[SAMPLE_SIZE];
     for (int i = 0; i < SAMPLE_SIZE; ++i) {
       long weight = ThreadLocalRandom.current().nextLong(totalWeight);
-      int index = binarySearchIndex(weightsCDF, weight);
+      int index = DataDistributionUtil.binarySearchIndex(weightsCDF, weight);
       rows[i] =
           GenericRowData.of(
               keys.get(index),
-              randomString("name2-"),
-              randomString("name3-"),
-              randomString("name4-"),
-              randomString("name5-"),
-              randomString("name6-"),
-              randomString("name7-"),
-              randomString("name8-"),
-              randomString("name9-"));
+              DataDistributionUtil.randomString("name2-", 200),
+              DataDistributionUtil.randomString("name3-", 200),
+              DataDistributionUtil.randomString("name4-", 200),
+              DataDistributionUtil.randomString("name5-", 200),
+              DataDistributionUtil.randomString("name6-", 200),
+              DataDistributionUtil.randomString("name7-", 200),
+              DataDistributionUtil.randomString("name8-", 200),
+              DataDistributionUtil.randomString("name9-", 200));
     }
   }
 
@@ -128,79 +115,7 @@ public class MapRangePartitionerBenchmark {
   @Threads(1)
   public void testPartitionerLongTailDistribution(Blackhole blackhole) {
     for (int i = 0; i < SAMPLE_SIZE; ++i) {
-      blackhole.consume(partitioner.partition(rows[i], 128));
-    }
-  }
-
-  private static String randomString(String prefix) {
-    int length = ThreadLocalRandom.current().nextInt(200);
-    byte[] buffer = new byte[length];
-
-    for (int i = 0; i < length; i += 1) {
-      buffer[i] = (byte) 
CHARS.charAt(ThreadLocalRandom.current().nextInt(CHARS.length()));
+      blackhole.consume(partitioner.partition(rows[i], PARALLELISM));
     }
-
-    return prefix + new String(buffer, StandardCharsets.US_ASCII);
-  }
-
-  /** find the index where weightsUDF[index] < weight && weightsUDF[index+1] 
>= weight */
-  private static int binarySearchIndex(long[] weightsUDF, long target) {
-    Preconditions.checkArgument(
-        target < weightsUDF[weightsUDF.length - 1],
-        "weight is out of range: total weight = %s, search target = %s",
-        weightsUDF[weightsUDF.length - 1],
-        target);
-    int start = 0;
-    int end = weightsUDF.length - 1;
-    while (start < end) {
-      int mid = (start + end) / 2;
-      if (weightsUDF[mid] < target && weightsUDF[mid + 1] >= target) {
-        return mid;
-      }
-
-      if (weightsUDF[mid] >= target) {
-        end = mid - 1;
-      } else if (weightsUDF[mid + 1] < target) {
-        start = mid + 1;
-      }
-    }
-    return start;
-  }
-
-  /** Key is the id string and value is the weight in long value. */
-  private static NavigableMap<Integer, Long> longTailDistribution(
-      long startingWeight,
-      int longTailStartingIndex,
-      int longTailLength,
-      long longTailBaseWeight,
-      double weightRandomJitterPercentage) {
-
-    NavigableMap<Integer, Long> weights = Maps.newTreeMap();
-
-    // first part just decays the weight by half
-    long currentWeight = startingWeight;
-    for (int index = 0; index < longTailStartingIndex; ++index) {
-      double jitter = 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage / 100);
-      long weight = (long) (currentWeight * (1.0 + jitter));
-      weight = weight > 0 ? weight : 1;
-      weights.put(index, weight);
-      if (currentWeight > longTailBaseWeight) {
-        currentWeight = currentWeight / 2;
-      }
-    }
-
-    // long tail part
-    for (int index = longTailStartingIndex;
-        index < longTailStartingIndex + longTailLength;
-        ++index) {
-      long longTailWeight =
-          (long)
-              (longTailBaseWeight
-                  * 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage));
-      longTailWeight = longTailWeight > 0 ? longTailWeight : 1;
-      weights.put(index, longTailWeight);
-    }
-
-    return weights;
   }
 }
diff --git 
a/flink/v1.19/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitionerBenchmark.java
 
b/flink/v1.19/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitionerBenchmark.java
new file mode 100644
index 0000000000..53a24cd896
--- /dev/null
+++ 
b/flink/v1.19/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitionerBenchmark.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.Arrays;
+import java.util.UUID;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.types.Types;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Threads;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+@Fork(1)
+@State(Scope.Benchmark)
+@Warmup(iterations = 3)
+@Measurement(iterations = 5)
+@BenchmarkMode(Mode.SingleShotTime)
+public class SketchRangePartitionerBenchmark {
+
+  private static final int SAMPLE_SIZE = 100_000;
+  private static final Schema SCHEMA =
+      new Schema(
+          Types.NestedField.required(1, "id", Types.UUIDType.get()),
+          Types.NestedField.required(2, "name2", Types.StringType.get()),
+          Types.NestedField.required(3, "name3", Types.StringType.get()),
+          Types.NestedField.required(4, "name4", Types.StringType.get()),
+          Types.NestedField.required(5, "name5", Types.StringType.get()),
+          Types.NestedField.required(6, "name6", Types.StringType.get()),
+          Types.NestedField.required(7, "name7", Types.StringType.get()),
+          Types.NestedField.required(8, "name8", Types.StringType.get()),
+          Types.NestedField.required(9, "name9", Types.StringType.get()));
+
+  private static final SortOrder SORT_ORDER = 
SortOrder.builderFor(SCHEMA).asc("id").build();
+  private static final SortKey SORT_KEY = new SortKey(SCHEMA, SORT_ORDER);
+  private static final int PARALLELISM = 100;
+
+  private SketchRangePartitioner partitioner;
+  private RowData[] rows;
+
+  @Setup
+  public void setupBenchmark() {
+    UUID[] reservoir = DataDistributionUtil.reservoirSampleUUIDs(1_000_000, 
100_000);
+    UUID[] rangeBound = DataDistributionUtil.rangeBoundSampleUUIDs(reservoir, 
PARALLELISM);
+    SortKey[] rangeBoundSortKeys =
+        Arrays.stream(rangeBound)
+            .map(
+                uuid -> {
+                  SortKey sortKeyCopy = SORT_KEY.copy();
+                  sortKeyCopy.set(0, uuid);
+                  return sortKeyCopy;
+                })
+            .toArray(SortKey[]::new);
+
+    this.partitioner = new SketchRangePartitioner(SCHEMA, SORT_ORDER, 
rangeBoundSortKeys);
+
+    // pre-calculate the samples for benchmark run
+    this.rows = new GenericRowData[SAMPLE_SIZE];
+    for (int i = 0; i < SAMPLE_SIZE; ++i) {
+      UUID uuid = UUID.randomUUID();
+      Object uuidBytes = DataDistributionUtil.uuidBytes(uuid);
+      rows[i] =
+          GenericRowData.of(
+              uuidBytes,
+              DataDistributionUtil.randomString("name2-", 200),
+              DataDistributionUtil.randomString("name3-", 200),
+              DataDistributionUtil.randomString("name4-", 200),
+              DataDistributionUtil.randomString("name5-", 200),
+              DataDistributionUtil.randomString("name6-", 200),
+              DataDistributionUtil.randomString("name7-", 200),
+              DataDistributionUtil.randomString("name8-", 200),
+              DataDistributionUtil.randomString("name9-", 200));
+    }
+  }
+
+  @TearDown
+  public void tearDownBenchmark() {}
+
+  @Benchmark
+  @Threads(1)
+  public void testPartitionerLongTailDistribution(Blackhole blackhole) {
+    for (int i = 0; i < SAMPLE_SIZE; ++i) {
+      blackhole.consume(partitioner.partition(rows[i], PARALLELISM));
+    }
+  }
+}
diff --git 
a/flink/v1.19/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/DataDistributionUtil.java
 
b/flink/v1.19/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/DataDistributionUtil.java
new file mode 100644
index 0000000000..b0d98b358b
--- /dev/null
+++ 
b/flink/v1.19/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/DataDistributionUtil.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.UUID;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+
+public class DataDistributionUtil {
+  private DataDistributionUtil() {}
+
+  private static final String CHARS =
+      "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.!?";
+
+  /** Generate a random string with a given prefix and a random length up to 
maxLength. */
+  public static String randomString(String prefix, int maxLength) {
+    int length = ThreadLocalRandom.current().nextInt(maxLength);
+    byte[] buffer = new byte[length];
+
+    for (int i = 0; i < length; i += 1) {
+      buffer[i] = (byte) 
CHARS.charAt(ThreadLocalRandom.current().nextInt(CHARS.length()));
+    }
+
+    return prefix + new String(buffer, StandardCharsets.UTF_8);
+  }
+
+  /**
+   * return index if index == 0 && weightsUDF[index] > target (or) 
weightsUDF[index-1] <= target &&
+   * weightsUDF[index] > target
+   */
+  public static int binarySearchIndex(long[] weightsCDF, long target) {
+    Preconditions.checkArgument(
+        target >= 0, "target weight must be non-negative: search target = %s", 
target);
+    Preconditions.checkArgument(
+        target < weightsCDF[weightsCDF.length - 1],
+        "target weight is out of range: total weight = %s, search target = %s",
+        weightsCDF[weightsCDF.length - 1],
+        target);
+
+    int start = 0;
+    int end = weightsCDF.length - 1;
+    while (start <= end) {
+      int mid = (start + end) / 2;
+      boolean leftOk = (mid == 0) || (weightsCDF[mid - 1] <= target);
+      boolean rightOk = weightsCDF[mid] > target;
+      if (leftOk && rightOk) {
+        return mid;
+      } else if (weightsCDF[mid] <= target) {
+        start = mid + 1;
+      } else {
+        end = mid - 1;
+      }
+    }
+
+    throw new IllegalStateException("should never reach here");
+  }
+
+  /** Key is the id string and value is the weight in long value. */
+  public static NavigableMap<Integer, Long> longTailDistribution(
+      long startingWeight,
+      int longTailStartingIndex,
+      int longTailLength,
+      long longTailBaseWeight,
+      double weightRandomJitterPercentage,
+      double decayFactor) {
+
+    NavigableMap<Integer, Long> weights = Maps.newTreeMap();
+
+    // decay part
+    long currentWeight = startingWeight;
+    for (int index = 0; index < longTailStartingIndex; ++index) {
+      double jitter = 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage / 100);
+      long weight = (long) (currentWeight * (1.0 + jitter));
+      weight = weight > 0 ? weight : 1;
+      weights.put(index, weight);
+      if (currentWeight > longTailBaseWeight) {
+        currentWeight = (long) (currentWeight * decayFactor); // decay the 
weight by 40%
+      }
+    }
+
+    // long tail part (flat with some random jitter)
+    for (int index = longTailStartingIndex;
+        index < longTailStartingIndex + longTailLength;
+        ++index) {
+      long longTailWeight =
+          (long)
+              (longTailBaseWeight
+                  * 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage));
+      longTailWeight = longTailWeight > 0 ? longTailWeight : 1;
+      weights.put(index, longTailWeight);
+    }
+
+    return weights;
+  }
+
+  public static Map<SortKey, Long> mapStatisticsWithLongTailDistribution(
+      NavigableMap<Integer, Long> weights, SortKey sortKey) {
+    Map<SortKey, Long> mapStatistics = 
Maps.newHashMapWithExpectedSize(weights.size());
+    weights.forEach(
+        (id, weight) -> {
+          SortKey sortKeyCopy = sortKey.copy();
+          sortKeyCopy.set(0, id);
+          mapStatistics.put(sortKeyCopy, weight);
+        });
+
+    return mapStatistics;
+  }
+
+  public static long[] computeCumulativeWeights(List<Integer> keys, 
Map<Integer, Long> weights) {
+    long[] weightsCDF = new long[keys.size()];
+    long totalWeight = 0;
+    for (int i = 0; i < keys.size(); ++i) {
+      totalWeight += weights.get(keys.get(i));
+      weightsCDF[i] = totalWeight;
+    }
+
+    return weightsCDF;
+  }
+
+  public static byte[] uuidBytes(UUID uuid) {
+    ByteBuffer bb = ByteBuffer.wrap(new byte[16]);
+    bb.putLong(uuid.getMostSignificantBits());
+    bb.putLong(uuid.getLeastSignificantBits());
+    return bb.array();
+  }
+
+  public static UUID[] reservoirSampleUUIDs(int sampleSize, int reservoirSize) 
{
+    UUID[] reservoir = new UUID[reservoirSize];
+    for (int i = 0; i < reservoirSize; ++i) {
+      reservoir[i] = UUID.randomUUID();
+    }
+
+    ThreadLocalRandom random = ThreadLocalRandom.current();
+    for (int i = reservoirSize; i < sampleSize; ++i) {
+      int rand = random.nextInt(i + 1);
+      if (rand < reservoirSize) {
+        reservoir[rand] = UUID.randomUUID();
+      }
+    }
+
+    Arrays.sort(reservoir);
+    return reservoir;
+  }
+
+  public static UUID[] rangeBoundSampleUUIDs(UUID[] sampledUUIDs, int 
rangeBoundSize) {
+    UUID[] rangeBounds = new UUID[rangeBoundSize];
+    int step = sampledUUIDs.length / rangeBoundSize;
+    for (int i = 0; i < rangeBoundSize; ++i) {
+      rangeBounds[i] = sampledUUIDs[i * step];
+    }
+    Arrays.sort(rangeBounds);
+    return rangeBounds;
+  }
+}
diff --git 
a/flink/v1.19/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataDistributionUtil.java
 
b/flink/v1.19/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataDistributionUtil.java
new file mode 100644
index 0000000000..a9dd1b5d81
--- /dev/null
+++ 
b/flink/v1.19/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataDistributionUtil.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static 
org.apache.iceberg.flink.sink.shuffle.DataDistributionUtil.binarySearchIndex;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+import org.junit.jupiter.api.Test;
+
+public class TestDataDistributionUtil {
+  @Test
+  public void testBinarySearchIndex() {
+    long[] weightsUDF = {10, 20, 30, 40, 50};
+    assertThat(binarySearchIndex(weightsUDF, 0)).isEqualTo(0);
+    assertThat(binarySearchIndex(weightsUDF, 9)).isEqualTo(0);
+    assertThat(binarySearchIndex(weightsUDF, 10)).isEqualTo(1);
+    assertThat(binarySearchIndex(weightsUDF, 15)).isEqualTo(1);
+    assertThat(binarySearchIndex(weightsUDF, 20)).isEqualTo(2);
+    assertThat(binarySearchIndex(weightsUDF, 29)).isEqualTo(2);
+    assertThat(binarySearchIndex(weightsUDF, 30)).isEqualTo(3);
+    assertThat(binarySearchIndex(weightsUDF, 31)).isEqualTo(3);
+    assertThat(binarySearchIndex(weightsUDF, 40)).isEqualTo(4);
+
+    // Test with a target that is out of range
+    assertThatThrownBy(() -> binarySearchIndex(weightsUDF, -1))
+        .isInstanceOf(IllegalArgumentException.class)
+        .hasMessageContaining("target weight must be non-negative");
+    assertThatThrownBy(() -> binarySearchIndex(weightsUDF, 50))
+        .isInstanceOf(IllegalArgumentException.class)
+        .hasMessageContaining("target weight is out of range");
+  }
+}
diff --git 
a/flink/v1.19/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestRangePartitionerSkew.java
 
b/flink/v1.19/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestRangePartitionerSkew.java
new file mode 100644
index 0000000000..d6d8aebc63
--- /dev/null
+++ 
b/flink/v1.19/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestRangePartitionerSkew.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static java.lang.String.format;
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.DoubleSummaryStatistics;
+import java.util.IntSummaryStatistics;
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.UUID;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.SortOrderComparators;
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.types.Types;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class TestRangePartitionerSkew {
+  private static final Logger LOG = 
LoggerFactory.getLogger(TestRangePartitionerSkew.class);
+
+  // change the iterations to a larger number (like 100) to see the statistics 
of max skew.
+  // like min, max, avg, stddev of max skew.
+  private static final int ITERATIONS = 1;
+
+  /**
+   * @param parallelism number of partitions
+   * @param maxSkewUpperBound the upper bound of max skew. maxSkewUpperBound 
is set to a loose bound
+   *     (~5x of the max value) to avoid flakiness.
+   *     <p>
+   *     <li>Map parallelism 8: max skew statistics over 100 iterations: mean 
= 0.0124, min =
+   *         0.0046, max = 0.0213
+   *     <li>Map parallelism 32: max skew statistics over 100 iterations: mean 
= 0.0183, min =
+   *         0.0100, max = 0.0261
+   */
+  @ParameterizedTest
+  @CsvSource({"8, 100_000, 0.1", "32, 400_000, 0.15"})
+  public void testMapStatisticsSkewWithLongTailDistribution(
+      int parallelism, int sampleSize, double maxSkewUpperBound) {
+    Schema schema =
+        new Schema(Types.NestedField.optional(1, "event_hour", 
Types.IntegerType.get()));
+    SortOrder sortOrder = 
SortOrder.builderFor(schema).asc("event_hour").build();
+    Comparator<StructLike> comparator = SortOrderComparators.forSchema(schema, 
sortOrder);
+    SortKey sortKey = new SortKey(schema, sortOrder);
+
+    NavigableMap<Integer, Long> weights =
+        DataDistributionUtil.longTailDistribution(100_000, 24, 240, 100, 2.0, 
0.7);
+    Map<SortKey, Long> mapStatistics =
+        DataDistributionUtil.mapStatisticsWithLongTailDistribution(weights, 
sortKey);
+    MapAssignment mapAssignment =
+        MapAssignment.fromKeyFrequency(parallelism, mapStatistics, 0.0, 
comparator);
+    MapRangePartitioner partitioner = new MapRangePartitioner(schema, 
sortOrder, mapAssignment);
+
+    List<Integer> keys = Lists.newArrayList(weights.keySet().iterator());
+    long[] weightsCDF = DataDistributionUtil.computeCumulativeWeights(keys, 
weights);
+    long totalWeight = weightsCDF[weightsCDF.length - 1];
+
+    // change the iterations to a larger number (like 100) to see the 
statistics of max skew.
+    // like min, max, avg, stddev of max skew.
+    double[] maxSkews = new double[ITERATIONS];
+    for (int iteration = 0; iteration < ITERATIONS; ++iteration) {
+      int[] recordsPerTask = new int[parallelism];
+      for (int i = 0; i < sampleSize; ++i) {
+        // randomly pick a key according to the weight distribution
+        long weight = ThreadLocalRandom.current().nextLong(totalWeight);
+        int index = DataDistributionUtil.binarySearchIndex(weightsCDF, weight);
+        RowData row = GenericRowData.of(keys.get(index));
+        int subtaskId = partitioner.partition(row, parallelism);
+        recordsPerTask[subtaskId] += 1;
+      }
+
+      IntSummaryStatistics recordsPerTaskStats = 
Arrays.stream(recordsPerTask).summaryStatistics();
+      LOG.debug("Map parallelism {}: records per task stats: {}", parallelism, 
recordsPerTaskStats);
+      double maxSkew =
+          (recordsPerTaskStats.getMax() - recordsPerTaskStats.getAverage())
+              / recordsPerTaskStats.getAverage();
+      LOG.debug("Map parallelism {}: max skew: {}", parallelism, 
format("%.03f", maxSkew));
+      assertThat(maxSkew).isLessThan(maxSkewUpperBound);
+      maxSkews[iteration] = maxSkew;
+    }
+
+    DoubleSummaryStatistics maxSkewStats = 
Arrays.stream(maxSkews).summaryStatistics();
+    LOG.info(
+        "Map parallelism {}: max skew statistics over {} iterations: mean = 
{}, min = {}, max = {}",
+        parallelism,
+        ITERATIONS,
+        format("%.4f", maxSkewStats.getAverage()),
+        format("%.4f", maxSkewStats.getMin()),
+        format("%.4f", maxSkewStats.getMax()));
+  }
+
+  /**
+   * @param parallelism number of partitions
+   * @param maxSkewUpperBound the upper bound of max skew. maxSkewUpperBound 
is set to a loose bound
+   *     (~5x of the max value) to avoid flakiness.
+   *     <p>
+   *     <li>pMap parallelism 8: max skew statistics over 100 iterations: mean 
= 0.0192, min =
+   *         0.0073, max = 0.0437
+   *     <li>Map parallelism 32: max skew statistics over 100 iterations: mean 
= 0.0426, min =
+   *         0.0262, max = 0.0613
+   */
+  @ParameterizedTest
+  @CsvSource({"8, 100_000, 0.20", "32, 400_000, 0.25"})
+  public void testSketchStatisticsSkewWithLongTailDistribution(
+      int parallelism, int sampleSize, double maxSkewUpperBound) {
+    Schema schema = new Schema(Types.NestedField.optional(1, "uuid", 
Types.UUIDType.get()));
+    SortOrder sortOrder = SortOrder.builderFor(schema).asc("uuid").build();
+    SortKey sortKey = new SortKey(schema, sortOrder);
+
+    UUID[] reservoir = DataDistributionUtil.reservoirSampleUUIDs(1_000_000, 
100_000);
+    UUID[] rangeBound = DataDistributionUtil.rangeBoundSampleUUIDs(reservoir, 
parallelism);
+    SortKey[] rangeBoundSortKeys =
+        Arrays.stream(rangeBound)
+            .map(
+                uuid -> {
+                  SortKey sortKeyCopy = sortKey.copy();
+                  sortKeyCopy.set(0, uuid);
+                  return sortKeyCopy;
+                })
+            .toArray(SortKey[]::new);
+
+    SketchRangePartitioner partitioner =
+        new SketchRangePartitioner(schema, sortOrder, rangeBoundSortKeys);
+
+    double[] maxSkews = new double[ITERATIONS];
+    for (int iteration = 0; iteration < ITERATIONS; ++iteration) {
+      int[] recordsPerTask = new int[parallelism];
+      for (int i = 0; i < sampleSize; ++i) {
+        UUID uuid = UUID.randomUUID();
+        Object uuidBytes = DataDistributionUtil.uuidBytes(uuid);
+        RowData row = GenericRowData.of(uuidBytes);
+        int subtaskId = partitioner.partition(row, parallelism);
+        recordsPerTask[subtaskId] += 1;
+      }
+
+      IntSummaryStatistics recordsPerTaskStats = 
Arrays.stream(recordsPerTask).summaryStatistics();
+      LOG.debug("Map parallelism {}: records per task stats: {}", parallelism, 
recordsPerTaskStats);
+      double maxSkew =
+          (recordsPerTaskStats.getMax() - recordsPerTaskStats.getAverage())
+              / recordsPerTaskStats.getAverage();
+      LOG.debug("Map parallelism {}: max skew: {}", parallelism, 
format("%.03f", maxSkew));
+      assertThat(maxSkew).isLessThan(maxSkewUpperBound);
+      maxSkews[iteration] = maxSkew;
+    }
+
+    DoubleSummaryStatistics maxSkewStats = 
Arrays.stream(maxSkews).summaryStatistics();
+    LOG.info(
+        "Map parallelism {}: max skew statistics over {} iterations: mean = 
{}, min = {}, max = {}",
+        parallelism,
+        ITERATIONS,
+        format("%.4f", maxSkewStats.getAverage()),
+        format("%.4f", maxSkewStats.getMin()),
+        format("%.4f", maxSkewStats.getMax()));
+  }
+}
diff --git 
a/flink/v1.20/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
 
b/flink/v1.20/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
index 592e7ff162..80a46ac530 100644
--- 
a/flink/v1.20/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
+++ 
b/flink/v1.20/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
@@ -18,7 +18,6 @@
  */
 package org.apache.iceberg.flink.sink.shuffle;
 
-import java.nio.charset.StandardCharsets;
 import java.util.Comparator;
 import java.util.List;
 import java.util.Map;
@@ -31,9 +30,7 @@ import org.apache.iceberg.SortKey;
 import org.apache.iceberg.SortOrder;
 import org.apache.iceberg.SortOrderComparators;
 import org.apache.iceberg.StructLike;
-import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.Lists;
-import org.apache.iceberg.relocated.com.google.common.collect.Maps;
 import org.apache.iceberg.types.Types;
 import org.openjdk.jmh.annotations.Benchmark;
 import org.openjdk.jmh.annotations.BenchmarkMode;
@@ -54,8 +51,7 @@ import org.openjdk.jmh.infra.Blackhole;
 @Measurement(iterations = 5)
 @BenchmarkMode(Mode.SingleShotTime)
 public class MapRangePartitionerBenchmark {
-  private static final String CHARS =
-      "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.!?";
+
   private static final int SAMPLE_SIZE = 100_000;
   private static final Schema SCHEMA =
       new Schema(
@@ -73,51 +69,42 @@ public class MapRangePartitionerBenchmark {
   private static final Comparator<StructLike> SORT_ORDER_COMPARTOR =
       SortOrderComparators.forSchema(SCHEMA, SORT_ORDER);
   private static final SortKey SORT_KEY = new SortKey(SCHEMA, SORT_ORDER);
+  private static final int PARALLELISM = 100;
 
   private MapRangePartitioner partitioner;
   private RowData[] rows;
 
   @Setup
   public void setupBenchmark() {
-    NavigableMap<Integer, Long> weights = longTailDistribution(100_000, 24, 
240, 100, 2.0);
-    Map<SortKey, Long> mapStatistics = 
Maps.newHashMapWithExpectedSize(weights.size());
-    weights.forEach(
-        (id, weight) -> {
-          SortKey sortKey = SORT_KEY.copy();
-          sortKey.set(0, id);
-          mapStatistics.put(sortKey, weight);
-        });
+    NavigableMap<Integer, Long> weights =
+        DataDistributionUtil.longTailDistribution(100_000, 24, 240, 100, 2.0, 
0.7);
+    Map<SortKey, Long> mapStatistics =
+        DataDistributionUtil.mapStatisticsWithLongTailDistribution(weights, 
SORT_KEY);
 
     MapAssignment mapAssignment =
-        MapAssignment.fromKeyFrequency(2, mapStatistics, 0.0, 
SORT_ORDER_COMPARTOR);
-    this.partitioner =
-        new MapRangePartitioner(
-            SCHEMA, SortOrder.builderFor(SCHEMA).asc("id").build(), 
mapAssignment);
+        MapAssignment.fromKeyFrequency(PARALLELISM, mapStatistics, 0.0, 
SORT_ORDER_COMPARTOR);
+    this.partitioner = new MapRangePartitioner(SCHEMA, SORT_ORDER, 
mapAssignment);
 
     List<Integer> keys = Lists.newArrayList(weights.keySet().iterator());
-    long[] weightsCDF = new long[keys.size()];
-    long totalWeight = 0;
-    for (int i = 0; i < keys.size(); ++i) {
-      totalWeight += weights.get(keys.get(i));
-      weightsCDF[i] = totalWeight;
-    }
+    long[] weightsCDF = DataDistributionUtil.computeCumulativeWeights(keys, 
weights);
+    long totalWeight = weightsCDF[weightsCDF.length - 1];
 
     // pre-calculate the samples for benchmark run
     this.rows = new GenericRowData[SAMPLE_SIZE];
     for (int i = 0; i < SAMPLE_SIZE; ++i) {
       long weight = ThreadLocalRandom.current().nextLong(totalWeight);
-      int index = binarySearchIndex(weightsCDF, weight);
+      int index = DataDistributionUtil.binarySearchIndex(weightsCDF, weight);
       rows[i] =
           GenericRowData.of(
               keys.get(index),
-              randomString("name2-"),
-              randomString("name3-"),
-              randomString("name4-"),
-              randomString("name5-"),
-              randomString("name6-"),
-              randomString("name7-"),
-              randomString("name8-"),
-              randomString("name9-"));
+              DataDistributionUtil.randomString("name2-", 200),
+              DataDistributionUtil.randomString("name3-", 200),
+              DataDistributionUtil.randomString("name4-", 200),
+              DataDistributionUtil.randomString("name5-", 200),
+              DataDistributionUtil.randomString("name6-", 200),
+              DataDistributionUtil.randomString("name7-", 200),
+              DataDistributionUtil.randomString("name8-", 200),
+              DataDistributionUtil.randomString("name9-", 200));
     }
   }
 
@@ -128,79 +115,7 @@ public class MapRangePartitionerBenchmark {
   @Threads(1)
   public void testPartitionerLongTailDistribution(Blackhole blackhole) {
     for (int i = 0; i < SAMPLE_SIZE; ++i) {
-      blackhole.consume(partitioner.partition(rows[i], 128));
-    }
-  }
-
-  private static String randomString(String prefix) {
-    int length = ThreadLocalRandom.current().nextInt(200);
-    byte[] buffer = new byte[length];
-
-    for (int i = 0; i < length; i += 1) {
-      buffer[i] = (byte) 
CHARS.charAt(ThreadLocalRandom.current().nextInt(CHARS.length()));
+      blackhole.consume(partitioner.partition(rows[i], PARALLELISM));
     }
-
-    return prefix + new String(buffer, StandardCharsets.UTF_8);
-  }
-
-  /** find the index where weightsUDF[index] < weight && weightsUDF[index+1] 
>= weight */
-  private static int binarySearchIndex(long[] weightsUDF, long target) {
-    Preconditions.checkArgument(
-        target < weightsUDF[weightsUDF.length - 1],
-        "weight is out of range: total weight = %s, search target = %s",
-        weightsUDF[weightsUDF.length - 1],
-        target);
-    int start = 0;
-    int end = weightsUDF.length - 1;
-    while (start < end) {
-      int mid = (start + end) / 2;
-      if (weightsUDF[mid] < target && weightsUDF[mid + 1] >= target) {
-        return mid;
-      }
-
-      if (weightsUDF[mid] >= target) {
-        end = mid - 1;
-      } else if (weightsUDF[mid + 1] < target) {
-        start = mid + 1;
-      }
-    }
-    return start;
-  }
-
-  /** Key is the id string and value is the weight in long value. */
-  private static NavigableMap<Integer, Long> longTailDistribution(
-      long startingWeight,
-      int longTailStartingIndex,
-      int longTailLength,
-      long longTailBaseWeight,
-      double weightRandomJitterPercentage) {
-
-    NavigableMap<Integer, Long> weights = Maps.newTreeMap();
-
-    // first part just decays the weight by half
-    long currentWeight = startingWeight;
-    for (int index = 0; index < longTailStartingIndex; ++index) {
-      double jitter = 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage / 100);
-      long weight = (long) (currentWeight * (1.0 + jitter));
-      weight = weight > 0 ? weight : 1;
-      weights.put(index, weight);
-      if (currentWeight > longTailBaseWeight) {
-        currentWeight = currentWeight / 2;
-      }
-    }
-
-    // long tail part
-    for (int index = longTailStartingIndex;
-        index < longTailStartingIndex + longTailLength;
-        ++index) {
-      long longTailWeight =
-          (long)
-              (longTailBaseWeight
-                  * 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage));
-      longTailWeight = longTailWeight > 0 ? longTailWeight : 1;
-      weights.put(index, longTailWeight);
-    }
-
-    return weights;
   }
 }
diff --git 
a/flink/v1.20/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitionerBenchmark.java
 
b/flink/v1.20/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitionerBenchmark.java
new file mode 100644
index 0000000000..53a24cd896
--- /dev/null
+++ 
b/flink/v1.20/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitionerBenchmark.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.Arrays;
+import java.util.UUID;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.types.Types;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Threads;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+@Fork(1)
+@State(Scope.Benchmark)
+@Warmup(iterations = 3)
+@Measurement(iterations = 5)
+@BenchmarkMode(Mode.SingleShotTime)
+public class SketchRangePartitionerBenchmark {
+
+  private static final int SAMPLE_SIZE = 100_000;
+  private static final Schema SCHEMA =
+      new Schema(
+          Types.NestedField.required(1, "id", Types.UUIDType.get()),
+          Types.NestedField.required(2, "name2", Types.StringType.get()),
+          Types.NestedField.required(3, "name3", Types.StringType.get()),
+          Types.NestedField.required(4, "name4", Types.StringType.get()),
+          Types.NestedField.required(5, "name5", Types.StringType.get()),
+          Types.NestedField.required(6, "name6", Types.StringType.get()),
+          Types.NestedField.required(7, "name7", Types.StringType.get()),
+          Types.NestedField.required(8, "name8", Types.StringType.get()),
+          Types.NestedField.required(9, "name9", Types.StringType.get()));
+
+  private static final SortOrder SORT_ORDER = 
SortOrder.builderFor(SCHEMA).asc("id").build();
+  private static final SortKey SORT_KEY = new SortKey(SCHEMA, SORT_ORDER);
+  private static final int PARALLELISM = 100;
+
+  private SketchRangePartitioner partitioner;
+  private RowData[] rows;
+
+  @Setup
+  public void setupBenchmark() {
+    UUID[] reservoir = DataDistributionUtil.reservoirSampleUUIDs(1_000_000, 
100_000);
+    UUID[] rangeBound = DataDistributionUtil.rangeBoundSampleUUIDs(reservoir, 
PARALLELISM);
+    SortKey[] rangeBoundSortKeys =
+        Arrays.stream(rangeBound)
+            .map(
+                uuid -> {
+                  SortKey sortKeyCopy = SORT_KEY.copy();
+                  sortKeyCopy.set(0, uuid);
+                  return sortKeyCopy;
+                })
+            .toArray(SortKey[]::new);
+
+    this.partitioner = new SketchRangePartitioner(SCHEMA, SORT_ORDER, 
rangeBoundSortKeys);
+
+    // pre-calculate the samples for benchmark run
+    this.rows = new GenericRowData[SAMPLE_SIZE];
+    for (int i = 0; i < SAMPLE_SIZE; ++i) {
+      UUID uuid = UUID.randomUUID();
+      Object uuidBytes = DataDistributionUtil.uuidBytes(uuid);
+      rows[i] =
+          GenericRowData.of(
+              uuidBytes,
+              DataDistributionUtil.randomString("name2-", 200),
+              DataDistributionUtil.randomString("name3-", 200),
+              DataDistributionUtil.randomString("name4-", 200),
+              DataDistributionUtil.randomString("name5-", 200),
+              DataDistributionUtil.randomString("name6-", 200),
+              DataDistributionUtil.randomString("name7-", 200),
+              DataDistributionUtil.randomString("name8-", 200),
+              DataDistributionUtil.randomString("name9-", 200));
+    }
+  }
+
+  @TearDown
+  public void tearDownBenchmark() {}
+
+  @Benchmark
+  @Threads(1)
+  public void testPartitionerLongTailDistribution(Blackhole blackhole) {
+    for (int i = 0; i < SAMPLE_SIZE; ++i) {
+      blackhole.consume(partitioner.partition(rows[i], PARALLELISM));
+    }
+  }
+}
diff --git 
a/flink/v1.20/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/DataDistributionUtil.java
 
b/flink/v1.20/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/DataDistributionUtil.java
new file mode 100644
index 0000000000..b0d98b358b
--- /dev/null
+++ 
b/flink/v1.20/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/DataDistributionUtil.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.UUID;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+
+public class DataDistributionUtil {
+  private DataDistributionUtil() {}
+
+  private static final String CHARS =
+      "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.!?";
+
+  /** Generate a random string with a given prefix and a random length up to 
maxLength. */
+  public static String randomString(String prefix, int maxLength) {
+    int length = ThreadLocalRandom.current().nextInt(maxLength);
+    byte[] buffer = new byte[length];
+
+    for (int i = 0; i < length; i += 1) {
+      buffer[i] = (byte) 
CHARS.charAt(ThreadLocalRandom.current().nextInt(CHARS.length()));
+    }
+
+    return prefix + new String(buffer, StandardCharsets.UTF_8);
+  }
+
+  /**
+   * return index if index == 0 && weightsUDF[index] > target (or) 
weightsUDF[index-1] <= target &&
+   * weightsUDF[index] > target
+   */
+  public static int binarySearchIndex(long[] weightsCDF, long target) {
+    Preconditions.checkArgument(
+        target >= 0, "target weight must be non-negative: search target = %s", 
target);
+    Preconditions.checkArgument(
+        target < weightsCDF[weightsCDF.length - 1],
+        "target weight is out of range: total weight = %s, search target = %s",
+        weightsCDF[weightsCDF.length - 1],
+        target);
+
+    int start = 0;
+    int end = weightsCDF.length - 1;
+    while (start <= end) {
+      int mid = (start + end) / 2;
+      boolean leftOk = (mid == 0) || (weightsCDF[mid - 1] <= target);
+      boolean rightOk = weightsCDF[mid] > target;
+      if (leftOk && rightOk) {
+        return mid;
+      } else if (weightsCDF[mid] <= target) {
+        start = mid + 1;
+      } else {
+        end = mid - 1;
+      }
+    }
+
+    throw new IllegalStateException("should never reach here");
+  }
+
+  /** Key is the id string and value is the weight in long value. */
+  public static NavigableMap<Integer, Long> longTailDistribution(
+      long startingWeight,
+      int longTailStartingIndex,
+      int longTailLength,
+      long longTailBaseWeight,
+      double weightRandomJitterPercentage,
+      double decayFactor) {
+
+    NavigableMap<Integer, Long> weights = Maps.newTreeMap();
+
+    // decay part
+    long currentWeight = startingWeight;
+    for (int index = 0; index < longTailStartingIndex; ++index) {
+      double jitter = 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage / 100);
+      long weight = (long) (currentWeight * (1.0 + jitter));
+      weight = weight > 0 ? weight : 1;
+      weights.put(index, weight);
+      if (currentWeight > longTailBaseWeight) {
+        currentWeight = (long) (currentWeight * decayFactor); // decay the 
weight by 40%
+      }
+    }
+
+    // long tail part (flat with some random jitter)
+    for (int index = longTailStartingIndex;
+        index < longTailStartingIndex + longTailLength;
+        ++index) {
+      long longTailWeight =
+          (long)
+              (longTailBaseWeight
+                  * 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage));
+      longTailWeight = longTailWeight > 0 ? longTailWeight : 1;
+      weights.put(index, longTailWeight);
+    }
+
+    return weights;
+  }
+
+  public static Map<SortKey, Long> mapStatisticsWithLongTailDistribution(
+      NavigableMap<Integer, Long> weights, SortKey sortKey) {
+    Map<SortKey, Long> mapStatistics = 
Maps.newHashMapWithExpectedSize(weights.size());
+    weights.forEach(
+        (id, weight) -> {
+          SortKey sortKeyCopy = sortKey.copy();
+          sortKeyCopy.set(0, id);
+          mapStatistics.put(sortKeyCopy, weight);
+        });
+
+    return mapStatistics;
+  }
+
+  public static long[] computeCumulativeWeights(List<Integer> keys, 
Map<Integer, Long> weights) {
+    long[] weightsCDF = new long[keys.size()];
+    long totalWeight = 0;
+    for (int i = 0; i < keys.size(); ++i) {
+      totalWeight += weights.get(keys.get(i));
+      weightsCDF[i] = totalWeight;
+    }
+
+    return weightsCDF;
+  }
+
+  public static byte[] uuidBytes(UUID uuid) {
+    ByteBuffer bb = ByteBuffer.wrap(new byte[16]);
+    bb.putLong(uuid.getMostSignificantBits());
+    bb.putLong(uuid.getLeastSignificantBits());
+    return bb.array();
+  }
+
+  public static UUID[] reservoirSampleUUIDs(int sampleSize, int reservoirSize) 
{
+    UUID[] reservoir = new UUID[reservoirSize];
+    for (int i = 0; i < reservoirSize; ++i) {
+      reservoir[i] = UUID.randomUUID();
+    }
+
+    ThreadLocalRandom random = ThreadLocalRandom.current();
+    for (int i = reservoirSize; i < sampleSize; ++i) {
+      int rand = random.nextInt(i + 1);
+      if (rand < reservoirSize) {
+        reservoir[rand] = UUID.randomUUID();
+      }
+    }
+
+    Arrays.sort(reservoir);
+    return reservoir;
+  }
+
+  public static UUID[] rangeBoundSampleUUIDs(UUID[] sampledUUIDs, int 
rangeBoundSize) {
+    UUID[] rangeBounds = new UUID[rangeBoundSize];
+    int step = sampledUUIDs.length / rangeBoundSize;
+    for (int i = 0; i < rangeBoundSize; ++i) {
+      rangeBounds[i] = sampledUUIDs[i * step];
+    }
+    Arrays.sort(rangeBounds);
+    return rangeBounds;
+  }
+}
diff --git 
a/flink/v1.20/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataDistributionUtil.java
 
b/flink/v1.20/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataDistributionUtil.java
new file mode 100644
index 0000000000..a9dd1b5d81
--- /dev/null
+++ 
b/flink/v1.20/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataDistributionUtil.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static 
org.apache.iceberg.flink.sink.shuffle.DataDistributionUtil.binarySearchIndex;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+import org.junit.jupiter.api.Test;
+
+public class TestDataDistributionUtil {
+  @Test
+  public void testBinarySearchIndex() {
+    long[] weightsUDF = {10, 20, 30, 40, 50};
+    assertThat(binarySearchIndex(weightsUDF, 0)).isEqualTo(0);
+    assertThat(binarySearchIndex(weightsUDF, 9)).isEqualTo(0);
+    assertThat(binarySearchIndex(weightsUDF, 10)).isEqualTo(1);
+    assertThat(binarySearchIndex(weightsUDF, 15)).isEqualTo(1);
+    assertThat(binarySearchIndex(weightsUDF, 20)).isEqualTo(2);
+    assertThat(binarySearchIndex(weightsUDF, 29)).isEqualTo(2);
+    assertThat(binarySearchIndex(weightsUDF, 30)).isEqualTo(3);
+    assertThat(binarySearchIndex(weightsUDF, 31)).isEqualTo(3);
+    assertThat(binarySearchIndex(weightsUDF, 40)).isEqualTo(4);
+
+    // Test with a target that is out of range
+    assertThatThrownBy(() -> binarySearchIndex(weightsUDF, -1))
+        .isInstanceOf(IllegalArgumentException.class)
+        .hasMessageContaining("target weight must be non-negative");
+    assertThatThrownBy(() -> binarySearchIndex(weightsUDF, 50))
+        .isInstanceOf(IllegalArgumentException.class)
+        .hasMessageContaining("target weight is out of range");
+  }
+}
diff --git 
a/flink/v1.20/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestRangePartitionerSkew.java
 
b/flink/v1.20/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestRangePartitionerSkew.java
new file mode 100644
index 0000000000..d6d8aebc63
--- /dev/null
+++ 
b/flink/v1.20/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestRangePartitionerSkew.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static java.lang.String.format;
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.DoubleSummaryStatistics;
+import java.util.IntSummaryStatistics;
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.UUID;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.SortOrderComparators;
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.types.Types;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class TestRangePartitionerSkew {
+  private static final Logger LOG = 
LoggerFactory.getLogger(TestRangePartitionerSkew.class);
+
+  // change the iterations to a larger number (like 100) to see the statistics 
of max skew.
+  // like min, max, avg, stddev of max skew.
+  private static final int ITERATIONS = 1;
+
+  /**
+   * @param parallelism number of partitions
+   * @param maxSkewUpperBound the upper bound of max skew. maxSkewUpperBound 
is set to a loose bound
+   *     (~5x of the max value) to avoid flakiness.
+   *     <p>
+   *     <li>Map parallelism 8: max skew statistics over 100 iterations: mean 
= 0.0124, min =
+   *         0.0046, max = 0.0213
+   *     <li>Map parallelism 32: max skew statistics over 100 iterations: mean 
= 0.0183, min =
+   *         0.0100, max = 0.0261
+   */
+  @ParameterizedTest
+  @CsvSource({"8, 100_000, 0.1", "32, 400_000, 0.15"})
+  public void testMapStatisticsSkewWithLongTailDistribution(
+      int parallelism, int sampleSize, double maxSkewUpperBound) {
+    Schema schema =
+        new Schema(Types.NestedField.optional(1, "event_hour", 
Types.IntegerType.get()));
+    SortOrder sortOrder = 
SortOrder.builderFor(schema).asc("event_hour").build();
+    Comparator<StructLike> comparator = SortOrderComparators.forSchema(schema, 
sortOrder);
+    SortKey sortKey = new SortKey(schema, sortOrder);
+
+    NavigableMap<Integer, Long> weights =
+        DataDistributionUtil.longTailDistribution(100_000, 24, 240, 100, 2.0, 
0.7);
+    Map<SortKey, Long> mapStatistics =
+        DataDistributionUtil.mapStatisticsWithLongTailDistribution(weights, 
sortKey);
+    MapAssignment mapAssignment =
+        MapAssignment.fromKeyFrequency(parallelism, mapStatistics, 0.0, 
comparator);
+    MapRangePartitioner partitioner = new MapRangePartitioner(schema, 
sortOrder, mapAssignment);
+
+    List<Integer> keys = Lists.newArrayList(weights.keySet().iterator());
+    long[] weightsCDF = DataDistributionUtil.computeCumulativeWeights(keys, 
weights);
+    long totalWeight = weightsCDF[weightsCDF.length - 1];
+
+    // change the iterations to a larger number (like 100) to see the 
statistics of max skew.
+    // like min, max, avg, stddev of max skew.
+    double[] maxSkews = new double[ITERATIONS];
+    for (int iteration = 0; iteration < ITERATIONS; ++iteration) {
+      int[] recordsPerTask = new int[parallelism];
+      for (int i = 0; i < sampleSize; ++i) {
+        // randomly pick a key according to the weight distribution
+        long weight = ThreadLocalRandom.current().nextLong(totalWeight);
+        int index = DataDistributionUtil.binarySearchIndex(weightsCDF, weight);
+        RowData row = GenericRowData.of(keys.get(index));
+        int subtaskId = partitioner.partition(row, parallelism);
+        recordsPerTask[subtaskId] += 1;
+      }
+
+      IntSummaryStatistics recordsPerTaskStats = 
Arrays.stream(recordsPerTask).summaryStatistics();
+      LOG.debug("Map parallelism {}: records per task stats: {}", parallelism, 
recordsPerTaskStats);
+      double maxSkew =
+          (recordsPerTaskStats.getMax() - recordsPerTaskStats.getAverage())
+              / recordsPerTaskStats.getAverage();
+      LOG.debug("Map parallelism {}: max skew: {}", parallelism, 
format("%.03f", maxSkew));
+      assertThat(maxSkew).isLessThan(maxSkewUpperBound);
+      maxSkews[iteration] = maxSkew;
+    }
+
+    DoubleSummaryStatistics maxSkewStats = 
Arrays.stream(maxSkews).summaryStatistics();
+    LOG.info(
+        "Map parallelism {}: max skew statistics over {} iterations: mean = 
{}, min = {}, max = {}",
+        parallelism,
+        ITERATIONS,
+        format("%.4f", maxSkewStats.getAverage()),
+        format("%.4f", maxSkewStats.getMin()),
+        format("%.4f", maxSkewStats.getMax()));
+  }
+
+  /**
+   * @param parallelism number of partitions
+   * @param maxSkewUpperBound the upper bound of max skew. maxSkewUpperBound 
is set to a loose bound
+   *     (~5x of the max value) to avoid flakiness.
+   *     <p>
+   *     <li>pMap parallelism 8: max skew statistics over 100 iterations: mean 
= 0.0192, min =
+   *         0.0073, max = 0.0437
+   *     <li>Map parallelism 32: max skew statistics over 100 iterations: mean 
= 0.0426, min =
+   *         0.0262, max = 0.0613
+   */
+  @ParameterizedTest
+  @CsvSource({"8, 100_000, 0.20", "32, 400_000, 0.25"})
+  public void testSketchStatisticsSkewWithLongTailDistribution(
+      int parallelism, int sampleSize, double maxSkewUpperBound) {
+    Schema schema = new Schema(Types.NestedField.optional(1, "uuid", 
Types.UUIDType.get()));
+    SortOrder sortOrder = SortOrder.builderFor(schema).asc("uuid").build();
+    SortKey sortKey = new SortKey(schema, sortOrder);
+
+    UUID[] reservoir = DataDistributionUtil.reservoirSampleUUIDs(1_000_000, 
100_000);
+    UUID[] rangeBound = DataDistributionUtil.rangeBoundSampleUUIDs(reservoir, 
parallelism);
+    SortKey[] rangeBoundSortKeys =
+        Arrays.stream(rangeBound)
+            .map(
+                uuid -> {
+                  SortKey sortKeyCopy = sortKey.copy();
+                  sortKeyCopy.set(0, uuid);
+                  return sortKeyCopy;
+                })
+            .toArray(SortKey[]::new);
+
+    SketchRangePartitioner partitioner =
+        new SketchRangePartitioner(schema, sortOrder, rangeBoundSortKeys);
+
+    double[] maxSkews = new double[ITERATIONS];
+    for (int iteration = 0; iteration < ITERATIONS; ++iteration) {
+      int[] recordsPerTask = new int[parallelism];
+      for (int i = 0; i < sampleSize; ++i) {
+        UUID uuid = UUID.randomUUID();
+        Object uuidBytes = DataDistributionUtil.uuidBytes(uuid);
+        RowData row = GenericRowData.of(uuidBytes);
+        int subtaskId = partitioner.partition(row, parallelism);
+        recordsPerTask[subtaskId] += 1;
+      }
+
+      IntSummaryStatistics recordsPerTaskStats = 
Arrays.stream(recordsPerTask).summaryStatistics();
+      LOG.debug("Map parallelism {}: records per task stats: {}", parallelism, 
recordsPerTaskStats);
+      double maxSkew =
+          (recordsPerTaskStats.getMax() - recordsPerTaskStats.getAverage())
+              / recordsPerTaskStats.getAverage();
+      LOG.debug("Map parallelism {}: max skew: {}", parallelism, 
format("%.03f", maxSkew));
+      assertThat(maxSkew).isLessThan(maxSkewUpperBound);
+      maxSkews[iteration] = maxSkew;
+    }
+
+    DoubleSummaryStatistics maxSkewStats = 
Arrays.stream(maxSkews).summaryStatistics();
+    LOG.info(
+        "Map parallelism {}: max skew statistics over {} iterations: mean = 
{}, min = {}, max = {}",
+        parallelism,
+        ITERATIONS,
+        format("%.4f", maxSkewStats.getAverage()),
+        format("%.4f", maxSkewStats.getMin()),
+        format("%.4f", maxSkewStats.getMax()));
+  }
+}

Reply via email to