This is an automated email from the ASF dual-hosted git repository.

stevenwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new e8a28affb8 Flink: Implement data statistics coordinator to aggregate 
data statistics from operator subtasks (#7360)
e8a28affb8 is described below

commit e8a28affb8bc8eb023352da2556be9c1b5e50141
Author: gangy <[email protected]>
AuthorDate: Sun Sep 24 20:57:16 2023 -0700

    Flink: Implement data statistics coordinator to aggregate data statistics 
from operator subtasks (#7360)
---
 ...tisticsEvent.java => AggregatedStatistics.java} |  34 +-
 .../sink/shuffle/AggregatedStatisticsTracker.java  | 133 +++++++
 .../sink/shuffle/DataStatisticsCoordinator.java    | 395 +++++++++++++++++++++
 .../shuffle/DataStatisticsCoordinatorProvider.java |  51 +++
 .../flink/sink/shuffle/DataStatisticsEvent.java    |  32 +-
 .../flink/sink/shuffle/DataStatisticsOperator.java |  55 ++-
 .../flink/sink/shuffle/DataStatisticsOrRecord.java |   3 +-
 .../flink/sink/shuffle/DataStatisticsUtil.java     |  97 +++++
 .../sink/shuffle/TestAggregatedStatistics.java     |  61 ++++
 .../shuffle/TestAggregatedStatisticsTracker.java   | 177 +++++++++
 .../shuffle/TestDataStatisticsCoordinator.java     | 174 +++++++++
 .../TestDataStatisticsCoordinatorProvider.java     | 147 ++++++++
 .../sink/shuffle/TestDataStatisticsOperator.java   |  84 ++---
 13 files changed, 1357 insertions(+), 86 deletions(-)

diff --git 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java
similarity index 53%
copy from 
flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
copy to 
flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java
index 3aba66fd42..157f04b8b0 100644
--- 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
+++ 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java
@@ -18,24 +18,28 @@
  */
 package org.apache.iceberg.flink.sink.shuffle;
 
-import org.apache.flink.annotation.Internal;
-import org.apache.flink.runtime.operators.coordination.OperatorEvent;
+import java.io.Serializable;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 
 /**
- * DataStatisticsEvent is sent between data statistics coordinator and 
operator to transmit data
- * statistics
+ * AggregatedStatistics is used by {@link DataStatisticsCoordinator} to 
collect {@link
+ * DataStatistics} from {@link DataStatisticsOperator} subtasks for specific 
checkpoint. It stores
+ * the merged {@link DataStatistics} result from all reported subtasks.
  */
-@Internal
-class DataStatisticsEvent<D extends DataStatistics<D, S>, S> implements 
OperatorEvent {
-
-  private static final long serialVersionUID = 1L;
+class AggregatedStatistics<D extends DataStatistics<D, S>, S> implements 
Serializable {
 
   private final long checkpointId;
   private final DataStatistics<D, S> dataStatistics;
 
-  DataStatisticsEvent(long checkpointId, DataStatistics<D, S> dataStatistics) {
-    this.checkpointId = checkpointId;
+  AggregatedStatistics(long checkpoint, TypeSerializer<DataStatistics<D, S>> 
statisticsSerializer) {
+    this.checkpointId = checkpoint;
+    this.dataStatistics = statisticsSerializer.createInstance();
+  }
+
+  AggregatedStatistics(long checkpoint, DataStatistics<D, S> dataStatistics) {
+    this.checkpointId = checkpoint;
     this.dataStatistics = dataStatistics;
   }
 
@@ -47,6 +51,16 @@ class DataStatisticsEvent<D extends DataStatistics<D, S>, S> 
implements Operator
     return dataStatistics;
   }
 
+  void mergeDataStatistic(String operatorName, long eventCheckpointId, D 
eventDataStatistics) {
+    Preconditions.checkArgument(
+        checkpointId == eventCheckpointId,
+        "Received unexpected event from operator %s checkpoint %s. Expected 
checkpoint %s",
+        operatorName,
+        eventCheckpointId,
+        checkpointId);
+    dataStatistics.merge(eventDataStatistics);
+  }
+
   @Override
   public String toString() {
     return MoreObjects.toStringHelper(this)
diff --git 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java
 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java
new file mode 100644
index 0000000000..e8ff61dbeb
--- /dev/null
+++ 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.Set;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import 
org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
+import org.apache.iceberg.relocated.com.google.common.collect.Sets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * AggregatedStatisticsTracker is used by {@link DataStatisticsCoordinator} to 
track the in progress
+ * {@link AggregatedStatistics} received from {@link DataStatisticsOperator} 
subtasks for specific
+ * checkpoint.
+ */
+class AggregatedStatisticsTracker<D extends DataStatistics<D, S>, S> {
+  private static final Logger LOG = 
LoggerFactory.getLogger(AggregatedStatisticsTracker.class);
+  private static final double ACCEPT_PARTIAL_AGGR_THRESHOLD = 90;
+  private final String operatorName;
+  private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer;
+  private final int parallelism;
+  private final Set<Integer> inProgressSubtaskSet;
+  private volatile AggregatedStatistics<D, S> inProgressStatistics;
+
+  AggregatedStatisticsTracker(
+      String operatorName,
+      TypeSerializer<DataStatistics<D, S>> statisticsSerializer,
+      int parallelism) {
+    this.operatorName = operatorName;
+    this.statisticsSerializer = statisticsSerializer;
+    this.parallelism = parallelism;
+    this.inProgressSubtaskSet = Sets.newHashSet();
+  }
+
+  AggregatedStatistics<D, S> updateAndCheckCompletion(
+      int subtask, DataStatisticsEvent<D, S> event) {
+    long checkpointId = event.checkpointId();
+
+    if (inProgressStatistics != null && inProgressStatistics.checkpointId() > 
checkpointId) {
+      LOG.info(
+          "Expect data statistics for operator {} checkpoint {}, but receive 
event from older checkpoint {}. Ignore it.",
+          operatorName,
+          inProgressStatistics.checkpointId(),
+          checkpointId);
+      return null;
+    }
+
+    AggregatedStatistics<D, S> completedStatistics = null;
+    if (inProgressStatistics != null && inProgressStatistics.checkpointId() < 
checkpointId) {
+      if ((double) inProgressSubtaskSet.size() / parallelism * 100
+          >= ACCEPT_PARTIAL_AGGR_THRESHOLD) {
+        completedStatistics = inProgressStatistics;
+        LOG.info(
+            "Received data statistics from {} subtasks out of total {} for 
operator {} at checkpoint {}. "
+                + "Complete data statistics aggregation at checkpoint {} as it 
is more than the threshold of {} percentage",
+            inProgressSubtaskSet.size(),
+            parallelism,
+            operatorName,
+            checkpointId,
+            inProgressStatistics.checkpointId(),
+            ACCEPT_PARTIAL_AGGR_THRESHOLD);
+      } else {
+        LOG.info(
+            "Received data statistics from {} subtasks out of total {} for 
operator {} at checkpoint {}. "
+                + "Aborting the incomplete aggregation for checkpoint {}",
+            inProgressSubtaskSet.size(),
+            parallelism,
+            operatorName,
+            checkpointId,
+            inProgressStatistics.checkpointId());
+      }
+
+      inProgressStatistics = null;
+      inProgressSubtaskSet.clear();
+    }
+
+    if (inProgressStatistics == null) {
+      LOG.info("Starting a new data statistics for checkpoint {}", 
checkpointId);
+      inProgressStatistics = new AggregatedStatistics<>(checkpointId, 
statisticsSerializer);
+      inProgressSubtaskSet.clear();
+    }
+
+    if (!inProgressSubtaskSet.add(subtask)) {
+      LOG.debug(
+          "Ignore duplicated data statistics from operator {} subtask {} for 
checkpoint {}.",
+          operatorName,
+          subtask,
+          checkpointId);
+    } else {
+      inProgressStatistics.mergeDataStatistic(
+          operatorName,
+          event.checkpointId(),
+          DataStatisticsUtil.deserializeDataStatistics(
+              event.statisticsBytes(), statisticsSerializer));
+    }
+
+    if (inProgressSubtaskSet.size() == parallelism) {
+      completedStatistics = inProgressStatistics;
+      LOG.info(
+          "Received data statistics from all {} operators {} for checkpoint 
{}. Return last completed aggregator {}.",
+          parallelism,
+          operatorName,
+          inProgressStatistics.checkpointId(),
+          completedStatistics.dataStatistics());
+      inProgressStatistics = new AggregatedStatistics<>(checkpointId + 1, 
statisticsSerializer);
+      inProgressSubtaskSet.clear();
+    }
+
+    return completedStatistics;
+  }
+
+  @VisibleForTesting
+  AggregatedStatistics<D, S> inProgressStatistics() {
+    return inProgressStatistics;
+  }
+}
diff --git 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java
 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java
new file mode 100644
index 0000000000..fcfd798842
--- /dev/null
+++ 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.Map;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
+import org.apache.flink.runtime.operators.coordination.OperatorEvent;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.FatalExitExceptionHandler;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.ThrowableCatchingRunnable;
+import org.apache.flink.util.function.ThrowingRunnable;
+import 
org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
+import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.jetbrains.annotations.NotNull;
+import org.jetbrains.annotations.Nullable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link
+ * DataStatisticsOperator} every subtask and then merge them together. Once 
aggregation for all
+ * subtasks data statistics completes, DataStatisticsCoordinator will send the 
aggregated data
+ * statistics back to {@link DataStatisticsOperator}. In the end a custom 
partitioner will
+ * distribute traffic based on the aggregated data statistics to improve data 
clustering.
+ */
+@Internal
+class DataStatisticsCoordinator<D extends DataStatistics<D, S>, S> implements 
OperatorCoordinator {
+  private static final Logger LOG = 
LoggerFactory.getLogger(DataStatisticsCoordinator.class);
+
+  private final String operatorName;
+  private final ExecutorService coordinatorExecutor;
+  private final OperatorCoordinator.Context operatorCoordinatorContext;
+  private final SubtaskGateways subtaskGateways;
+  private final CoordinatorExecutorThreadFactory coordinatorThreadFactory;
+  private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer;
+  private final transient AggregatedStatisticsTracker<D, S> 
aggregatedStatisticsTracker;
+  private volatile AggregatedStatistics<D, S> completedStatistics;
+  private volatile boolean started;
+
+  DataStatisticsCoordinator(
+      String operatorName,
+      OperatorCoordinator.Context context,
+      TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+    this.operatorName = operatorName;
+    this.coordinatorThreadFactory =
+        new CoordinatorExecutorThreadFactory(
+            "DataStatisticsCoordinator-" + operatorName, 
context.getUserCodeClassloader());
+    this.coordinatorExecutor = 
Executors.newSingleThreadExecutor(coordinatorThreadFactory);
+    this.operatorCoordinatorContext = context;
+    this.subtaskGateways = new SubtaskGateways(operatorName, parallelism());
+    this.statisticsSerializer = statisticsSerializer;
+    this.aggregatedStatisticsTracker =
+        new AggregatedStatisticsTracker<>(operatorName, statisticsSerializer, 
parallelism());
+  }
+
+  @Override
+  public void start() throws Exception {
+    LOG.info("Starting data statistics coordinator: {}.", operatorName);
+    started = true;
+  }
+
+  @Override
+  public void close() throws Exception {
+    coordinatorExecutor.shutdown();
+    LOG.info("Closed data statistics coordinator: {}.", operatorName);
+  }
+
+  @VisibleForTesting
+  void callInCoordinatorThread(Callable<Void> callable, String errorMessage) {
+    ensureStarted();
+    // Ensure the task is done by the coordinator executor.
+    if (!coordinatorThreadFactory.isCurrentThreadCoordinatorThread()) {
+      try {
+        Callable<Void> guardedCallable =
+            () -> {
+              try {
+                return callable.call();
+              } catch (Throwable t) {
+                LOG.error(
+                    "Uncaught Exception in data statistics coordinator: {} 
executor",
+                    operatorName,
+                    t);
+                ExceptionUtils.rethrowException(t);
+                return null;
+              }
+            };
+
+        coordinatorExecutor.submit(guardedCallable).get();
+      } catch (InterruptedException | ExecutionException e) {
+        throw new FlinkRuntimeException(errorMessage, e);
+      }
+    } else {
+      try {
+        callable.call();
+      } catch (Throwable t) {
+        LOG.error(
+            "Uncaught Exception in data statistics coordinator: {} executor", 
operatorName, t);
+        throw new FlinkRuntimeException(errorMessage, t);
+      }
+    }
+  }
+
+  public void runInCoordinatorThread(Runnable runnable) {
+    this.coordinatorExecutor.execute(
+        new ThrowableCatchingRunnable(
+            throwable ->
+                
this.coordinatorThreadFactory.uncaughtException(Thread.currentThread(), 
throwable),
+            runnable));
+  }
+
+  private void runInCoordinatorThread(ThrowingRunnable<Throwable> action, 
String actionString) {
+    ensureStarted();
+    runInCoordinatorThread(
+        () -> {
+          try {
+            action.run();
+          } catch (Throwable t) {
+            ExceptionUtils.rethrowIfFatalErrorOrOOM(t);
+            LOG.error(
+                "Uncaught exception in the data statistics coordinator: {} 
while {}. Triggering job failover",
+                operatorName,
+                actionString,
+                t);
+            operatorCoordinatorContext.failJob(t);
+          }
+        });
+  }
+
+  private void ensureStarted() {
+    Preconditions.checkState(started, "The coordinator of %s has not started 
yet.", operatorName);
+  }
+
+  private int parallelism() {
+    return operatorCoordinatorContext.currentParallelism();
+  }
+
+  private void handleDataStatisticRequest(int subtask, DataStatisticsEvent<D, 
S> event) {
+    AggregatedStatistics<D, S> aggregatedStatistics =
+        aggregatedStatisticsTracker.updateAndCheckCompletion(subtask, event);
+
+    if (aggregatedStatistics != null) {
+      completedStatistics = aggregatedStatistics;
+      sendDataStatisticsToSubtasks(
+          completedStatistics.checkpointId(), 
completedStatistics.dataStatistics());
+    }
+  }
+
+  private void sendDataStatisticsToSubtasks(
+      long checkpointId, DataStatistics<D, S> globalDataStatistics) {
+    callInCoordinatorThread(
+        () -> {
+          DataStatisticsEvent<D, S> dataStatisticsEvent =
+              DataStatisticsEvent.create(checkpointId, globalDataStatistics, 
statisticsSerializer);
+          int parallelism = parallelism();
+          for (int i = 0; i < parallelism; ++i) {
+            
subtaskGateways.getSubtaskGateway(i).sendEvent(dataStatisticsEvent);
+          }
+
+          return null;
+        },
+        String.format(
+            "Failed to send operator %s coordinator global data statistics for 
checkpoint %d",
+            operatorName, checkpointId));
+  }
+
+  @Override
+  @SuppressWarnings("unchecked")
+  public void handleEventFromOperator(int subtask, int attemptNumber, 
OperatorEvent event) {
+    runInCoordinatorThread(
+        () -> {
+          LOG.debug(
+              "Handling event from subtask {} (#{}) of {}: {}",
+              subtask,
+              attemptNumber,
+              operatorName,
+              event);
+          Preconditions.checkArgument(event instanceof DataStatisticsEvent);
+          handleDataStatisticRequest(subtask, ((DataStatisticsEvent<D, S>) 
event));
+        },
+        String.format(
+            "handling operator event %s from subtask %d (#%d)",
+            event.getClass(), subtask, attemptNumber));
+  }
+
+  @Override
+  public void checkpointCoordinator(long checkpointId, 
CompletableFuture<byte[]> resultFuture) {
+    runInCoordinatorThread(
+        () -> {
+          LOG.debug(
+              "Snapshotting data statistics coordinator {} for checkpoint {}",
+              operatorName,
+              checkpointId);
+          resultFuture.complete(
+              DataStatisticsUtil.serializeAggregatedStatistics(
+                  completedStatistics, statisticsSerializer));
+        },
+        String.format("taking checkpoint %d", checkpointId));
+  }
+
+  @Override
+  public void notifyCheckpointComplete(long checkpointId) {}
+
+  @Override
+  public void resetToCheckpoint(long checkpointId, @Nullable byte[] 
checkpointData)
+      throws Exception {
+    Preconditions.checkState(
+        !started, "The coordinator %s can only be reset if it was not yet 
started", operatorName);
+
+    if (checkpointData == null) {
+      LOG.info(
+          "Data statistic coordinator {} has nothing to restore from 
checkpoint {}",
+          operatorName,
+          checkpointId);
+      return;
+    }
+
+    LOG.info(
+        "Restoring data statistic coordinator {} from checkpoint {}", 
operatorName, checkpointId);
+    completedStatistics =
+        DataStatisticsUtil.deserializeAggregatedStatistics(checkpointData, 
statisticsSerializer);
+  }
+
+  @Override
+  public void subtaskReset(int subtask, long checkpointId) {
+    runInCoordinatorThread(
+        () -> {
+          LOG.info(
+              "Operator {} subtask {} is reset to checkpoint {}",
+              operatorName,
+              subtask,
+              checkpointId);
+          Preconditions.checkState(
+              
this.coordinatorThreadFactory.isCurrentThreadCoordinatorThread());
+          subtaskGateways.reset(subtask);
+        },
+        String.format("handling subtask %d recovery to checkpoint %d", 
subtask, checkpointId));
+  }
+
+  @Override
+  public void executionAttemptFailed(int subtask, int attemptNumber, @Nullable 
Throwable reason) {
+    runInCoordinatorThread(
+        () -> {
+          LOG.info(
+              "Unregistering gateway after failure for subtask {} (#{}) of 
data statistic {}",
+              subtask,
+              attemptNumber,
+              operatorName);
+          Preconditions.checkState(
+              
this.coordinatorThreadFactory.isCurrentThreadCoordinatorThread());
+          subtaskGateways.unregisterSubtaskGateway(subtask, attemptNumber);
+        },
+        String.format("handling subtask %d (#%d) failure", subtask, 
attemptNumber));
+  }
+
+  @Override
+  public void executionAttemptReady(int subtask, int attemptNumber, 
SubtaskGateway gateway) {
+    Preconditions.checkArgument(subtask == gateway.getSubtask());
+    Preconditions.checkArgument(attemptNumber == 
gateway.getExecution().getAttemptNumber());
+    runInCoordinatorThread(
+        () -> {
+          Preconditions.checkState(
+              
this.coordinatorThreadFactory.isCurrentThreadCoordinatorThread());
+          subtaskGateways.registerSubtaskGateway(gateway);
+        },
+        String.format(
+            "making event gateway to subtask %d (#%d) available", subtask, 
attemptNumber));
+  }
+
+  @VisibleForTesting
+  AggregatedStatistics<D, S> completedStatistics() {
+    return completedStatistics;
+  }
+
+  private static class SubtaskGateways {
+    private final String operatorName;
+    private final Map<Integer, SubtaskGateway>[] gateways;
+
+    private SubtaskGateways(String operatorName, int parallelism) {
+      this.operatorName = operatorName;
+      gateways = new Map[parallelism];
+
+      for (int i = 0; i < parallelism; ++i) {
+        gateways[i] = Maps.newHashMap();
+      }
+    }
+
+    private void registerSubtaskGateway(OperatorCoordinator.SubtaskGateway 
gateway) {
+      int subtaskIndex = gateway.getSubtask();
+      int attemptNumber = gateway.getExecution().getAttemptNumber();
+      Preconditions.checkState(
+          !gateways[subtaskIndex].containsKey(attemptNumber),
+          "Coordinator of %s already has a subtask gateway for %d (#%d)",
+          operatorName,
+          subtaskIndex,
+          attemptNumber);
+      LOG.debug(
+          "Coordinator of {} registers gateway for subtask {} attempt {}",
+          operatorName,
+          subtaskIndex,
+          attemptNumber);
+      gateways[subtaskIndex].put(attemptNumber, gateway);
+    }
+
+    private void unregisterSubtaskGateway(int subtaskIndex, int attemptNumber) 
{
+      LOG.debug(
+          "Coordinator of {} unregisters gateway for subtask {} attempt {}",
+          operatorName,
+          subtaskIndex,
+          attemptNumber);
+      gateways[subtaskIndex].remove(attemptNumber);
+    }
+
+    private OperatorCoordinator.SubtaskGateway getSubtaskGateway(int 
subtaskIndex) {
+      Preconditions.checkState(
+          gateways[subtaskIndex].size() > 0,
+          "Coordinator of %s subtask %d is not ready yet to receive events",
+          operatorName,
+          subtaskIndex);
+      return Iterables.getOnlyElement(gateways[subtaskIndex].values());
+    }
+
+    private void reset(int subtaskIndex) {
+      gateways[subtaskIndex].clear();
+    }
+  }
+
+  private static class CoordinatorExecutorThreadFactory
+      implements ThreadFactory, Thread.UncaughtExceptionHandler {
+
+    private final String coordinatorThreadName;
+    private final ClassLoader classLoader;
+    private final Thread.UncaughtExceptionHandler errorHandler;
+
+    @javax.annotation.Nullable private Thread thread;
+
+    CoordinatorExecutorThreadFactory(
+        final String coordinatorThreadName, final ClassLoader 
contextClassLoader) {
+      this(coordinatorThreadName, contextClassLoader, 
FatalExitExceptionHandler.INSTANCE);
+    }
+
+    @org.apache.flink.annotation.VisibleForTesting
+    CoordinatorExecutorThreadFactory(
+        final String coordinatorThreadName,
+        final ClassLoader contextClassLoader,
+        final Thread.UncaughtExceptionHandler errorHandler) {
+      this.coordinatorThreadName = coordinatorThreadName;
+      this.classLoader = contextClassLoader;
+      this.errorHandler = errorHandler;
+    }
+
+    @Override
+    public synchronized Thread newThread(@NotNull Runnable runnable) {
+      thread = new Thread(runnable, coordinatorThreadName);
+      thread.setContextClassLoader(classLoader);
+      thread.setUncaughtExceptionHandler(this);
+      return thread;
+    }
+
+    @Override
+    public synchronized void uncaughtException(Thread t, Throwable e) {
+      errorHandler.uncaughtException(t, e);
+    }
+
+    boolean isCurrentThreadCoordinatorThread() {
+      return Thread.currentThread() == thread;
+    }
+  }
+}
diff --git 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java
 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java
new file mode 100644
index 0000000000..47dbfc3cfb
--- /dev/null
+++ 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
+import 
org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator;
+
+/**
+ * DataStatisticsCoordinatorProvider provides the method to create new {@link
+ * DataStatisticsCoordinator}
+ */
+@Internal
+public class DataStatisticsCoordinatorProvider<D extends DataStatistics<D, S>, 
S>
+    extends RecreateOnResetOperatorCoordinator.Provider {
+
+  private final String operatorName;
+  private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer;
+
+  public DataStatisticsCoordinatorProvider(
+      String operatorName,
+      OperatorID operatorID,
+      TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+    super(operatorID);
+    this.operatorName = operatorName;
+    this.statisticsSerializer = statisticsSerializer;
+  }
+
+  @Override
+  public OperatorCoordinator getCoordinator(OperatorCoordinator.Context 
context) {
+    return new DataStatisticsCoordinator<>(operatorName, context, 
statisticsSerializer);
+  }
+}
diff --git 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
index 3aba66fd42..852d2157b8 100644
--- 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
+++ 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
@@ -19,39 +19,39 @@
 package org.apache.iceberg.flink.sink.shuffle;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
-import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
 
 /**
  * DataStatisticsEvent is sent between data statistics coordinator and 
operator to transmit data
- * statistics
+ * statistics in bytes
  */
 @Internal
 class DataStatisticsEvent<D extends DataStatistics<D, S>, S> implements 
OperatorEvent {
 
   private static final long serialVersionUID = 1L;
-
   private final long checkpointId;
-  private final DataStatistics<D, S> dataStatistics;
+  private final byte[] statisticsBytes;
 
-  DataStatisticsEvent(long checkpointId, DataStatistics<D, S> dataStatistics) {
+  private DataStatisticsEvent(long checkpointId, byte[] statisticsBytes) {
     this.checkpointId = checkpointId;
-    this.dataStatistics = dataStatistics;
+    this.statisticsBytes = statisticsBytes;
   }
 
-  long checkpointId() {
-    return checkpointId;
+  static <D extends DataStatistics<D, S>, S> DataStatisticsEvent<D, S> create(
+      long checkpointId,
+      DataStatistics<D, S> dataStatistics,
+      TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+    return new DataStatisticsEvent<>(
+        checkpointId,
+        DataStatisticsUtil.serializeDataStatistics(dataStatistics, 
statisticsSerializer));
   }
 
-  DataStatistics<D, S> dataStatistics() {
-    return dataStatistics;
+  long checkpointId() {
+    return checkpointId;
   }
 
-  @Override
-  public String toString() {
-    return MoreObjects.toStringHelper(this)
-        .add("checkpointId", checkpointId)
-        .add("dataStatistics", dataStatistics)
-        .toString();
+  byte[] statisticsBytes() {
+    return statisticsBytes;
   }
 }
diff --git 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
index 6d4209b02a..d00d5d2e5a 100644
--- 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
+++ 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
@@ -18,6 +18,7 @@
  */
 package org.apache.iceberg.flink.sink.shuffle;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
@@ -40,11 +41,13 @@ import 
org.apache.iceberg.relocated.com.google.common.base.Preconditions;
  * shuffle record to improve data clustering while maintaining relative 
balanced traffic
  * distribution to downstream subtasks.
  */
+@Internal
 class DataStatisticsOperator<D extends DataStatistics<D, S>, S>
     extends AbstractStreamOperator<DataStatisticsOrRecord<D, S>>
     implements OneInputStreamOperator<RowData, DataStatisticsOrRecord<D, S>>, 
OperatorEventHandler {
   private static final long serialVersionUID = 1L;
 
+  private final String operatorName;
   // keySelector will be used to generate key from data for collecting data 
statistics
   private final KeySelector<RowData, RowData> keySelector;
   private final OperatorEventGateway operatorEventGateway;
@@ -54,9 +57,11 @@ class DataStatisticsOperator<D extends DataStatistics<D, S>, 
S>
   private transient ListState<DataStatistics<D, S>> globalStatisticsState;
 
   DataStatisticsOperator(
+      String operatorName,
       KeySelector<RowData, RowData> keySelector,
       OperatorEventGateway operatorEventGateway,
       TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+    this.operatorName = operatorName;
     this.keySelector = keySelector;
     this.operatorEventGateway = operatorEventGateway;
     this.statisticsSerializer = statisticsSerializer;
@@ -75,10 +80,16 @@ class DataStatisticsOperator<D extends DataStatistics<D, 
S>, S>
       int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
       if (globalStatisticsState.get() == null
           || !globalStatisticsState.get().iterator().hasNext()) {
-        LOG.warn("Subtask {} doesn't have global statistics state to restore", 
subtaskIndex);
+        LOG.warn(
+            "Operator {} subtask {} doesn't have global statistics state to 
restore",
+            operatorName,
+            subtaskIndex);
         globalStatistics = statisticsSerializer.createInstance();
       } else {
-        LOG.info("Restoring global statistics state for subtask {}", 
subtaskIndex);
+        LOG.info(
+            "Restoring operator {} global statistics state for subtask {}",
+            operatorName,
+            subtaskIndex);
         globalStatistics = globalStatisticsState.get().iterator().next();
       }
     } else {
@@ -95,12 +106,22 @@ class DataStatisticsOperator<D extends DataStatistics<D, 
S>, S>
   }
 
   @Override
+  @SuppressWarnings("unchecked")
   public void handleOperatorEvent(OperatorEvent event) {
+    int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
     Preconditions.checkArgument(
         event instanceof DataStatisticsEvent,
-        "Received unexpected operator event " + event.getClass());
+        String.format(
+            "Operator %s subtask %s received unexpected operator event %s",
+            operatorName, subtaskIndex, event.getClass()));
     DataStatisticsEvent<D, S> statisticsEvent = (DataStatisticsEvent<D, S>) 
event;
-    globalStatistics = statisticsEvent.dataStatistics();
+    LOG.info(
+        "Operator {} received global data event from coordinator checkpoint 
{}",
+        operatorName,
+        statisticsEvent.checkpointId());
+    globalStatistics =
+        DataStatisticsUtil.deserializeDataStatistics(
+            statisticsEvent.statisticsBytes(), statisticsSerializer);
     output.collect(new 
StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics)));
   }
 
@@ -117,21 +138,39 @@ class DataStatisticsOperator<D extends DataStatistics<D, 
S>, S>
     long checkpointId = context.getCheckpointId();
     int subTaskId = getRuntimeContext().getIndexOfThisSubtask();
     LOG.info(
-        "Taking data statistics operator snapshot for checkpoint {} in subtask 
{}",
+        "Snapshotting data statistics operator {} for checkpoint {} in subtask 
{}",
+        operatorName,
         checkpointId,
         subTaskId);
 
+    // Pass global statistics to partitioners so that all the operators 
refresh statistics
+    // at same checkpoint barrier
+    if (!globalStatistics.isEmpty()) {
+      output.collect(
+          new 
StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics)));
+    }
+
     // Only subtask 0 saves the state so that 
globalStatisticsState(UnionListState) stores
     // an exact copy of globalStatistics
     if (!globalStatistics.isEmpty() && 
getRuntimeContext().getIndexOfThisSubtask() == 0) {
       globalStatisticsState.clear();
-      LOG.info("Saving global statistics {} to state in subtask {}", 
globalStatistics, subTaskId);
+      LOG.info(
+          "Saving operator {} global statistics {} to state in subtask {}",
+          operatorName,
+          globalStatistics,
+          subTaskId);
       globalStatisticsState.add(globalStatistics);
     }
 
-    // For now, we make it simple to send globalStatisticsState at checkpoint
+    // For now, local statistics are sent to coordinator at checkpoint
     operatorEventGateway.sendEventToCoordinator(
-        new DataStatisticsEvent<>(checkpointId, localStatistics));
+        DataStatisticsEvent.create(checkpointId, localStatistics, 
statisticsSerializer));
+    LOG.debug(
+        "Subtask {} of operator {} sent local statistics to coordinator at 
checkpoint{}: {}",
+        subTaskId,
+        operatorName,
+        checkpointId,
+        localStatistics);
 
     // Recreate the local statistics
     localStatistics = statisticsSerializer.createInstance();
diff --git 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java
 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java
index b8ced3a47d..889e85112e 100644
--- 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java
+++ 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java
@@ -43,8 +43,7 @@ class DataStatisticsOrRecord<D extends DataStatistics<D, S>, 
S> implements Seria
 
   private DataStatisticsOrRecord(DataStatistics<D, S> statistics, RowData 
record) {
     Preconditions.checkArgument(
-        record != null ^ statistics != null,
-        "A DataStatisticsOrRecord contain either statistics or record, not 
neither or both");
+        record != null ^ statistics != null, "DataStatistics or record, not 
neither or both");
     this.statistics = statistics;
     this.record = record;
   }
diff --git 
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java
 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java
new file mode 100644
index 0000000000..2737b1346f
--- /dev/null
+++ 
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+
+/**
+ * DataStatisticsUtil is the utility to serialize and deserialize {@link 
DataStatistics} and {@link
+ * AggregatedStatistics}
+ */
+class DataStatisticsUtil {
+
+  private DataStatisticsUtil() {}
+
+  static <D extends DataStatistics<D, S>, S> byte[] serializeDataStatistics(
+      DataStatistics<D, S> dataStatistics,
+      TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+    DataOutputSerializer out = new DataOutputSerializer(64);
+    try {
+      statisticsSerializer.serialize(dataStatistics, out);
+      return out.getCopyOfBuffer();
+    } catch (IOException e) {
+      throw new IllegalStateException("Fail to serialize data statistics", e);
+    }
+  }
+
+  @SuppressWarnings("unchecked")
+  static <D extends DataStatistics<D, S>, S> D deserializeDataStatistics(
+      byte[] bytes, TypeSerializer<DataStatistics<D, S>> statisticsSerializer) 
{
+    DataInputDeserializer input = new DataInputDeserializer(bytes, 0, 
bytes.length);
+    try {
+      return (D) statisticsSerializer.deserialize(input);
+    } catch (IOException e) {
+      throw new IllegalStateException("Fail to deserialize data statistics", 
e);
+    }
+  }
+
+  static <D extends DataStatistics<D, S>, S> byte[] 
serializeAggregatedStatistics(
+      AggregatedStatistics<D, S> aggregatedStatistics,
+      TypeSerializer<DataStatistics<D, S>> statisticsSerializer)
+      throws IOException {
+    ByteArrayOutputStream bytes = new ByteArrayOutputStream();
+    ObjectOutputStream out = new ObjectOutputStream(bytes);
+
+    DataOutputSerializer outSerializer = new DataOutputSerializer(64);
+    out.writeLong(aggregatedStatistics.checkpointId());
+    statisticsSerializer.serialize(aggregatedStatistics.dataStatistics(), 
outSerializer);
+    byte[] statisticsBytes = outSerializer.getCopyOfBuffer();
+    out.writeInt(statisticsBytes.length);
+    out.write(statisticsBytes);
+    out.flush();
+
+    return bytes.toByteArray();
+  }
+
+  @SuppressWarnings("unchecked")
+  static <D extends DataStatistics<D, S>, S>
+      AggregatedStatistics<D, S> deserializeAggregatedStatistics(
+          byte[] bytes, TypeSerializer<DataStatistics<D, S>> 
statisticsSerializer)
+          throws IOException {
+    ByteArrayInputStream bytesIn = new ByteArrayInputStream(bytes);
+    ObjectInputStream in = new ObjectInputStream(bytesIn);
+
+    long completedCheckpointId = in.readLong();
+    int statisticsBytesLength = in.readInt();
+    byte[] statisticsBytes = new byte[statisticsBytesLength];
+    in.readFully(statisticsBytes);
+    DataInputDeserializer input =
+        new DataInputDeserializer(statisticsBytes, 0, statisticsBytesLength);
+    DataStatistics<D, S> dataStatistics = 
statisticsSerializer.deserialize(input);
+
+    return new AggregatedStatistics<>(completedCheckpointId, dataStatistics);
+  }
+}
diff --git 
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java
 
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java
new file mode 100644
index 0000000000..dd7fcafe53
--- /dev/null
+++ 
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.Map;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+import org.junit.Test;
+
+public class TestAggregatedStatistics {
+
+  @Test
+  public void mergeDataStatisticTest() {
+    GenericRowData rowDataA = GenericRowData.of(StringData.fromString("a"));
+    GenericRowData rowDataB = GenericRowData.of(StringData.fromString("b"));
+
+    AggregatedStatistics<MapDataStatistics, Map<RowData, Long>> 
aggregatedStatistics =
+        new AggregatedStatistics<>(
+            1,
+            MapDataStatisticsSerializer.fromKeySerializer(
+                new RowDataSerializer(RowType.of(new VarCharType()))));
+    MapDataStatistics mapDataStatistics1 = new MapDataStatistics();
+    mapDataStatistics1.add(rowDataA);
+    mapDataStatistics1.add(rowDataA);
+    mapDataStatistics1.add(rowDataB);
+    aggregatedStatistics.mergeDataStatistic("testOperator", 1, 
mapDataStatistics1);
+    MapDataStatistics mapDataStatistics2 = new MapDataStatistics();
+    mapDataStatistics2.add(rowDataA);
+    aggregatedStatistics.mergeDataStatistic("testOperator", 1, 
mapDataStatistics2);
+    
assertThat(aggregatedStatistics.dataStatistics().statistics().get(rowDataA))
+        .isEqualTo(
+            mapDataStatistics1.statistics().get(rowDataA)
+                + mapDataStatistics2.statistics().get(rowDataA));
+    
assertThat(aggregatedStatistics.dataStatistics().statistics().get(rowDataB))
+        .isEqualTo(
+            mapDataStatistics1.statistics().get(rowDataB)
+                + mapDataStatistics2.statistics().getOrDefault(rowDataB, 0L));
+  }
+}
diff --git 
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java
 
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java
new file mode 100644
index 0000000000..48e4e4d8f9
--- /dev/null
+++ 
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.Map;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TestAggregatedStatisticsTracker {
+  private static final int NUM_SUBTASKS = 2;
+  private final RowType rowType = RowType.of(new VarCharType());
+  // When coordinator handles events from operator, 
DataStatisticsUtil#deserializeDataStatistics
+  // deserializes bytes into BinaryRowData
+  private final BinaryRowData binaryRowDataA =
+      new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
+  private final BinaryRowData binaryRowDataB =
+      new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
+  private final TypeSerializer<RowData> rowSerializer = new 
RowDataSerializer(rowType);
+  private final TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, 
Long>>>
+      statisticsSerializer = 
MapDataStatisticsSerializer.fromKeySerializer(rowSerializer);
+  private AggregatedStatisticsTracker<MapDataStatistics, Map<RowData, Long>>
+      aggregatedStatisticsTracker;
+
+  @Before
+  public void before() throws Exception {
+    aggregatedStatisticsTracker =
+        new AggregatedStatisticsTracker<>("testOperator", 
statisticsSerializer, NUM_SUBTASKS);
+  }
+
+  @Test
+  public void receiveNewerDataStatisticEvent() {
+    MapDataStatistics checkpoint1Subtask0DataStatistic = new 
MapDataStatistics();
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint1Subtask0DataStatisticEvent =
+            DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, 
statisticsSerializer);
+    assertThat(
+            aggregatedStatisticsTracker.updateAndCheckCompletion(
+                0, checkpoint1Subtask0DataStatisticEvent))
+        .isNull();
+    
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(1);
+
+    MapDataStatistics checkpoint2Subtask0DataStatistic = new 
MapDataStatistics();
+    checkpoint2Subtask0DataStatistic.add(binaryRowDataA);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint2Subtask0DataStatisticEvent =
+            DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, 
statisticsSerializer);
+    assertThat(
+            aggregatedStatisticsTracker.updateAndCheckCompletion(
+                0, checkpoint2Subtask0DataStatisticEvent))
+        .isNull();
+    // Checkpoint 2 is newer than checkpoint1, thus dropping in progress 
statistics for checkpoint1
+    
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(2);
+  }
+
+  @Test
+  public void receiveOlderDataStatisticEventTest() {
+    MapDataStatistics checkpoint2Subtask0DataStatistic = new 
MapDataStatistics();
+    checkpoint2Subtask0DataStatistic.add(binaryRowDataA);
+    checkpoint2Subtask0DataStatistic.add(binaryRowDataB);
+    checkpoint2Subtask0DataStatistic.add(binaryRowDataB);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint3Subtask0DataStatisticEvent =
+            DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, 
statisticsSerializer);
+    assertThat(
+            aggregatedStatisticsTracker.updateAndCheckCompletion(
+                0, checkpoint3Subtask0DataStatisticEvent))
+        .isNull();
+
+    MapDataStatistics checkpoint1Subtask1DataStatistic = new 
MapDataStatistics();
+    checkpoint1Subtask1DataStatistic.add(binaryRowDataB);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint1Subtask1DataStatisticEvent =
+            DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, 
statisticsSerializer);
+    // Receive event from old checkpoint, 
aggregatedStatisticsAggregatorTracker won't return
+    // completed statistics and in progress statistics won't be updated
+    assertThat(
+            aggregatedStatisticsTracker.updateAndCheckCompletion(
+                1, checkpoint1Subtask1DataStatisticEvent))
+        .isNull();
+    
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(2);
+  }
+
+  @Test
+  public void receiveCompletedDataStatisticEvent() {
+    MapDataStatistics checkpoint1Subtask0DataStatistic = new 
MapDataStatistics();
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint1Subtask0DataStatisticEvent =
+            DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, 
statisticsSerializer);
+    assertThat(
+            aggregatedStatisticsTracker.updateAndCheckCompletion(
+                0, checkpoint1Subtask0DataStatisticEvent))
+        .isNull();
+
+    MapDataStatistics checkpoint1Subtask1DataStatistic = new 
MapDataStatistics();
+    checkpoint1Subtask1DataStatistic.add(binaryRowDataA);
+    checkpoint1Subtask1DataStatistic.add(binaryRowDataA);
+    checkpoint1Subtask1DataStatistic.add(binaryRowDataB);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint1Subtask1DataStatisticEvent =
+            DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, 
statisticsSerializer);
+    // Receive data statistics from all subtasks at checkpoint 1
+    AggregatedStatistics<MapDataStatistics, Map<RowData, Long>> 
completedStatistics =
+        aggregatedStatisticsTracker.updateAndCheckCompletion(
+            1, checkpoint1Subtask1DataStatisticEvent);
+
+    assertThat(completedStatistics).isNotNull();
+    assertThat(completedStatistics.checkpointId()).isEqualTo(1);
+    MapDataStatistics globalDataStatistics =
+        (MapDataStatistics) completedStatistics.dataStatistics();
+    assertThat((long) globalDataStatistics.statistics().get(binaryRowDataA))
+        .isEqualTo(
+            checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataA)
+                + 
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataA));
+    assertThat((long) globalDataStatistics.statistics().get(binaryRowDataB))
+        .isEqualTo(
+            checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataB)
+                + 
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataB));
+    
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId())
+        .isEqualTo(completedStatistics.checkpointId() + 1);
+
+    MapDataStatistics checkpoint2Subtask0DataStatistic = new 
MapDataStatistics();
+    checkpoint2Subtask0DataStatistic.add(binaryRowDataA);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint2Subtask0DataStatisticEvent =
+            DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, 
statisticsSerializer);
+    assertThat(
+            aggregatedStatisticsTracker.updateAndCheckCompletion(
+                0, checkpoint2Subtask0DataStatisticEvent))
+        .isNull();
+    assertThat(completedStatistics.checkpointId()).isEqualTo(1);
+
+    MapDataStatistics checkpoint2Subtask1DataStatistic = new 
MapDataStatistics();
+    checkpoint2Subtask1DataStatistic.add(binaryRowDataB);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint2Subtask1DataStatisticEvent =
+            DataStatisticsEvent.create(2, checkpoint2Subtask1DataStatistic, 
statisticsSerializer);
+    // Receive data statistics from all subtasks at checkpoint 2
+    completedStatistics =
+        aggregatedStatisticsTracker.updateAndCheckCompletion(
+            1, checkpoint2Subtask1DataStatisticEvent);
+
+    assertThat(completedStatistics).isNotNull();
+    assertThat(completedStatistics.checkpointId()).isEqualTo(2);
+    
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId())
+        .isEqualTo(completedStatistics.checkpointId() + 1);
+  }
+}
diff --git 
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java
 
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java
new file mode 100644
index 0000000000..9ec2606e10
--- /dev/null
+++ 
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.operators.coordination.EventReceivingTasks;
+import 
org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TestDataStatisticsCoordinator {
+  private static final String OPERATOR_NAME = "TestCoordinator";
+  private static final OperatorID TEST_OPERATOR_ID = new OperatorID(1234L, 
5678L);
+  private static final int NUM_SUBTASKS = 2;
+  private TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>>
+      statisticsSerializer;
+
+  private EventReceivingTasks receivingTasks;
+  private DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>
+      dataStatisticsCoordinator;
+
+  @Before
+  public void before() throws Exception {
+    receivingTasks = EventReceivingTasks.createForRunningTasks();
+    statisticsSerializer =
+        MapDataStatisticsSerializer.fromKeySerializer(
+            new RowDataSerializer(RowType.of(new VarCharType())));
+
+    dataStatisticsCoordinator =
+        new DataStatisticsCoordinator<>(
+            OPERATOR_NAME,
+            new MockOperatorCoordinatorContext(TEST_OPERATOR_ID, NUM_SUBTASKS),
+            statisticsSerializer);
+  }
+
+  private void tasksReady() throws Exception {
+    dataStatisticsCoordinator.start();
+    setAllTasksReady(NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks);
+  }
+
+  @Test
+  public void testThrowExceptionWhenNotStarted() {
+    String failureMessage = "The coordinator of TestCoordinator has not 
started yet.";
+
+    assertThatThrownBy(
+            () ->
+                dataStatisticsCoordinator.handleEventFromOperator(
+                    0,
+                    0,
+                    DataStatisticsEvent.create(0, new MapDataStatistics(), 
statisticsSerializer)))
+        .isInstanceOf(IllegalStateException.class)
+        .hasMessage(failureMessage);
+    assertThatThrownBy(() -> 
dataStatisticsCoordinator.executionAttemptFailed(0, 0, null))
+        .isInstanceOf(IllegalStateException.class)
+        .hasMessage(failureMessage);
+    assertThatThrownBy(() -> 
dataStatisticsCoordinator.checkpointCoordinator(0, null))
+        .isInstanceOf(IllegalStateException.class)
+        .hasMessage(failureMessage);
+  }
+
+  @Test
+  public void testDataStatisticsEventHandling() throws Exception {
+    tasksReady();
+    // When coordinator handles events from operator, 
DataStatisticsUtil#deserializeDataStatistics
+    // deserializes bytes into BinaryRowData
+    RowType rowType = RowType.of(new VarCharType());
+    BinaryRowData binaryRowDataA =
+        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
+    BinaryRowData binaryRowDataB =
+        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
+    BinaryRowData binaryRowDataC =
+        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("c")));
+
+    MapDataStatistics checkpoint1Subtask0DataStatistic = new 
MapDataStatistics();
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint1Subtask0DataStatisticEvent =
+            DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, 
statisticsSerializer);
+    MapDataStatistics checkpoint1Subtask1DataStatistic = new 
MapDataStatistics();
+    checkpoint1Subtask1DataStatistic.add(binaryRowDataA);
+    checkpoint1Subtask1DataStatistic.add(binaryRowDataB);
+    checkpoint1Subtask1DataStatistic.add(binaryRowDataC);
+    checkpoint1Subtask1DataStatistic.add(binaryRowDataC);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint1Subtask1DataStatisticEvent =
+            DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, 
statisticsSerializer);
+    // Handle events from operators for checkpoint 1
+    dataStatisticsCoordinator.handleEventFromOperator(0, 0, 
checkpoint1Subtask0DataStatisticEvent);
+    dataStatisticsCoordinator.handleEventFromOperator(1, 0, 
checkpoint1Subtask1DataStatisticEvent);
+
+    waitForCoordinatorToProcessActions(dataStatisticsCoordinator);
+    // Verify global data statistics is the aggregation of all subtasks data 
statistics
+    MapDataStatistics globalDataStatistics =
+        (MapDataStatistics) 
dataStatisticsCoordinator.completedStatistics().dataStatistics();
+    assertThat(globalDataStatistics.statistics())
+        .containsExactlyInAnyOrderEntriesOf(
+            ImmutableMap.of(
+                binaryRowDataA,
+                
checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataA)
+                    + (long) 
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataA),
+                binaryRowDataB,
+                
checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataB)
+                    + (long) 
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataB),
+                binaryRowDataC,
+                
checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataC)
+                    + (long) 
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataC)));
+  }
+
+  static void setAllTasksReady(
+      int subtasks,
+      DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> 
dataStatisticsCoordinator,
+      EventReceivingTasks receivingTasks) {
+    for (int i = 0; i < subtasks; i++) {
+      dataStatisticsCoordinator.executionAttemptReady(
+          i, 0, receivingTasks.createGatewayForSubtask(i, 0));
+    }
+  }
+
+  static void waitForCoordinatorToProcessActions(
+      DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> 
coordinator) {
+    CompletableFuture<Void> future = new CompletableFuture<>();
+    coordinator.callInCoordinatorThread(
+        () -> {
+          future.complete(null);
+          return null;
+        },
+        "Coordinator fails to process action");
+
+    try {
+      future.get();
+    } catch (InterruptedException e) {
+      throw new AssertionError("test interrupted");
+    } catch (ExecutionException e) {
+      ExceptionUtils.rethrow(ExceptionUtils.stripExecutionException(e));
+    }
+  }
+}
diff --git 
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java
 
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java
new file mode 100644
index 0000000000..cb9d3f48ff
--- /dev/null
+++ 
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.operators.coordination.EventReceivingTasks;
+import 
org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext;
+import 
org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TestDataStatisticsCoordinatorProvider {
+  private static final OperatorID OPERATOR_ID = new OperatorID();
+  private static final int NUM_SUBTASKS = 1;
+
+  private DataStatisticsCoordinatorProvider<MapDataStatistics, Map<RowData, 
Long>> provider;
+  private EventReceivingTasks receivingTasks;
+  private TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>>
+      statisticsSerializer;
+
+  @Before
+  public void before() {
+    statisticsSerializer =
+        MapDataStatisticsSerializer.fromKeySerializer(
+            new RowDataSerializer(RowType.of(new VarCharType())));
+    provider =
+        new DataStatisticsCoordinatorProvider<>(
+            "DataStatisticsCoordinatorProvider", OPERATOR_ID, 
statisticsSerializer);
+    receivingTasks = EventReceivingTasks.createForRunningTasks();
+  }
+
+  @Test
+  @SuppressWarnings("unchecked")
+  public void testCheckpointAndReset() throws Exception {
+    RowType rowType = RowType.of(new VarCharType());
+    // When coordinator handles events from operator, 
DataStatisticsUtil#deserializeDataStatistics
+    // deserializes bytes into BinaryRowData
+    BinaryRowData binaryRowDataA =
+        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
+    BinaryRowData binaryRowDataB =
+        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
+    BinaryRowData binaryRowDataC =
+        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("c")));
+    BinaryRowData binaryRowDataD =
+        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("d")));
+    BinaryRowData binaryRowDataE =
+        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("e")));
+
+    RecreateOnResetOperatorCoordinator coordinator =
+        (RecreateOnResetOperatorCoordinator)
+            provider.create(new MockOperatorCoordinatorContext(OPERATOR_ID, 
NUM_SUBTASKS));
+    DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> 
dataStatisticsCoordinator =
+        (DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>)
+            coordinator.getInternalCoordinator();
+
+    // Start the coordinator
+    coordinator.start();
+    TestDataStatisticsCoordinator.setAllTasksReady(
+        NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks);
+    MapDataStatistics checkpoint1Subtask0DataStatistic = new 
MapDataStatistics();
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
+    checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint1Subtask0DataStatisticEvent =
+            DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, 
statisticsSerializer);
+
+    // Handle events from operators for checkpoint 1
+    coordinator.handleEventFromOperator(0, 0, 
checkpoint1Subtask0DataStatisticEvent);
+    
TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator);
+    // Verify checkpoint 1 global data statistics
+    MapDataStatistics checkpoint1GlobalDataStatistics =
+        (MapDataStatistics) 
dataStatisticsCoordinator.completedStatistics().dataStatistics();
+    assertThat(checkpoint1GlobalDataStatistics.statistics())
+        .isEqualTo(checkpoint1Subtask0DataStatistic.statistics());
+    byte[] checkpoint1Bytes = waitForCheckpoint(1L, dataStatisticsCoordinator);
+
+    MapDataStatistics checkpoint2Subtask0DataStatistic = new 
MapDataStatistics();
+    checkpoint2Subtask0DataStatistic.add(binaryRowDataD);
+    checkpoint2Subtask0DataStatistic.add(binaryRowDataE);
+    checkpoint2Subtask0DataStatistic.add(binaryRowDataE);
+    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+        checkpoint2Subtask0DataStatisticEvent =
+            DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, 
statisticsSerializer);
+    // Handle events from operators for checkpoint 2
+    coordinator.handleEventFromOperator(0, 0, 
checkpoint2Subtask0DataStatisticEvent);
+    
TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator);
+    // Verify checkpoint 2 global data statistics
+    MapDataStatistics checkpoint2GlobalDataStatistics =
+        (MapDataStatistics) 
dataStatisticsCoordinator.completedStatistics().dataStatistics();
+    assertThat(checkpoint2GlobalDataStatistics.statistics())
+        .isEqualTo(checkpoint2Subtask0DataStatistic.statistics());
+    waitForCheckpoint(2L, dataStatisticsCoordinator);
+
+    // Reset coordinator to checkpoint 1
+    coordinator.resetToCheckpoint(1L, checkpoint1Bytes);
+    DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>
+        restoredDataStatisticsCoordinator =
+            (DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>)
+                coordinator.getInternalCoordinator();
+    
assertThat(dataStatisticsCoordinator).isNotEqualTo(restoredDataStatisticsCoordinator);
+    // Verify restored data statistics
+    MapDataStatistics restoredAggregateDataStatistics =
+        (MapDataStatistics)
+            
restoredDataStatisticsCoordinator.completedStatistics().dataStatistics();
+    assertThat(restoredAggregateDataStatistics.statistics())
+        .isEqualTo(checkpoint1GlobalDataStatistics.statistics());
+  }
+
+  private byte[] waitForCheckpoint(
+      long checkpointId,
+      DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> 
coordinator)
+      throws InterruptedException, ExecutionException {
+    CompletableFuture<byte[]> future = new CompletableFuture<>();
+    coordinator.checkpointCoordinator(checkpointId, future);
+    return future.get();
+  }
+}
diff --git 
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
 
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
index 039d70a69d..880cb3d551 100644
--- 
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
+++ 
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
@@ -19,7 +19,6 @@
 package org.apache.iceberg.flink.sink.shuffle;
 
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.junit.Assert.assertTrue;
 
 import java.util.Collections;
 import java.util.List;
@@ -51,13 +50,13 @@ import 
org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.data.binary.BinaryRowData;
 import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
 import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.table.types.logical.VarCharType;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.iceberg.relocated.com.google.common.collect.Lists;
-import org.apache.iceberg.relocated.com.google.common.collect.Maps;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -65,7 +64,18 @@ import org.junit.Test;
 public class TestDataStatisticsOperator {
   private final RowType rowType = RowType.of(new VarCharType());
   private final TypeSerializer<RowData> rowSerializer = new 
RowDataSerializer(rowType);
-
+  private final GenericRowData genericRowDataA = 
GenericRowData.of(StringData.fromString("a"));
+  private final GenericRowData genericRowDataB = 
GenericRowData.of(StringData.fromString("b"));
+  // When operator hands events from coordinator, 
DataStatisticsUtil#deserializeDataStatistics
+  // deserializes bytes into BinaryRowData
+  private final BinaryRowData binaryRowDataA =
+      new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
+  private final BinaryRowData binaryRowDataB =
+      new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
+  private final BinaryRowData binaryRowDataC =
+      new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("c")));
+  private final TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, 
Long>>>
+      statisticsSerializer = 
MapDataStatisticsSerializer.fromKeySerializer(rowSerializer);
   private DataStatisticsOperator<MapDataStatistics, Map<RowData, Long>> 
operator;
 
   private Environment getTestingEnvironment() {
@@ -101,9 +111,8 @@ public class TestDataStatisticsOperator {
           }
         };
 
-    TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> 
statisticsSerializer =
-        MapDataStatisticsSerializer.fromKeySerializer(rowSerializer);
-    return new DataStatisticsOperator<>(keySelector, mockGateway, 
statisticsSerializer);
+    return new DataStatisticsOperator<>(
+        "testOperator", keySelector, mockGateway, statisticsSerializer);
   }
 
   @After
@@ -118,20 +127,16 @@ public class TestDataStatisticsOperator {
         testHarness = createHarness(this.operator)) {
       StateInitializationContext stateContext = getStateContext();
       operator.initializeState(stateContext);
-      operator.processElement(new 
StreamRecord<>(GenericRowData.of(StringData.fromString("a"))));
-      operator.processElement(new 
StreamRecord<>(GenericRowData.of(StringData.fromString("a"))));
-      operator.processElement(new 
StreamRecord<>(GenericRowData.of(StringData.fromString("b"))));
-      assertTrue(operator.localDataStatistics() instanceof MapDataStatistics);
+      operator.processElement(new StreamRecord<>(genericRowDataA));
+      operator.processElement(new StreamRecord<>(genericRowDataA));
+      operator.processElement(new StreamRecord<>(genericRowDataB));
+      
assertThat(operator.localDataStatistics()).isInstanceOf(MapDataStatistics.class);
       MapDataStatistics mapDataStatistics = (MapDataStatistics) 
operator.localDataStatistics();
       Map<RowData, Long> statsMap = mapDataStatistics.statistics();
       assertThat(statsMap).hasSize(2);
       assertThat(statsMap)
           .containsExactlyInAnyOrderEntriesOf(
-              ImmutableMap.of(
-                  GenericRowData.of(StringData.fromString("a")),
-                  2L,
-                  GenericRowData.of(StringData.fromString("b")),
-                  1L));
+              ImmutableMap.of(genericRowDataA, 2L, genericRowDataB, 1L));
       testHarness.endInput();
     }
   }
@@ -141,9 +146,9 @@ public class TestDataStatisticsOperator {
     try (OneInputStreamOperatorTestHarness<
             RowData, DataStatisticsOrRecord<MapDataStatistics, Map<RowData, 
Long>>>
         testHarness = createHarness(this.operator)) {
-      testHarness.processElement(new 
StreamRecord<>(GenericRowData.of(StringData.fromString("a"))));
-      testHarness.processElement(new 
StreamRecord<>(GenericRowData.of(StringData.fromString("b"))));
-      testHarness.processElement(new 
StreamRecord<>(GenericRowData.of(StringData.fromString("b"))));
+      testHarness.processElement(new StreamRecord<>(genericRowDataA));
+      testHarness.processElement(new StreamRecord<>(genericRowDataB));
+      testHarness.processElement(new StreamRecord<>(genericRowDataB));
 
       List<RowData> recordsOutput =
           testHarness.extractOutputValues().stream()
@@ -152,10 +157,7 @@ public class TestDataStatisticsOperator {
               .collect(Collectors.toList());
       assertThat(recordsOutput)
           .containsExactlyInAnyOrderElementsOf(
-              ImmutableList.of(
-                  GenericRowData.of(StringData.fromString("a")),
-                  GenericRowData.of(StringData.fromString("b")),
-                  GenericRowData.of(StringData.fromString("b"))));
+              ImmutableList.of(genericRowDataA, genericRowDataB, 
genericRowDataB));
     }
   }
 
@@ -167,21 +169,16 @@ public class TestDataStatisticsOperator {
         testHarness1 = createHarness(this.operator)) {
       DataStatistics<MapDataStatistics, Map<RowData, Long>> mapDataStatistics =
           new MapDataStatistics();
-      mapDataStatistics.add(GenericRowData.of(StringData.fromString("a")));
-      mapDataStatistics.add(GenericRowData.of(StringData.fromString("a")));
-      mapDataStatistics.add(GenericRowData.of(StringData.fromString("b")));
-      mapDataStatistics.add(GenericRowData.of(StringData.fromString("c")));
-      operator.handleOperatorEvent(new DataStatisticsEvent(0, 
mapDataStatistics));
+      mapDataStatistics.add(binaryRowDataA);
+      mapDataStatistics.add(binaryRowDataA);
+      mapDataStatistics.add(binaryRowDataB);
+      mapDataStatistics.add(binaryRowDataC);
+      operator.handleOperatorEvent(
+          DataStatisticsEvent.create(0, mapDataStatistics, 
statisticsSerializer));
       
assertThat(operator.globalDataStatistics()).isInstanceOf(MapDataStatistics.class);
-      assertThat(((MapDataStatistics) 
operator.globalDataStatistics()).statistics())
+      assertThat(operator.globalDataStatistics().statistics())
           .containsExactlyInAnyOrderEntriesOf(
-              ImmutableMap.of(
-                  GenericRowData.of(StringData.fromString("a")),
-                  2L,
-                  GenericRowData.of(StringData.fromString("b")),
-                  1L,
-                  GenericRowData.of(StringData.fromString("c")),
-                  1L));
+              ImmutableMap.of(binaryRowDataA, 2L, binaryRowDataB, 1L, 
binaryRowDataC, 1L));
       snapshot = testHarness1.snapshot(1L, 0);
     }
 
@@ -195,22 +192,9 @@ public class TestDataStatisticsOperator {
       testHarness2.setup();
       testHarness2.initializeState(snapshot);
       
assertThat(restoredOperator.globalDataStatistics()).isInstanceOf(MapDataStatistics.class);
-      // restored RowData is BinaryRowData. convert to GenericRowData for 
comparison
-      Map<RowData, Long> restoredStatistics = Maps.newHashMap();
-      ((MapDataStatistics) restoredOperator.globalDataStatistics())
-          .statistics()
-          .forEach(
-              (rowData, count) ->
-                  
restoredStatistics.put(GenericRowData.of(rowData.getString(0)), count));
-      assertThat(restoredStatistics)
+      assertThat(restoredOperator.globalDataStatistics().statistics())
           .containsExactlyInAnyOrderEntriesOf(
-              ImmutableMap.of(
-                  GenericRowData.of(StringData.fromString("a")),
-                  2L,
-                  GenericRowData.of(StringData.fromString("b")),
-                  1L,
-                  GenericRowData.of(StringData.fromString("c")),
-                  1L));
+              ImmutableMap.of(binaryRowDataA, 2L, binaryRowDataB, 1L, 
binaryRowDataC, 1L));
     }
   }
 

Reply via email to