This is an automated email from the ASF dual-hosted git repository.

stevenwu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new a86e1b3bbd Flink: backport PR #9321 for range partitioner on map 
statistics (#10061)
a86e1b3bbd is described below

commit a86e1b3bbd4101bd1ecb3cc8551590022f0211ac
Author: Steven Zhen Wu <[email protected]>
AuthorDate: Sat Mar 30 13:01:28 2024 -0700

    Flink: backport PR #9321 for range partitioner on map statistics (#10061)
---
 .../sink/shuffle/MapRangePartitionerBenchmark.java | 199 +++++++++
 .../flink/sink/shuffle/MapRangePartitioner.java    | 381 ++++++++++++++++++
 .../sink/shuffle/TestMapRangePartitioner.java      | 448 +++++++++++++++++++++
 .../sink/shuffle/MapRangePartitionerBenchmark.java | 199 +++++++++
 .../flink/sink/shuffle/MapRangePartitioner.java    | 381 ++++++++++++++++++
 .../sink/shuffle/TestMapRangePartitioner.java      | 448 +++++++++++++++++++++
 jmh.gradle                                         |   8 +
 7 files changed, 2064 insertions(+)

diff --git 
a/flink/v1.16/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
 
b/flink/v1.16/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
new file mode 100644
index 0000000000..c391716575
--- /dev/null
+++ 
b/flink/v1.16/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.types.Types;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Threads;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+@Fork(1)
+@State(Scope.Benchmark)
+@Warmup(iterations = 3)
+@Measurement(iterations = 5)
+@BenchmarkMode(Mode.SingleShotTime)
+public class MapRangePartitionerBenchmark {
+  private static final String CHARS =
+      "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.!?";
+  private static final int SAMPLE_SIZE = 100_000;
+  private static final Schema SCHEMA =
+      new Schema(
+          Types.NestedField.required(1, "id", Types.IntegerType.get()),
+          Types.NestedField.required(2, "name2", Types.StringType.get()),
+          Types.NestedField.required(3, "name3", Types.StringType.get()),
+          Types.NestedField.required(4, "name4", Types.StringType.get()),
+          Types.NestedField.required(5, "name5", Types.StringType.get()),
+          Types.NestedField.required(6, "name6", Types.StringType.get()),
+          Types.NestedField.required(7, "name7", Types.StringType.get()),
+          Types.NestedField.required(8, "name8", Types.StringType.get()),
+          Types.NestedField.required(9, "name9", Types.StringType.get()));
+
+  private static final SortOrder SORT_ORDER = 
SortOrder.builderFor(SCHEMA).asc("id").build();
+  private static final SortKey SORT_KEY = new SortKey(SCHEMA, SORT_ORDER);
+
+  private MapRangePartitioner partitioner;
+  private RowData[] rows;
+
+  @Setup
+  public void setupBenchmark() {
+    NavigableMap<Integer, Long> weights = longTailDistribution(100_000, 24, 
240, 100, 2.0);
+    Map<SortKey, Long> mapStatistics = 
Maps.newHashMapWithExpectedSize(weights.size());
+    weights.forEach(
+        (id, weight) -> {
+          SortKey sortKey = SORT_KEY.copy();
+          sortKey.set(0, id);
+          mapStatistics.put(sortKey, weight);
+        });
+
+    MapDataStatistics dataStatistics = new MapDataStatistics(mapStatistics);
+    this.partitioner =
+        new MapRangePartitioner(
+            SCHEMA, SortOrder.builderFor(SCHEMA).asc("id").build(), 
dataStatistics, 2);
+
+    List<Integer> keys = Lists.newArrayList(weights.keySet().iterator());
+    long[] weightsCDF = new long[keys.size()];
+    long totalWeight = 0;
+    for (int i = 0; i < keys.size(); ++i) {
+      totalWeight += weights.get(keys.get(i));
+      weightsCDF[i] = totalWeight;
+    }
+
+    // pre-calculate the samples for benchmark run
+    this.rows = new GenericRowData[SAMPLE_SIZE];
+    for (int i = 0; i < SAMPLE_SIZE; ++i) {
+      long weight = ThreadLocalRandom.current().nextLong(totalWeight);
+      int index = binarySearchIndex(weightsCDF, weight);
+      rows[i] =
+          GenericRowData.of(
+              keys.get(index),
+              randomString("name2-"),
+              randomString("name3-"),
+              randomString("name4-"),
+              randomString("name5-"),
+              randomString("name6-"),
+              randomString("name7-"),
+              randomString("name8-"),
+              randomString("name9-"));
+    }
+  }
+
+  @TearDown
+  public void tearDownBenchmark() {}
+
+  @Benchmark
+  @Threads(1)
+  public void testPartitionerLongTailDistribution(Blackhole blackhole) {
+    for (int i = 0; i < SAMPLE_SIZE; ++i) {
+      blackhole.consume(partitioner.partition(rows[i], 128));
+    }
+  }
+
+  private static String randomString(String prefix) {
+    int length = ThreadLocalRandom.current().nextInt(200);
+    byte[] buffer = new byte[length];
+
+    for (int i = 0; i < length; i += 1) {
+      buffer[i] = (byte) 
CHARS.charAt(ThreadLocalRandom.current().nextInt(CHARS.length()));
+    }
+
+    return prefix + new String(buffer);
+  }
+
+  /** find the index where weightsUDF[index] < weight && weightsUDF[index+1] 
>= weight */
+  private static int binarySearchIndex(long[] weightsUDF, long target) {
+    Preconditions.checkArgument(
+        target < weightsUDF[weightsUDF.length - 1],
+        "weight is out of range: total weight = %s, search target = %s",
+        weightsUDF[weightsUDF.length - 1],
+        target);
+    int start = 0;
+    int end = weightsUDF.length - 1;
+    while (start < end) {
+      int mid = (start + end) / 2;
+      if (weightsUDF[mid] < target && weightsUDF[mid + 1] >= target) {
+        return mid;
+      }
+
+      if (weightsUDF[mid] >= target) {
+        end = mid - 1;
+      } else if (weightsUDF[mid + 1] < target) {
+        start = mid + 1;
+      }
+    }
+    return start;
+  }
+
+  /** Key is the id string and value is the weight in long value. */
+  private static NavigableMap<Integer, Long> longTailDistribution(
+      long startingWeight,
+      int longTailStartingIndex,
+      int longTailLength,
+      long longTailBaseWeight,
+      double weightRandomJitterPercentage) {
+
+    NavigableMap<Integer, Long> weights = Maps.newTreeMap();
+
+    // first part just decays the weight by half
+    long currentWeight = startingWeight;
+    for (int index = 0; index < longTailStartingIndex; ++index) {
+      double jitter = 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage / 100);
+      long weight = (long) (currentWeight * (1.0 + jitter));
+      weight = weight > 0 ? weight : 1;
+      weights.put(index, weight);
+      if (currentWeight > longTailBaseWeight) {
+        currentWeight = currentWeight / 2;
+      }
+    }
+
+    // long tail part
+    for (int index = longTailStartingIndex;
+        index < longTailStartingIndex + longTailLength;
+        ++index) {
+      long longTailWeight =
+          (long)
+              (longTailBaseWeight
+                  * 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage));
+      longTailWeight = longTailWeight > 0 ? longTailWeight : 1;
+      weights.put(index, longTailWeight);
+    }
+
+    return weights;
+  }
+}
diff --git 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java
 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java
new file mode 100644
index 0000000000..fb1a8f03a6
--- /dev/null
+++ 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java
@@ -0,0 +1,381 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.table.data.RowData;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.SortOrderComparators;
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.flink.FlinkSchemaUtil;
+import org.apache.iceberg.flink.RowDataWrapper;
+import 
org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
+import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.util.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Internal partitioner implementation that supports MapDataStatistics, which 
is typically used for
+ * low-cardinality use cases. While MapDataStatistics can keep accurate 
counters, it can't be used
+ * for high-cardinality use cases. Otherwise, the memory footprint is too high.
+ *
+ * <p>It is a greedy algorithm for bin packing. With close file cost, the 
calculation isn't always
+ * precise when calculating close cost for every file, target weight per 
subtask, padding residual
+ * weight, assigned weight without close cost.
+ *
+ * <p>All actions should be executed in a single Flink mailbox thread. So 
there is no need to make
+ * it thread safe.
+ */
+class MapRangePartitioner implements Partitioner<RowData> {
+  private static final Logger LOG = 
LoggerFactory.getLogger(MapRangePartitioner.class);
+
+  private final RowDataWrapper rowDataWrapper;
+  private final SortKey sortKey;
+  private final Comparator<StructLike> comparator;
+  private final Map<SortKey, Long> mapStatistics;
+  private final double closeFileCostInWeightPercentage;
+
+  // Counter that tracks how many times a new key encountered
+  // where there is no traffic statistics learned about it.
+  private long newSortKeyCounter;
+  private long lastNewSortKeyLogTimeMilli;
+
+  // lazily computed due to the need of numPartitions
+  private Map<SortKey, KeyAssignment> assignment;
+  private NavigableMap<SortKey, Long> sortedStatsWithCloseFileCost;
+
+  MapRangePartitioner(
+      Schema schema,
+      SortOrder sortOrder,
+      MapDataStatistics dataStatistics,
+      double closeFileCostInWeightPercentage) {
+    dataStatistics
+        .statistics()
+        .entrySet()
+        .forEach(
+            entry ->
+                Preconditions.checkArgument(
+                    entry.getValue() > 0,
+                    "Invalid statistics: weight is 0 for key %s",
+                    entry.getKey()));
+
+    this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), 
schema.asStruct());
+    this.sortKey = new SortKey(schema, sortOrder);
+    this.comparator = SortOrderComparators.forSchema(schema, sortOrder);
+    this.mapStatistics = dataStatistics.statistics();
+    this.closeFileCostInWeightPercentage = closeFileCostInWeightPercentage;
+    this.newSortKeyCounter = 0;
+    this.lastNewSortKeyLogTimeMilli = System.currentTimeMillis();
+  }
+
+  @Override
+  public int partition(RowData row, int numPartitions) {
+    // assignment table can only be built lazily when first referenced here,
+    // because number of partitions (downstream subtasks) is needed.
+    // the numPartitions is not available in the constructor.
+    Map<SortKey, KeyAssignment> assignmentMap = assignment(numPartitions);
+    // reuse the sortKey and rowDataWrapper
+    sortKey.wrap(rowDataWrapper.wrap(row));
+    KeyAssignment keyAssignment = assignmentMap.get(sortKey);
+    if (keyAssignment == null) {
+      LOG.trace(
+          "Encountered new sort key: {}. Fall back to round robin as 
statistics not learned yet.",
+          sortKey);
+      // Ideally unknownKeyCounter should be published as a counter metric.
+      // It seems difficult to pass in MetricGroup into the partitioner.
+      // Just log an INFO message every minute.
+      newSortKeyCounter += 1;
+      long now = System.currentTimeMillis();
+      if (now - lastNewSortKeyLogTimeMilli > TimeUnit.MINUTES.toMillis(1)) {
+        LOG.info("Encounter new sort keys in total {} times", 
newSortKeyCounter);
+        lastNewSortKeyLogTimeMilli = now;
+      }
+      return (int) (newSortKeyCounter % numPartitions);
+    }
+
+    return keyAssignment.select();
+  }
+
+  @VisibleForTesting
+  Map<SortKey, KeyAssignment> assignment(int numPartitions) {
+    if (assignment == null) {
+      long totalWeight = mapStatistics.values().stream().mapToLong(l -> 
l).sum();
+      double targetWeightPerSubtask = ((double) totalWeight) / numPartitions;
+      long closeFileCostInWeight =
+          (long) Math.ceil(targetWeightPerSubtask * 
closeFileCostInWeightPercentage / 100);
+
+      this.sortedStatsWithCloseFileCost = Maps.newTreeMap(comparator);
+      mapStatistics.forEach(
+          (k, v) -> {
+            int estimatedSplits = (int) Math.ceil(v / targetWeightPerSubtask);
+            long estimatedCloseFileCost = closeFileCostInWeight * 
estimatedSplits;
+            sortedStatsWithCloseFileCost.put(k, v + estimatedCloseFileCost);
+          });
+
+      long totalWeightWithCloseFileCost =
+          sortedStatsWithCloseFileCost.values().stream().mapToLong(l -> 
l).sum();
+      long targetWeightPerSubtaskWithCloseFileCost =
+          (long) Math.ceil(((double) totalWeightWithCloseFileCost) / 
numPartitions);
+      this.assignment =
+          buildAssignment(
+              numPartitions,
+              sortedStatsWithCloseFileCost,
+              targetWeightPerSubtaskWithCloseFileCost,
+              closeFileCostInWeight);
+    }
+
+    return assignment;
+  }
+
+  @VisibleForTesting
+  Map<SortKey, Long> mapStatistics() {
+    return mapStatistics;
+  }
+
+  /**
+   * @return assignment summary for every subtask. Key is subtaskId. Value 
pair is (weight assigned
+   *     to the subtask, number of keys assigned to the subtask)
+   */
+  Map<Integer, Pair<Long, Integer>> assignmentInfo() {
+    Map<Integer, Pair<Long, Integer>> assignmentInfo = Maps.newTreeMap();
+    assignment.forEach(
+        (key, keyAssignment) -> {
+          for (int i = 0; i < keyAssignment.assignedSubtasks.length; ++i) {
+            int subtaskId = keyAssignment.assignedSubtasks[i];
+            long subtaskWeight = 
keyAssignment.subtaskWeightsExcludingCloseCost[i];
+            Pair<Long, Integer> oldValue = 
assignmentInfo.getOrDefault(subtaskId, Pair.of(0L, 0));
+            assignmentInfo.put(
+                subtaskId, Pair.of(oldValue.first() + subtaskWeight, 
oldValue.second() + 1));
+          }
+        });
+
+    return assignmentInfo;
+  }
+
+  private Map<SortKey, KeyAssignment> buildAssignment(
+      int numPartitions,
+      NavigableMap<SortKey, Long> sortedStatistics,
+      long targetWeightPerSubtask,
+      long closeFileCostInWeight) {
+    Map<SortKey, KeyAssignment> assignmentMap =
+        Maps.newHashMapWithExpectedSize(sortedStatistics.size());
+    Iterator<SortKey> mapKeyIterator = sortedStatistics.keySet().iterator();
+    int subtaskId = 0;
+    SortKey currentKey = null;
+    long keyRemainingWeight = 0L;
+    long subtaskRemainingWeight = targetWeightPerSubtask;
+    List<Integer> assignedSubtasks = Lists.newArrayList();
+    List<Long> subtaskWeights = Lists.newArrayList();
+    while (mapKeyIterator.hasNext() || currentKey != null) {
+      // This should never happen because target weight is calculated using 
ceil function.
+      if (subtaskId >= numPartitions) {
+        LOG.error(
+            "Internal algorithm error: exhausted subtasks with unassigned keys 
left. number of partitions: {}, "
+                + "target weight per subtask: {}, close file cost in weight: 
{}, data statistics: {}",
+            numPartitions,
+            targetWeightPerSubtask,
+            closeFileCostInWeight,
+            sortedStatistics);
+        throw new IllegalStateException(
+            "Internal algorithm error: exhausted subtasks with unassigned keys 
left");
+      }
+
+      if (currentKey == null) {
+        currentKey = mapKeyIterator.next();
+        keyRemainingWeight = sortedStatistics.get(currentKey);
+      }
+
+      assignedSubtasks.add(subtaskId);
+      if (keyRemainingWeight < subtaskRemainingWeight) {
+        // assign the remaining weight of the key to the current subtask
+        subtaskWeights.add(keyRemainingWeight);
+        subtaskRemainingWeight -= keyRemainingWeight;
+        keyRemainingWeight = 0L;
+      } else {
+        // filled up the current subtask
+        long assignedWeight = subtaskRemainingWeight;
+        keyRemainingWeight -= subtaskRemainingWeight;
+
+        // If assigned weight is less than close file cost, pad it up with 
close file cost.
+        // This might cause the subtask assigned weight over the target weight.
+        // But it should be no more than one close file cost. Small skew is 
acceptable.
+        if (assignedWeight <= closeFileCostInWeight) {
+          long paddingWeight = Math.min(keyRemainingWeight, 
closeFileCostInWeight);
+          keyRemainingWeight -= paddingWeight;
+          assignedWeight += paddingWeight;
+        }
+
+        subtaskWeights.add(assignedWeight);
+        // move on to the next subtask
+        subtaskId += 1;
+        subtaskRemainingWeight = targetWeightPerSubtask;
+      }
+
+      Preconditions.checkState(
+          assignedSubtasks.size() == subtaskWeights.size(),
+          "List size mismatch: assigned subtasks = %s, subtask weights = %s",
+          assignedSubtasks,
+          subtaskWeights);
+
+      // If the remaining key weight is smaller than the close file cost, 
simply skip the residual
+      // as it doesn't make sense to assign a weight smaller than close file 
cost to a new subtask.
+      // this might lead to some inaccuracy in weight calculation. E.g., 
assuming the key weight is
+      // 2 and close file cost is 2. key weight with close cost is 4. Let's 
assume the previous
+      // task has a weight of 3 available. So weight of 3 for this key is 
assigned to the task and
+      // the residual weight of 1 is dropped. Then the routing weight for this 
key is 1 (minus the
+      // close file cost), which is inaccurate as the true key weight should 
be 2.
+      // Again, this greedy algorithm is not intended to be perfect. Some 
small inaccuracy is
+      // expected and acceptable. Traffic distribution should still be 
balanced.
+      if (keyRemainingWeight > 0 && keyRemainingWeight <= 
closeFileCostInWeight) {
+        keyRemainingWeight = 0;
+      }
+
+      if (keyRemainingWeight == 0) {
+        // finishing up the assignment for the current key
+        KeyAssignment keyAssignment =
+            new KeyAssignment(assignedSubtasks, subtaskWeights, 
closeFileCostInWeight);
+        assignmentMap.put(currentKey, keyAssignment);
+        assignedSubtasks.clear();
+        subtaskWeights.clear();
+        currentKey = null;
+      }
+    }
+
+    return assignmentMap;
+  }
+
+  /** Subtask assignment for a key */
+  @VisibleForTesting
+  static class KeyAssignment {
+    private final int[] assignedSubtasks;
+    private final long[] subtaskWeightsExcludingCloseCost;
+    private final long keyWeight;
+    private final long[] cumulativeWeights;
+
+    /**
+     * @param assignedSubtasks assigned subtasks for this key. It could be a 
single subtask. It
+     *     could also be multiple subtasks if the key has heavy weight that 
should be handled by
+     *     multiple subtasks.
+     * @param subtaskWeightsWithCloseFileCost assigned weight for each 
subtask. E.g., if the
+     *     keyWeight is 27 and the key is assigned to 3 subtasks, 
subtaskWeights could contain
+     *     values as [10, 10, 7] for target weight of 10 per subtask.
+     */
+    KeyAssignment(
+        List<Integer> assignedSubtasks,
+        List<Long> subtaskWeightsWithCloseFileCost,
+        long closeFileCostInWeight) {
+      Preconditions.checkArgument(
+          assignedSubtasks != null && !assignedSubtasks.isEmpty(),
+          "Invalid assigned subtasks: null or empty");
+      Preconditions.checkArgument(
+          subtaskWeightsWithCloseFileCost != null && 
!subtaskWeightsWithCloseFileCost.isEmpty(),
+          "Invalid assigned subtasks weights: null or empty");
+      Preconditions.checkArgument(
+          assignedSubtasks.size() == subtaskWeightsWithCloseFileCost.size(),
+          "Invalid assignment: size mismatch (tasks length = %s, weights 
length = %s)",
+          assignedSubtasks.size(),
+          subtaskWeightsWithCloseFileCost.size());
+      subtaskWeightsWithCloseFileCost.forEach(
+          weight ->
+              Preconditions.checkArgument(
+                  weight > closeFileCostInWeight,
+                  "Invalid weight: should be larger than close file cost: 
weight = %s, close file cost = %s",
+                  weight,
+                  closeFileCostInWeight));
+
+      this.assignedSubtasks = assignedSubtasks.stream().mapToInt(i -> 
i).toArray();
+      // Exclude the close file cost for key routing
+      this.subtaskWeightsExcludingCloseCost =
+          subtaskWeightsWithCloseFileCost.stream()
+              .mapToLong(weightWithCloseFileCost -> weightWithCloseFileCost - 
closeFileCostInWeight)
+              .toArray();
+      this.keyWeight = Arrays.stream(subtaskWeightsExcludingCloseCost).sum();
+      this.cumulativeWeights = new 
long[subtaskWeightsExcludingCloseCost.length];
+      long cumulativeWeight = 0;
+      for (int i = 0; i < subtaskWeightsExcludingCloseCost.length; ++i) {
+        cumulativeWeight += subtaskWeightsExcludingCloseCost[i];
+        cumulativeWeights[i] = cumulativeWeight;
+      }
+    }
+
+    /** @return subtask id */
+    int select() {
+      if (assignedSubtasks.length == 1) {
+        // only choice. no need to run random number generator.
+        return assignedSubtasks[0];
+      } else {
+        long randomNumber = ThreadLocalRandom.current().nextLong(keyWeight);
+        int index = Arrays.binarySearch(cumulativeWeights, randomNumber);
+        // choose the subtask where randomNumber < cumulativeWeights[pos].
+        // this works regardless whether index is negative or not.
+        int position = Math.abs(index + 1);
+        Preconditions.checkState(
+            position < assignedSubtasks.length,
+            "Invalid selected position: out of range. key weight = %s, random 
number = %s, cumulative weights array = %s",
+            keyWeight,
+            randomNumber,
+            cumulativeWeights);
+        return assignedSubtasks[position];
+      }
+    }
+
+    @Override
+    public int hashCode() {
+      return 31 * Arrays.hashCode(assignedSubtasks)
+          + Arrays.hashCode(subtaskWeightsExcludingCloseCost);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+
+      KeyAssignment that = (KeyAssignment) o;
+      return Arrays.equals(assignedSubtasks, that.assignedSubtasks)
+          && Arrays.equals(subtaskWeightsExcludingCloseCost, 
that.subtaskWeightsExcludingCloseCost);
+    }
+
+    @Override
+    public String toString() {
+      return MoreObjects.toStringHelper(this)
+          .add("assignedSubtasks", assignedSubtasks)
+          .add("subtaskWeightsExcludingCloseCost", 
subtaskWeightsExcludingCloseCost)
+          .toString();
+    }
+  }
+}
diff --git 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java
 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java
new file mode 100644
index 0000000000..92eb71acc8
--- /dev/null
+++ 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java
@@ -0,0 +1,448 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicLong;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.flink.FlinkSchemaUtil;
+import org.apache.iceberg.flink.RowDataWrapper;
+import org.apache.iceberg.flink.TestFixtures;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.relocated.com.google.common.collect.Sets;
+import org.apache.iceberg.util.Pair;
+import org.assertj.core.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class TestMapRangePartitioner {
+  private static final SortOrder SORT_ORDER =
+      SortOrder.builderFor(TestFixtures.SCHEMA).asc("data").build();
+
+  private static final SortKey SORT_KEY = new SortKey(TestFixtures.SCHEMA, 
SORT_ORDER);
+  private static final RowType ROW_TYPE = 
FlinkSchemaUtil.convert(TestFixtures.SCHEMA);
+  private static final SortKey[] SORT_KEYS = initSortKeys();
+
+  private static SortKey[] initSortKeys() {
+    SortKey[] sortKeys = new SortKey[10];
+    for (int i = 0; i < 10; ++i) {
+      RowData rowData =
+          GenericRowData.of(StringData.fromString("k" + i), i, 
StringData.fromString("2023-06-20"));
+      RowDataWrapper keyWrapper = new RowDataWrapper(ROW_TYPE, 
TestFixtures.SCHEMA.asStruct());
+      keyWrapper.wrap(rowData);
+      SortKey sortKey = SORT_KEY.copy();
+      sortKey.wrap(keyWrapper);
+      sortKeys[i] = sortKey;
+    }
+    return sortKeys;
+  }
+
+  // Total weight is 800
+  private final MapDataStatistics mapDataStatistics =
+      new MapDataStatistics(
+          ImmutableMap.of(
+              SORT_KEYS[0],
+              350L,
+              SORT_KEYS[1],
+              230L,
+              SORT_KEYS[2],
+              120L,
+              SORT_KEYS[3],
+              40L,
+              SORT_KEYS[4],
+              10L,
+              SORT_KEYS[5],
+              10L,
+              SORT_KEYS[6],
+              10L,
+              SORT_KEYS[7],
+              10L,
+              SORT_KEYS[8],
+              10L,
+              SORT_KEYS[9],
+              10L));
+
+  @Test
+  public void testEvenlyDividableNoClosingFileCost() {
+    MapRangePartitioner partitioner =
+        new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, 
mapDataStatistics, 0.0);
+    int numPartitions = 8;
+
+    // each task should get targeted weight of 100 (=800/8)
+    Map<SortKey, MapRangePartitioner.KeyAssignment> expectedAssignment =
+        ImmutableMap.of(
+            SORT_KEYS[0],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(0, 1, 2, 3), ImmutableList.of(100L, 100L, 
100L, 50L), 0L),
+            SORT_KEYS[1],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(3, 4, 5), ImmutableList.of(50L, 100L, 80L), 
0L),
+            SORT_KEYS[2],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(5, 6), ImmutableList.of(20L, 100L), 0L),
+            SORT_KEYS[3],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(40L), 0L),
+            SORT_KEYS[4],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[5],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[6],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[7],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[8],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[9],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L));
+    Map<SortKey, MapRangePartitioner.KeyAssignment> actualAssignment =
+        partitioner.assignment(numPartitions);
+    Assertions.assertThat(actualAssignment).isEqualTo(expectedAssignment);
+
+    // key: subtask id
+    // value pair: first is the assigned weight, second is the number of 
assigned keys
+    Map<Integer, Pair<Long, Integer>> expectedAssignmentInfo =
+        ImmutableMap.of(
+            0,
+            Pair.of(100L, 1),
+            1,
+            Pair.of(100L, 1),
+            2,
+            Pair.of(100L, 1),
+            3,
+            Pair.of(100L, 2),
+            4,
+            Pair.of(100L, 1),
+            5,
+            Pair.of(100L, 2),
+            6,
+            Pair.of(100L, 1),
+            7,
+            Pair.of(100L, 7));
+    Map<Integer, Pair<Long, Integer>> actualAssignmentInfo = 
partitioner.assignmentInfo();
+    
Assertions.assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo);
+
+    Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults =
+        runPartitioner(partitioner, numPartitions);
+    validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0);
+  }
+
+  @Test
+  public void testEvenlyDividableWithClosingFileCost() {
+    MapRangePartitioner partitioner =
+        new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, 
mapDataStatistics, 5.0);
+    int numPartitions = 8;
+
+    // target subtask weight is 100 before close file cost factored in.
+    // close file cost is 5 = 5% * 100.
+    // key weights before and after close file cost factored in
+    // before:     350, 230, 120, 40, 10, 10, 10, 10, 10, 10
+    // close-cost:  20,  15,  10,  5,  5,  5,  5,  5,  5,  5
+    // after:      370, 245, 130, 45, 15, 15, 15, 15, 15, 15
+    // target subtask weight with close cost per subtask is 110 (880/8)
+    Map<SortKey, MapRangePartitioner.KeyAssignment> expectedAssignment =
+        ImmutableMap.of(
+            SORT_KEYS[0],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(0, 1, 2, 3), ImmutableList.of(110L, 110L, 
110L, 40L), 5L),
+            SORT_KEYS[1],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(3, 4, 5), ImmutableList.of(70L, 110L, 65L), 
5L),
+            SORT_KEYS[2],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(5, 6), ImmutableList.of(45L, 85L), 5L),
+            SORT_KEYS[3],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(6, 7), ImmutableList.of(25L, 20L), 5L),
+            SORT_KEYS[4],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[5],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[6],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[7],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[8],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[9],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L));
+    Map<SortKey, MapRangePartitioner.KeyAssignment> actualAssignment =
+        partitioner.assignment(numPartitions);
+    Assertions.assertThat(actualAssignment).isEqualTo(expectedAssignment);
+
+    // key: subtask id
+    // value pair: first is the assigned weight (excluding close file cost) 
for the subtask,
+    // second is the number of keys assigned to the subtask
+    Map<Integer, Pair<Long, Integer>> expectedAssignmentInfo =
+        ImmutableMap.of(
+            0,
+            Pair.of(105L, 1),
+            1,
+            Pair.of(105L, 1),
+            2,
+            Pair.of(105L, 1),
+            3,
+            Pair.of(100L, 2),
+            4,
+            Pair.of(105L, 1),
+            5,
+            Pair.of(100L, 2),
+            6,
+            Pair.of(100L, 2),
+            7,
+            Pair.of(75L, 7));
+    Map<Integer, Pair<Long, Integer>> actualAssignmentInfo = 
partitioner.assignmentInfo();
+    
Assertions.assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo);
+
+    Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults =
+        runPartitioner(partitioner, numPartitions);
+    validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0);
+  }
+
+  @Test
+  public void testNonDividableNoClosingFileCost() {
+    MapRangePartitioner partitioner =
+        new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, 
mapDataStatistics, 0.0);
+    int numPartitions = 9;
+
+    // before:     350, 230, 120, 40, 10, 10, 10, 10, 10, 10
+    // each task should get targeted weight of 89 = ceiling(800/9)
+    Map<SortKey, MapRangePartitioner.KeyAssignment> expectedAssignment =
+        ImmutableMap.of(
+            SORT_KEYS[0],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(0, 1, 2, 3), ImmutableList.of(89L, 89L, 89L, 
83L), 0L),
+            SORT_KEYS[1],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(3, 4, 5, 6), ImmutableList.of(6L, 89L, 89L, 
46L), 0L),
+            SORT_KEYS[2],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(6, 7), ImmutableList.of(43L, 77L), 0L),
+            SORT_KEYS[3],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(7, 8), ImmutableList.of(12L, 28L), 0L),
+            SORT_KEYS[4],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[5],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[6],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[7],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[8],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[9],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L));
+    Map<SortKey, MapRangePartitioner.KeyAssignment> actualAssignment =
+        partitioner.assignment(numPartitions);
+    Assertions.assertThat(actualAssignment).isEqualTo(expectedAssignment);
+
+    // key: subtask id
+    // value pair: first is the assigned weight, second is the number of 
assigned keys
+    Map<Integer, Pair<Long, Integer>> expectedAssignmentInfo =
+        ImmutableMap.of(
+            0,
+            Pair.of(89L, 1),
+            1,
+            Pair.of(89L, 1),
+            2,
+            Pair.of(89L, 1),
+            3,
+            Pair.of(89L, 2),
+            4,
+            Pair.of(89L, 1),
+            5,
+            Pair.of(89L, 1),
+            6,
+            Pair.of(89L, 2),
+            7,
+            Pair.of(89L, 2),
+            8,
+            Pair.of(88L, 7));
+    Map<Integer, Pair<Long, Integer>> actualAssignmentInfo = 
partitioner.assignmentInfo();
+    
Assertions.assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo);
+
+    Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults =
+        runPartitioner(partitioner, numPartitions);
+    validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0);
+  }
+
+  @Test
+  public void testNonDividableWithClosingFileCost() {
+    MapRangePartitioner partitioner =
+        new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, 
mapDataStatistics, 5.0);
+    int numPartitions = 9;
+
+    // target subtask weight is 89 before close file cost factored in.
+    // close file cost is 5 (= 5% * 89) per file.
+    // key weights before and after close file cost factored in
+    // before:     350, 230, 120, 40, 10, 10, 10, 10, 10, 10
+    // close-cost:  20,  15,  10,  5,  5,  5,  5,  5,  5,  5
+    // after:      370, 245, 130, 45, 15, 15, 15, 15, 15, 15
+    // target subtask weight per subtask is 98 ceiling(880/9)
+    Map<SortKey, MapRangePartitioner.KeyAssignment> expectedAssignment =
+        ImmutableMap.of(
+            SORT_KEYS[0],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(0, 1, 2, 3), ImmutableList.of(98L, 98L, 98L, 
76L), 5L),
+            SORT_KEYS[1],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(3, 4, 5, 6), ImmutableList.of(22L, 98L, 98L, 
27L), 5L),
+            SORT_KEYS[2],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(6, 7), ImmutableList.of(71L, 59L), 5L),
+            SORT_KEYS[3],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(7, 8), ImmutableList.of(39L, 6L), 5L),
+            SORT_KEYS[4],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[5],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[6],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[7],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[8],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[9],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L));
+    Map<SortKey, MapRangePartitioner.KeyAssignment> actualAssignment =
+        partitioner.assignment(numPartitions);
+    Assertions.assertThat(actualAssignment).isEqualTo(expectedAssignment);
+
+    // key: subtask id
+    // value pair: first is the assigned weight for the subtask, second is the 
number of keys
+    // assigned to the subtask
+    Map<Integer, Pair<Long, Integer>> expectedAssignmentInfo =
+        ImmutableMap.of(
+            0,
+            Pair.of(93L, 1),
+            1,
+            Pair.of(93L, 1),
+            2,
+            Pair.of(93L, 1),
+            3,
+            Pair.of(88L, 2),
+            4,
+            Pair.of(93L, 1),
+            5,
+            Pair.of(93L, 1),
+            6,
+            Pair.of(88L, 2),
+            7,
+            Pair.of(88L, 2),
+            8,
+            Pair.of(61L, 7));
+    Map<Integer, Pair<Long, Integer>> actualAssignmentInfo = 
partitioner.assignmentInfo();
+    
Assertions.assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo);
+
+    Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults =
+        runPartitioner(partitioner, numPartitions);
+    // drift threshold is high for non-dividable scenario with close cost
+    validatePartitionResults(expectedAssignmentInfo, partitionResults, 10.0);
+  }
+
+  private static Map<Integer, Pair<AtomicLong, Set<RowData>>> runPartitioner(
+      MapRangePartitioner partitioner, int numPartitions) {
+    // The Map key is the subtaskId.
+    // For the map value pair, the first element is the count of assigned and
+    // the second element of Set<String> is for the set of assigned keys.
+    Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults = 
Maps.newHashMap();
+    partitioner
+        .mapStatistics()
+        .forEach(
+            (sortKey, weight) -> {
+              String key = sortKey.get(0, String.class);
+              // run 100x times of the weight
+              long iterations = weight * 100;
+              for (int i = 0; i < iterations; ++i) {
+                RowData rowData =
+                    GenericRowData.of(
+                        StringData.fromString(key), 1, 
StringData.fromString("2023-06-20"));
+                int subtaskId = partitioner.partition(rowData, numPartitions);
+                partitionResults.computeIfAbsent(
+                    subtaskId, k -> Pair.of(new AtomicLong(0), 
Sets.newHashSet()));
+                Pair<AtomicLong, Set<RowData>> pair = 
partitionResults.get(subtaskId);
+                pair.first().incrementAndGet();
+                pair.second().add(rowData);
+              }
+            });
+    return partitionResults;
+  }
+
+  /** @param expectedAssignmentInfo excluding closing cost */
+  private void validatePartitionResults(
+      Map<Integer, Pair<Long, Integer>> expectedAssignmentInfo,
+      Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults,
+      double maxDriftPercentage) {
+
+    
Assertions.assertThat(partitionResults.size()).isEqualTo(expectedAssignmentInfo.size());
+
+    List<Integer> expectedAssignedKeyCounts =
+        Lists.newArrayListWithExpectedSize(expectedAssignmentInfo.size());
+    List<Integer> actualAssignedKeyCounts =
+        Lists.newArrayListWithExpectedSize(partitionResults.size());
+    List<Double> expectedNormalizedWeights =
+        Lists.newArrayListWithExpectedSize(expectedAssignmentInfo.size());
+    List<Double> actualNormalizedWeights =
+        Lists.newArrayListWithExpectedSize(partitionResults.size());
+
+    long expectedTotalWeight =
+        expectedAssignmentInfo.values().stream().mapToLong(Pair::first).sum();
+    expectedAssignmentInfo.forEach(
+        (subtaskId, pair) -> {
+          expectedAssignedKeyCounts.add(pair.second());
+          expectedNormalizedWeights.add(pair.first().doubleValue() / 
expectedTotalWeight);
+        });
+
+    long actualTotalWeight =
+        partitionResults.values().stream().mapToLong(pair -> 
pair.first().longValue()).sum();
+    partitionResults.forEach(
+        (subtaskId, pair) -> {
+          actualAssignedKeyCounts.add(pair.second().size());
+          actualNormalizedWeights.add(pair.first().doubleValue() / 
actualTotalWeight);
+        });
+
+    // number of assigned keys should match exactly
+    Assertions.assertThat(actualAssignedKeyCounts)
+        .as("the number of assigned keys should match for every subtask")
+        .isEqualTo(expectedAssignedKeyCounts);
+
+    // weight for every subtask shouldn't differ for more than some threshold 
relative to the
+    // expected weight
+    for (int subtaskId = 0; subtaskId < expectedNormalizedWeights.size(); 
++subtaskId) {
+      double expectedWeight = expectedNormalizedWeights.get(subtaskId);
+      double min = expectedWeight * (1 - maxDriftPercentage / 100);
+      double max = expectedWeight * (1 + maxDriftPercentage / 100);
+      Assertions.assertThat(actualNormalizedWeights.get(subtaskId))
+          .as(
+              "Subtask %d weight should within %.1f percent of the expected 
range %s",
+              subtaskId, maxDriftPercentage, expectedWeight)
+          .isBetween(min, max);
+    }
+  }
+}
diff --git 
a/flink/v1.18/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
 
b/flink/v1.18/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
new file mode 100644
index 0000000000..c391716575
--- /dev/null
+++ 
b/flink/v1.18/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.types.Types;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Threads;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+@Fork(1)
+@State(Scope.Benchmark)
+@Warmup(iterations = 3)
+@Measurement(iterations = 5)
+@BenchmarkMode(Mode.SingleShotTime)
+public class MapRangePartitionerBenchmark {
+  private static final String CHARS =
+      "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.!?";
+  private static final int SAMPLE_SIZE = 100_000;
+  private static final Schema SCHEMA =
+      new Schema(
+          Types.NestedField.required(1, "id", Types.IntegerType.get()),
+          Types.NestedField.required(2, "name2", Types.StringType.get()),
+          Types.NestedField.required(3, "name3", Types.StringType.get()),
+          Types.NestedField.required(4, "name4", Types.StringType.get()),
+          Types.NestedField.required(5, "name5", Types.StringType.get()),
+          Types.NestedField.required(6, "name6", Types.StringType.get()),
+          Types.NestedField.required(7, "name7", Types.StringType.get()),
+          Types.NestedField.required(8, "name8", Types.StringType.get()),
+          Types.NestedField.required(9, "name9", Types.StringType.get()));
+
+  private static final SortOrder SORT_ORDER = 
SortOrder.builderFor(SCHEMA).asc("id").build();
+  private static final SortKey SORT_KEY = new SortKey(SCHEMA, SORT_ORDER);
+
+  private MapRangePartitioner partitioner;
+  private RowData[] rows;
+
+  @Setup
+  public void setupBenchmark() {
+    NavigableMap<Integer, Long> weights = longTailDistribution(100_000, 24, 
240, 100, 2.0);
+    Map<SortKey, Long> mapStatistics = 
Maps.newHashMapWithExpectedSize(weights.size());
+    weights.forEach(
+        (id, weight) -> {
+          SortKey sortKey = SORT_KEY.copy();
+          sortKey.set(0, id);
+          mapStatistics.put(sortKey, weight);
+        });
+
+    MapDataStatistics dataStatistics = new MapDataStatistics(mapStatistics);
+    this.partitioner =
+        new MapRangePartitioner(
+            SCHEMA, SortOrder.builderFor(SCHEMA).asc("id").build(), 
dataStatistics, 2);
+
+    List<Integer> keys = Lists.newArrayList(weights.keySet().iterator());
+    long[] weightsCDF = new long[keys.size()];
+    long totalWeight = 0;
+    for (int i = 0; i < keys.size(); ++i) {
+      totalWeight += weights.get(keys.get(i));
+      weightsCDF[i] = totalWeight;
+    }
+
+    // pre-calculate the samples for benchmark run
+    this.rows = new GenericRowData[SAMPLE_SIZE];
+    for (int i = 0; i < SAMPLE_SIZE; ++i) {
+      long weight = ThreadLocalRandom.current().nextLong(totalWeight);
+      int index = binarySearchIndex(weightsCDF, weight);
+      rows[i] =
+          GenericRowData.of(
+              keys.get(index),
+              randomString("name2-"),
+              randomString("name3-"),
+              randomString("name4-"),
+              randomString("name5-"),
+              randomString("name6-"),
+              randomString("name7-"),
+              randomString("name8-"),
+              randomString("name9-"));
+    }
+  }
+
+  @TearDown
+  public void tearDownBenchmark() {}
+
+  @Benchmark
+  @Threads(1)
+  public void testPartitionerLongTailDistribution(Blackhole blackhole) {
+    for (int i = 0; i < SAMPLE_SIZE; ++i) {
+      blackhole.consume(partitioner.partition(rows[i], 128));
+    }
+  }
+
+  private static String randomString(String prefix) {
+    int length = ThreadLocalRandom.current().nextInt(200);
+    byte[] buffer = new byte[length];
+
+    for (int i = 0; i < length; i += 1) {
+      buffer[i] = (byte) 
CHARS.charAt(ThreadLocalRandom.current().nextInt(CHARS.length()));
+    }
+
+    return prefix + new String(buffer);
+  }
+
+  /** find the index where weightsUDF[index] < weight && weightsUDF[index+1] 
>= weight */
+  private static int binarySearchIndex(long[] weightsUDF, long target) {
+    Preconditions.checkArgument(
+        target < weightsUDF[weightsUDF.length - 1],
+        "weight is out of range: total weight = %s, search target = %s",
+        weightsUDF[weightsUDF.length - 1],
+        target);
+    int start = 0;
+    int end = weightsUDF.length - 1;
+    while (start < end) {
+      int mid = (start + end) / 2;
+      if (weightsUDF[mid] < target && weightsUDF[mid + 1] >= target) {
+        return mid;
+      }
+
+      if (weightsUDF[mid] >= target) {
+        end = mid - 1;
+      } else if (weightsUDF[mid + 1] < target) {
+        start = mid + 1;
+      }
+    }
+    return start;
+  }
+
+  /** Key is the id string and value is the weight in long value. */
+  private static NavigableMap<Integer, Long> longTailDistribution(
+      long startingWeight,
+      int longTailStartingIndex,
+      int longTailLength,
+      long longTailBaseWeight,
+      double weightRandomJitterPercentage) {
+
+    NavigableMap<Integer, Long> weights = Maps.newTreeMap();
+
+    // first part just decays the weight by half
+    long currentWeight = startingWeight;
+    for (int index = 0; index < longTailStartingIndex; ++index) {
+      double jitter = 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage / 100);
+      long weight = (long) (currentWeight * (1.0 + jitter));
+      weight = weight > 0 ? weight : 1;
+      weights.put(index, weight);
+      if (currentWeight > longTailBaseWeight) {
+        currentWeight = currentWeight / 2;
+      }
+    }
+
+    // long tail part
+    for (int index = longTailStartingIndex;
+        index < longTailStartingIndex + longTailLength;
+        ++index) {
+      long longTailWeight =
+          (long)
+              (longTailBaseWeight
+                  * 
ThreadLocalRandom.current().nextDouble(weightRandomJitterPercentage));
+      longTailWeight = longTailWeight > 0 ? longTailWeight : 1;
+      weights.put(index, longTailWeight);
+    }
+
+    return weights;
+  }
+}
diff --git 
a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java
 
b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java
new file mode 100644
index 0000000000..fb1a8f03a6
--- /dev/null
+++ 
b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java
@@ -0,0 +1,381 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.table.data.RowData;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.SortOrderComparators;
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.flink.FlinkSchemaUtil;
+import org.apache.iceberg.flink.RowDataWrapper;
+import 
org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
+import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.util.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Internal partitioner implementation that supports MapDataStatistics, which 
is typically used for
+ * low-cardinality use cases. While MapDataStatistics can keep accurate 
counters, it can't be used
+ * for high-cardinality use cases. Otherwise, the memory footprint is too high.
+ *
+ * <p>It is a greedy algorithm for bin packing. With close file cost, the 
calculation isn't always
+ * precise when calculating close cost for every file, target weight per 
subtask, padding residual
+ * weight, assigned weight without close cost.
+ *
+ * <p>All actions should be executed in a single Flink mailbox thread. So 
there is no need to make
+ * it thread safe.
+ */
+class MapRangePartitioner implements Partitioner<RowData> {
+  private static final Logger LOG = 
LoggerFactory.getLogger(MapRangePartitioner.class);
+
+  private final RowDataWrapper rowDataWrapper;
+  private final SortKey sortKey;
+  private final Comparator<StructLike> comparator;
+  private final Map<SortKey, Long> mapStatistics;
+  private final double closeFileCostInWeightPercentage;
+
+  // Counter that tracks how many times a new key encountered
+  // where there is no traffic statistics learned about it.
+  private long newSortKeyCounter;
+  private long lastNewSortKeyLogTimeMilli;
+
+  // lazily computed due to the need of numPartitions
+  private Map<SortKey, KeyAssignment> assignment;
+  private NavigableMap<SortKey, Long> sortedStatsWithCloseFileCost;
+
+  MapRangePartitioner(
+      Schema schema,
+      SortOrder sortOrder,
+      MapDataStatistics dataStatistics,
+      double closeFileCostInWeightPercentage) {
+    dataStatistics
+        .statistics()
+        .entrySet()
+        .forEach(
+            entry ->
+                Preconditions.checkArgument(
+                    entry.getValue() > 0,
+                    "Invalid statistics: weight is 0 for key %s",
+                    entry.getKey()));
+
+    this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), 
schema.asStruct());
+    this.sortKey = new SortKey(schema, sortOrder);
+    this.comparator = SortOrderComparators.forSchema(schema, sortOrder);
+    this.mapStatistics = dataStatistics.statistics();
+    this.closeFileCostInWeightPercentage = closeFileCostInWeightPercentage;
+    this.newSortKeyCounter = 0;
+    this.lastNewSortKeyLogTimeMilli = System.currentTimeMillis();
+  }
+
+  @Override
+  public int partition(RowData row, int numPartitions) {
+    // assignment table can only be built lazily when first referenced here,
+    // because number of partitions (downstream subtasks) is needed.
+    // the numPartitions is not available in the constructor.
+    Map<SortKey, KeyAssignment> assignmentMap = assignment(numPartitions);
+    // reuse the sortKey and rowDataWrapper
+    sortKey.wrap(rowDataWrapper.wrap(row));
+    KeyAssignment keyAssignment = assignmentMap.get(sortKey);
+    if (keyAssignment == null) {
+      LOG.trace(
+          "Encountered new sort key: {}. Fall back to round robin as 
statistics not learned yet.",
+          sortKey);
+      // Ideally unknownKeyCounter should be published as a counter metric.
+      // It seems difficult to pass in MetricGroup into the partitioner.
+      // Just log an INFO message every minute.
+      newSortKeyCounter += 1;
+      long now = System.currentTimeMillis();
+      if (now - lastNewSortKeyLogTimeMilli > TimeUnit.MINUTES.toMillis(1)) {
+        LOG.info("Encounter new sort keys in total {} times", 
newSortKeyCounter);
+        lastNewSortKeyLogTimeMilli = now;
+      }
+      return (int) (newSortKeyCounter % numPartitions);
+    }
+
+    return keyAssignment.select();
+  }
+
+  @VisibleForTesting
+  Map<SortKey, KeyAssignment> assignment(int numPartitions) {
+    if (assignment == null) {
+      long totalWeight = mapStatistics.values().stream().mapToLong(l -> 
l).sum();
+      double targetWeightPerSubtask = ((double) totalWeight) / numPartitions;
+      long closeFileCostInWeight =
+          (long) Math.ceil(targetWeightPerSubtask * 
closeFileCostInWeightPercentage / 100);
+
+      this.sortedStatsWithCloseFileCost = Maps.newTreeMap(comparator);
+      mapStatistics.forEach(
+          (k, v) -> {
+            int estimatedSplits = (int) Math.ceil(v / targetWeightPerSubtask);
+            long estimatedCloseFileCost = closeFileCostInWeight * 
estimatedSplits;
+            sortedStatsWithCloseFileCost.put(k, v + estimatedCloseFileCost);
+          });
+
+      long totalWeightWithCloseFileCost =
+          sortedStatsWithCloseFileCost.values().stream().mapToLong(l -> 
l).sum();
+      long targetWeightPerSubtaskWithCloseFileCost =
+          (long) Math.ceil(((double) totalWeightWithCloseFileCost) / 
numPartitions);
+      this.assignment =
+          buildAssignment(
+              numPartitions,
+              sortedStatsWithCloseFileCost,
+              targetWeightPerSubtaskWithCloseFileCost,
+              closeFileCostInWeight);
+    }
+
+    return assignment;
+  }
+
+  @VisibleForTesting
+  Map<SortKey, Long> mapStatistics() {
+    return mapStatistics;
+  }
+
+  /**
+   * @return assignment summary for every subtask. Key is subtaskId. Value 
pair is (weight assigned
+   *     to the subtask, number of keys assigned to the subtask)
+   */
+  Map<Integer, Pair<Long, Integer>> assignmentInfo() {
+    Map<Integer, Pair<Long, Integer>> assignmentInfo = Maps.newTreeMap();
+    assignment.forEach(
+        (key, keyAssignment) -> {
+          for (int i = 0; i < keyAssignment.assignedSubtasks.length; ++i) {
+            int subtaskId = keyAssignment.assignedSubtasks[i];
+            long subtaskWeight = 
keyAssignment.subtaskWeightsExcludingCloseCost[i];
+            Pair<Long, Integer> oldValue = 
assignmentInfo.getOrDefault(subtaskId, Pair.of(0L, 0));
+            assignmentInfo.put(
+                subtaskId, Pair.of(oldValue.first() + subtaskWeight, 
oldValue.second() + 1));
+          }
+        });
+
+    return assignmentInfo;
+  }
+
+  private Map<SortKey, KeyAssignment> buildAssignment(
+      int numPartitions,
+      NavigableMap<SortKey, Long> sortedStatistics,
+      long targetWeightPerSubtask,
+      long closeFileCostInWeight) {
+    Map<SortKey, KeyAssignment> assignmentMap =
+        Maps.newHashMapWithExpectedSize(sortedStatistics.size());
+    Iterator<SortKey> mapKeyIterator = sortedStatistics.keySet().iterator();
+    int subtaskId = 0;
+    SortKey currentKey = null;
+    long keyRemainingWeight = 0L;
+    long subtaskRemainingWeight = targetWeightPerSubtask;
+    List<Integer> assignedSubtasks = Lists.newArrayList();
+    List<Long> subtaskWeights = Lists.newArrayList();
+    while (mapKeyIterator.hasNext() || currentKey != null) {
+      // This should never happen because target weight is calculated using 
ceil function.
+      if (subtaskId >= numPartitions) {
+        LOG.error(
+            "Internal algorithm error: exhausted subtasks with unassigned keys 
left. number of partitions: {}, "
+                + "target weight per subtask: {}, close file cost in weight: 
{}, data statistics: {}",
+            numPartitions,
+            targetWeightPerSubtask,
+            closeFileCostInWeight,
+            sortedStatistics);
+        throw new IllegalStateException(
+            "Internal algorithm error: exhausted subtasks with unassigned keys 
left");
+      }
+
+      if (currentKey == null) {
+        currentKey = mapKeyIterator.next();
+        keyRemainingWeight = sortedStatistics.get(currentKey);
+      }
+
+      assignedSubtasks.add(subtaskId);
+      if (keyRemainingWeight < subtaskRemainingWeight) {
+        // assign the remaining weight of the key to the current subtask
+        subtaskWeights.add(keyRemainingWeight);
+        subtaskRemainingWeight -= keyRemainingWeight;
+        keyRemainingWeight = 0L;
+      } else {
+        // filled up the current subtask
+        long assignedWeight = subtaskRemainingWeight;
+        keyRemainingWeight -= subtaskRemainingWeight;
+
+        // If assigned weight is less than close file cost, pad it up with 
close file cost.
+        // This might cause the subtask assigned weight over the target weight.
+        // But it should be no more than one close file cost. Small skew is 
acceptable.
+        if (assignedWeight <= closeFileCostInWeight) {
+          long paddingWeight = Math.min(keyRemainingWeight, 
closeFileCostInWeight);
+          keyRemainingWeight -= paddingWeight;
+          assignedWeight += paddingWeight;
+        }
+
+        subtaskWeights.add(assignedWeight);
+        // move on to the next subtask
+        subtaskId += 1;
+        subtaskRemainingWeight = targetWeightPerSubtask;
+      }
+
+      Preconditions.checkState(
+          assignedSubtasks.size() == subtaskWeights.size(),
+          "List size mismatch: assigned subtasks = %s, subtask weights = %s",
+          assignedSubtasks,
+          subtaskWeights);
+
+      // If the remaining key weight is smaller than the close file cost, 
simply skip the residual
+      // as it doesn't make sense to assign a weight smaller than close file 
cost to a new subtask.
+      // this might lead to some inaccuracy in weight calculation. E.g., 
assuming the key weight is
+      // 2 and close file cost is 2. key weight with close cost is 4. Let's 
assume the previous
+      // task has a weight of 3 available. So weight of 3 for this key is 
assigned to the task and
+      // the residual weight of 1 is dropped. Then the routing weight for this 
key is 1 (minus the
+      // close file cost), which is inaccurate as the true key weight should 
be 2.
+      // Again, this greedy algorithm is not intended to be perfect. Some 
small inaccuracy is
+      // expected and acceptable. Traffic distribution should still be 
balanced.
+      if (keyRemainingWeight > 0 && keyRemainingWeight <= 
closeFileCostInWeight) {
+        keyRemainingWeight = 0;
+      }
+
+      if (keyRemainingWeight == 0) {
+        // finishing up the assignment for the current key
+        KeyAssignment keyAssignment =
+            new KeyAssignment(assignedSubtasks, subtaskWeights, 
closeFileCostInWeight);
+        assignmentMap.put(currentKey, keyAssignment);
+        assignedSubtasks.clear();
+        subtaskWeights.clear();
+        currentKey = null;
+      }
+    }
+
+    return assignmentMap;
+  }
+
+  /** Subtask assignment for a key */
+  @VisibleForTesting
+  static class KeyAssignment {
+    private final int[] assignedSubtasks;
+    private final long[] subtaskWeightsExcludingCloseCost;
+    private final long keyWeight;
+    private final long[] cumulativeWeights;
+
+    /**
+     * @param assignedSubtasks assigned subtasks for this key. It could be a 
single subtask. It
+     *     could also be multiple subtasks if the key has heavy weight that 
should be handled by
+     *     multiple subtasks.
+     * @param subtaskWeightsWithCloseFileCost assigned weight for each 
subtask. E.g., if the
+     *     keyWeight is 27 and the key is assigned to 3 subtasks, 
subtaskWeights could contain
+     *     values as [10, 10, 7] for target weight of 10 per subtask.
+     */
+    KeyAssignment(
+        List<Integer> assignedSubtasks,
+        List<Long> subtaskWeightsWithCloseFileCost,
+        long closeFileCostInWeight) {
+      Preconditions.checkArgument(
+          assignedSubtasks != null && !assignedSubtasks.isEmpty(),
+          "Invalid assigned subtasks: null or empty");
+      Preconditions.checkArgument(
+          subtaskWeightsWithCloseFileCost != null && 
!subtaskWeightsWithCloseFileCost.isEmpty(),
+          "Invalid assigned subtasks weights: null or empty");
+      Preconditions.checkArgument(
+          assignedSubtasks.size() == subtaskWeightsWithCloseFileCost.size(),
+          "Invalid assignment: size mismatch (tasks length = %s, weights 
length = %s)",
+          assignedSubtasks.size(),
+          subtaskWeightsWithCloseFileCost.size());
+      subtaskWeightsWithCloseFileCost.forEach(
+          weight ->
+              Preconditions.checkArgument(
+                  weight > closeFileCostInWeight,
+                  "Invalid weight: should be larger than close file cost: 
weight = %s, close file cost = %s",
+                  weight,
+                  closeFileCostInWeight));
+
+      this.assignedSubtasks = assignedSubtasks.stream().mapToInt(i -> 
i).toArray();
+      // Exclude the close file cost for key routing
+      this.subtaskWeightsExcludingCloseCost =
+          subtaskWeightsWithCloseFileCost.stream()
+              .mapToLong(weightWithCloseFileCost -> weightWithCloseFileCost - 
closeFileCostInWeight)
+              .toArray();
+      this.keyWeight = Arrays.stream(subtaskWeightsExcludingCloseCost).sum();
+      this.cumulativeWeights = new 
long[subtaskWeightsExcludingCloseCost.length];
+      long cumulativeWeight = 0;
+      for (int i = 0; i < subtaskWeightsExcludingCloseCost.length; ++i) {
+        cumulativeWeight += subtaskWeightsExcludingCloseCost[i];
+        cumulativeWeights[i] = cumulativeWeight;
+      }
+    }
+
+    /** @return subtask id */
+    int select() {
+      if (assignedSubtasks.length == 1) {
+        // only choice. no need to run random number generator.
+        return assignedSubtasks[0];
+      } else {
+        long randomNumber = ThreadLocalRandom.current().nextLong(keyWeight);
+        int index = Arrays.binarySearch(cumulativeWeights, randomNumber);
+        // choose the subtask where randomNumber < cumulativeWeights[pos].
+        // this works regardless whether index is negative or not.
+        int position = Math.abs(index + 1);
+        Preconditions.checkState(
+            position < assignedSubtasks.length,
+            "Invalid selected position: out of range. key weight = %s, random 
number = %s, cumulative weights array = %s",
+            keyWeight,
+            randomNumber,
+            cumulativeWeights);
+        return assignedSubtasks[position];
+      }
+    }
+
+    @Override
+    public int hashCode() {
+      return 31 * Arrays.hashCode(assignedSubtasks)
+          + Arrays.hashCode(subtaskWeightsExcludingCloseCost);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+
+      KeyAssignment that = (KeyAssignment) o;
+      return Arrays.equals(assignedSubtasks, that.assignedSubtasks)
+          && Arrays.equals(subtaskWeightsExcludingCloseCost, 
that.subtaskWeightsExcludingCloseCost);
+    }
+
+    @Override
+    public String toString() {
+      return MoreObjects.toStringHelper(this)
+          .add("assignedSubtasks", assignedSubtasks)
+          .add("subtaskWeightsExcludingCloseCost", 
subtaskWeightsExcludingCloseCost)
+          .toString();
+    }
+  }
+}
diff --git 
a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java
 
b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java
new file mode 100644
index 0000000000..92eb71acc8
--- /dev/null
+++ 
b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java
@@ -0,0 +1,448 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicLong;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.flink.FlinkSchemaUtil;
+import org.apache.iceberg.flink.RowDataWrapper;
+import org.apache.iceberg.flink.TestFixtures;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.relocated.com.google.common.collect.Sets;
+import org.apache.iceberg.util.Pair;
+import org.assertj.core.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class TestMapRangePartitioner {
+  private static final SortOrder SORT_ORDER =
+      SortOrder.builderFor(TestFixtures.SCHEMA).asc("data").build();
+
+  private static final SortKey SORT_KEY = new SortKey(TestFixtures.SCHEMA, 
SORT_ORDER);
+  private static final RowType ROW_TYPE = 
FlinkSchemaUtil.convert(TestFixtures.SCHEMA);
+  private static final SortKey[] SORT_KEYS = initSortKeys();
+
+  private static SortKey[] initSortKeys() {
+    SortKey[] sortKeys = new SortKey[10];
+    for (int i = 0; i < 10; ++i) {
+      RowData rowData =
+          GenericRowData.of(StringData.fromString("k" + i), i, 
StringData.fromString("2023-06-20"));
+      RowDataWrapper keyWrapper = new RowDataWrapper(ROW_TYPE, 
TestFixtures.SCHEMA.asStruct());
+      keyWrapper.wrap(rowData);
+      SortKey sortKey = SORT_KEY.copy();
+      sortKey.wrap(keyWrapper);
+      sortKeys[i] = sortKey;
+    }
+    return sortKeys;
+  }
+
+  // Total weight is 800
+  private final MapDataStatistics mapDataStatistics =
+      new MapDataStatistics(
+          ImmutableMap.of(
+              SORT_KEYS[0],
+              350L,
+              SORT_KEYS[1],
+              230L,
+              SORT_KEYS[2],
+              120L,
+              SORT_KEYS[3],
+              40L,
+              SORT_KEYS[4],
+              10L,
+              SORT_KEYS[5],
+              10L,
+              SORT_KEYS[6],
+              10L,
+              SORT_KEYS[7],
+              10L,
+              SORT_KEYS[8],
+              10L,
+              SORT_KEYS[9],
+              10L));
+
+  @Test
+  public void testEvenlyDividableNoClosingFileCost() {
+    MapRangePartitioner partitioner =
+        new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, 
mapDataStatistics, 0.0);
+    int numPartitions = 8;
+
+    // each task should get targeted weight of 100 (=800/8)
+    Map<SortKey, MapRangePartitioner.KeyAssignment> expectedAssignment =
+        ImmutableMap.of(
+            SORT_KEYS[0],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(0, 1, 2, 3), ImmutableList.of(100L, 100L, 
100L, 50L), 0L),
+            SORT_KEYS[1],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(3, 4, 5), ImmutableList.of(50L, 100L, 80L), 
0L),
+            SORT_KEYS[2],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(5, 6), ImmutableList.of(20L, 100L), 0L),
+            SORT_KEYS[3],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(40L), 0L),
+            SORT_KEYS[4],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[5],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[6],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[7],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[8],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[9],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(10L), 0L));
+    Map<SortKey, MapRangePartitioner.KeyAssignment> actualAssignment =
+        partitioner.assignment(numPartitions);
+    Assertions.assertThat(actualAssignment).isEqualTo(expectedAssignment);
+
+    // key: subtask id
+    // value pair: first is the assigned weight, second is the number of 
assigned keys
+    Map<Integer, Pair<Long, Integer>> expectedAssignmentInfo =
+        ImmutableMap.of(
+            0,
+            Pair.of(100L, 1),
+            1,
+            Pair.of(100L, 1),
+            2,
+            Pair.of(100L, 1),
+            3,
+            Pair.of(100L, 2),
+            4,
+            Pair.of(100L, 1),
+            5,
+            Pair.of(100L, 2),
+            6,
+            Pair.of(100L, 1),
+            7,
+            Pair.of(100L, 7));
+    Map<Integer, Pair<Long, Integer>> actualAssignmentInfo = 
partitioner.assignmentInfo();
+    
Assertions.assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo);
+
+    Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults =
+        runPartitioner(partitioner, numPartitions);
+    validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0);
+  }
+
+  @Test
+  public void testEvenlyDividableWithClosingFileCost() {
+    MapRangePartitioner partitioner =
+        new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, 
mapDataStatistics, 5.0);
+    int numPartitions = 8;
+
+    // target subtask weight is 100 before close file cost factored in.
+    // close file cost is 5 = 5% * 100.
+    // key weights before and after close file cost factored in
+    // before:     350, 230, 120, 40, 10, 10, 10, 10, 10, 10
+    // close-cost:  20,  15,  10,  5,  5,  5,  5,  5,  5,  5
+    // after:      370, 245, 130, 45, 15, 15, 15, 15, 15, 15
+    // target subtask weight with close cost per subtask is 110 (880/8)
+    Map<SortKey, MapRangePartitioner.KeyAssignment> expectedAssignment =
+        ImmutableMap.of(
+            SORT_KEYS[0],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(0, 1, 2, 3), ImmutableList.of(110L, 110L, 
110L, 40L), 5L),
+            SORT_KEYS[1],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(3, 4, 5), ImmutableList.of(70L, 110L, 65L), 
5L),
+            SORT_KEYS[2],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(5, 6), ImmutableList.of(45L, 85L), 5L),
+            SORT_KEYS[3],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(6, 7), ImmutableList.of(25L, 20L), 5L),
+            SORT_KEYS[4],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[5],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[6],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[7],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[8],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[9],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), 
ImmutableList.of(15L), 5L));
+    Map<SortKey, MapRangePartitioner.KeyAssignment> actualAssignment =
+        partitioner.assignment(numPartitions);
+    Assertions.assertThat(actualAssignment).isEqualTo(expectedAssignment);
+
+    // key: subtask id
+    // value pair: first is the assigned weight (excluding close file cost) 
for the subtask,
+    // second is the number of keys assigned to the subtask
+    Map<Integer, Pair<Long, Integer>> expectedAssignmentInfo =
+        ImmutableMap.of(
+            0,
+            Pair.of(105L, 1),
+            1,
+            Pair.of(105L, 1),
+            2,
+            Pair.of(105L, 1),
+            3,
+            Pair.of(100L, 2),
+            4,
+            Pair.of(105L, 1),
+            5,
+            Pair.of(100L, 2),
+            6,
+            Pair.of(100L, 2),
+            7,
+            Pair.of(75L, 7));
+    Map<Integer, Pair<Long, Integer>> actualAssignmentInfo = 
partitioner.assignmentInfo();
+    
Assertions.assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo);
+
+    Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults =
+        runPartitioner(partitioner, numPartitions);
+    validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0);
+  }
+
+  @Test
+  public void testNonDividableNoClosingFileCost() {
+    MapRangePartitioner partitioner =
+        new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, 
mapDataStatistics, 0.0);
+    int numPartitions = 9;
+
+    // before:     350, 230, 120, 40, 10, 10, 10, 10, 10, 10
+    // each task should get targeted weight of 89 = ceiling(800/9)
+    Map<SortKey, MapRangePartitioner.KeyAssignment> expectedAssignment =
+        ImmutableMap.of(
+            SORT_KEYS[0],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(0, 1, 2, 3), ImmutableList.of(89L, 89L, 89L, 
83L), 0L),
+            SORT_KEYS[1],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(3, 4, 5, 6), ImmutableList.of(6L, 89L, 89L, 
46L), 0L),
+            SORT_KEYS[2],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(6, 7), ImmutableList.of(43L, 77L), 0L),
+            SORT_KEYS[3],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(7, 8), ImmutableList.of(12L, 28L), 0L),
+            SORT_KEYS[4],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[5],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[6],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[7],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[8],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L),
+            SORT_KEYS[9],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(10L), 0L));
+    Map<SortKey, MapRangePartitioner.KeyAssignment> actualAssignment =
+        partitioner.assignment(numPartitions);
+    Assertions.assertThat(actualAssignment).isEqualTo(expectedAssignment);
+
+    // key: subtask id
+    // value pair: first is the assigned weight, second is the number of 
assigned keys
+    Map<Integer, Pair<Long, Integer>> expectedAssignmentInfo =
+        ImmutableMap.of(
+            0,
+            Pair.of(89L, 1),
+            1,
+            Pair.of(89L, 1),
+            2,
+            Pair.of(89L, 1),
+            3,
+            Pair.of(89L, 2),
+            4,
+            Pair.of(89L, 1),
+            5,
+            Pair.of(89L, 1),
+            6,
+            Pair.of(89L, 2),
+            7,
+            Pair.of(89L, 2),
+            8,
+            Pair.of(88L, 7));
+    Map<Integer, Pair<Long, Integer>> actualAssignmentInfo = 
partitioner.assignmentInfo();
+    
Assertions.assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo);
+
+    Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults =
+        runPartitioner(partitioner, numPartitions);
+    validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0);
+  }
+
+  @Test
+  public void testNonDividableWithClosingFileCost() {
+    MapRangePartitioner partitioner =
+        new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, 
mapDataStatistics, 5.0);
+    int numPartitions = 9;
+
+    // target subtask weight is 89 before close file cost factored in.
+    // close file cost is 5 (= 5% * 89) per file.
+    // key weights before and after close file cost factored in
+    // before:     350, 230, 120, 40, 10, 10, 10, 10, 10, 10
+    // close-cost:  20,  15,  10,  5,  5,  5,  5,  5,  5,  5
+    // after:      370, 245, 130, 45, 15, 15, 15, 15, 15, 15
+    // target subtask weight per subtask is 98 ceiling(880/9)
+    Map<SortKey, MapRangePartitioner.KeyAssignment> expectedAssignment =
+        ImmutableMap.of(
+            SORT_KEYS[0],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(0, 1, 2, 3), ImmutableList.of(98L, 98L, 98L, 
76L), 5L),
+            SORT_KEYS[1],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(3, 4, 5, 6), ImmutableList.of(22L, 98L, 98L, 
27L), 5L),
+            SORT_KEYS[2],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(6, 7), ImmutableList.of(71L, 59L), 5L),
+            SORT_KEYS[3],
+            new MapRangePartitioner.KeyAssignment(
+                ImmutableList.of(7, 8), ImmutableList.of(39L, 6L), 5L),
+            SORT_KEYS[4],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[5],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[6],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[7],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[8],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L),
+            SORT_KEYS[9],
+            new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), 
ImmutableList.of(15L), 5L));
+    Map<SortKey, MapRangePartitioner.KeyAssignment> actualAssignment =
+        partitioner.assignment(numPartitions);
+    Assertions.assertThat(actualAssignment).isEqualTo(expectedAssignment);
+
+    // key: subtask id
+    // value pair: first is the assigned weight for the subtask, second is the 
number of keys
+    // assigned to the subtask
+    Map<Integer, Pair<Long, Integer>> expectedAssignmentInfo =
+        ImmutableMap.of(
+            0,
+            Pair.of(93L, 1),
+            1,
+            Pair.of(93L, 1),
+            2,
+            Pair.of(93L, 1),
+            3,
+            Pair.of(88L, 2),
+            4,
+            Pair.of(93L, 1),
+            5,
+            Pair.of(93L, 1),
+            6,
+            Pair.of(88L, 2),
+            7,
+            Pair.of(88L, 2),
+            8,
+            Pair.of(61L, 7));
+    Map<Integer, Pair<Long, Integer>> actualAssignmentInfo = 
partitioner.assignmentInfo();
+    
Assertions.assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo);
+
+    Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults =
+        runPartitioner(partitioner, numPartitions);
+    // drift threshold is high for non-dividable scenario with close cost
+    validatePartitionResults(expectedAssignmentInfo, partitionResults, 10.0);
+  }
+
+  private static Map<Integer, Pair<AtomicLong, Set<RowData>>> runPartitioner(
+      MapRangePartitioner partitioner, int numPartitions) {
+    // The Map key is the subtaskId.
+    // For the map value pair, the first element is the count of assigned and
+    // the second element of Set<String> is for the set of assigned keys.
+    Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults = 
Maps.newHashMap();
+    partitioner
+        .mapStatistics()
+        .forEach(
+            (sortKey, weight) -> {
+              String key = sortKey.get(0, String.class);
+              // run 100x times of the weight
+              long iterations = weight * 100;
+              for (int i = 0; i < iterations; ++i) {
+                RowData rowData =
+                    GenericRowData.of(
+                        StringData.fromString(key), 1, 
StringData.fromString("2023-06-20"));
+                int subtaskId = partitioner.partition(rowData, numPartitions);
+                partitionResults.computeIfAbsent(
+                    subtaskId, k -> Pair.of(new AtomicLong(0), 
Sets.newHashSet()));
+                Pair<AtomicLong, Set<RowData>> pair = 
partitionResults.get(subtaskId);
+                pair.first().incrementAndGet();
+                pair.second().add(rowData);
+              }
+            });
+    return partitionResults;
+  }
+
+  /** @param expectedAssignmentInfo excluding closing cost */
+  private void validatePartitionResults(
+      Map<Integer, Pair<Long, Integer>> expectedAssignmentInfo,
+      Map<Integer, Pair<AtomicLong, Set<RowData>>> partitionResults,
+      double maxDriftPercentage) {
+
+    
Assertions.assertThat(partitionResults.size()).isEqualTo(expectedAssignmentInfo.size());
+
+    List<Integer> expectedAssignedKeyCounts =
+        Lists.newArrayListWithExpectedSize(expectedAssignmentInfo.size());
+    List<Integer> actualAssignedKeyCounts =
+        Lists.newArrayListWithExpectedSize(partitionResults.size());
+    List<Double> expectedNormalizedWeights =
+        Lists.newArrayListWithExpectedSize(expectedAssignmentInfo.size());
+    List<Double> actualNormalizedWeights =
+        Lists.newArrayListWithExpectedSize(partitionResults.size());
+
+    long expectedTotalWeight =
+        expectedAssignmentInfo.values().stream().mapToLong(Pair::first).sum();
+    expectedAssignmentInfo.forEach(
+        (subtaskId, pair) -> {
+          expectedAssignedKeyCounts.add(pair.second());
+          expectedNormalizedWeights.add(pair.first().doubleValue() / 
expectedTotalWeight);
+        });
+
+    long actualTotalWeight =
+        partitionResults.values().stream().mapToLong(pair -> 
pair.first().longValue()).sum();
+    partitionResults.forEach(
+        (subtaskId, pair) -> {
+          actualAssignedKeyCounts.add(pair.second().size());
+          actualNormalizedWeights.add(pair.first().doubleValue() / 
actualTotalWeight);
+        });
+
+    // number of assigned keys should match exactly
+    Assertions.assertThat(actualAssignedKeyCounts)
+        .as("the number of assigned keys should match for every subtask")
+        .isEqualTo(expectedAssignedKeyCounts);
+
+    // weight for every subtask shouldn't differ for more than some threshold 
relative to the
+    // expected weight
+    for (int subtaskId = 0; subtaskId < expectedNormalizedWeights.size(); 
++subtaskId) {
+      double expectedWeight = expectedNormalizedWeights.get(subtaskId);
+      double min = expectedWeight * (1 - maxDriftPercentage / 100);
+      double max = expectedWeight * (1 + maxDriftPercentage / 100);
+      Assertions.assertThat(actualNormalizedWeights.get(subtaskId))
+          .as(
+              "Subtask %d weight should within %.1f percent of the expected 
range %s",
+              subtaskId, maxDriftPercentage, expectedWeight)
+          .isBetween(min, max);
+    }
+  }
+}
diff --git a/jmh.gradle b/jmh.gradle
index ea317cc2ee..de50162cb0 100644
--- a/jmh.gradle
+++ b/jmh.gradle
@@ -26,10 +26,18 @@ def sparkVersions = (System.getProperty("sparkVersions") != 
null ? System.getPro
 def scalaVersion = System.getProperty("scalaVersion") != null ? 
System.getProperty("scalaVersion") : System.getProperty("defaultScalaVersion")
 def jmhProjects = [project(":iceberg-core"), project(":iceberg-data")]
 
+if (flinkVersions.contains("1.16")) {
+  jmhProjects.add(project(":iceberg-flink:iceberg-flink-1.16"))
+}
+
 if (flinkVersions.contains("1.17")) {
   jmhProjects.add(project(":iceberg-flink:iceberg-flink-1.17"))
 }
 
+if (flinkVersions.contains("1.18")) {
+  jmhProjects.add(project(":iceberg-flink:iceberg-flink-1.18"))
+}
+
 if (sparkVersions.contains("3.3")) {
   jmhProjects.add(project(":iceberg-spark:iceberg-spark-3.3_${scalaVersion}"))
   
jmhProjects.add(project(":iceberg-spark:iceberg-spark-extensions-3.3_${scalaVersion}"))


Reply via email to