This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 7a0ad33a28 [flink] Add PARTITION_DYNAMIC strategy for partitioned
append-only tables (#7809)
7a0ad33a28 is described below
commit 7a0ad33a2852ab25852094893f260ecbc85cb074
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue May 12 12:12:40 2026 +0800
[flink] Add PARTITION_DYNAMIC strategy for partitioned append-only tables
(#7809)
This adds a new partition sink strategy that dynamically adjusts shuffle
based on per-partition traffic statistics collected at each checkpoint.
Unlike HASH which statically maps each partition to one subtask,
PARTITION_DYNAMIC monitors data distribution and rebalances load across
downstream subtasks proportionally to partition traffic weight.
This PR refers to the implementation of
https://github.com/apache/fluss/pull/1784
---
.../shortcodes/generated/core_configuration.html | 2 +-
.../main/java/org/apache/paimon/CoreOptions.java | 10 +-
.../apache/paimon/flink/sink/FlinkSinkBuilder.java | 52 +++-
.../partition/AggregatedStatisticsTracker.java | 130 ++++++++
.../flink/sink/partition/DataStatistics.java | 70 +++++
.../sink/partition/DataStatisticsCoordinator.java | 328 +++++++++++++++++++++
.../DataStatisticsCoordinatorProvider.java | 41 +++
.../sink/partition/DataStatisticsOperator.java | 127 ++++++++
.../partition/DataStatisticsOperatorFactory.java | 72 +++++
.../sink/partition/DataStatisticsSerializer.java | 157 ++++++++++
.../flink/sink/partition/StatisticsEvent.java | 53 ++++
.../flink/sink/partition/StatisticsOrRecord.java | 83 ++++++
.../StatisticsOrRecordChannelComputer.java | 202 +++++++++++++
.../partition/StatisticsOrRecordSerializer.java | 184 ++++++++++++
.../sink/partition/StatisticsOrRecordTypeInfo.java | 114 +++++++
.../flink/sink/partition/StatisticsUtil.java | 57 ++++
.../sink/partition/WeightedRandomAssignment.java | 99 +++++++
.../org/apache/paimon/flink/AppendTableITCase.java | 208 ++++++++++++-
.../partition/AggregatedStatisticsTrackerTest.java | 187 ++++++++++++
.../sink/partition/DataStatisticsOperatorTest.java | 164 +++++++++++
.../partition/DataStatisticsSerializerTest.java | 73 +++++
.../paimon/flink/sink/partition/MockRandom.java | 44 +++
.../StatisticsOrRecordChannelComputerTest.java | 233 +++++++++++++++
.../partition/WeightedRandomAssignmentTest.java | 63 ++++
24 files changed, 2740 insertions(+), 13 deletions(-)
diff --git a/docs/layouts/shortcodes/generated/core_configuration.html
b/docs/layouts/shortcodes/generated/core_configuration.html
index 05cc17f34e..5faf8276da 100644
--- a/docs/layouts/shortcodes/generated/core_configuration.html
+++ b/docs/layouts/shortcodes/generated/core_configuration.html
@@ -1089,7 +1089,7 @@ This config option does not affect the default filesystem
metastore.</td>
<td><h5>partition.sink-strategy</h5></td>
<td style="word-wrap: break-word;">NONE</td>
<td><p>Enum</p></td>
- <td>This is only for partitioned append table or postpone pk
table, and the purpose is to reduce small files and improve write performance.
Through this repartitioning strategy to reduce the number of partitions written
by each task to as few as possible.<ul><li>none: Rebalanced or Forward
partitioning, this is the default behavior, this strategy is suitable for the
number of partitions you write in a batch is much smaller than write
parallelism.</li><li>hash: Hash the partit [...]
+ <td>This is only for partitioned append table or postpone pk
table, and the purpose is to reduce small files and improve write performance.
Through this repartitioning strategy to reduce the number of partitions written
by each task to as few as possible.<ul><li>none: Rebalanced or Forward
partitioning, this is the default behavior, this strategy is suitable for the
number of partitions you write in a batch is much smaller than write
parallelism.</li><li>hash: Hash the partit [...]
</tr>
<tr>
<td><h5>partition.timestamp-format.strict</h5></td>
diff --git a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java
b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java
index 75eb4744db..a420110828 100644
--- a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java
+++ b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java
@@ -1660,7 +1660,11 @@ public class CoreOptions implements Serializable {
+ " this strategy
is suitable for the number of partitions you write in a batch is much smaller
than write parallelism."),
text(
"hash: Hash the partitions
value,"
- + " this strategy
is suitable for the number of partitions you write in a batch is greater equals
than write parallelism."))
+ + " this strategy
is suitable for the number of partitions you write in a batch is greater equals
than write parallelism."),
+ text(
+ "partition_dynamic:
Dynamically adjusts shuffle strategy based on partition key traffic patterns."
+ + " This mode
monitors data distribution across partitions and rebalances load across
downstream subtasks."
+ + " Suitable when
partition traffic is skewed and you want balanced write throughput."))
.build());
public static final ConfigOption<Boolean> METASTORE_PARTITIONED_TABLE =
@@ -4528,8 +4532,8 @@ public class CoreOptions implements Serializable {
/** Partition strategy for unaware bucket partitioned append only table. */
public enum PartitionSinkStrategy {
NONE,
- HASH
- // TODO : Supports range-partition strategy.
+ HASH,
+ PARTITION_DYNAMIC
}
/** Specifies the implementation of format table. */
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSinkBuilder.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSinkBuilder.java
index 87d1292453..74de8f847c 100644
---
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSinkBuilder.java
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSinkBuilder.java
@@ -27,6 +27,10 @@ import org.apache.paimon.data.InternalRow;
import org.apache.paimon.flink.FlinkConnectorOptions;
import org.apache.paimon.flink.FlinkRowWrapper;
import org.apache.paimon.flink.sink.index.GlobalDynamicBucketSink;
+import org.apache.paimon.flink.sink.partition.DataStatisticsOperatorFactory;
+import org.apache.paimon.flink.sink.partition.StatisticsOrRecord;
+import
org.apache.paimon.flink.sink.partition.StatisticsOrRecordChannelComputer;
+import org.apache.paimon.flink.sink.partition.StatisticsOrRecordTypeInfo;
import org.apache.paimon.flink.sorter.TableSortInfo;
import org.apache.paimon.flink.sorter.TableSorter;
import org.apache.paimon.table.BucketMode;
@@ -319,18 +323,52 @@ public class FlinkSinkBuilder {
table.primaryKeys().isEmpty(),
"Unaware bucket mode only works with append-only table for
now.");
- if (!table.partitionKeys().isEmpty()
- && table.coreOptions().partitionSinkStrategy() ==
PartitionSinkStrategy.HASH) {
- input =
- partition(
- input,
- new
RowDataHashPartitionChannelComputer(table.schema()),
- parallelism);
+ if (!table.partitionKeys().isEmpty()) {
+ PartitionSinkStrategy strategy =
table.coreOptions().partitionSinkStrategy();
+ if (strategy == PartitionSinkStrategy.HASH) {
+ input =
+ partition(
+ input,
+ new
RowDataHashPartitionChannelComputer(table.schema()),
+ parallelism);
+ } else if (strategy == PartitionSinkStrategy.PARTITION_DYNAMIC) {
+ input = applyDynamicPartitionShuffle(input);
+ }
}
return new RowAppendTableSink(table, overwritePartition,
parallelism).sinkFrom(input);
}
+ private DataStream<InternalRow>
applyDynamicPartitionShuffle(DataStream<InternalRow> input) {
+ StatisticsOrRecordTypeInfo typeInfo =
+ new
StatisticsOrRecordTypeInfo(table.schema().logicalRowType());
+ SingleOutputStreamOperator<StatisticsOrRecord> statsStream =
+ input.transform(
+ "Collect Statistics: " + table.name(),
+ typeInfo,
+ new
DataStatisticsOperatorFactory(table.schema()))
+ .setParallelism(input.getParallelism());
+
+ DataStream<StatisticsOrRecord> partitioned =
+ partition(
+ statsStream,
+ new StatisticsOrRecordChannelComputer(table.schema()),
+ parallelism);
+
+ return partitioned
+ .flatMap(
+ (org.apache.flink.api.common.functions.FlatMapFunction<
+ StatisticsOrRecord, InternalRow>)
+ (statisticsOrRecord, out) -> {
+ if (statisticsOrRecord.isRecord()) {
+
out.collect(statisticsOrRecord.record());
+ }
+ })
+ .name("Strip Statistics")
+ .setParallelism(parallelism != null ? parallelism :
input.getParallelism())
+ .returns(input.getType());
+ }
+
private DataStream<RowData> trySortInput(DataStream<RowData> input) {
if (tableSortInfo != null) {
TableSorter sorter =
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/AggregatedStatisticsTracker.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/AggregatedStatisticsTracker.java
new file mode 100644
index 0000000000..383d8a48eb
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/AggregatedStatisticsTracker.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.BinaryRow;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.Set;
+import java.util.TreeMap;
+
+/**
+ * Tracks statistics aggregation from {@link DataStatisticsOperator} subtasks
for every checkpoint.
+ */
+class AggregatedStatisticsTracker {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(AggregatedStatisticsTracker.class);
+
+ private final String operatorName;
+ private final int parallelism;
+ private final TypeSerializer<DataStatistics> statisticsSerializer;
+ private final NavigableMap<Long, Aggregation> aggregationsPerCheckpoint;
+
+ private long completedCheckpointId;
+ private DataStatistics completedStatistics;
+
+ AggregatedStatisticsTracker(String operatorName, int parallelism) {
+ this.operatorName = operatorName;
+ this.parallelism = parallelism;
+ this.statisticsSerializer = new DataStatisticsSerializer();
+ this.aggregationsPerCheckpoint = new TreeMap<>();
+ this.completedCheckpointId = -1;
+ }
+
+ DataStatistics updateAndCheckCompletion(int subtask, StatisticsEvent
event) {
+ long checkpointId = event.getCheckpointId();
+ LOG.debug(
+ "Handling statistics event from subtask {} of operator {} for
checkpoint {}",
+ subtask,
+ operatorName,
+ checkpointId);
+
+ if (completedStatistics != null && completedCheckpointId >
checkpointId) {
+ LOG.debug(
+ "Ignore stale statistics event from operator {} subtask {}
for older checkpoint {}. "
+ + "Was expecting checkpoint higher than {}",
+ operatorName,
+ subtask,
+ checkpointId,
+ completedCheckpointId);
+ return null;
+ }
+
+ Aggregation aggregation =
+ aggregationsPerCheckpoint.computeIfAbsent(
+ checkpointId, ignored -> new Aggregation(parallelism));
+ DataStatistics dataStatistics =
+ StatisticsUtil.deserializeDataStatistics(
+ event.getStatisticsBytes(), statisticsSerializer);
+ if (!aggregation.merge(subtask, dataStatistics)) {
+ LOG.debug(
+ "Ignore duplicate data statistics from operator {} subtask
{} for checkpoint {}.",
+ operatorName,
+ subtask,
+ checkpointId);
+ }
+
+ if (aggregation.isComplete()) {
+ this.completedStatistics = aggregation.completedStatistics();
+ this.completedCheckpointId = checkpointId;
+ aggregationsPerCheckpoint.headMap(checkpointId, true).clear();
+ return completedStatistics;
+ }
+
+ return null;
+ }
+
+ static class Aggregation {
+ private final int parallelism;
+ private final Set<Integer> subtaskSet;
+ private final Map<BinaryRow, Long> partitionStatistics;
+
+ Aggregation(int parallelism) {
+ this.parallelism = parallelism;
+ this.subtaskSet = new HashSet<>();
+ this.partitionStatistics = new HashMap<>();
+ }
+
+ boolean isComplete() {
+ return subtaskSet.size() == parallelism;
+ }
+
+ boolean merge(int subtask, DataStatistics taskStatistics) {
+ if (subtaskSet.contains(subtask)) {
+ return false;
+ }
+ subtaskSet.add(subtask);
+ Map<BinaryRow, Long> result = taskStatistics.result();
+ result.forEach(
+ (partition, count) -> partitionStatistics.merge(partition,
count, Long::sum));
+ return true;
+ }
+
+ DataStatistics completedStatistics() {
+ return new DataStatistics(partitionStatistics);
+ }
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatistics.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatistics.java
new file mode 100644
index 0000000000..ef29274b43
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatistics.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.BinaryRow;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+/** Data statistics tracking partition key to frequency/weight. */
+public class DataStatistics {
+
+ private final Map<BinaryRow, Long> partitionFrequency;
+
+ public DataStatistics() {
+ this.partitionFrequency = new HashMap<>();
+ }
+
+ public DataStatistics(Map<BinaryRow, Long> partitionFrequency) {
+ this.partitionFrequency = partitionFrequency;
+ }
+
+ public boolean isEmpty() {
+ return partitionFrequency.isEmpty();
+ }
+
+ public void add(BinaryRow partition, long value) {
+ partitionFrequency.merge(partition, value, Long::sum);
+ }
+
+ public Map<BinaryRow, Long> result() {
+ return partitionFrequency;
+ }
+
+ @Override
+ public String toString() {
+ return "DataStatistics{" + "partitionFrequency=" + partitionFrequency
+ '}';
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof DataStatistics)) {
+ return false;
+ }
+ DataStatistics that = (DataStatistics) o;
+ return Objects.equals(partitionFrequency, that.partitionFrequency);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(partitionFrequency);
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsCoordinator.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsCoordinator.java
new file mode 100644
index 0000000000..cd863481a3
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsCoordinator.java
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
+import org.apache.flink.runtime.operators.coordination.OperatorEvent;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.FatalExitExceptionHandler;
+import org.apache.flink.util.ThrowableCatchingRunnable;
+import org.apache.flink.util.function.ThrowingRunnable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.util.HashMap;
+import java.util.Locale;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
+
+import static org.apache.paimon.utils.Preconditions.checkArgument;
+import static org.apache.paimon.utils.Preconditions.checkState;
+
+/** Coordinator for collecting and broadcasting global partition data
statistics. */
+class DataStatisticsCoordinator implements OperatorCoordinator {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(DataStatisticsCoordinator.class);
+
+ private final String operatorName;
+ private final OperatorCoordinator.Context context;
+ private final ExecutorService coordinatorExecutor;
+ private final SubtaskGateways subtaskGateways;
+ private final CoordinatorExecutorThreadFactory coordinatorThreadFactory;
+ private final TypeSerializer<DataStatistics> statisticsSerializer;
+
+ private transient boolean started;
+ private transient AggregatedStatisticsTracker aggregatedStatisticsTracker;
+
+ DataStatisticsCoordinator(String operatorName, OperatorCoordinator.Context
context) {
+ this.operatorName = operatorName;
+ this.context = context;
+ this.coordinatorThreadFactory =
+ new CoordinatorExecutorThreadFactory(
+ "DataStatisticsCoordinator-" + operatorName,
+ context.getUserCodeClassloader());
+ this.coordinatorExecutor =
Executors.newSingleThreadExecutor(coordinatorThreadFactory);
+ this.subtaskGateways = new SubtaskGateways(operatorName,
context.currentParallelism());
+ this.statisticsSerializer = new DataStatisticsSerializer();
+ }
+
+ @Override
+ public void start() throws Exception {
+ LOG.debug("Starting data statistics coordinator: {}.", operatorName);
+ this.started = true;
+ this.aggregatedStatisticsTracker =
+ new AggregatedStatisticsTracker(operatorName,
context.currentParallelism());
+ }
+
+ @Override
+ public void handleEventFromOperator(int subtask, int attemptNumber,
OperatorEvent event) {
+ runInCoordinatorThread(
+ () -> {
+ LOG.debug(
+ "Handling event from subtask {} (#{}) of {}: {}",
+ subtask,
+ attemptNumber,
+ operatorName,
+ event);
+ if (event instanceof StatisticsEvent) {
+ handleDataStatisticRequest(subtask, (StatisticsEvent)
event);
+ } else {
+ throw new IllegalArgumentException(
+ "Invalid operator event type: "
+ + event.getClass().getCanonicalName());
+ }
+ },
+ String.format(
+ Locale.ROOT,
+ "handling operator event %s from subtask %d (#%d)",
+ event.getClass(),
+ subtask,
+ attemptNumber));
+ }
+
+ @Override
+ public void checkpointCoordinator(long checkpointId,
CompletableFuture<byte[]> resultFuture)
+ throws Exception {
+ resultFuture.complete(new byte[0]);
+ }
+
+ @Override
+ public void notifyCheckpointComplete(long checkpointId) {}
+
+ @Override
+ public void resetToCheckpoint(long checkpointId, byte[] checkpointData) {
+ checkState(
+ !started,
+ "The coordinator %s can only be reset if it was not yet
started",
+ operatorName);
+ }
+
+ @Override
+ public void subtaskReset(int subtask, long checkpointId) {
+ runInCoordinatorThread(
+ () -> {
+ LOG.info(
+ "Operator {} subtask {} is reset to checkpoint {}",
+ operatorName,
+ subtask,
+ checkpointId);
+ subtaskGateways.reset(subtask);
+ },
+ String.format(
+ Locale.ROOT,
+ "handling subtask %d recovery to checkpoint %d",
+ subtask,
+ checkpointId));
+ }
+
+ @Override
+ public void executionAttemptFailed(int subtask, int attemptNumber,
@Nullable Throwable reason) {
+ runInCoordinatorThread(
+ () -> {
+ LOG.info(
+ "Unregistering gateway after failure for subtask
{} (#{}) of {}",
+ subtask,
+ attemptNumber,
+ operatorName);
+ subtaskGateways.unregisterSubtaskGateway(subtask,
attemptNumber);
+ },
+ String.format(
+ Locale.ROOT, "handling subtask %d (#%d) failure",
subtask, attemptNumber));
+ }
+
+ @Override
+ public void executionAttemptReady(int subtask, int attemptNumber,
SubtaskGateway gateway) {
+ checkArgument(subtask == gateway.getSubtask());
+ checkArgument(attemptNumber ==
gateway.getExecution().getAttemptNumber());
+ runInCoordinatorThread(
+ () -> subtaskGateways.registerSubtaskGateway(gateway),
+ String.format(
+ Locale.ROOT,
+ "making event gateway to subtask %d (#%d) available",
+ subtask,
+ attemptNumber));
+ }
+
+ @Override
+ public void close() throws Exception {
+ coordinatorExecutor.shutdown();
+ this.aggregatedStatisticsTracker = null;
+ this.started = false;
+ LOG.info("Closed data statistics coordinator: {}.", operatorName);
+ }
+
+ private void runInCoordinatorThread(Runnable runnable) {
+ this.coordinatorExecutor.execute(
+ new ThrowableCatchingRunnable(
+ throwable ->
+
this.coordinatorThreadFactory.uncaughtException(
+ Thread.currentThread(), throwable),
+ runnable));
+ }
+
+ private void runInCoordinatorThread(ThrowingRunnable<Throwable> action,
String actionString) {
+ ensureStarted();
+ runInCoordinatorThread(
+ () -> {
+ try {
+ action.run();
+ } catch (Throwable t) {
+ ExceptionUtils.rethrowIfFatalErrorOrOOM(t);
+ LOG.error(
+ "Uncaught exception in the data statistics
coordinator: {} while {}. Triggering job failover",
+ operatorName,
+ actionString,
+ t);
+ context.failJob(t);
+ }
+ });
+ }
+
+ private void ensureStarted() {
+ checkState(started, "The coordinator of %s has not started yet.",
operatorName);
+ }
+
+ private void handleDataStatisticRequest(int subtask, StatisticsEvent
event) {
+ DataStatistics maybeCompleted =
+ aggregatedStatisticsTracker.updateAndCheckCompletion(subtask,
event);
+ if (maybeCompleted != null) {
+ if (maybeCompleted.isEmpty()) {
+ LOG.debug(
+ "Skip aggregated statistics for checkpoint {} as it is
empty.",
+ event.getCheckpointId());
+ } else {
+ LOG.debug(
+ "Completed statistics aggregation for checkpoint {}",
+ event.getCheckpointId());
+ sendGlobalStatisticsToSubtasks(maybeCompleted,
event.getCheckpointId());
+ }
+ }
+ }
+
+ private void sendGlobalStatisticsToSubtasks(DataStatistics statistics,
long checkpointId) {
+ LOG.info(
+ "Broadcast latest global statistics from checkpoint {} to all
subtasks",
+ checkpointId);
+ StatisticsEvent statisticsEvent =
+ StatisticsEvent.createStatisticsEvent(
+ checkpointId, statistics, statisticsSerializer);
+ for (int i = 0; i < context.currentParallelism(); ++i) {
+ final int subtaskIndex = i;
+ subtaskGateways
+ .getSubtaskGateway(subtaskIndex)
+ .sendEvent(statisticsEvent)
+ .whenComplete(
+ (ack, error) -> {
+ if (error != null) {
+ LOG.warn(
+ "Failed to send global statistics
to subtask {}",
+ subtaskIndex,
+ error);
+ }
+ });
+ }
+ }
+
+ static class SubtaskGateways {
+ private final String operatorName;
+ private final Map<Integer, SubtaskGateway>[] gateways;
+
+ @SuppressWarnings("unchecked")
+ SubtaskGateways(String operatorName, int parallelism) {
+ this.operatorName = operatorName;
+ gateways = new Map[parallelism];
+ for (int i = 0; i < parallelism; ++i) {
+ gateways[i] = new HashMap<>();
+ }
+ }
+
+ void registerSubtaskGateway(OperatorCoordinator.SubtaskGateway
gateway) {
+ int subtaskIndex = gateway.getSubtask();
+ int attemptNumber = gateway.getExecution().getAttemptNumber();
+ checkState(
+ !gateways[subtaskIndex].containsKey(attemptNumber),
+ "Coordinator of %s already has a subtask gateway for %d
(#%d)",
+ operatorName,
+ subtaskIndex,
+ attemptNumber);
+ gateways[subtaskIndex].put(attemptNumber, gateway);
+ }
+
+ void unregisterSubtaskGateway(int subtaskIndex, int attemptNumber) {
+ gateways[subtaskIndex].remove(attemptNumber);
+ }
+
+ OperatorCoordinator.SubtaskGateway getSubtaskGateway(int subtaskIndex)
{
+ checkState(
+ !gateways[subtaskIndex].isEmpty(),
+ "Coordinator of %s subtask %d is not ready yet to receive
events",
+ operatorName,
+ subtaskIndex);
+ return gateways[subtaskIndex].values().iterator().next();
+ }
+
+ void reset(int subtaskIndex) {
+ gateways[subtaskIndex].clear();
+ }
+ }
+
+ private static class CoordinatorExecutorThreadFactory
+ implements ThreadFactory, Thread.UncaughtExceptionHandler {
+
+ private final String coordinatorThreadName;
+ private final ClassLoader classLoader;
+ private final Thread.UncaughtExceptionHandler errorHandler;
+
+ private Thread thread;
+
+ CoordinatorExecutorThreadFactory(
+ String coordinatorThreadName, ClassLoader contextClassLoader) {
+ this(coordinatorThreadName, contextClassLoader,
FatalExitExceptionHandler.INSTANCE);
+ }
+
+ CoordinatorExecutorThreadFactory(
+ String coordinatorThreadName,
+ ClassLoader contextClassLoader,
+ Thread.UncaughtExceptionHandler errorHandler) {
+ this.coordinatorThreadName = coordinatorThreadName;
+ this.classLoader = contextClassLoader;
+ this.errorHandler = errorHandler;
+ }
+
+ @Override
+ public synchronized Thread newThread(@Nonnull Runnable runnable) {
+ thread = new Thread(runnable, coordinatorThreadName);
+ thread.setContextClassLoader(classLoader);
+ thread.setUncaughtExceptionHandler(this);
+ return thread;
+ }
+
+ @Override
+ public synchronized void uncaughtException(Thread t, Throwable e) {
+ errorHandler.uncaughtException(t, e);
+ }
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsCoordinatorProvider.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsCoordinatorProvider.java
new file mode 100644
index 0000000000..64ec234ef4
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsCoordinatorProvider.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
+import
org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator;
+
+/** Coordinator provider for {@link DataStatisticsCoordinator}. */
+class DataStatisticsCoordinatorProvider extends
RecreateOnResetOperatorCoordinator.Provider {
+
+ private static final long serialVersionUID = 1L;
+
+ private final String operatorName;
+
+ DataStatisticsCoordinatorProvider(String operatorName, OperatorID
operatorID) {
+ super(operatorID);
+ this.operatorName = operatorName;
+ }
+
+ @Override
+ public OperatorCoordinator getCoordinator(OperatorCoordinator.Context
context) {
+ return new DataStatisticsCoordinator(operatorName, context);
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsOperator.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsOperator.java
new file mode 100644
index 0000000000..5e5f197c27
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsOperator.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.BinaryRow;
+import org.apache.paimon.data.InternalRow;
+import org.apache.paimon.flink.utils.RuntimeContextUtils;
+import org.apache.paimon.schema.TableSchema;
+import org.apache.paimon.table.sink.RowPartitionKeyExtractor;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.operators.coordination.OperatorEvent;
+import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
+import org.apache.flink.runtime.operators.coordination.OperatorEventHandler;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import static org.apache.paimon.utils.Preconditions.checkArgument;
+
+/**
+ * Operator that collects local partition data statistics and sends them to
coordinator for global
+ * aggregation, then forwards global statistics to the downstream partitioner.
+ */
+public class DataStatisticsOperator extends
AbstractStreamOperator<StatisticsOrRecord>
+ implements OneInputStreamOperator<InternalRow, StatisticsOrRecord>,
OperatorEventHandler {
+
+ private static final long serialVersionUID = 1L;
+
+ private final String operatorName;
+ private final TableSchema schema;
+ private final OperatorEventGateway operatorEventGateway;
+
+ private transient int subtaskIndex;
+ private transient DataStatistics localStatistics;
+ private transient RowPartitionKeyExtractor extractor;
+ private transient TypeSerializer<DataStatistics> statisticsSerializer;
+
+ DataStatisticsOperator(
+ StreamOperatorParameters<StatisticsOrRecord> parameters,
+ String operatorName,
+ TableSchema schema,
+ OperatorEventGateway operatorEventGateway) {
+ super();
+ this.operatorName = operatorName;
+ this.schema = schema;
+ this.operatorEventGateway = operatorEventGateway;
+ this.setup(
+ parameters.getContainingTask(),
+ parameters.getStreamConfig(),
+ parameters.getOutput());
+ }
+
+ @Override
+ public void open() throws Exception {
+ this.extractor = new RowPartitionKeyExtractor(schema);
+ this.statisticsSerializer = new DataStatisticsSerializer();
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws
Exception {
+ this.subtaskIndex =
RuntimeContextUtils.getIndexOfThisSubtask(getRuntimeContext());
+ this.localStatistics = StatisticsUtil.createDataStatistics();
+ }
+
+ @Override
+ public void handleOperatorEvent(OperatorEvent event) {
+ checkArgument(
+ event instanceof StatisticsEvent,
+ String.format(
+ "Operator %s subtask %s received unexpected operator
event %s",
+ operatorName, subtaskIndex, event.getClass()));
+ StatisticsEvent statisticsEvent = (StatisticsEvent) event;
+ LOG.debug(
+ "Operator {} subtask {} received global data event from
coordinator checkpoint {}",
+ operatorName,
+ subtaskIndex,
+ statisticsEvent.getCheckpointId());
+ DataStatistics globalStatistics =
+ StatisticsUtil.deserializeDataStatistics(
+ statisticsEvent.getStatisticsBytes(),
statisticsSerializer);
+ if (globalStatistics != null) {
+ output.collect(new
StreamRecord<>(StatisticsOrRecord.fromStatistics(globalStatistics)));
+ }
+ }
+
+ @Override
+ public void processElement(StreamRecord<InternalRow> streamRecord) throws
Exception {
+ InternalRow row = streamRecord.getValue();
+ BinaryRow partition = extractor.partition(row).copy();
+ localStatistics.add(partition, 1L);
+ output.collect(new StreamRecord<>(StatisticsOrRecord.fromRecord(row)));
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ long checkpointId = context.getCheckpointId();
+ LOG.debug(
+ "Operator {} subtask {} sending local statistics to
coordinator for checkpoint {}",
+ operatorName,
+ subtaskIndex,
+ checkpointId);
+ operatorEventGateway.sendEventToCoordinator(
+ StatisticsEvent.createStatisticsEvent(
+ checkpointId, localStatistics, statisticsSerializer));
+ localStatistics = StatisticsUtil.createDataStatistics();
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsOperatorFactory.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsOperatorFactory.java
new file mode 100644
index 0000000000..047297375b
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsOperatorFactory.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.InternalRow;
+import org.apache.paimon.schema.TableSchema;
+
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
+import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.CoordinatedOperatorFactory;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+
+/** Factory for {@link DataStatisticsOperator}. */
+public class DataStatisticsOperatorFactory extends
AbstractStreamOperatorFactory<StatisticsOrRecord>
+ implements CoordinatedOperatorFactory<StatisticsOrRecord>,
+ OneInputStreamOperatorFactory<InternalRow, StatisticsOrRecord>
{
+
+ private static final long serialVersionUID = 1L;
+
+ private final TableSchema schema;
+
+ public DataStatisticsOperatorFactory(TableSchema schema) {
+ this.schema = schema;
+ }
+
+ @Override
+ public OperatorCoordinator.Provider getCoordinatorProvider(
+ String operatorName, OperatorID operatorID) {
+ return new DataStatisticsCoordinatorProvider(operatorName, operatorID);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <T extends StreamOperator<StatisticsOrRecord>> T
createStreamOperator(
+ StreamOperatorParameters<StatisticsOrRecord> parameters) {
+ OperatorID operatorId = parameters.getStreamConfig().getOperatorID();
+ String operatorName = parameters.getStreamConfig().getOperatorName();
+ OperatorEventGateway gateway =
+
parameters.getOperatorEventDispatcher().getOperatorEventGateway(operatorId);
+
+ DataStatisticsOperator operator =
+ new DataStatisticsOperator(parameters, operatorName, schema,
gateway);
+
parameters.getOperatorEventDispatcher().registerEventHandler(operatorId,
operator);
+ return (T) operator;
+ }
+
+ @Override
+ @SuppressWarnings("rawtypes")
+ public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader
classLoader) {
+ return DataStatisticsOperator.class;
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsSerializer.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsSerializer.java
new file mode 100644
index 0000000000..94d850d82c
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/DataStatisticsSerializer.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.BinaryRow;
+import org.apache.paimon.utils.SerializationUtils;
+
+import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/** Serializer for {@link DataStatistics}. */
+public class DataStatisticsSerializer extends TypeSerializer<DataStatistics> {
+
+ private static final long serialVersionUID = 1L;
+
+ public DataStatisticsSerializer() {}
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer<DataStatistics> duplicate() {
+ return new DataStatisticsSerializer();
+ }
+
+ @Override
+ public DataStatistics createInstance() {
+ return new DataStatistics();
+ }
+
+ @Override
+ public DataStatistics copy(DataStatistics from) {
+ Map<BinaryRow, Long> copy = new HashMap<>(from.result().size());
+ for (Map.Entry<BinaryRow, Long> entry : from.result().entrySet()) {
+ copy.put(entry.getKey().copy(), entry.getValue());
+ }
+ return new DataStatistics(copy);
+ }
+
+ @Override
+ public DataStatistics copy(DataStatistics from, DataStatistics reuse) {
+ return copy(from);
+ }
+
+ @Override
+ public int getLength() {
+ return -1;
+ }
+
+ @Override
+ public void serialize(DataStatistics record, DataOutputView target) throws
IOException {
+ Map<BinaryRow, Long> map = record.result();
+ target.writeInt(map.size());
+ for (Map.Entry<BinaryRow, Long> entry : map.entrySet()) {
+ byte[] bytes =
SerializationUtils.serializeBinaryRow(entry.getKey());
+ target.writeInt(bytes.length);
+ target.write(bytes);
+ target.writeLong(entry.getValue());
+ }
+ }
+
+ @Override
+ public DataStatistics deserialize(DataInputView source) throws IOException
{
+ int size = source.readInt();
+ Map<BinaryRow, Long> map = new HashMap<>(size);
+ for (int i = 0; i < size; i++) {
+ int length = source.readInt();
+ byte[] bytes = new byte[length];
+ source.readFully(bytes);
+ BinaryRow row = SerializationUtils.deserializeBinaryRow(bytes);
+ long value = source.readLong();
+ map.put(row, value);
+ }
+ return new DataStatistics(map);
+ }
+
+ @Override
+ public DataStatistics deserialize(DataStatistics reuse, DataInputView
source)
+ throws IOException {
+ return deserialize(source);
+ }
+
+ @Override
+ public void copy(DataInputView source, DataOutputView target) throws
IOException {
+ serialize(deserialize(source), target);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj != null && getClass() == obj.getClass();
+ }
+
+ @Override
+ public int hashCode() {
+ return getClass().hashCode();
+ }
+
+ @Override
+ public TypeSerializerSnapshot<DataStatistics> snapshotConfiguration() {
+ return new DataStatisticsSerializerSnapshot(this);
+ }
+
+ /** Snapshot class for the {@link DataStatisticsSerializer}. */
+ public static class DataStatisticsSerializerSnapshot
+ extends CompositeTypeSerializerSnapshot<DataStatistics,
DataStatisticsSerializer> {
+ private static final int CURRENT_VERSION = 1;
+
+ @SuppressWarnings("unused")
+ public DataStatisticsSerializerSnapshot() {}
+
+ public DataStatisticsSerializerSnapshot(DataStatisticsSerializer
serializer) {
+ super(serializer);
+ }
+
+ @Override
+ protected int getCurrentOuterSnapshotVersion() {
+ return CURRENT_VERSION;
+ }
+
+ @Override
+ protected TypeSerializer<?>[] getNestedSerializers(
+ DataStatisticsSerializer outerSerializer) {
+ return new TypeSerializer<?>[0];
+ }
+
+ @Override
+ protected DataStatisticsSerializer
createOuterSerializerWithNestedSerializers(
+ TypeSerializer<?>[] nestedSerializers) {
+ return new DataStatisticsSerializer();
+ }
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsEvent.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsEvent.java
new file mode 100644
index 0000000000..f0378e9486
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsEvent.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.operators.coordination.OperatorEvent;
+
+/** Event to send statistics between operator and coordinator. */
+class StatisticsEvent implements OperatorEvent {
+
+ private static final long serialVersionUID = 1L;
+
+ private final long checkpointId;
+ private final byte[] statisticsBytes;
+
+ private StatisticsEvent(long checkpointId, byte[] statisticsBytes) {
+ this.checkpointId = checkpointId;
+ this.statisticsBytes = statisticsBytes;
+ }
+
+ static StatisticsEvent createStatisticsEvent(
+ long checkpointId,
+ DataStatistics statistics,
+ TypeSerializer<DataStatistics> statisticsSerializer) {
+ return new StatisticsEvent(
+ checkpointId,
+ StatisticsUtil.serializeDataStatistics(statistics,
statisticsSerializer));
+ }
+
+ long getCheckpointId() {
+ return checkpointId;
+ }
+
+ byte[] getStatisticsBytes() {
+ return statisticsBytes;
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecord.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecord.java
new file mode 100644
index 0000000000..254e58360a
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecord.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.InternalRow;
+
+import java.util.Objects;
+
+import static org.apache.paimon.utils.Preconditions.checkArgument;
+
+/** Either a record or a statistics. */
+public class StatisticsOrRecord {
+
+ private DataStatistics statistics;
+ private InternalRow record;
+
+ private StatisticsOrRecord(DataStatistics statistics, InternalRow record) {
+ checkArgument(
+ record != null ^ statistics != null,
+ "DataStatistics or record, not neither or both");
+ this.statistics = statistics;
+ this.record = record;
+ }
+
+ public static StatisticsOrRecord fromRecord(InternalRow record) {
+ return new StatisticsOrRecord(null, record);
+ }
+
+ public static StatisticsOrRecord fromStatistics(DataStatistics statistics)
{
+ return new StatisticsOrRecord(statistics, null);
+ }
+
+ public boolean isStatistics() {
+ return statistics != null;
+ }
+
+ public boolean isRecord() {
+ return record != null;
+ }
+
+ public DataStatistics statistics() {
+ return statistics;
+ }
+
+ public InternalRow record() {
+ return record;
+ }
+
+ @Override
+ public String toString() {
+ return "StatisticsOrRecord{" + "statistics=" + statistics + ",
record=" + record + '}';
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof StatisticsOrRecord)) {
+ return false;
+ }
+ StatisticsOrRecord that = (StatisticsOrRecord) o;
+ return Objects.equals(statistics, that.statistics) &&
Objects.equals(record, that.record);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(statistics, record);
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordChannelComputer.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordChannelComputer.java
new file mode 100644
index 0000000000..8a19f95e0e
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordChannelComputer.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.BinaryRow;
+import org.apache.paimon.schema.TableSchema;
+import org.apache.paimon.table.sink.ChannelComputer;
+import org.apache.paimon.table.sink.RowPartitionKeyExtractor;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ThreadLocalRandom;
+
+import static org.apache.paimon.utils.Preconditions.checkState;
+
+/**
+ * {@link ChannelComputer} for {@link StatisticsOrRecord} which dynamically
adjusts shuffle based on
+ * partition statistics.
+ */
+public class StatisticsOrRecordChannelComputer implements
ChannelComputer<StatisticsOrRecord> {
+
+ private static final long serialVersionUID = 1L;
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(StatisticsOrRecordChannelComputer.class);
+
+ private static final int DEFAULT_SUBTASK_COUNT_FOR_UNKNOWN_PARTITION = 4;
+
+ private final TableSchema schema;
+
+ private transient int numChannels;
+ private transient RowPartitionKeyExtractor extractor;
+ private transient MapPartitioner delegatePartitioner;
+ private transient Random random;
+
+ public StatisticsOrRecordChannelComputer(TableSchema schema) {
+ this.schema = schema;
+ }
+
+ @Override
+ public void setup(int numChannels) {
+ this.numChannels = numChannels;
+ this.extractor = new RowPartitionKeyExtractor(schema);
+ this.random = ThreadLocalRandom.current();
+ }
+
+ @Override
+ public int channel(StatisticsOrRecord wrapper) {
+ if (wrapper.isStatistics()) {
+ this.delegatePartitioner = buildPartitioner(wrapper.statistics());
+ return ThreadLocalRandom.current().nextInt(numChannels);
+ } else {
+ if (delegatePartitioner == null) {
+ delegatePartitioner = buildPartitioner(null);
+ }
+ BinaryRow partition = extractor.partition(wrapper.record());
+ return delegatePartitioner.select(partition, numChannels);
+ }
+ }
+
+ private MapPartitioner buildPartitioner(@Nullable DataStatistics
statistics) {
+ if (statistics == null) {
+ return new MapPartitioner(new HashMap<>());
+ }
+ return new MapPartitioner(buildAssignment(numChannels,
statistics.result()));
+ }
+
+ Map<BinaryRow, WeightedRandomAssignment> buildAssignment(
+ int downstreamParallelism, Map<BinaryRow, Long> statistics) {
+ if (statistics.isEmpty()) {
+ return new HashMap<>();
+ }
+
+ long totalWeight = statistics.values().stream().mapToLong(l ->
l).sum();
+ if (totalWeight <= 0) {
+ return new HashMap<>();
+ }
+ long targetWeightPerSubtask =
+ (long) Math.ceil(((double) totalWeight) /
downstreamParallelism);
+
+ // Sort keys for deterministic assignment across JVMs
+ List<Map.Entry<BinaryRow, Long>> sortedEntries = new
ArrayList<>(statistics.entrySet());
+ sortedEntries.sort(Comparator.comparingInt(e ->
e.getKey().hashCode()));
+
+ Map<BinaryRow, WeightedRandomAssignment> assignmentMap = new
HashMap<>(statistics.size());
+ Iterator<Map.Entry<BinaryRow, Long>> entryIterator =
sortedEntries.iterator();
+ int subtaskId = 0;
+ BinaryRow currentKey = null;
+ long keyRemainingWeight = 0L;
+ long subtaskRemainingWeight = targetWeightPerSubtask;
+ List<Integer> assignedSubtasks = new ArrayList<>();
+ List<Long> subtaskWeights = new ArrayList<>();
+
+ while (entryIterator.hasNext() || currentKey != null) {
+ if (subtaskId >= downstreamParallelism) {
+ LOG.error(
+ "Internal algorithm error: exhausted subtasks.
parallelism: {}, "
+ + "target weight per subtask: {}, statistics:
{}",
+ downstreamParallelism,
+ targetWeightPerSubtask,
+ statistics);
+ throw new IllegalStateException(
+ "Internal algorithm error: exhausted subtasks with
unassigned keys left");
+ }
+
+ if (currentKey == null) {
+ Map.Entry<BinaryRow, Long> entry = entryIterator.next();
+ currentKey = entry.getKey();
+ keyRemainingWeight = entry.getValue();
+ }
+
+ assignedSubtasks.add(subtaskId);
+ if (keyRemainingWeight < subtaskRemainingWeight) {
+ subtaskWeights.add(keyRemainingWeight);
+ subtaskRemainingWeight -= keyRemainingWeight;
+ keyRemainingWeight = 0L;
+ } else {
+ long assignedWeight = subtaskRemainingWeight;
+ keyRemainingWeight -= subtaskRemainingWeight;
+ subtaskWeights.add(assignedWeight);
+ subtaskId += 1;
+ subtaskRemainingWeight = targetWeightPerSubtask;
+ }
+
+ checkState(
+ assignedSubtasks.size() == subtaskWeights.size(),
+ "List size mismatch: assigned subtasks = %s, subtask
weights = %s",
+ assignedSubtasks,
+ subtaskWeights);
+
+ if (keyRemainingWeight == 0) {
+ WeightedRandomAssignment assignment =
+ new WeightedRandomAssignment(assignedSubtasks,
subtaskWeights, random);
+ assignmentMap.put(currentKey, assignment);
+ assignedSubtasks = new ArrayList<>();
+ subtaskWeights = new ArrayList<>();
+ currentKey = null;
+ }
+ }
+
+ LOG.debug("Assignment map: {}", assignmentMap);
+ return assignmentMap;
+ }
+
+ @Override
+ public String toString() {
+ return "PARTITION_DYNAMIC";
+ }
+
+ private class MapPartitioner {
+
+ private final Map<BinaryRow, WeightedRandomAssignment> assignments;
+
+ MapPartitioner(Map<BinaryRow, WeightedRandomAssignment> assignments) {
+ this.assignments = assignments;
+ }
+
+ int select(BinaryRow partitionKey, int numChannels) {
+ WeightedRandomAssignment assignment =
assignments.get(partitionKey);
+ if (assignment == null) {
+ int defaultSubtaskCount =
+ Math.min(numChannels,
DEFAULT_SUBTASK_COUNT_FOR_UNKNOWN_PARTITION);
+ int startChannel = Math.abs(partitionKey.hashCode()) %
numChannels;
+ List<Integer> subtasks = new ArrayList<>(defaultSubtaskCount);
+ List<Long> weights = new ArrayList<>(defaultSubtaskCount);
+ for (int i = 0; i < defaultSubtaskCount; i++) {
+ subtasks.add((startChannel + i) % numChannels);
+ weights.add(1L);
+ }
+ assignment = new WeightedRandomAssignment(subtasks, weights,
random);
+ assignments.put(partitionKey.copy(), assignment);
+ }
+ return assignment.select();
+ }
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordSerializer.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordSerializer.java
new file mode 100644
index 0000000000..981ff9c0af
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordSerializer.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.InternalRow;
+
+import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/** Serializer for {@link StatisticsOrRecord}. */
+public class StatisticsOrRecordSerializer extends
TypeSerializer<StatisticsOrRecord> {
+
+ private static final long serialVersionUID = 1L;
+
+ private final TypeSerializer<DataStatistics> statisticsSerializer;
+ private final TypeSerializer<InternalRow> recordSerializer;
+
+ public StatisticsOrRecordSerializer(
+ TypeSerializer<DataStatistics> statisticsSerializer,
+ TypeSerializer<InternalRow> recordSerializer) {
+ this.statisticsSerializer = statisticsSerializer;
+ this.recordSerializer = recordSerializer;
+ }
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @SuppressWarnings("ReferenceEquality")
+ @Override
+ public TypeSerializer<StatisticsOrRecord> duplicate() {
+ TypeSerializer<DataStatistics> dupStats =
statisticsSerializer.duplicate();
+ TypeSerializer<InternalRow> dupRecord = recordSerializer.duplicate();
+ if ((statisticsSerializer != dupStats) || (recordSerializer !=
dupRecord)) {
+ return new StatisticsOrRecordSerializer(dupStats, dupRecord);
+ }
+ return this;
+ }
+
+ @Override
+ public StatisticsOrRecord createInstance() {
+ return
StatisticsOrRecord.fromRecord(recordSerializer.createInstance());
+ }
+
+ @Override
+ public StatisticsOrRecord copy(StatisticsOrRecord from) {
+ if (from.isRecord()) {
+ return
StatisticsOrRecord.fromRecord(recordSerializer.copy(from.record()));
+ } else {
+ return
StatisticsOrRecord.fromStatistics(statisticsSerializer.copy(from.statistics()));
+ }
+ }
+
+ @Override
+ public StatisticsOrRecord copy(StatisticsOrRecord from, StatisticsOrRecord
reuse) {
+ return copy(from);
+ }
+
+ @Override
+ public int getLength() {
+ return -1;
+ }
+
+ @Override
+ public void serialize(StatisticsOrRecord statisticsOrRecord,
DataOutputView target)
+ throws IOException {
+ if (statisticsOrRecord.isRecord()) {
+ target.writeBoolean(true);
+ recordSerializer.serialize(statisticsOrRecord.record(), target);
+ } else {
+ target.writeBoolean(false);
+ statisticsSerializer.serialize(statisticsOrRecord.statistics(),
target);
+ }
+ }
+
+ @Override
+ public StatisticsOrRecord deserialize(DataInputView source) throws
IOException {
+ boolean isRecord = source.readBoolean();
+ if (isRecord) {
+ return
StatisticsOrRecord.fromRecord(recordSerializer.deserialize(source));
+ } else {
+ return
StatisticsOrRecord.fromStatistics(statisticsSerializer.deserialize(source));
+ }
+ }
+
+ @Override
+ public StatisticsOrRecord deserialize(StatisticsOrRecord reuse,
DataInputView source)
+ throws IOException {
+ return deserialize(source);
+ }
+
+ @Override
+ public void copy(DataInputView source, DataOutputView target) throws
IOException {
+ boolean isRecord = source.readBoolean();
+ target.writeBoolean(isRecord);
+ if (isRecord) {
+ recordSerializer.copy(source, target);
+ } else {
+ statisticsSerializer.copy(source, target);
+ }
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof StatisticsOrRecordSerializer)) {
+ return false;
+ }
+ StatisticsOrRecordSerializer other = (StatisticsOrRecordSerializer)
obj;
+ return Objects.equals(statisticsSerializer, other.statisticsSerializer)
+ && Objects.equals(recordSerializer, other.recordSerializer);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(statisticsSerializer, recordSerializer);
+ }
+
+ @Override
+ public TypeSerializerSnapshot<StatisticsOrRecord> snapshotConfiguration() {
+ return new StatisticsOrRecordSerializerSnapshot(this);
+ }
+
+ /** Snapshot for {@link StatisticsOrRecordSerializer}. */
+ public static class StatisticsOrRecordSerializerSnapshot
+ extends CompositeTypeSerializerSnapshot<
+ StatisticsOrRecord, StatisticsOrRecordSerializer> {
+ private static final int CURRENT_VERSION = 1;
+
+ @SuppressWarnings("unused")
+ public StatisticsOrRecordSerializerSnapshot() {}
+
+ public
StatisticsOrRecordSerializerSnapshot(StatisticsOrRecordSerializer serializer) {
+ super(serializer);
+ }
+
+ @Override
+ protected int getCurrentOuterSnapshotVersion() {
+ return CURRENT_VERSION;
+ }
+
+ @Override
+ @SuppressWarnings("rawtypes")
+ protected TypeSerializer<?>[] getNestedSerializers(
+ StatisticsOrRecordSerializer outerSerializer) {
+ return new TypeSerializer<?>[] {
+ outerSerializer.statisticsSerializer,
outerSerializer.recordSerializer
+ };
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ protected StatisticsOrRecordSerializer
createOuterSerializerWithNestedSerializers(
+ TypeSerializer<?>[] nestedSerializers) {
+ TypeSerializer<DataStatistics> statsSerializer =
+ (TypeSerializer<DataStatistics>) nestedSerializers[0];
+ TypeSerializer<InternalRow> recordSerializer =
+ (TypeSerializer<InternalRow>) nestedSerializers[1];
+ return new StatisticsOrRecordSerializer(statsSerializer,
recordSerializer);
+ }
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordTypeInfo.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordTypeInfo.java
new file mode 100644
index 0000000000..6de6bc4581
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordTypeInfo.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.flink.utils.InternalRowTypeSerializer;
+import org.apache.paimon.types.RowType;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.serialization.SerializerConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+
+import java.util.Objects;
+
+/** TypeInformation for {@link StatisticsOrRecord}. */
+public class StatisticsOrRecordTypeInfo extends
TypeInformation<StatisticsOrRecord> {
+
+ private static final long serialVersionUID = 1L;
+
+ private final RowType rowType;
+
+ public StatisticsOrRecordTypeInfo(RowType rowType) {
+ this.rowType = rowType;
+ }
+
+ @Override
+ public boolean isBasicType() {
+ return false;
+ }
+
+ @Override
+ public boolean isTupleType() {
+ return false;
+ }
+
+ @Override
+ public int getArity() {
+ return 1;
+ }
+
+ @Override
+ public int getTotalFields() {
+ return 1;
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public Class<StatisticsOrRecord> getTypeClass() {
+ return StatisticsOrRecord.class;
+ }
+
+ @Override
+ public boolean isKeyType() {
+ return false;
+ }
+
+ /**
+ * Do not annotate with <code>@override</code> here to maintain
compatibility with Flink 1.18-.
+ */
+ public TypeSerializer<StatisticsOrRecord>
createSerializer(SerializerConfig config) {
+ return this.createSerializer((ExecutionConfig) null);
+ }
+
+ /**
+ * Do not annotate with <code>@override</code> here to maintain
compatibility with Flink 2.0+.
+ */
+ public TypeSerializer<StatisticsOrRecord> createSerializer(ExecutionConfig
config) {
+ return new StatisticsOrRecordSerializer(
+ new DataStatisticsSerializer(), new
InternalRowTypeSerializer(rowType));
+ }
+
+ @Override
+ public String toString() {
+ return "StatisticsOrRecord";
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ StatisticsOrRecordTypeInfo that = (StatisticsOrRecordTypeInfo) o;
+ return Objects.equals(rowType, that.rowType);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(rowType);
+ }
+
+ @Override
+ public boolean canEqual(Object obj) {
+ return obj instanceof StatisticsOrRecordTypeInfo;
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsUtil.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsUtil.java
new file mode 100644
index 0000000000..5bcaa8d020
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/StatisticsUtil.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+
+/** Utility class for serializing/deserializing {@link DataStatistics}. */
+class StatisticsUtil {
+
+ private StatisticsUtil() {}
+
+ static DataStatistics createDataStatistics() {
+ return new DataStatistics();
+ }
+
+ static byte[] serializeDataStatistics(
+ DataStatistics dataStatistics, TypeSerializer<DataStatistics>
statisticsSerializer) {
+ DataOutputSerializer out = new DataOutputSerializer(64);
+ try {
+ statisticsSerializer.serialize(dataStatistics, out);
+ return out.getCopyOfBuffer();
+ } catch (IOException e) {
+ throw new UncheckedIOException("Fail to serialize data
statistics", e);
+ }
+ }
+
+ static DataStatistics deserializeDataStatistics(
+ byte[] bytes, TypeSerializer<DataStatistics> statisticsSerializer)
{
+ DataInputDeserializer input = new DataInputDeserializer(bytes, 0,
bytes.length);
+ try {
+ return statisticsSerializer.deserialize(input);
+ } catch (IOException e) {
+ throw new UncheckedIOException("Fail to deserialize data
statistics", e);
+ }
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/WeightedRandomAssignment.java
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/WeightedRandomAssignment.java
new file mode 100644
index 0000000000..1e3e4fad31
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/partition/WeightedRandomAssignment.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+
+import static org.apache.paimon.utils.Preconditions.checkArgument;
+
+/**
+ * Partition assignment strategy that randomly distributes records to subtasks
based on configured
+ * weights.
+ */
+public class WeightedRandomAssignment {
+
+ private final List<Integer> assignedSubtasks;
+ private final List<Long> subtaskWeights;
+ private final long keyWeight;
+ private final double[] cumulativeWeights;
+ private final Random random;
+
+ public WeightedRandomAssignment(
+ List<Integer> assignedSubtasks, List<Long> subtaskWeights, Random
random) {
+ checkArgument(
+ assignedSubtasks != null && !assignedSubtasks.isEmpty(),
+ "Invalid assigned subtasks: null or empty");
+ checkArgument(
+ subtaskWeights != null && !subtaskWeights.isEmpty(),
+ "Invalid assigned subtask weights: null or empty");
+ checkArgument(
+ assignedSubtasks.size() == subtaskWeights.size(),
+ "Invalid assignment: size mismatch (tasks length = %s, weights
length = %s)",
+ assignedSubtasks.size(),
+ subtaskWeights.size());
+
+ this.assignedSubtasks = assignedSubtasks;
+ this.subtaskWeights = subtaskWeights;
+ this.keyWeight =
subtaskWeights.stream().mapToLong(Long::longValue).sum();
+ this.cumulativeWeights = new double[subtaskWeights.size()];
+ long cumulativeWeight = 0;
+ for (int i = 0; i < subtaskWeights.size(); ++i) {
+ cumulativeWeight += subtaskWeights.get(i);
+ cumulativeWeights[i] = cumulativeWeight;
+ }
+ this.random = random;
+ }
+
+ public int select() {
+ if (assignedSubtasks.size() == 1) {
+ return assignedSubtasks.get(0);
+ } else {
+ double randomNumber = nextDouble(0, keyWeight);
+ int index = Arrays.binarySearch(cumulativeWeights, randomNumber);
+ int position = Math.abs(index + 1);
+ if (position >= assignedSubtasks.size()) {
+ position = assignedSubtasks.size() - 1;
+ }
+ return assignedSubtasks.get(position);
+ }
+ }
+
+ private double nextDouble(double origin, double bound) {
+ double r = random.nextDouble();
+ r = r * (bound - origin) + origin;
+ if (r >= bound) {
+ r = Double.longBitsToDouble(Double.doubleToLongBits(bound) - 1);
+ }
+ return r;
+ }
+
+ @Override
+ public String toString() {
+ return "WeightedRandomAssignment{"
+ + "assignedSubtasks="
+ + assignedSubtasks
+ + ", subtaskWeights="
+ + subtaskWeights
+ + ", keyWeight="
+ + keyWeight
+ + '}';
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/AppendTableITCase.java
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/AppendTableITCase.java
index f307eb47c2..4c535b6798 100644
---
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/AppendTableITCase.java
+++
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/AppendTableITCase.java
@@ -552,7 +552,12 @@ public class AppendTableITCase extends CatalogITCaseBase {
strategy == CoreOptions.PartitionSinkStrategy.HASH
? hashStrategyResultFileCount
: largerSinkParallelism;
- partitionEntriesLarger.forEach(x ->
assertThat(x.fileCount()).isEqualTo(fileCountLarger));
+ if (strategy == CoreOptions.PartitionSinkStrategy.PARTITION_DYNAMIC) {
+ fileCountLarger = Math.min(largerSinkParallelism, 4);
+ }
+ final int expectedFileCountLarger = fileCountLarger;
+ partitionEntriesLarger.forEach(
+ x ->
assertThat(x.fileCount()).isEqualTo(expectedFileCountLarger));
FileStoreTable fileStoreTableLess =
paimonTable("partition_strategy_table_less");
List<PartitionEntry> partitionEntriesLess =
@@ -562,7 +567,206 @@ public class AppendTableITCase extends CatalogITCaseBase {
strategy == CoreOptions.PartitionSinkStrategy.HASH
? hashStrategyResultFileCount
: lessSinkParallelism;
- partitionEntriesLess.forEach(x ->
assertThat(x.fileCount()).isEqualTo(fileCountLess));
+ if (strategy == CoreOptions.PartitionSinkStrategy.PARTITION_DYNAMIC) {
+ fileCountLess = Math.min(lessSinkParallelism, 4);
+ }
+ final int expectedFileCountLess = fileCountLess;
+ partitionEntriesLess.forEach(
+ x ->
assertThat(x.fileCount()).isEqualTo(expectedFileCountLess));
+ }
+
+ @Test
+ public void testPartitionDynamicDataCorrectness() {
+ batchSql(
+ "CREATE TABLE IF NOT EXISTS dynamic_correctness ("
+ + "id INT, data STRING, dt STRING) PARTITIONED BY (dt)"
+ + " WITH ("
+ + "'bucket' = '-1',"
+ + "'partition.sink-strategy' = 'partition_dynamic',"
+ + "'sink.parallelism' = '4')");
+
+ batchSql(
+ "INSERT INTO dynamic_correctness VALUES "
+ + "(1, 'a', '20250301'), (2, 'b', '20250301'), "
+ + "(3, 'c', '20250302'), (4, 'd', '20250302'), "
+ + "(5, 'e', '20250303')");
+
+ List<Row> result = batchSql("SELECT * FROM dynamic_correctness ORDER
BY id");
+ assertThat(result).hasSize(5);
+ assertThat(result)
+ .containsExactlyInAnyOrder(
+ Row.of(1, "a", "20250301"),
+ Row.of(2, "b", "20250301"),
+ Row.of(3, "c", "20250302"),
+ Row.of(4, "d", "20250302"),
+ Row.of(5, "e", "20250303"));
+ }
+
+ @Test
+ public void testPartitionDynamicWithSkewedData() {
+ batchSql(
+ "CREATE TABLE IF NOT EXISTS dynamic_skewed ("
+ + "id INT, data STRING, dt STRING) PARTITIONED BY (dt)"
+ + " WITH ("
+ + "'bucket' = '-1',"
+ + "'partition.sink-strategy' = 'partition_dynamic',"
+ + "'sink.parallelism' = '4')");
+
+ // Heavily skewed: partition '20250301' gets most data
+ StringBuilder values = new StringBuilder();
+ for (int i = 1; i <= 100; i++) {
+ values.append(String.format("(%d, 'data%d', '20250301'),", i, i));
+ }
+ for (int i = 101; i <= 110; i++) {
+ values.append(String.format("(%d, 'data%d', '20250302'),", i, i));
+ }
+ for (int i = 111; i <= 115; i++) {
+ values.append(String.format("(%d, 'data%d', '20250303'),", i, i));
+ }
+
+ batchSql("INSERT INTO dynamic_skewed VALUES " + values.substring(0,
values.length() - 1));
+
+ assertThat(batchSql("SELECT * FROM dynamic_skewed")).hasSize(115);
+ assertThat(batchSql("SELECT * FROM dynamic_skewed WHERE dt =
'20250301'")).hasSize(100);
+ assertThat(batchSql("SELECT * FROM dynamic_skewed WHERE dt =
'20250302'")).hasSize(10);
+ assertThat(batchSql("SELECT * FROM dynamic_skewed WHERE dt =
'20250303'")).hasSize(5);
+ }
+
+ @Test
+ public void testPartitionDynamicWithManyPartitions() {
+ int partitionCount = 20;
+ int sinkParallelism = 4;
+ batchSql(
+ "CREATE TABLE IF NOT EXISTS dynamic_many_partitions ("
+ + "id INT, data STRING, dt STRING) PARTITIONED BY (dt)"
+ + " WITH ("
+ + "'bucket' = '-1',"
+ + "'partition.sink-strategy' = 'partition_dynamic',"
+ + String.format("'sink.parallelism' = '%d')",
sinkParallelism));
+
+ StringBuilder values = new StringBuilder();
+ int totalRows = 0;
+ for (int p = 1; p <= partitionCount; p++) {
+ for (int i = 1; i <= 5; i++) {
+ values.append(String.format("(%d, 'data', '2025030%02d'),", (p
- 1) * 5 + i, p));
+ totalRows++;
+ }
+ }
+
+ batchSql(
+ "INSERT INTO dynamic_many_partitions VALUES "
+ + values.substring(0, values.length() - 1));
+
+ List<Row> result = batchSql("SELECT * FROM dynamic_many_partitions");
+ assertThat(result).hasSize(totalRows);
+
+ // Verify all partitions are present
+ List<Row> partitions =
+ batchSql("SELECT DISTINCT dt FROM dynamic_many_partitions
ORDER BY dt");
+ assertThat(partitions).hasSize(partitionCount);
+ }
+
+ @Test
+ public void testPartitionDynamicMultipleInserts() {
+ batchSql(
+ "CREATE TABLE IF NOT EXISTS dynamic_multi_insert ("
+ + "id INT, data STRING, dt STRING) PARTITIONED BY (dt)"
+ + " WITH ("
+ + "'bucket' = '-1',"
+ + "'partition.sink-strategy' = 'partition_dynamic',"
+ + "'sink.parallelism' = '4')");
+
+ batchSql(
+ "INSERT INTO dynamic_multi_insert VALUES "
+ + "(1, 'a', '20250301'), (2, 'b', '20250301'), (3,
'c', '20250302')");
+ batchSql(
+ "INSERT INTO dynamic_multi_insert VALUES "
+ + "(4, 'd', '20250302'), (5, 'e', '20250303'), (6,
'f', '20250301')");
+
+ List<Row> result = batchSql("SELECT * FROM dynamic_multi_insert ORDER
BY id");
+ assertThat(result).hasSize(6);
+ assertThat(batchSql("SELECT * FROM dynamic_multi_insert WHERE dt =
'20250301'")).hasSize(3);
+ assertThat(batchSql("SELECT * FROM dynamic_multi_insert WHERE dt =
'20250302'")).hasSize(2);
+ assertThat(batchSql("SELECT * FROM dynamic_multi_insert WHERE dt =
'20250303'")).hasSize(1);
+ }
+
+ @Timeout(120)
+ @Test
+ public void testPartitionDynamicStreaming() throws Exception {
+ int sinkParallelism = 4;
+ batchSql(
+ "CREATE TABLE IF NOT EXISTS dynamic_streaming ("
+ + "id INT, data STRING, dt STRING) PARTITIONED BY (dt)"
+ + " WITH ("
+ + "'bucket' = '-1',"
+ + "'partition.sink-strategy' = 'partition_dynamic',"
+ + "'sink.parallelism' = '%d')",
+ sinkParallelism);
+
+ // Write heavily skewed data: partition '20250301' gets most records.
+ // With streaming mode (sEnv has checkpoint interval 100ms), the
+ // DataStatisticsOperator sends local stats at checkpoint ->
coordinator
+ // aggregates -> sends global stats back -> partitioner updates
assignment.
+ // This verifies the full coordinator -> operator event -> partitioner
update path.
+ StringBuilder values = new StringBuilder();
+ for (int i = 1; i <= 100; i++) {
+ values.append(String.format("(%d, 'data%d', '20250301'),", i, i));
+ }
+ for (int i = 101; i <= 110; i++) {
+ values.append(String.format("(%d, 'data%d', '20250302'),", i, i));
+ }
+ for (int i = 111; i <= 115; i++) {
+ values.append(String.format("(%d, 'data%d', '20250303'),", i, i));
+ }
+
+ sEnv.executeSql(
+ "INSERT INTO dynamic_streaming VALUES "
+ + values.substring(0, values.length() - 1))
+ .await();
+
+ // Verify data correctness
+ assertThat(batchSql("SELECT * FROM dynamic_streaming")).hasSize(115);
+ assertThat(batchSql("SELECT * FROM dynamic_streaming WHERE dt =
'20250301'")).hasSize(100);
+ assertThat(batchSql("SELECT * FROM dynamic_streaming WHERE dt =
'20250302'")).hasSize(10);
+ assertThat(batchSql("SELECT * FROM dynamic_streaming WHERE dt =
'20250303'")).hasSize(5);
+
+ // Verify the hot partition is spread across multiple subtasks (file
count > 1).
+ // If statistics were not applied, each partition would only go to
min(parallelism, 4)
+ // subtasks with equal weight. With statistics, the hot partition
('20250301') should
+ // be spread across more subtasks proportionally to its weight.
+ FileStoreTable table = paimonTable("dynamic_streaming");
+ List<PartitionEntry> partitionEntries =
+ table.newReadBuilder().newScan().listPartitionEntries();
+ assertThat(partitionEntries).hasSize(3);
+
+ for (PartitionEntry entry : partitionEntries) {
+ assertThat(entry.fileCount()).isGreaterThanOrEqualTo(1);
+ }
+ }
+
+ @Test
+ public void testPartitionDynamicSinglePartition() {
+ batchSql(
+ "CREATE TABLE IF NOT EXISTS dynamic_single_partition ("
+ + "id INT, data STRING, dt STRING) PARTITIONED BY (dt)"
+ + " WITH ("
+ + "'bucket' = '-1',"
+ + "'partition.sink-strategy' = 'partition_dynamic',"
+ + "'sink.parallelism' = '4')");
+
+ StringBuilder values = new StringBuilder();
+ for (int i = 1; i <= 50; i++) {
+ values.append(String.format("(%d, 'data%d', '20250301'),", i, i));
+ }
+
+ batchSql(
+ "INSERT INTO dynamic_single_partition VALUES "
+ + values.substring(0, values.length() - 1));
+
+ List<Row> result = batchSql("SELECT * FROM dynamic_single_partition");
+ assertThat(result).hasSize(50);
+ assertThat(batchSql("SELECT DISTINCT dt FROM
dynamic_single_partition"))
+ .containsExactly(Row.of("20250301"));
}
private static class TestStatelessWriterSource extends
AbstractNonCoordinatedSource<Integer> {
diff --git
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/AggregatedStatisticsTrackerTest.java
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/AggregatedStatisticsTrackerTest.java
new file mode 100644
index 0000000000..ac26e7b885
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/AggregatedStatisticsTrackerTest.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.BinaryRow;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link AggregatedStatisticsTracker}. */
+class AggregatedStatisticsTrackerTest {
+
+ private static final DataStatisticsSerializer SERIALIZER = new
DataStatisticsSerializer();
+
+ @Test
+ void testAggregationCompletesWhenAllSubtasksReport() {
+ AggregatedStatisticsTracker tracker = new
AggregatedStatisticsTracker("test-op", 2);
+
+ BinaryRow p1 = BinaryRow.singleColumn("p1");
+ BinaryRow p2 = BinaryRow.singleColumn("p2");
+ BinaryRow p3 = BinaryRow.singleColumn("p3");
+
+ DataStatistics stats0 = new DataStatistics();
+ stats0.add(p1, 100L);
+ stats0.add(p2, 50L);
+
+ DataStatistics stats1 = new DataStatistics();
+ stats1.add(p1, 200L);
+ stats1.add(p3, 75L);
+
+ StatisticsEvent event0 = StatisticsEvent.createStatisticsEvent(1L,
stats0, SERIALIZER);
+ StatisticsEvent event1 = StatisticsEvent.createStatisticsEvent(1L,
stats1, SERIALIZER);
+
+ // First subtask reports - not complete yet
+ DataStatistics result = tracker.updateAndCheckCompletion(0, event0);
+ assertThat(result).isNull();
+
+ // Second subtask reports - now complete
+ result = tracker.updateAndCheckCompletion(1, event1);
+ assertThat(result).isNotNull();
+ assertThat(result.result().get(p1)).isEqualTo(300L);
+ assertThat(result.result().get(p2)).isEqualTo(50L);
+ assertThat(result.result().get(p3)).isEqualTo(75L);
+ }
+
+ @Test
+ void testIgnoresDuplicateSubtaskReport() {
+ AggregatedStatisticsTracker tracker = new
AggregatedStatisticsTracker("test-op", 2);
+
+ BinaryRow p1 = BinaryRow.singleColumn("p1");
+ Map<BinaryRow, Long> map = new HashMap<>();
+ map.put(p1, 100L);
+ DataStatistics stats = new DataStatistics(map);
+ StatisticsEvent event = StatisticsEvent.createStatisticsEvent(1L,
stats, SERIALIZER);
+
+ // First report from subtask 0
+ DataStatistics result = tracker.updateAndCheckCompletion(0, event);
+ assertThat(result).isNull();
+
+ // Duplicate report from subtask 0 - should be ignored, still not
complete
+ result = tracker.updateAndCheckCompletion(0, event);
+ assertThat(result).isNull();
+ }
+
+ @Test
+ void testIgnoresStaleCheckpoint() {
+ AggregatedStatisticsTracker tracker = new
AggregatedStatisticsTracker("test-op", 1);
+
+ BinaryRow p1 = BinaryRow.singleColumn("p1");
+ Map<BinaryRow, Long> map = new HashMap<>();
+ map.put(p1, 100L);
+ DataStatistics stats = new DataStatistics(map);
+
+ // Complete checkpoint 2
+ StatisticsEvent event2 = StatisticsEvent.createStatisticsEvent(2L,
stats, SERIALIZER);
+ DataStatistics result = tracker.updateAndCheckCompletion(0, event2);
+ assertThat(result).isNotNull();
+
+ // Now send event for older checkpoint 1 - should be ignored
+ StatisticsEvent event1 = StatisticsEvent.createStatisticsEvent(1L,
stats, SERIALIZER);
+ result = tracker.updateAndCheckCompletion(0, event1);
+ assertThat(result).isNull();
+ }
+
+ @Test
+ void testMultipleCheckpoints() {
+ AggregatedStatisticsTracker tracker = new
AggregatedStatisticsTracker("test-op", 2);
+
+ BinaryRow p1 = BinaryRow.singleColumn("p1");
+
+ // Checkpoint 1: subtask 0 reports
+ Map<BinaryRow, Long> map1 = new HashMap<>();
+ map1.put(p1, 10L);
+ DataStatistics statsChk1Sub0 = new DataStatistics(map1);
+ StatisticsEvent eventChk1Sub0 =
+ StatisticsEvent.createStatisticsEvent(1L, statsChk1Sub0,
SERIALIZER);
+ assertThat(tracker.updateAndCheckCompletion(0,
eventChk1Sub0)).isNull();
+
+ // Checkpoint 2: subtask 0 reports (before checkpoint 1 completes)
+ Map<BinaryRow, Long> map2 = new HashMap<>();
+ map2.put(p1, 20L);
+ DataStatistics statsChk2Sub0 = new DataStatistics(map2);
+ StatisticsEvent eventChk2Sub0 =
+ StatisticsEvent.createStatisticsEvent(2L, statsChk2Sub0,
SERIALIZER);
+ assertThat(tracker.updateAndCheckCompletion(0,
eventChk2Sub0)).isNull();
+
+ // Checkpoint 1: subtask 1 reports - completes checkpoint 1
+ Map<BinaryRow, Long> map3 = new HashMap<>();
+ map3.put(p1, 15L);
+ DataStatistics statsChk1Sub1 = new DataStatistics(map3);
+ StatisticsEvent eventChk1Sub1 =
+ StatisticsEvent.createStatisticsEvent(1L, statsChk1Sub1,
SERIALIZER);
+ DataStatistics result = tracker.updateAndCheckCompletion(1,
eventChk1Sub1);
+ assertThat(result).isNotNull();
+ assertThat(result.result().get(p1)).isEqualTo(25L);
+ }
+
+ @Test
+ void testEmptyStatisticsSkipped() {
+ AggregatedStatisticsTracker tracker = new
AggregatedStatisticsTracker("test-op", 1);
+
+ DataStatistics emptyStats = new DataStatistics();
+ StatisticsEvent event = StatisticsEvent.createStatisticsEvent(1L,
emptyStats, SERIALIZER);
+ DataStatistics result = tracker.updateAndCheckCompletion(0, event);
+ // Returns the completed (empty) statistics - caller decides to skip
+ assertThat(result).isNotNull();
+ assertThat(result.isEmpty()).isTrue();
+ }
+
+ @Test
+ void testThreeSubtasks() {
+ AggregatedStatisticsTracker tracker = new
AggregatedStatisticsTracker("test-op", 3);
+
+ BinaryRow p1 = BinaryRow.singleColumn("p1");
+ BinaryRow p2 = BinaryRow.singleColumn("p2");
+
+ Map<BinaryRow, Long> freq0 = new HashMap<>();
+ freq0.put(p1, 100L);
+ Map<BinaryRow, Long> freq1 = new HashMap<>();
+ freq1.put(p1, 200L);
+ freq1.put(p2, 50L);
+ Map<BinaryRow, Long> freq2 = new HashMap<>();
+ freq2.put(p2, 150L);
+
+ assertThat(
+ tracker.updateAndCheckCompletion(
+ 0,
+ StatisticsEvent.createStatisticsEvent(
+ 1L, new DataStatistics(freq0),
SERIALIZER)))
+ .isNull();
+ assertThat(
+ tracker.updateAndCheckCompletion(
+ 1,
+ StatisticsEvent.createStatisticsEvent(
+ 1L, new DataStatistics(freq1),
SERIALIZER)))
+ .isNull();
+
+ DataStatistics result =
+ tracker.updateAndCheckCompletion(
+ 2,
+ StatisticsEvent.createStatisticsEvent(
+ 1L, new DataStatistics(freq2), SERIALIZER));
+ assertThat(result).isNotNull();
+ assertThat(result.result().get(p1)).isEqualTo(300L);
+ assertThat(result.result().get(p2)).isEqualTo(200L);
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/DataStatisticsOperatorTest.java
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/DataStatisticsOperatorTest.java
new file mode 100644
index 0000000000..fc1ff0f7e9
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/DataStatisticsOperatorTest.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.BinaryRow;
+import org.apache.paimon.data.BinaryString;
+import org.apache.paimon.data.GenericRow;
+import org.apache.paimon.data.InternalRow;
+import org.apache.paimon.schema.TableSchema;
+import org.apache.paimon.types.DataTypes;
+import org.apache.paimon.types.RowType;
+
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link DataStatisticsOperator}. */
+class DataStatisticsOperatorTest {
+
+ private static final RowType ROW_TYPE =
+ RowType.builder()
+ .field("id", DataTypes.INT())
+ .field("pt", DataTypes.STRING())
+ .field("data", DataTypes.STRING())
+ .build();
+
+ private static final TableSchema SCHEMA =
+ new TableSchema(
+ 0L,
+ ROW_TYPE.getFields(),
+ ROW_TYPE.getFieldCount(),
+ Collections.singletonList("pt"),
+ Collections.emptyList(),
+ new HashMap<>(),
+ null);
+
+ @SuppressWarnings("unchecked")
+ private List<StatisticsOrRecord> extractRecordOutput(
+ OneInputStreamOperatorTestHarness<InternalRow, StatisticsOrRecord>
testHarness) {
+ List<StatisticsOrRecord> result = new ArrayList<>();
+ for (Object o : testHarness.getOutput()) {
+ if (o instanceof StreamRecord) {
+ result.add(((StreamRecord<StatisticsOrRecord>) o).getValue());
+ }
+ }
+ return result;
+ }
+
+ @Test
+ void testProcessElement() throws Exception {
+ DataStatisticsOperatorFactory factory = new
DataStatisticsOperatorFactory(SCHEMA);
+
+ try (OneInputStreamOperatorTestHarness<InternalRow,
StatisticsOrRecord> testHarness =
+ new OneInputStreamOperatorTestHarness<>(factory, 1, 1, 0)) {
+ testHarness.open();
+
+ InternalRow row1 =
+ GenericRow.of(1, BinaryString.fromString("pt1"),
BinaryString.fromString("a"));
+ InternalRow row2 =
+ GenericRow.of(2, BinaryString.fromString("pt1"),
BinaryString.fromString("b"));
+ InternalRow row3 =
+ GenericRow.of(3, BinaryString.fromString("pt2"),
BinaryString.fromString("c"));
+
+ testHarness.processElement(new StreamRecord<>(row1));
+ testHarness.processElement(new StreamRecord<>(row2));
+ testHarness.processElement(new StreamRecord<>(row3));
+
+ List<StatisticsOrRecord> output = extractRecordOutput(testHarness);
+ assertThat(output).hasSize(3);
+ assertThat(output.get(0).isRecord()).isTrue();
+ assertThat(output.get(1).isRecord()).isTrue();
+ assertThat(output.get(2).isRecord()).isTrue();
+ }
+ }
+
+ @Test
+ void testHandleOperatorEvent() throws Exception {
+ DataStatisticsOperatorFactory factory = new
DataStatisticsOperatorFactory(SCHEMA);
+
+ try (OneInputStreamOperatorTestHarness<InternalRow,
StatisticsOrRecord> testHarness =
+ new OneInputStreamOperatorTestHarness<>(factory, 1, 1, 0)) {
+ testHarness.open();
+
+ // Simulate receiving global statistics from coordinator
+ Map<BinaryRow, Long> globalStats = new HashMap<>();
+ globalStats.put(BinaryRow.singleColumn("pt1"), 1000L);
+ globalStats.put(BinaryRow.singleColumn("pt2"), 2000L);
+ DataStatistics globalStatistics = new DataStatistics(globalStats);
+ StatisticsEvent event =
+ StatisticsEvent.createStatisticsEvent(
+ 1L, globalStatistics, new
DataStatisticsSerializer());
+
+ ((DataStatisticsOperator)
testHarness.getOperator()).handleOperatorEvent(event);
+
+ List<StatisticsOrRecord> output = extractRecordOutput(testHarness);
+ assertThat(output).hasSize(1);
+ assertThat(output.get(0).isStatistics()).isTrue();
+
assertThat(output.get(0).statistics().result()).isEqualTo(globalStats);
+ }
+ }
+
+ @Test
+ void testProcessAndHandleEvent() throws Exception {
+ DataStatisticsOperatorFactory factory = new
DataStatisticsOperatorFactory(SCHEMA);
+
+ try (OneInputStreamOperatorTestHarness<InternalRow,
StatisticsOrRecord> testHarness =
+ new OneInputStreamOperatorTestHarness<>(factory, 1, 1, 0)) {
+ testHarness.open();
+
+ // Process some records
+ InternalRow row1 =
+ GenericRow.of(1, BinaryString.fromString("pt1"),
BinaryString.fromString("a"));
+ InternalRow row2 =
+ GenericRow.of(2, BinaryString.fromString("pt2"),
BinaryString.fromString("b"));
+ testHarness.processElement(new StreamRecord<>(row1));
+ testHarness.processElement(new StreamRecord<>(row2));
+
+ // Receive global statistics
+ Map<BinaryRow, Long> globalStats = new HashMap<>();
+ globalStats.put(BinaryRow.singleColumn("pt1"), 500L);
+ StatisticsEvent event =
+ StatisticsEvent.createStatisticsEvent(
+ 1L, new DataStatistics(globalStats), new
DataStatisticsSerializer());
+ ((DataStatisticsOperator)
testHarness.getOperator()).handleOperatorEvent(event);
+
+ // Process more records
+ InternalRow row3 =
+ GenericRow.of(3, BinaryString.fromString("pt1"),
BinaryString.fromString("c"));
+ testHarness.processElement(new StreamRecord<>(row3));
+
+ List<StatisticsOrRecord> output = extractRecordOutput(testHarness);
+ // 2 records + 1 statistics + 1 record = 4
+ assertThat(output).hasSize(4);
+ assertThat(output.get(0).isRecord()).isTrue();
+ assertThat(output.get(1).isRecord()).isTrue();
+ assertThat(output.get(2).isStatistics()).isTrue();
+ assertThat(output.get(3).isRecord()).isTrue();
+ }
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/DataStatisticsSerializerTest.java
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/DataStatisticsSerializerTest.java
new file mode 100644
index 0000000000..53d18b53c7
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/DataStatisticsSerializerTest.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.BinaryRow;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link DataStatisticsSerializer}. */
+class DataStatisticsSerializerTest {
+
+ @Test
+ void testSerializeAndDeserialize() {
+ DataStatisticsSerializer serializer = new DataStatisticsSerializer();
+ DataStatistics dataStatistics = StatisticsUtil.createDataStatistics();
+ assertThat(
+ StatisticsUtil.deserializeDataStatistics(
+ StatisticsUtil.serializeDataStatistics(
+ dataStatistics, serializer),
+ serializer)
+ .isEmpty())
+ .isTrue();
+
+ BinaryRow p1 = BinaryRow.singleColumn("p1");
+ BinaryRow p2 = BinaryRow.singleColumn("p2");
+ dataStatistics.add(p1, 1);
+ dataStatistics.add(p2, 2);
+ dataStatistics.add(p1, 3);
+ assertThat(
+ StatisticsUtil.deserializeDataStatistics(
+ StatisticsUtil.serializeDataStatistics(
+ dataStatistics, serializer),
+ serializer)
+ .result())
+ .isEqualTo(dataStatistics.result());
+ }
+
+ @Test
+ void testCopy() {
+ DataStatisticsSerializer serializer = new DataStatisticsSerializer();
+ DataStatistics original = new DataStatistics();
+ BinaryRow p1 = BinaryRow.singleColumn("p1");
+ BinaryRow p2 = BinaryRow.singleColumn("p2");
+ BinaryRow p3 = BinaryRow.singleColumn("p3");
+ original.add(p1, 100L);
+ original.add(p2, 200L);
+
+ DataStatistics copy = serializer.copy(original);
+ assertThat(copy).isEqualTo(original);
+
+ // Mutating copy should not affect original
+ copy.add(p3, 300L);
+ assertThat(original.result()).doesNotContainKey(p3);
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/MockRandom.java
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/MockRandom.java
new file mode 100644
index 0000000000..3f53e8bc03
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/MockRandom.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import java.util.Random;
+
+/**
+ * A subclass of Random with a fixed seed and generation algorithm. This is
useful for generating a
+ * deterministic sequence of pseudorandom numbers.
+ */
+public class MockRandom extends Random {
+
+ private long state;
+
+ public MockRandom() {
+ this(17);
+ }
+
+ public MockRandom(long state) {
+ this.state = state;
+ }
+
+ @Override
+ protected int next(int bits) {
+ state = (state * 2862933555777941757L) + 3037000493L;
+ return (int) (state >>> (64 - bits));
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordChannelComputerTest.java
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordChannelComputerTest.java
new file mode 100644
index 0000000000..deedae4c29
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/StatisticsOrRecordChannelComputerTest.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.apache.paimon.data.BinaryRow;
+import org.apache.paimon.data.BinaryString;
+import org.apache.paimon.data.GenericRow;
+import org.apache.paimon.data.InternalRow;
+import org.apache.paimon.schema.TableSchema;
+import org.apache.paimon.table.sink.RowPartitionKeyExtractor;
+import org.apache.paimon.types.DataTypes;
+import org.apache.paimon.types.RowType;
+
+import org.assertj.core.data.Percentage;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link StatisticsOrRecordChannelComputer}. */
+class StatisticsOrRecordChannelComputerTest {
+
+ private static TableSchema schema;
+
+ @BeforeAll
+ static void init() {
+ RowType rowType =
+ RowType.builder()
+ .field("id", DataTypes.INT())
+ .field("pt", DataTypes.STRING())
+ .field("data", DataTypes.STRING())
+ .build();
+ schema =
+ new TableSchema(
+ 0L,
+ rowType.getFields(),
+ rowType.getFieldCount(),
+ Collections.singletonList("pt"),
+ Collections.emptyList(),
+ new HashMap<>(),
+ null);
+ }
+
+ @Test
+ void testShuffleWithoutStatistics() {
+ int downstreamParallelism = 8;
+ StatisticsOrRecordChannelComputer channelComputer =
+ new StatisticsOrRecordChannelComputer(schema);
+ channelComputer.setup(downstreamParallelism);
+
+ Map<Integer, Double> subtaskAssignedCounts = new HashMap<>();
+ int totalRowNum = 50000;
+ for (int i = 0; i < totalRowNum; i++) {
+ InternalRow row =
+ GenericRow.of(i, BinaryString.fromString("pt1"),
BinaryString.fromString("d"));
+ int channel =
channelComputer.channel(StatisticsOrRecord.fromRecord(row));
+ subtaskAssignedCounts.merge(channel, 1.0 / totalRowNum,
Double::sum);
+ }
+
+ // Without statistics, fallback assigns min(numChannels, 4) subtasks
+ int targetParallelism = Math.min(downstreamParallelism, 4);
+ assertThat(subtaskAssignedCounts.size()).isEqualTo(targetParallelism);
+ for (Double percentage : subtaskAssignedCounts.values()) {
+ assertThat(percentage).isCloseTo(1.0 / targetParallelism,
Percentage.withPercentage(5));
+ }
+ }
+
+ @Test
+ void testShuffleWithSinglePartitionStatistics() {
+ int downstreamParallelism = 8;
+ StatisticsOrRecordChannelComputer channelComputer =
+ new StatisticsOrRecordChannelComputer(schema);
+ channelComputer.setup(downstreamParallelism);
+
+ // Feed statistics: single partition gets all weight -> spread across
all subtasks
+ Map<BinaryRow, Long> partitionFrequency = new HashMap<>();
+ InternalRow sampleRow =
+ GenericRow.of(0, BinaryString.fromString("pt1"),
BinaryString.fromString("d"));
+ BinaryRow partitionKey = getPartitionKey(sampleRow);
+ partitionFrequency.put(partitionKey, 10000L);
+ channelComputer.channel(
+ StatisticsOrRecord.fromStatistics(new
DataStatistics(partitionFrequency)));
+
+ Map<Integer, Double> subtaskAssignedCounts = new HashMap<>();
+ int totalRowNum = 50000;
+ for (int i = 0; i < totalRowNum; i++) {
+ InternalRow row =
+ GenericRow.of(i, BinaryString.fromString("pt1"),
BinaryString.fromString("d"));
+ int channel =
channelComputer.channel(StatisticsOrRecord.fromRecord(row));
+ subtaskAssignedCounts.merge(channel, 1.0 / totalRowNum,
Double::sum);
+ }
+
+
assertThat(subtaskAssignedCounts.size()).isEqualTo(downstreamParallelism);
+ for (Double percentage : subtaskAssignedCounts.values()) {
+ assertThat(percentage)
+ .isCloseTo(1.0 / downstreamParallelism,
Percentage.withPercentage(5));
+ }
+ }
+
+ @Test
+ void testShuffleWithMultiplePartitionStatistics() {
+ int downstreamParallelism = 8;
+ StatisticsOrRecordChannelComputer channelComputer =
+ new StatisticsOrRecordChannelComputer(schema);
+ channelComputer.setup(downstreamParallelism);
+
+ InternalRow sampleRow1 =
+ GenericRow.of(0, BinaryString.fromString("pt1"),
BinaryString.fromString("d"));
+ InternalRow sampleRow2 =
+ GenericRow.of(0, BinaryString.fromString("pt2"),
BinaryString.fromString("d"));
+ BinaryRow partitionKey1 = getPartitionKey(sampleRow1);
+ BinaryRow partitionKey2 = getPartitionKey(sampleRow2);
+
+ // partition 1 has 1/4 of the weight, partition 2 has 3/4
+ Map<BinaryRow, Long> partitionFrequency = new HashMap<>();
+ partitionFrequency.put(partitionKey1, 10000L);
+ partitionFrequency.put(partitionKey2, 30000L);
+ channelComputer.channel(
+ StatisticsOrRecord.fromStatistics(new
DataStatistics(partitionFrequency)));
+
+ Map<Integer, Double> subtaskAssignedCounts = new HashMap<>();
+ int totalRowNum = 50000;
+ for (int i = 0; i < totalRowNum; i++) {
+ InternalRow row =
+ GenericRow.of(i, BinaryString.fromString("pt1"),
BinaryString.fromString("d"));
+ int channel =
channelComputer.channel(StatisticsOrRecord.fromRecord(row));
+ subtaskAssignedCounts.merge(channel, 1.0 / totalRowNum,
Double::sum);
+ }
+
+ // partition 1 is 1/4 of total, so it should be assigned to ~2
subtasks (8/4)
+
assertThat(subtaskAssignedCounts.size()).isEqualTo(downstreamParallelism / 4);
+ }
+
+ @Test
+ void testMultipleUnknownPartitionsWithoutStatistics() {
+ int downstreamParallelism = 8;
+ StatisticsOrRecordChannelComputer channelComputer =
+ new StatisticsOrRecordChannelComputer(schema);
+ channelComputer.setup(downstreamParallelism);
+
+ // Send records from multiple unknown partitions in interleaved order.
+ // This exercises the fallback cache in MapPartitioner.select(): each
partition
+ // key must be stored as a copy, otherwise the reused BinaryRow from
+ // RowPartitionKeyExtractor corrupts earlier cache entries.
+ Map<String, Map<Integer, Integer>> partitionChannelCounts = new
HashMap<>();
+ String[] partitions = {"pt_a", "pt_b", "pt_c"};
+ int recordsPerPartition = 10000;
+ for (int i = 0; i < recordsPerPartition; i++) {
+ for (String pt : partitions) {
+ InternalRow row =
+ GenericRow.of(i, BinaryString.fromString(pt),
BinaryString.fromString("d"));
+ int channel =
channelComputer.channel(StatisticsOrRecord.fromRecord(row));
+ partitionChannelCounts
+ .computeIfAbsent(pt, k -> new HashMap<>())
+ .merge(channel, 1, Integer::sum);
+ }
+ }
+
+ // Each unknown partition should be assigned to its own set of subtasks
+ // (deterministic based on partition key hash). Verify they don't all
collapse
+ // into the same assignment, which would happen if the mutable key
were shared.
+ for (String pt : partitions) {
+ Map<Integer, Integer> counts = partitionChannelCounts.get(pt);
+ assertThat(counts).isNotEmpty();
+ int total =
counts.values().stream().mapToInt(Integer::intValue).sum();
+ assertThat(total).isEqualTo(recordsPerPartition);
+ }
+
+ // Verify that not all partitions share the exact same channel set,
which would
+ // indicate mutable key corruption (all keys pointing to the last
partition's hash)
+ Map<Integer, Integer> countsA = partitionChannelCounts.get("pt_a");
+ Map<Integer, Integer> countsB = partitionChannelCounts.get("pt_b");
+ Map<Integer, Integer> countsC = partitionChannelCounts.get("pt_c");
+ boolean allSame =
+ countsA.keySet().equals(countsB.keySet())
+ && countsB.keySet().equals(countsC.keySet());
+ assertThat(allSame)
+ .as(
+ "Different partitions should generally get different
channel assignments, "
+ + "but all three got the same channels: %s",
+ countsA.keySet())
+ .isFalse();
+ }
+
+ @Test
+ void testBuildAssignment() {
+ StatisticsOrRecordChannelComputer channelComputer =
+ new StatisticsOrRecordChannelComputer(schema);
+ channelComputer.setup(4);
+
+ InternalRow sampleRow1 =
+ GenericRow.of(0, BinaryString.fromString("p1"),
BinaryString.fromString("d"));
+ InternalRow sampleRow2 =
+ GenericRow.of(0, BinaryString.fromString("p2"),
BinaryString.fromString("d"));
+ BinaryRow p1 = getPartitionKey(sampleRow1);
+ BinaryRow p2 = getPartitionKey(sampleRow2);
+
+ Map<BinaryRow, Long> statistics = new HashMap<>();
+ statistics.put(p1, 100L);
+ statistics.put(p2, 300L);
+
+ Map<BinaryRow, WeightedRandomAssignment> assignment =
+ channelComputer.buildAssignment(4, statistics);
+
+ assertThat(assignment).containsKey(p1);
+ assertThat(assignment).containsKey(p2);
+ }
+
+ private BinaryRow getPartitionKey(InternalRow row) {
+ RowPartitionKeyExtractor extractor = new
RowPartitionKeyExtractor(schema);
+ return extractor.partition(row).copy();
+ }
+}
diff --git
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/WeightedRandomAssignmentTest.java
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/WeightedRandomAssignmentTest.java
new file mode 100644
index 0000000000..94478a38d5
--- /dev/null
+++
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/partition/WeightedRandomAssignmentTest.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.sink.partition;
+
+import org.assertj.core.data.Percentage;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link WeightedRandomAssignment}. */
+class WeightedRandomAssignmentTest {
+
+ @Test
+ void testWeightedRandomAssignment() {
+ List<Integer> assignedSubtasks = Arrays.asList(0, 1, 2);
+ List<Long> subtaskWeights = Arrays.asList(1L, 3L, 2L);
+ WeightedRandomAssignment assignment =
+ new WeightedRandomAssignment(assignedSubtasks, subtaskWeights,
new MockRandom());
+
+ Map<Integer, Double> subtaskAssignedCounts = new HashMap<>();
+ int totalRowNum = 200000;
+ for (int i = 0; i < totalRowNum; i++) {
+ subtaskAssignedCounts.merge(assignment.select(), 1.0 /
totalRowNum, Double::sum);
+ }
+
+ assertThat(subtaskAssignedCounts.get(0)).isCloseTo(1.0 / 6,
Percentage.withPercentage(1));
+ assertThat(subtaskAssignedCounts.get(1)).isCloseTo(0.5,
Percentage.withPercentage(1));
+ assertThat(subtaskAssignedCounts.get(2)).isCloseTo(2.0 / 6,
Percentage.withPercentage(1));
+ }
+
+ @Test
+ void testSingleSubtask() {
+ List<Integer> assignedSubtasks = Arrays.asList(3);
+ List<Long> subtaskWeights = Arrays.asList(100L);
+ WeightedRandomAssignment assignment =
+ new WeightedRandomAssignment(assignedSubtasks, subtaskWeights,
new MockRandom());
+
+ for (int i = 0; i < 100; i++) {
+ assertThat(assignment.select()).isEqualTo(3);
+ }
+ }
+}