This is an automated email from the ASF dual-hosted git repository.
stevenwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new e8a28affb8 Flink: Implement data statistics coordinator to aggregate
data statistics from operator subtasks (#7360)
e8a28affb8 is described below
commit e8a28affb8bc8eb023352da2556be9c1b5e50141
Author: gangy <[email protected]>
AuthorDate: Sun Sep 24 20:57:16 2023 -0700
Flink: Implement data statistics coordinator to aggregate data statistics
from operator subtasks (#7360)
---
...tisticsEvent.java => AggregatedStatistics.java} | 34 +-
.../sink/shuffle/AggregatedStatisticsTracker.java | 133 +++++++
.../sink/shuffle/DataStatisticsCoordinator.java | 395 +++++++++++++++++++++
.../shuffle/DataStatisticsCoordinatorProvider.java | 51 +++
.../flink/sink/shuffle/DataStatisticsEvent.java | 32 +-
.../flink/sink/shuffle/DataStatisticsOperator.java | 55 ++-
.../flink/sink/shuffle/DataStatisticsOrRecord.java | 3 +-
.../flink/sink/shuffle/DataStatisticsUtil.java | 97 +++++
.../sink/shuffle/TestAggregatedStatistics.java | 61 ++++
.../shuffle/TestAggregatedStatisticsTracker.java | 177 +++++++++
.../shuffle/TestDataStatisticsCoordinator.java | 174 +++++++++
.../TestDataStatisticsCoordinatorProvider.java | 147 ++++++++
.../sink/shuffle/TestDataStatisticsOperator.java | 84 ++---
13 files changed, 1357 insertions(+), 86 deletions(-)
diff --git
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java
similarity index 53%
copy from
flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
copy to
flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java
index 3aba66fd42..157f04b8b0 100644
---
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
+++
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java
@@ -18,24 +18,28 @@
*/
package org.apache.iceberg.flink.sink.shuffle;
-import org.apache.flink.annotation.Internal;
-import org.apache.flink.runtime.operators.coordination.OperatorEvent;
+import java.io.Serializable;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
/**
- * DataStatisticsEvent is sent between data statistics coordinator and
operator to transmit data
- * statistics
+ * AggregatedStatistics is used by {@link DataStatisticsCoordinator} to
collect {@link
+ * DataStatistics} from {@link DataStatisticsOperator} subtasks for specific
checkpoint. It stores
+ * the merged {@link DataStatistics} result from all reported subtasks.
*/
-@Internal
-class DataStatisticsEvent<D extends DataStatistics<D, S>, S> implements
OperatorEvent {
-
- private static final long serialVersionUID = 1L;
+class AggregatedStatistics<D extends DataStatistics<D, S>, S> implements
Serializable {
private final long checkpointId;
private final DataStatistics<D, S> dataStatistics;
- DataStatisticsEvent(long checkpointId, DataStatistics<D, S> dataStatistics) {
- this.checkpointId = checkpointId;
+ AggregatedStatistics(long checkpoint, TypeSerializer<DataStatistics<D, S>>
statisticsSerializer) {
+ this.checkpointId = checkpoint;
+ this.dataStatistics = statisticsSerializer.createInstance();
+ }
+
+ AggregatedStatistics(long checkpoint, DataStatistics<D, S> dataStatistics) {
+ this.checkpointId = checkpoint;
this.dataStatistics = dataStatistics;
}
@@ -47,6 +51,16 @@ class DataStatisticsEvent<D extends DataStatistics<D, S>, S>
implements Operator
return dataStatistics;
}
+ void mergeDataStatistic(String operatorName, long eventCheckpointId, D
eventDataStatistics) {
+ Preconditions.checkArgument(
+ checkpointId == eventCheckpointId,
+ "Received unexpected event from operator %s checkpoint %s. Expected
checkpoint %s",
+ operatorName,
+ eventCheckpointId,
+ checkpointId);
+ dataStatistics.merge(eventDataStatistics);
+ }
+
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
diff --git
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java
new file mode 100644
index 0000000000..e8ff61dbeb
--- /dev/null
+++
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.Set;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import
org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
+import org.apache.iceberg.relocated.com.google.common.collect.Sets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * AggregatedStatisticsTracker is used by {@link DataStatisticsCoordinator} to
track the in progress
+ * {@link AggregatedStatistics} received from {@link DataStatisticsOperator}
subtasks for specific
+ * checkpoint.
+ */
+class AggregatedStatisticsTracker<D extends DataStatistics<D, S>, S> {
+ private static final Logger LOG =
LoggerFactory.getLogger(AggregatedStatisticsTracker.class);
+ private static final double ACCEPT_PARTIAL_AGGR_THRESHOLD = 90;
+ private final String operatorName;
+ private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer;
+ private final int parallelism;
+ private final Set<Integer> inProgressSubtaskSet;
+ private volatile AggregatedStatistics<D, S> inProgressStatistics;
+
+ AggregatedStatisticsTracker(
+ String operatorName,
+ TypeSerializer<DataStatistics<D, S>> statisticsSerializer,
+ int parallelism) {
+ this.operatorName = operatorName;
+ this.statisticsSerializer = statisticsSerializer;
+ this.parallelism = parallelism;
+ this.inProgressSubtaskSet = Sets.newHashSet();
+ }
+
+ AggregatedStatistics<D, S> updateAndCheckCompletion(
+ int subtask, DataStatisticsEvent<D, S> event) {
+ long checkpointId = event.checkpointId();
+
+ if (inProgressStatistics != null && inProgressStatistics.checkpointId() >
checkpointId) {
+ LOG.info(
+ "Expect data statistics for operator {} checkpoint {}, but receive
event from older checkpoint {}. Ignore it.",
+ operatorName,
+ inProgressStatistics.checkpointId(),
+ checkpointId);
+ return null;
+ }
+
+ AggregatedStatistics<D, S> completedStatistics = null;
+ if (inProgressStatistics != null && inProgressStatistics.checkpointId() <
checkpointId) {
+ if ((double) inProgressSubtaskSet.size() / parallelism * 100
+ >= ACCEPT_PARTIAL_AGGR_THRESHOLD) {
+ completedStatistics = inProgressStatistics;
+ LOG.info(
+ "Received data statistics from {} subtasks out of total {} for
operator {} at checkpoint {}. "
+ + "Complete data statistics aggregation at checkpoint {} as it
is more than the threshold of {} percentage",
+ inProgressSubtaskSet.size(),
+ parallelism,
+ operatorName,
+ checkpointId,
+ inProgressStatistics.checkpointId(),
+ ACCEPT_PARTIAL_AGGR_THRESHOLD);
+ } else {
+ LOG.info(
+ "Received data statistics from {} subtasks out of total {} for
operator {} at checkpoint {}. "
+ + "Aborting the incomplete aggregation for checkpoint {}",
+ inProgressSubtaskSet.size(),
+ parallelism,
+ operatorName,
+ checkpointId,
+ inProgressStatistics.checkpointId());
+ }
+
+ inProgressStatistics = null;
+ inProgressSubtaskSet.clear();
+ }
+
+ if (inProgressStatistics == null) {
+ LOG.info("Starting a new data statistics for checkpoint {}",
checkpointId);
+ inProgressStatistics = new AggregatedStatistics<>(checkpointId,
statisticsSerializer);
+ inProgressSubtaskSet.clear();
+ }
+
+ if (!inProgressSubtaskSet.add(subtask)) {
+ LOG.debug(
+ "Ignore duplicated data statistics from operator {} subtask {} for
checkpoint {}.",
+ operatorName,
+ subtask,
+ checkpointId);
+ } else {
+ inProgressStatistics.mergeDataStatistic(
+ operatorName,
+ event.checkpointId(),
+ DataStatisticsUtil.deserializeDataStatistics(
+ event.statisticsBytes(), statisticsSerializer));
+ }
+
+ if (inProgressSubtaskSet.size() == parallelism) {
+ completedStatistics = inProgressStatistics;
+ LOG.info(
+ "Received data statistics from all {} operators {} for checkpoint
{}. Return last completed aggregator {}.",
+ parallelism,
+ operatorName,
+ inProgressStatistics.checkpointId(),
+ completedStatistics.dataStatistics());
+ inProgressStatistics = new AggregatedStatistics<>(checkpointId + 1,
statisticsSerializer);
+ inProgressSubtaskSet.clear();
+ }
+
+ return completedStatistics;
+ }
+
+ @VisibleForTesting
+ AggregatedStatistics<D, S> inProgressStatistics() {
+ return inProgressStatistics;
+ }
+}
diff --git
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java
new file mode 100644
index 0000000000..fcfd798842
--- /dev/null
+++
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.Map;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
+import org.apache.flink.runtime.operators.coordination.OperatorEvent;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.FatalExitExceptionHandler;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.ThrowableCatchingRunnable;
+import org.apache.flink.util.function.ThrowingRunnable;
+import
org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
+import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.jetbrains.annotations.NotNull;
+import org.jetbrains.annotations.Nullable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link
+ * DataStatisticsOperator} every subtask and then merge them together. Once
aggregation for all
+ * subtasks data statistics completes, DataStatisticsCoordinator will send the
aggregated data
+ * statistics back to {@link DataStatisticsOperator}. In the end a custom
partitioner will
+ * distribute traffic based on the aggregated data statistics to improve data
clustering.
+ */
+@Internal
+class DataStatisticsCoordinator<D extends DataStatistics<D, S>, S> implements
OperatorCoordinator {
+ private static final Logger LOG =
LoggerFactory.getLogger(DataStatisticsCoordinator.class);
+
+ private final String operatorName;
+ private final ExecutorService coordinatorExecutor;
+ private final OperatorCoordinator.Context operatorCoordinatorContext;
+ private final SubtaskGateways subtaskGateways;
+ private final CoordinatorExecutorThreadFactory coordinatorThreadFactory;
+ private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer;
+ private final transient AggregatedStatisticsTracker<D, S>
aggregatedStatisticsTracker;
+ private volatile AggregatedStatistics<D, S> completedStatistics;
+ private volatile boolean started;
+
+ DataStatisticsCoordinator(
+ String operatorName,
+ OperatorCoordinator.Context context,
+ TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+ this.operatorName = operatorName;
+ this.coordinatorThreadFactory =
+ new CoordinatorExecutorThreadFactory(
+ "DataStatisticsCoordinator-" + operatorName,
context.getUserCodeClassloader());
+ this.coordinatorExecutor =
Executors.newSingleThreadExecutor(coordinatorThreadFactory);
+ this.operatorCoordinatorContext = context;
+ this.subtaskGateways = new SubtaskGateways(operatorName, parallelism());
+ this.statisticsSerializer = statisticsSerializer;
+ this.aggregatedStatisticsTracker =
+ new AggregatedStatisticsTracker<>(operatorName, statisticsSerializer,
parallelism());
+ }
+
+ @Override
+ public void start() throws Exception {
+ LOG.info("Starting data statistics coordinator: {}.", operatorName);
+ started = true;
+ }
+
+ @Override
+ public void close() throws Exception {
+ coordinatorExecutor.shutdown();
+ LOG.info("Closed data statistics coordinator: {}.", operatorName);
+ }
+
+ @VisibleForTesting
+ void callInCoordinatorThread(Callable<Void> callable, String errorMessage) {
+ ensureStarted();
+ // Ensure the task is done by the coordinator executor.
+ if (!coordinatorThreadFactory.isCurrentThreadCoordinatorThread()) {
+ try {
+ Callable<Void> guardedCallable =
+ () -> {
+ try {
+ return callable.call();
+ } catch (Throwable t) {
+ LOG.error(
+ "Uncaught Exception in data statistics coordinator: {}
executor",
+ operatorName,
+ t);
+ ExceptionUtils.rethrowException(t);
+ return null;
+ }
+ };
+
+ coordinatorExecutor.submit(guardedCallable).get();
+ } catch (InterruptedException | ExecutionException e) {
+ throw new FlinkRuntimeException(errorMessage, e);
+ }
+ } else {
+ try {
+ callable.call();
+ } catch (Throwable t) {
+ LOG.error(
+ "Uncaught Exception in data statistics coordinator: {} executor",
operatorName, t);
+ throw new FlinkRuntimeException(errorMessage, t);
+ }
+ }
+ }
+
+ public void runInCoordinatorThread(Runnable runnable) {
+ this.coordinatorExecutor.execute(
+ new ThrowableCatchingRunnable(
+ throwable ->
+
this.coordinatorThreadFactory.uncaughtException(Thread.currentThread(),
throwable),
+ runnable));
+ }
+
+ private void runInCoordinatorThread(ThrowingRunnable<Throwable> action,
String actionString) {
+ ensureStarted();
+ runInCoordinatorThread(
+ () -> {
+ try {
+ action.run();
+ } catch (Throwable t) {
+ ExceptionUtils.rethrowIfFatalErrorOrOOM(t);
+ LOG.error(
+ "Uncaught exception in the data statistics coordinator: {}
while {}. Triggering job failover",
+ operatorName,
+ actionString,
+ t);
+ operatorCoordinatorContext.failJob(t);
+ }
+ });
+ }
+
+ private void ensureStarted() {
+ Preconditions.checkState(started, "The coordinator of %s has not started
yet.", operatorName);
+ }
+
+ private int parallelism() {
+ return operatorCoordinatorContext.currentParallelism();
+ }
+
+ private void handleDataStatisticRequest(int subtask, DataStatisticsEvent<D,
S> event) {
+ AggregatedStatistics<D, S> aggregatedStatistics =
+ aggregatedStatisticsTracker.updateAndCheckCompletion(subtask, event);
+
+ if (aggregatedStatistics != null) {
+ completedStatistics = aggregatedStatistics;
+ sendDataStatisticsToSubtasks(
+ completedStatistics.checkpointId(),
completedStatistics.dataStatistics());
+ }
+ }
+
+ private void sendDataStatisticsToSubtasks(
+ long checkpointId, DataStatistics<D, S> globalDataStatistics) {
+ callInCoordinatorThread(
+ () -> {
+ DataStatisticsEvent<D, S> dataStatisticsEvent =
+ DataStatisticsEvent.create(checkpointId, globalDataStatistics,
statisticsSerializer);
+ int parallelism = parallelism();
+ for (int i = 0; i < parallelism; ++i) {
+
subtaskGateways.getSubtaskGateway(i).sendEvent(dataStatisticsEvent);
+ }
+
+ return null;
+ },
+ String.format(
+ "Failed to send operator %s coordinator global data statistics for
checkpoint %d",
+ operatorName, checkpointId));
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void handleEventFromOperator(int subtask, int attemptNumber,
OperatorEvent event) {
+ runInCoordinatorThread(
+ () -> {
+ LOG.debug(
+ "Handling event from subtask {} (#{}) of {}: {}",
+ subtask,
+ attemptNumber,
+ operatorName,
+ event);
+ Preconditions.checkArgument(event instanceof DataStatisticsEvent);
+ handleDataStatisticRequest(subtask, ((DataStatisticsEvent<D, S>)
event));
+ },
+ String.format(
+ "handling operator event %s from subtask %d (#%d)",
+ event.getClass(), subtask, attemptNumber));
+ }
+
+ @Override
+ public void checkpointCoordinator(long checkpointId,
CompletableFuture<byte[]> resultFuture) {
+ runInCoordinatorThread(
+ () -> {
+ LOG.debug(
+ "Snapshotting data statistics coordinator {} for checkpoint {}",
+ operatorName,
+ checkpointId);
+ resultFuture.complete(
+ DataStatisticsUtil.serializeAggregatedStatistics(
+ completedStatistics, statisticsSerializer));
+ },
+ String.format("taking checkpoint %d", checkpointId));
+ }
+
+ @Override
+ public void notifyCheckpointComplete(long checkpointId) {}
+
+ @Override
+ public void resetToCheckpoint(long checkpointId, @Nullable byte[]
checkpointData)
+ throws Exception {
+ Preconditions.checkState(
+ !started, "The coordinator %s can only be reset if it was not yet
started", operatorName);
+
+ if (checkpointData == null) {
+ LOG.info(
+ "Data statistic coordinator {} has nothing to restore from
checkpoint {}",
+ operatorName,
+ checkpointId);
+ return;
+ }
+
+ LOG.info(
+ "Restoring data statistic coordinator {} from checkpoint {}",
operatorName, checkpointId);
+ completedStatistics =
+ DataStatisticsUtil.deserializeAggregatedStatistics(checkpointData,
statisticsSerializer);
+ }
+
+ @Override
+ public void subtaskReset(int subtask, long checkpointId) {
+ runInCoordinatorThread(
+ () -> {
+ LOG.info(
+ "Operator {} subtask {} is reset to checkpoint {}",
+ operatorName,
+ subtask,
+ checkpointId);
+ Preconditions.checkState(
+
this.coordinatorThreadFactory.isCurrentThreadCoordinatorThread());
+ subtaskGateways.reset(subtask);
+ },
+ String.format("handling subtask %d recovery to checkpoint %d",
subtask, checkpointId));
+ }
+
+ @Override
+ public void executionAttemptFailed(int subtask, int attemptNumber, @Nullable
Throwable reason) {
+ runInCoordinatorThread(
+ () -> {
+ LOG.info(
+ "Unregistering gateway after failure for subtask {} (#{}) of
data statistic {}",
+ subtask,
+ attemptNumber,
+ operatorName);
+ Preconditions.checkState(
+
this.coordinatorThreadFactory.isCurrentThreadCoordinatorThread());
+ subtaskGateways.unregisterSubtaskGateway(subtask, attemptNumber);
+ },
+ String.format("handling subtask %d (#%d) failure", subtask,
attemptNumber));
+ }
+
+ @Override
+ public void executionAttemptReady(int subtask, int attemptNumber,
SubtaskGateway gateway) {
+ Preconditions.checkArgument(subtask == gateway.getSubtask());
+ Preconditions.checkArgument(attemptNumber ==
gateway.getExecution().getAttemptNumber());
+ runInCoordinatorThread(
+ () -> {
+ Preconditions.checkState(
+
this.coordinatorThreadFactory.isCurrentThreadCoordinatorThread());
+ subtaskGateways.registerSubtaskGateway(gateway);
+ },
+ String.format(
+ "making event gateway to subtask %d (#%d) available", subtask,
attemptNumber));
+ }
+
+ @VisibleForTesting
+ AggregatedStatistics<D, S> completedStatistics() {
+ return completedStatistics;
+ }
+
+ private static class SubtaskGateways {
+ private final String operatorName;
+ private final Map<Integer, SubtaskGateway>[] gateways;
+
+ private SubtaskGateways(String operatorName, int parallelism) {
+ this.operatorName = operatorName;
+ gateways = new Map[parallelism];
+
+ for (int i = 0; i < parallelism; ++i) {
+ gateways[i] = Maps.newHashMap();
+ }
+ }
+
+ private void registerSubtaskGateway(OperatorCoordinator.SubtaskGateway
gateway) {
+ int subtaskIndex = gateway.getSubtask();
+ int attemptNumber = gateway.getExecution().getAttemptNumber();
+ Preconditions.checkState(
+ !gateways[subtaskIndex].containsKey(attemptNumber),
+ "Coordinator of %s already has a subtask gateway for %d (#%d)",
+ operatorName,
+ subtaskIndex,
+ attemptNumber);
+ LOG.debug(
+ "Coordinator of {} registers gateway for subtask {} attempt {}",
+ operatorName,
+ subtaskIndex,
+ attemptNumber);
+ gateways[subtaskIndex].put(attemptNumber, gateway);
+ }
+
+ private void unregisterSubtaskGateway(int subtaskIndex, int attemptNumber)
{
+ LOG.debug(
+ "Coordinator of {} unregisters gateway for subtask {} attempt {}",
+ operatorName,
+ subtaskIndex,
+ attemptNumber);
+ gateways[subtaskIndex].remove(attemptNumber);
+ }
+
+ private OperatorCoordinator.SubtaskGateway getSubtaskGateway(int
subtaskIndex) {
+ Preconditions.checkState(
+ gateways[subtaskIndex].size() > 0,
+ "Coordinator of %s subtask %d is not ready yet to receive events",
+ operatorName,
+ subtaskIndex);
+ return Iterables.getOnlyElement(gateways[subtaskIndex].values());
+ }
+
+ private void reset(int subtaskIndex) {
+ gateways[subtaskIndex].clear();
+ }
+ }
+
+ private static class CoordinatorExecutorThreadFactory
+ implements ThreadFactory, Thread.UncaughtExceptionHandler {
+
+ private final String coordinatorThreadName;
+ private final ClassLoader classLoader;
+ private final Thread.UncaughtExceptionHandler errorHandler;
+
+ @javax.annotation.Nullable private Thread thread;
+
+ CoordinatorExecutorThreadFactory(
+ final String coordinatorThreadName, final ClassLoader
contextClassLoader) {
+ this(coordinatorThreadName, contextClassLoader,
FatalExitExceptionHandler.INSTANCE);
+ }
+
+ @org.apache.flink.annotation.VisibleForTesting
+ CoordinatorExecutorThreadFactory(
+ final String coordinatorThreadName,
+ final ClassLoader contextClassLoader,
+ final Thread.UncaughtExceptionHandler errorHandler) {
+ this.coordinatorThreadName = coordinatorThreadName;
+ this.classLoader = contextClassLoader;
+ this.errorHandler = errorHandler;
+ }
+
+ @Override
+ public synchronized Thread newThread(@NotNull Runnable runnable) {
+ thread = new Thread(runnable, coordinatorThreadName);
+ thread.setContextClassLoader(classLoader);
+ thread.setUncaughtExceptionHandler(this);
+ return thread;
+ }
+
+ @Override
+ public synchronized void uncaughtException(Thread t, Throwable e) {
+ errorHandler.uncaughtException(t, e);
+ }
+
+ boolean isCurrentThreadCoordinatorThread() {
+ return Thread.currentThread() == thread;
+ }
+ }
+}
diff --git
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java
new file mode 100644
index 0000000000..47dbfc3cfb
--- /dev/null
+++
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
+import
org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator;
+
+/**
+ * DataStatisticsCoordinatorProvider provides the method to create new {@link
+ * DataStatisticsCoordinator}
+ */
+@Internal
+public class DataStatisticsCoordinatorProvider<D extends DataStatistics<D, S>,
S>
+ extends RecreateOnResetOperatorCoordinator.Provider {
+
+ private final String operatorName;
+ private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer;
+
+ public DataStatisticsCoordinatorProvider(
+ String operatorName,
+ OperatorID operatorID,
+ TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+ super(operatorID);
+ this.operatorName = operatorName;
+ this.statisticsSerializer = statisticsSerializer;
+ }
+
+ @Override
+ public OperatorCoordinator getCoordinator(OperatorCoordinator.Context
context) {
+ return new DataStatisticsCoordinator<>(operatorName, context,
statisticsSerializer);
+ }
+}
diff --git
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
index 3aba66fd42..852d2157b8 100644
---
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
+++
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java
@@ -19,39 +19,39 @@
package org.apache.iceberg.flink.sink.shuffle;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.operators.coordination.OperatorEvent;
-import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
/**
* DataStatisticsEvent is sent between data statistics coordinator and
operator to transmit data
- * statistics
+ * statistics in bytes
*/
@Internal
class DataStatisticsEvent<D extends DataStatistics<D, S>, S> implements
OperatorEvent {
private static final long serialVersionUID = 1L;
-
private final long checkpointId;
- private final DataStatistics<D, S> dataStatistics;
+ private final byte[] statisticsBytes;
- DataStatisticsEvent(long checkpointId, DataStatistics<D, S> dataStatistics) {
+ private DataStatisticsEvent(long checkpointId, byte[] statisticsBytes) {
this.checkpointId = checkpointId;
- this.dataStatistics = dataStatistics;
+ this.statisticsBytes = statisticsBytes;
}
- long checkpointId() {
- return checkpointId;
+ static <D extends DataStatistics<D, S>, S> DataStatisticsEvent<D, S> create(
+ long checkpointId,
+ DataStatistics<D, S> dataStatistics,
+ TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+ return new DataStatisticsEvent<>(
+ checkpointId,
+ DataStatisticsUtil.serializeDataStatistics(dataStatistics,
statisticsSerializer));
}
- DataStatistics<D, S> dataStatistics() {
- return dataStatistics;
+ long checkpointId() {
+ return checkpointId;
}
- @Override
- public String toString() {
- return MoreObjects.toStringHelper(this)
- .add("checkpointId", checkpointId)
- .add("dataStatistics", dataStatistics)
- .toString();
+ byte[] statisticsBytes() {
+ return statisticsBytes;
}
}
diff --git
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
index 6d4209b02a..d00d5d2e5a 100644
---
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
+++
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
@@ -18,6 +18,7 @@
*/
package org.apache.iceberg.flink.sink.shuffle;
+import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
@@ -40,11 +41,13 @@ import
org.apache.iceberg.relocated.com.google.common.base.Preconditions;
* shuffle record to improve data clustering while maintaining relative
balanced traffic
* distribution to downstream subtasks.
*/
+@Internal
class DataStatisticsOperator<D extends DataStatistics<D, S>, S>
extends AbstractStreamOperator<DataStatisticsOrRecord<D, S>>
implements OneInputStreamOperator<RowData, DataStatisticsOrRecord<D, S>>,
OperatorEventHandler {
private static final long serialVersionUID = 1L;
+ private final String operatorName;
// keySelector will be used to generate key from data for collecting data
statistics
private final KeySelector<RowData, RowData> keySelector;
private final OperatorEventGateway operatorEventGateway;
@@ -54,9 +57,11 @@ class DataStatisticsOperator<D extends DataStatistics<D, S>,
S>
private transient ListState<DataStatistics<D, S>> globalStatisticsState;
DataStatisticsOperator(
+ String operatorName,
KeySelector<RowData, RowData> keySelector,
OperatorEventGateway operatorEventGateway,
TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+ this.operatorName = operatorName;
this.keySelector = keySelector;
this.operatorEventGateway = operatorEventGateway;
this.statisticsSerializer = statisticsSerializer;
@@ -75,10 +80,16 @@ class DataStatisticsOperator<D extends DataStatistics<D,
S>, S>
int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
if (globalStatisticsState.get() == null
|| !globalStatisticsState.get().iterator().hasNext()) {
- LOG.warn("Subtask {} doesn't have global statistics state to restore",
subtaskIndex);
+ LOG.warn(
+ "Operator {} subtask {} doesn't have global statistics state to
restore",
+ operatorName,
+ subtaskIndex);
globalStatistics = statisticsSerializer.createInstance();
} else {
- LOG.info("Restoring global statistics state for subtask {}",
subtaskIndex);
+ LOG.info(
+ "Restoring operator {} global statistics state for subtask {}",
+ operatorName,
+ subtaskIndex);
globalStatistics = globalStatisticsState.get().iterator().next();
}
} else {
@@ -95,12 +106,22 @@ class DataStatisticsOperator<D extends DataStatistics<D,
S>, S>
}
@Override
+ @SuppressWarnings("unchecked")
public void handleOperatorEvent(OperatorEvent event) {
+ int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
Preconditions.checkArgument(
event instanceof DataStatisticsEvent,
- "Received unexpected operator event " + event.getClass());
+ String.format(
+ "Operator %s subtask %s received unexpected operator event %s",
+ operatorName, subtaskIndex, event.getClass()));
DataStatisticsEvent<D, S> statisticsEvent = (DataStatisticsEvent<D, S>)
event;
- globalStatistics = statisticsEvent.dataStatistics();
+ LOG.info(
+ "Operator {} received global data event from coordinator checkpoint
{}",
+ operatorName,
+ statisticsEvent.checkpointId());
+ globalStatistics =
+ DataStatisticsUtil.deserializeDataStatistics(
+ statisticsEvent.statisticsBytes(), statisticsSerializer);
output.collect(new
StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics)));
}
@@ -117,21 +138,39 @@ class DataStatisticsOperator<D extends DataStatistics<D,
S>, S>
long checkpointId = context.getCheckpointId();
int subTaskId = getRuntimeContext().getIndexOfThisSubtask();
LOG.info(
- "Taking data statistics operator snapshot for checkpoint {} in subtask
{}",
+ "Snapshotting data statistics operator {} for checkpoint {} in subtask
{}",
+ operatorName,
checkpointId,
subTaskId);
+ // Pass global statistics to partitioners so that all the operators
refresh statistics
+ // at same checkpoint barrier
+ if (!globalStatistics.isEmpty()) {
+ output.collect(
+ new
StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics)));
+ }
+
// Only subtask 0 saves the state so that
globalStatisticsState(UnionListState) stores
// an exact copy of globalStatistics
if (!globalStatistics.isEmpty() &&
getRuntimeContext().getIndexOfThisSubtask() == 0) {
globalStatisticsState.clear();
- LOG.info("Saving global statistics {} to state in subtask {}",
globalStatistics, subTaskId);
+ LOG.info(
+ "Saving operator {} global statistics {} to state in subtask {}",
+ operatorName,
+ globalStatistics,
+ subTaskId);
globalStatisticsState.add(globalStatistics);
}
- // For now, we make it simple to send globalStatisticsState at checkpoint
+ // For now, local statistics are sent to coordinator at checkpoint
operatorEventGateway.sendEventToCoordinator(
- new DataStatisticsEvent<>(checkpointId, localStatistics));
+ DataStatisticsEvent.create(checkpointId, localStatistics,
statisticsSerializer));
+ LOG.debug(
+ "Subtask {} of operator {} sent local statistics to coordinator at
checkpoint{}: {}",
+ subTaskId,
+ operatorName,
+ checkpointId,
+ localStatistics);
// Recreate the local statistics
localStatistics = statisticsSerializer.createInstance();
diff --git
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java
index b8ced3a47d..889e85112e 100644
---
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java
+++
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java
@@ -43,8 +43,7 @@ class DataStatisticsOrRecord<D extends DataStatistics<D, S>,
S> implements Seria
private DataStatisticsOrRecord(DataStatistics<D, S> statistics, RowData
record) {
Preconditions.checkArgument(
- record != null ^ statistics != null,
- "A DataStatisticsOrRecord contain either statistics or record, not
neither or both");
+ record != null ^ statistics != null, "DataStatistics or record, not
neither or both");
this.statistics = statistics;
this.record = record;
}
diff --git
a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java
new file mode 100644
index 0000000000..2737b1346f
--- /dev/null
+++
b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+
+/**
+ * DataStatisticsUtil is the utility to serialize and deserialize {@link
DataStatistics} and {@link
+ * AggregatedStatistics}
+ */
+class DataStatisticsUtil {
+
+ private DataStatisticsUtil() {}
+
+ static <D extends DataStatistics<D, S>, S> byte[] serializeDataStatistics(
+ DataStatistics<D, S> dataStatistics,
+ TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+ DataOutputSerializer out = new DataOutputSerializer(64);
+ try {
+ statisticsSerializer.serialize(dataStatistics, out);
+ return out.getCopyOfBuffer();
+ } catch (IOException e) {
+ throw new IllegalStateException("Fail to serialize data statistics", e);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ static <D extends DataStatistics<D, S>, S> D deserializeDataStatistics(
+ byte[] bytes, TypeSerializer<DataStatistics<D, S>> statisticsSerializer)
{
+ DataInputDeserializer input = new DataInputDeserializer(bytes, 0,
bytes.length);
+ try {
+ return (D) statisticsSerializer.deserialize(input);
+ } catch (IOException e) {
+ throw new IllegalStateException("Fail to deserialize data statistics",
e);
+ }
+ }
+
+ static <D extends DataStatistics<D, S>, S> byte[]
serializeAggregatedStatistics(
+ AggregatedStatistics<D, S> aggregatedStatistics,
+ TypeSerializer<DataStatistics<D, S>> statisticsSerializer)
+ throws IOException {
+ ByteArrayOutputStream bytes = new ByteArrayOutputStream();
+ ObjectOutputStream out = new ObjectOutputStream(bytes);
+
+ DataOutputSerializer outSerializer = new DataOutputSerializer(64);
+ out.writeLong(aggregatedStatistics.checkpointId());
+ statisticsSerializer.serialize(aggregatedStatistics.dataStatistics(),
outSerializer);
+ byte[] statisticsBytes = outSerializer.getCopyOfBuffer();
+ out.writeInt(statisticsBytes.length);
+ out.write(statisticsBytes);
+ out.flush();
+
+ return bytes.toByteArray();
+ }
+
+ @SuppressWarnings("unchecked")
+ static <D extends DataStatistics<D, S>, S>
+ AggregatedStatistics<D, S> deserializeAggregatedStatistics(
+ byte[] bytes, TypeSerializer<DataStatistics<D, S>>
statisticsSerializer)
+ throws IOException {
+ ByteArrayInputStream bytesIn = new ByteArrayInputStream(bytes);
+ ObjectInputStream in = new ObjectInputStream(bytesIn);
+
+ long completedCheckpointId = in.readLong();
+ int statisticsBytesLength = in.readInt();
+ byte[] statisticsBytes = new byte[statisticsBytesLength];
+ in.readFully(statisticsBytes);
+ DataInputDeserializer input =
+ new DataInputDeserializer(statisticsBytes, 0, statisticsBytesLength);
+ DataStatistics<D, S> dataStatistics =
statisticsSerializer.deserialize(input);
+
+ return new AggregatedStatistics<>(completedCheckpointId, dataStatistics);
+ }
+}
diff --git
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java
new file mode 100644
index 0000000000..dd7fcafe53
--- /dev/null
+++
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.Map;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+import org.junit.Test;
+
+public class TestAggregatedStatistics {
+
+ @Test
+ public void mergeDataStatisticTest() {
+ GenericRowData rowDataA = GenericRowData.of(StringData.fromString("a"));
+ GenericRowData rowDataB = GenericRowData.of(StringData.fromString("b"));
+
+ AggregatedStatistics<MapDataStatistics, Map<RowData, Long>>
aggregatedStatistics =
+ new AggregatedStatistics<>(
+ 1,
+ MapDataStatisticsSerializer.fromKeySerializer(
+ new RowDataSerializer(RowType.of(new VarCharType()))));
+ MapDataStatistics mapDataStatistics1 = new MapDataStatistics();
+ mapDataStatistics1.add(rowDataA);
+ mapDataStatistics1.add(rowDataA);
+ mapDataStatistics1.add(rowDataB);
+ aggregatedStatistics.mergeDataStatistic("testOperator", 1,
mapDataStatistics1);
+ MapDataStatistics mapDataStatistics2 = new MapDataStatistics();
+ mapDataStatistics2.add(rowDataA);
+ aggregatedStatistics.mergeDataStatistic("testOperator", 1,
mapDataStatistics2);
+
assertThat(aggregatedStatistics.dataStatistics().statistics().get(rowDataA))
+ .isEqualTo(
+ mapDataStatistics1.statistics().get(rowDataA)
+ + mapDataStatistics2.statistics().get(rowDataA));
+
assertThat(aggregatedStatistics.dataStatistics().statistics().get(rowDataB))
+ .isEqualTo(
+ mapDataStatistics1.statistics().get(rowDataB)
+ + mapDataStatistics2.statistics().getOrDefault(rowDataB, 0L));
+ }
+}
diff --git
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java
new file mode 100644
index 0000000000..48e4e4d8f9
--- /dev/null
+++
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.Map;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TestAggregatedStatisticsTracker {
+ private static final int NUM_SUBTASKS = 2;
+ private final RowType rowType = RowType.of(new VarCharType());
+ // When coordinator handles events from operator,
DataStatisticsUtil#deserializeDataStatistics
+ // deserializes bytes into BinaryRowData
+ private final BinaryRowData binaryRowDataA =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
+ private final BinaryRowData binaryRowDataB =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
+ private final TypeSerializer<RowData> rowSerializer = new
RowDataSerializer(rowType);
+ private final TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData,
Long>>>
+ statisticsSerializer =
MapDataStatisticsSerializer.fromKeySerializer(rowSerializer);
+ private AggregatedStatisticsTracker<MapDataStatistics, Map<RowData, Long>>
+ aggregatedStatisticsTracker;
+
+ @Before
+ public void before() throws Exception {
+ aggregatedStatisticsTracker =
+ new AggregatedStatisticsTracker<>("testOperator",
statisticsSerializer, NUM_SUBTASKS);
+ }
+
+ @Test
+ public void receiveNewerDataStatisticEvent() {
+ MapDataStatistics checkpoint1Subtask0DataStatistic = new
MapDataStatistics();
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint1Subtask0DataStatisticEvent =
+ DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic,
statisticsSerializer);
+ assertThat(
+ aggregatedStatisticsTracker.updateAndCheckCompletion(
+ 0, checkpoint1Subtask0DataStatisticEvent))
+ .isNull();
+
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(1);
+
+ MapDataStatistics checkpoint2Subtask0DataStatistic = new
MapDataStatistics();
+ checkpoint2Subtask0DataStatistic.add(binaryRowDataA);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint2Subtask0DataStatisticEvent =
+ DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic,
statisticsSerializer);
+ assertThat(
+ aggregatedStatisticsTracker.updateAndCheckCompletion(
+ 0, checkpoint2Subtask0DataStatisticEvent))
+ .isNull();
+ // Checkpoint 2 is newer than checkpoint1, thus dropping in progress
statistics for checkpoint1
+
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(2);
+ }
+
+ @Test
+ public void receiveOlderDataStatisticEventTest() {
+ MapDataStatistics checkpoint2Subtask0DataStatistic = new
MapDataStatistics();
+ checkpoint2Subtask0DataStatistic.add(binaryRowDataA);
+ checkpoint2Subtask0DataStatistic.add(binaryRowDataB);
+ checkpoint2Subtask0DataStatistic.add(binaryRowDataB);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint3Subtask0DataStatisticEvent =
+ DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic,
statisticsSerializer);
+ assertThat(
+ aggregatedStatisticsTracker.updateAndCheckCompletion(
+ 0, checkpoint3Subtask0DataStatisticEvent))
+ .isNull();
+
+ MapDataStatistics checkpoint1Subtask1DataStatistic = new
MapDataStatistics();
+ checkpoint1Subtask1DataStatistic.add(binaryRowDataB);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint1Subtask1DataStatisticEvent =
+ DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic,
statisticsSerializer);
+ // Receive event from old checkpoint,
aggregatedStatisticsAggregatorTracker won't return
+ // completed statistics and in progress statistics won't be updated
+ assertThat(
+ aggregatedStatisticsTracker.updateAndCheckCompletion(
+ 1, checkpoint1Subtask1DataStatisticEvent))
+ .isNull();
+
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(2);
+ }
+
+ @Test
+ public void receiveCompletedDataStatisticEvent() {
+ MapDataStatistics checkpoint1Subtask0DataStatistic = new
MapDataStatistics();
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint1Subtask0DataStatisticEvent =
+ DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic,
statisticsSerializer);
+ assertThat(
+ aggregatedStatisticsTracker.updateAndCheckCompletion(
+ 0, checkpoint1Subtask0DataStatisticEvent))
+ .isNull();
+
+ MapDataStatistics checkpoint1Subtask1DataStatistic = new
MapDataStatistics();
+ checkpoint1Subtask1DataStatistic.add(binaryRowDataA);
+ checkpoint1Subtask1DataStatistic.add(binaryRowDataA);
+ checkpoint1Subtask1DataStatistic.add(binaryRowDataB);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint1Subtask1DataStatisticEvent =
+ DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic,
statisticsSerializer);
+ // Receive data statistics from all subtasks at checkpoint 1
+ AggregatedStatistics<MapDataStatistics, Map<RowData, Long>>
completedStatistics =
+ aggregatedStatisticsTracker.updateAndCheckCompletion(
+ 1, checkpoint1Subtask1DataStatisticEvent);
+
+ assertThat(completedStatistics).isNotNull();
+ assertThat(completedStatistics.checkpointId()).isEqualTo(1);
+ MapDataStatistics globalDataStatistics =
+ (MapDataStatistics) completedStatistics.dataStatistics();
+ assertThat((long) globalDataStatistics.statistics().get(binaryRowDataA))
+ .isEqualTo(
+ checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataA)
+ +
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataA));
+ assertThat((long) globalDataStatistics.statistics().get(binaryRowDataB))
+ .isEqualTo(
+ checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataB)
+ +
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataB));
+
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId())
+ .isEqualTo(completedStatistics.checkpointId() + 1);
+
+ MapDataStatistics checkpoint2Subtask0DataStatistic = new
MapDataStatistics();
+ checkpoint2Subtask0DataStatistic.add(binaryRowDataA);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint2Subtask0DataStatisticEvent =
+ DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic,
statisticsSerializer);
+ assertThat(
+ aggregatedStatisticsTracker.updateAndCheckCompletion(
+ 0, checkpoint2Subtask0DataStatisticEvent))
+ .isNull();
+ assertThat(completedStatistics.checkpointId()).isEqualTo(1);
+
+ MapDataStatistics checkpoint2Subtask1DataStatistic = new
MapDataStatistics();
+ checkpoint2Subtask1DataStatistic.add(binaryRowDataB);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint2Subtask1DataStatisticEvent =
+ DataStatisticsEvent.create(2, checkpoint2Subtask1DataStatistic,
statisticsSerializer);
+ // Receive data statistics from all subtasks at checkpoint 2
+ completedStatistics =
+ aggregatedStatisticsTracker.updateAndCheckCompletion(
+ 1, checkpoint2Subtask1DataStatisticEvent);
+
+ assertThat(completedStatistics).isNotNull();
+ assertThat(completedStatistics.checkpointId()).isEqualTo(2);
+
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId())
+ .isEqualTo(completedStatistics.checkpointId() + 1);
+ }
+}
diff --git
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java
new file mode 100644
index 0000000000..9ec2606e10
--- /dev/null
+++
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.operators.coordination.EventReceivingTasks;
+import
org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TestDataStatisticsCoordinator {
+ private static final String OPERATOR_NAME = "TestCoordinator";
+ private static final OperatorID TEST_OPERATOR_ID = new OperatorID(1234L,
5678L);
+ private static final int NUM_SUBTASKS = 2;
+ private TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>>
+ statisticsSerializer;
+
+ private EventReceivingTasks receivingTasks;
+ private DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>
+ dataStatisticsCoordinator;
+
+ @Before
+ public void before() throws Exception {
+ receivingTasks = EventReceivingTasks.createForRunningTasks();
+ statisticsSerializer =
+ MapDataStatisticsSerializer.fromKeySerializer(
+ new RowDataSerializer(RowType.of(new VarCharType())));
+
+ dataStatisticsCoordinator =
+ new DataStatisticsCoordinator<>(
+ OPERATOR_NAME,
+ new MockOperatorCoordinatorContext(TEST_OPERATOR_ID, NUM_SUBTASKS),
+ statisticsSerializer);
+ }
+
+ private void tasksReady() throws Exception {
+ dataStatisticsCoordinator.start();
+ setAllTasksReady(NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks);
+ }
+
+ @Test
+ public void testThrowExceptionWhenNotStarted() {
+ String failureMessage = "The coordinator of TestCoordinator has not
started yet.";
+
+ assertThatThrownBy(
+ () ->
+ dataStatisticsCoordinator.handleEventFromOperator(
+ 0,
+ 0,
+ DataStatisticsEvent.create(0, new MapDataStatistics(),
statisticsSerializer)))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage(failureMessage);
+ assertThatThrownBy(() ->
dataStatisticsCoordinator.executionAttemptFailed(0, 0, null))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage(failureMessage);
+ assertThatThrownBy(() ->
dataStatisticsCoordinator.checkpointCoordinator(0, null))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage(failureMessage);
+ }
+
+ @Test
+ public void testDataStatisticsEventHandling() throws Exception {
+ tasksReady();
+ // When coordinator handles events from operator,
DataStatisticsUtil#deserializeDataStatistics
+ // deserializes bytes into BinaryRowData
+ RowType rowType = RowType.of(new VarCharType());
+ BinaryRowData binaryRowDataA =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
+ BinaryRowData binaryRowDataB =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
+ BinaryRowData binaryRowDataC =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("c")));
+
+ MapDataStatistics checkpoint1Subtask0DataStatistic = new
MapDataStatistics();
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint1Subtask0DataStatisticEvent =
+ DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic,
statisticsSerializer);
+ MapDataStatistics checkpoint1Subtask1DataStatistic = new
MapDataStatistics();
+ checkpoint1Subtask1DataStatistic.add(binaryRowDataA);
+ checkpoint1Subtask1DataStatistic.add(binaryRowDataB);
+ checkpoint1Subtask1DataStatistic.add(binaryRowDataC);
+ checkpoint1Subtask1DataStatistic.add(binaryRowDataC);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint1Subtask1DataStatisticEvent =
+ DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic,
statisticsSerializer);
+ // Handle events from operators for checkpoint 1
+ dataStatisticsCoordinator.handleEventFromOperator(0, 0,
checkpoint1Subtask0DataStatisticEvent);
+ dataStatisticsCoordinator.handleEventFromOperator(1, 0,
checkpoint1Subtask1DataStatisticEvent);
+
+ waitForCoordinatorToProcessActions(dataStatisticsCoordinator);
+ // Verify global data statistics is the aggregation of all subtasks data
statistics
+ MapDataStatistics globalDataStatistics =
+ (MapDataStatistics)
dataStatisticsCoordinator.completedStatistics().dataStatistics();
+ assertThat(globalDataStatistics.statistics())
+ .containsExactlyInAnyOrderEntriesOf(
+ ImmutableMap.of(
+ binaryRowDataA,
+
checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataA)
+ + (long)
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataA),
+ binaryRowDataB,
+
checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataB)
+ + (long)
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataB),
+ binaryRowDataC,
+
checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataC)
+ + (long)
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataC)));
+ }
+
+ static void setAllTasksReady(
+ int subtasks,
+ DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>
dataStatisticsCoordinator,
+ EventReceivingTasks receivingTasks) {
+ for (int i = 0; i < subtasks; i++) {
+ dataStatisticsCoordinator.executionAttemptReady(
+ i, 0, receivingTasks.createGatewayForSubtask(i, 0));
+ }
+ }
+
+ static void waitForCoordinatorToProcessActions(
+ DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>
coordinator) {
+ CompletableFuture<Void> future = new CompletableFuture<>();
+ coordinator.callInCoordinatorThread(
+ () -> {
+ future.complete(null);
+ return null;
+ },
+ "Coordinator fails to process action");
+
+ try {
+ future.get();
+ } catch (InterruptedException e) {
+ throw new AssertionError("test interrupted");
+ } catch (ExecutionException e) {
+ ExceptionUtils.rethrow(ExceptionUtils.stripExecutionException(e));
+ }
+ }
+}
diff --git
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java
new file mode 100644
index 0000000000..cb9d3f48ff
--- /dev/null
+++
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.operators.coordination.EventReceivingTasks;
+import
org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext;
+import
org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TestDataStatisticsCoordinatorProvider {
+ private static final OperatorID OPERATOR_ID = new OperatorID();
+ private static final int NUM_SUBTASKS = 1;
+
+ private DataStatisticsCoordinatorProvider<MapDataStatistics, Map<RowData,
Long>> provider;
+ private EventReceivingTasks receivingTasks;
+ private TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>>
+ statisticsSerializer;
+
+ @Before
+ public void before() {
+ statisticsSerializer =
+ MapDataStatisticsSerializer.fromKeySerializer(
+ new RowDataSerializer(RowType.of(new VarCharType())));
+ provider =
+ new DataStatisticsCoordinatorProvider<>(
+ "DataStatisticsCoordinatorProvider", OPERATOR_ID,
statisticsSerializer);
+ receivingTasks = EventReceivingTasks.createForRunningTasks();
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testCheckpointAndReset() throws Exception {
+ RowType rowType = RowType.of(new VarCharType());
+ // When coordinator handles events from operator,
DataStatisticsUtil#deserializeDataStatistics
+ // deserializes bytes into BinaryRowData
+ BinaryRowData binaryRowDataA =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
+ BinaryRowData binaryRowDataB =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
+ BinaryRowData binaryRowDataC =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("c")));
+ BinaryRowData binaryRowDataD =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("d")));
+ BinaryRowData binaryRowDataE =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("e")));
+
+ RecreateOnResetOperatorCoordinator coordinator =
+ (RecreateOnResetOperatorCoordinator)
+ provider.create(new MockOperatorCoordinatorContext(OPERATOR_ID,
NUM_SUBTASKS));
+ DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>
dataStatisticsCoordinator =
+ (DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>)
+ coordinator.getInternalCoordinator();
+
+ // Start the coordinator
+ coordinator.start();
+ TestDataStatisticsCoordinator.setAllTasksReady(
+ NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks);
+ MapDataStatistics checkpoint1Subtask0DataStatistic = new
MapDataStatistics();
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
+ checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint1Subtask0DataStatisticEvent =
+ DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic,
statisticsSerializer);
+
+ // Handle events from operators for checkpoint 1
+ coordinator.handleEventFromOperator(0, 0,
checkpoint1Subtask0DataStatisticEvent);
+
TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator);
+ // Verify checkpoint 1 global data statistics
+ MapDataStatistics checkpoint1GlobalDataStatistics =
+ (MapDataStatistics)
dataStatisticsCoordinator.completedStatistics().dataStatistics();
+ assertThat(checkpoint1GlobalDataStatistics.statistics())
+ .isEqualTo(checkpoint1Subtask0DataStatistic.statistics());
+ byte[] checkpoint1Bytes = waitForCheckpoint(1L, dataStatisticsCoordinator);
+
+ MapDataStatistics checkpoint2Subtask0DataStatistic = new
MapDataStatistics();
+ checkpoint2Subtask0DataStatistic.add(binaryRowDataD);
+ checkpoint2Subtask0DataStatistic.add(binaryRowDataE);
+ checkpoint2Subtask0DataStatistic.add(binaryRowDataE);
+ DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+ checkpoint2Subtask0DataStatisticEvent =
+ DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic,
statisticsSerializer);
+ // Handle events from operators for checkpoint 2
+ coordinator.handleEventFromOperator(0, 0,
checkpoint2Subtask0DataStatisticEvent);
+
TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator);
+ // Verify checkpoint 2 global data statistics
+ MapDataStatistics checkpoint2GlobalDataStatistics =
+ (MapDataStatistics)
dataStatisticsCoordinator.completedStatistics().dataStatistics();
+ assertThat(checkpoint2GlobalDataStatistics.statistics())
+ .isEqualTo(checkpoint2Subtask0DataStatistic.statistics());
+ waitForCheckpoint(2L, dataStatisticsCoordinator);
+
+ // Reset coordinator to checkpoint 1
+ coordinator.resetToCheckpoint(1L, checkpoint1Bytes);
+ DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>
+ restoredDataStatisticsCoordinator =
+ (DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>)
+ coordinator.getInternalCoordinator();
+
assertThat(dataStatisticsCoordinator).isNotEqualTo(restoredDataStatisticsCoordinator);
+ // Verify restored data statistics
+ MapDataStatistics restoredAggregateDataStatistics =
+ (MapDataStatistics)
+
restoredDataStatisticsCoordinator.completedStatistics().dataStatistics();
+ assertThat(restoredAggregateDataStatistics.statistics())
+ .isEqualTo(checkpoint1GlobalDataStatistics.statistics());
+ }
+
+ private byte[] waitForCheckpoint(
+ long checkpointId,
+ DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>
coordinator)
+ throws InterruptedException, ExecutionException {
+ CompletableFuture<byte[]> future = new CompletableFuture<>();
+ coordinator.checkpointCoordinator(checkpointId, future);
+ return future.get();
+ }
+}
diff --git
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
index 039d70a69d..880cb3d551 100644
---
a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
+++
b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
@@ -19,7 +19,6 @@
package org.apache.iceberg.flink.sink.shuffle;
import static org.assertj.core.api.Assertions.assertThat;
-import static org.junit.Assert.assertTrue;
import java.util.Collections;
import java.util.List;
@@ -51,13 +50,13 @@ import
org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
-import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -65,7 +64,18 @@ import org.junit.Test;
public class TestDataStatisticsOperator {
private final RowType rowType = RowType.of(new VarCharType());
private final TypeSerializer<RowData> rowSerializer = new
RowDataSerializer(rowType);
-
+ private final GenericRowData genericRowDataA =
GenericRowData.of(StringData.fromString("a"));
+ private final GenericRowData genericRowDataB =
GenericRowData.of(StringData.fromString("b"));
+ // When operator hands events from coordinator,
DataStatisticsUtil#deserializeDataStatistics
+ // deserializes bytes into BinaryRowData
+ private final BinaryRowData binaryRowDataA =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
+ private final BinaryRowData binaryRowDataB =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
+ private final BinaryRowData binaryRowDataC =
+ new
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("c")));
+ private final TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData,
Long>>>
+ statisticsSerializer =
MapDataStatisticsSerializer.fromKeySerializer(rowSerializer);
private DataStatisticsOperator<MapDataStatistics, Map<RowData, Long>>
operator;
private Environment getTestingEnvironment() {
@@ -101,9 +111,8 @@ public class TestDataStatisticsOperator {
}
};
- TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>>
statisticsSerializer =
- MapDataStatisticsSerializer.fromKeySerializer(rowSerializer);
- return new DataStatisticsOperator<>(keySelector, mockGateway,
statisticsSerializer);
+ return new DataStatisticsOperator<>(
+ "testOperator", keySelector, mockGateway, statisticsSerializer);
}
@After
@@ -118,20 +127,16 @@ public class TestDataStatisticsOperator {
testHarness = createHarness(this.operator)) {
StateInitializationContext stateContext = getStateContext();
operator.initializeState(stateContext);
- operator.processElement(new
StreamRecord<>(GenericRowData.of(StringData.fromString("a"))));
- operator.processElement(new
StreamRecord<>(GenericRowData.of(StringData.fromString("a"))));
- operator.processElement(new
StreamRecord<>(GenericRowData.of(StringData.fromString("b"))));
- assertTrue(operator.localDataStatistics() instanceof MapDataStatistics);
+ operator.processElement(new StreamRecord<>(genericRowDataA));
+ operator.processElement(new StreamRecord<>(genericRowDataA));
+ operator.processElement(new StreamRecord<>(genericRowDataB));
+
assertThat(operator.localDataStatistics()).isInstanceOf(MapDataStatistics.class);
MapDataStatistics mapDataStatistics = (MapDataStatistics)
operator.localDataStatistics();
Map<RowData, Long> statsMap = mapDataStatistics.statistics();
assertThat(statsMap).hasSize(2);
assertThat(statsMap)
.containsExactlyInAnyOrderEntriesOf(
- ImmutableMap.of(
- GenericRowData.of(StringData.fromString("a")),
- 2L,
- GenericRowData.of(StringData.fromString("b")),
- 1L));
+ ImmutableMap.of(genericRowDataA, 2L, genericRowDataB, 1L));
testHarness.endInput();
}
}
@@ -141,9 +146,9 @@ public class TestDataStatisticsOperator {
try (OneInputStreamOperatorTestHarness<
RowData, DataStatisticsOrRecord<MapDataStatistics, Map<RowData,
Long>>>
testHarness = createHarness(this.operator)) {
- testHarness.processElement(new
StreamRecord<>(GenericRowData.of(StringData.fromString("a"))));
- testHarness.processElement(new
StreamRecord<>(GenericRowData.of(StringData.fromString("b"))));
- testHarness.processElement(new
StreamRecord<>(GenericRowData.of(StringData.fromString("b"))));
+ testHarness.processElement(new StreamRecord<>(genericRowDataA));
+ testHarness.processElement(new StreamRecord<>(genericRowDataB));
+ testHarness.processElement(new StreamRecord<>(genericRowDataB));
List<RowData> recordsOutput =
testHarness.extractOutputValues().stream()
@@ -152,10 +157,7 @@ public class TestDataStatisticsOperator {
.collect(Collectors.toList());
assertThat(recordsOutput)
.containsExactlyInAnyOrderElementsOf(
- ImmutableList.of(
- GenericRowData.of(StringData.fromString("a")),
- GenericRowData.of(StringData.fromString("b")),
- GenericRowData.of(StringData.fromString("b"))));
+ ImmutableList.of(genericRowDataA, genericRowDataB,
genericRowDataB));
}
}
@@ -167,21 +169,16 @@ public class TestDataStatisticsOperator {
testHarness1 = createHarness(this.operator)) {
DataStatistics<MapDataStatistics, Map<RowData, Long>> mapDataStatistics =
new MapDataStatistics();
- mapDataStatistics.add(GenericRowData.of(StringData.fromString("a")));
- mapDataStatistics.add(GenericRowData.of(StringData.fromString("a")));
- mapDataStatistics.add(GenericRowData.of(StringData.fromString("b")));
- mapDataStatistics.add(GenericRowData.of(StringData.fromString("c")));
- operator.handleOperatorEvent(new DataStatisticsEvent(0,
mapDataStatistics));
+ mapDataStatistics.add(binaryRowDataA);
+ mapDataStatistics.add(binaryRowDataA);
+ mapDataStatistics.add(binaryRowDataB);
+ mapDataStatistics.add(binaryRowDataC);
+ operator.handleOperatorEvent(
+ DataStatisticsEvent.create(0, mapDataStatistics,
statisticsSerializer));
assertThat(operator.globalDataStatistics()).isInstanceOf(MapDataStatistics.class);
- assertThat(((MapDataStatistics)
operator.globalDataStatistics()).statistics())
+ assertThat(operator.globalDataStatistics().statistics())
.containsExactlyInAnyOrderEntriesOf(
- ImmutableMap.of(
- GenericRowData.of(StringData.fromString("a")),
- 2L,
- GenericRowData.of(StringData.fromString("b")),
- 1L,
- GenericRowData.of(StringData.fromString("c")),
- 1L));
+ ImmutableMap.of(binaryRowDataA, 2L, binaryRowDataB, 1L,
binaryRowDataC, 1L));
snapshot = testHarness1.snapshot(1L, 0);
}
@@ -195,22 +192,9 @@ public class TestDataStatisticsOperator {
testHarness2.setup();
testHarness2.initializeState(snapshot);
assertThat(restoredOperator.globalDataStatistics()).isInstanceOf(MapDataStatistics.class);
- // restored RowData is BinaryRowData. convert to GenericRowData for
comparison
- Map<RowData, Long> restoredStatistics = Maps.newHashMap();
- ((MapDataStatistics) restoredOperator.globalDataStatistics())
- .statistics()
- .forEach(
- (rowData, count) ->
-
restoredStatistics.put(GenericRowData.of(rowData.getString(0)), count));
- assertThat(restoredStatistics)
+ assertThat(restoredOperator.globalDataStatistics().statistics())
.containsExactlyInAnyOrderEntriesOf(
- ImmutableMap.of(
- GenericRowData.of(StringData.fromString("a")),
- 2L,
- GenericRowData.of(StringData.fromString("b")),
- 1L,
- GenericRowData.of(StringData.fromString("c")),
- 1L));
+ ImmutableMap.of(binaryRowDataA, 2L, binaryRowDataB, 1L,
binaryRowDataC, 1L));
}
}