stevenzwu commented on code in PR #7360: URL: https://github.com/apache/iceberg/pull/7360#discussion_r1289143500
########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java: ########## @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.operators.coordination.OperatorCoordinator; +import org.apache.flink.runtime.operators.coordination.OperatorEvent; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FatalExitExceptionHandler; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.ThrowableCatchingRunnable; +import org.apache.flink.util.function.ThrowingRunnable; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link + * DataStatisticsOperator} every subtask and then merge them together. Once aggregation for all + * subtasks data statistics completes, DataStatisticsCoordinator will send the aggregated + * result(global data statistics) back to {@link DataStatisticsOperator}. In the end a custom + * partitioner will distribute traffic based on the global data statistics to improve data + * clustering. + */ +@Internal +class DataStatisticsCoordinator<D extends DataStatistics<D, S>, S> implements OperatorCoordinator { + private static final Logger LOG = LoggerFactory.getLogger(DataStatisticsCoordinator.class); + + private final String operatorName; + private final ExecutorService coordinatorExecutor; + private final OperatorCoordinator.Context operatorCoordinatorContext; + private final SubtaskGateways subtaskGateways; + private final CoordinatorExecutorThreadFactory coordinatorThreadFactory; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final transient GlobalStatisticsTracker<D, S> globalStatisticsTracker; + private volatile GlobalStatistics<D, S> completedStatistics; + private volatile boolean started; + + DataStatisticsCoordinator( + String operatorName, + OperatorCoordinator.Context context, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + this.operatorName = operatorName; + this.coordinatorThreadFactory = + new CoordinatorExecutorThreadFactory( + "DataStatisticsCoordinator-" + operatorName, context.getUserCodeClassloader()); + this.coordinatorExecutor = Executors.newSingleThreadExecutor(coordinatorThreadFactory); + this.operatorCoordinatorContext = context; + this.subtaskGateways = new SubtaskGateways(operatorName, parallelism()); + this.statisticsSerializer = statisticsSerializer; + this.globalStatisticsTracker = + new GlobalStatisticsTracker<>(operatorName, statisticsSerializer, parallelism()); + } + + @Override + public void start() throws Exception { + LOG.info("Starting data statistics coordinator: {}.", operatorName); + started = true; + } + + @Override + public void close() throws Exception { + coordinatorExecutor.shutdown(); + LOG.info("Closed data statistics coordinator: {}.", operatorName); + } + + void callInCoordinatorThread(Callable<Void> callable, String errorMessage) { + ensureStarted(); + // Ensure the task is done by the coordinator executor. + if (!coordinatorThreadFactory.isCurrentThreadCoordinatorThread()) { + try { + final Callable<Void> guardedCallable = Review Comment: nit: Iceberg coding style doesn't use `final` for local var ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java: ########## @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.operators.coordination.OperatorCoordinator; +import org.apache.flink.runtime.operators.coordination.OperatorEvent; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FatalExitExceptionHandler; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.ThrowableCatchingRunnable; +import org.apache.flink.util.function.ThrowingRunnable; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link + * DataStatisticsOperator} every subtask and then merge them together. Once aggregation for all + * subtasks data statistics completes, DataStatisticsCoordinator will send the aggregated + * result(global data statistics) back to {@link DataStatisticsOperator}. In the end a custom + * partitioner will distribute traffic based on the global data statistics to improve data + * clustering. + */ +@Internal +class DataStatisticsCoordinator<D extends DataStatistics<D, S>, S> implements OperatorCoordinator { + private static final Logger LOG = LoggerFactory.getLogger(DataStatisticsCoordinator.class); + + private final String operatorName; + private final ExecutorService coordinatorExecutor; + private final OperatorCoordinator.Context operatorCoordinatorContext; + private final SubtaskGateways subtaskGateways; + private final CoordinatorExecutorThreadFactory coordinatorThreadFactory; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final transient GlobalStatisticsTracker<D, S> globalStatisticsTracker; + private volatile GlobalStatistics<D, S> completedStatistics; + private volatile boolean started; + + DataStatisticsCoordinator( + String operatorName, + OperatorCoordinator.Context context, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + this.operatorName = operatorName; + this.coordinatorThreadFactory = + new CoordinatorExecutorThreadFactory( + "DataStatisticsCoordinator-" + operatorName, context.getUserCodeClassloader()); + this.coordinatorExecutor = Executors.newSingleThreadExecutor(coordinatorThreadFactory); + this.operatorCoordinatorContext = context; + this.subtaskGateways = new SubtaskGateways(operatorName, parallelism()); + this.statisticsSerializer = statisticsSerializer; + this.globalStatisticsTracker = + new GlobalStatisticsTracker<>(operatorName, statisticsSerializer, parallelism()); + } + + @Override + public void start() throws Exception { + LOG.info("Starting data statistics coordinator: {}.", operatorName); + started = true; + } + + @Override + public void close() throws Exception { + coordinatorExecutor.shutdown(); + LOG.info("Closed data statistics coordinator: {}.", operatorName); + } + + void callInCoordinatorThread(Callable<Void> callable, String errorMessage) { + ensureStarted(); + // Ensure the task is done by the coordinator executor. + if (!coordinatorThreadFactory.isCurrentThreadCoordinatorThread()) { + try { + final Callable<Void> guardedCallable = + () -> { + try { + return callable.call(); + } catch (Throwable t) { + LOG.error( + "Uncaught Exception in data statistics coordinator: {} executor", + operatorName, + t); + ExceptionUtils.rethrowException(t); + return null; + } + }; + + coordinatorExecutor.submit(guardedCallable).get(); + } catch (InterruptedException | ExecutionException e) { + throw new FlinkRuntimeException(errorMessage, e); + } + } else { + try { + callable.call(); + } catch (Throwable t) { + LOG.error( + "Uncaught Exception in data statistics coordinator: {} executor", operatorName, t); + throw new FlinkRuntimeException(errorMessage, t); + } + } + } + + public void runInCoordinatorThread(Runnable runnable) { + this.coordinatorExecutor.execute( + new ThrowableCatchingRunnable( + throwable -> + this.coordinatorThreadFactory.uncaughtException(Thread.currentThread(), throwable), + runnable)); + } + + private void runInCoordinatorThread(ThrowingRunnable<Throwable> action, String actionString) { + ensureStarted(); + runInCoordinatorThread( + () -> { + try { + action.run(); + } catch (Throwable t) { + ExceptionUtils.rethrowIfFatalErrorOrOOM(t); + LOG.error( + "Uncaught exception in the data statistics coordinator: {} while {}. Triggering job failover.", + operatorName, + actionString, + t); + operatorCoordinatorContext.failJob(t); + } + }); + } + + private void ensureStarted() { + Preconditions.checkState(started, "The coordinator of %s has not started yet.", operatorName); + } + + private int parallelism() { + return operatorCoordinatorContext.currentParallelism(); + } + + private void handleDataStatisticRequest(int subtask, DataStatisticsEvent<D, S> event) { + GlobalStatistics<D, S> globalStatistics = + globalStatisticsTracker.receiveDataStatisticEventAndCheckCompletion(subtask, event); + + if (globalStatistics != null) { + completedStatistics = globalStatistics; + sendDataStatisticsToSubtasks( + completedStatistics.checkpointId(), completedStatistics.dataStatistics()); + } + } + + private void sendDataStatisticsToSubtasks( + long checkpointId, DataStatistics<D, S> globalDataStatistics) { + callInCoordinatorThread( + () -> { + DataStatisticsEvent<D, S> dataStatisticsEvent = + DataStatisticsEvent.create(checkpointId, globalDataStatistics, statisticsSerializer); + int parallelism = parallelism(); + for (int i = 0; i < parallelism; ++i) { + subtaskGateways.getSubtaskGateway(i).sendEvent(dataStatisticsEvent); + } + return null; Review Comment: nit: Iceberg coding style adds an empty line after each control block `}`. ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java: ########## @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.operators.coordination.OperatorCoordinator; +import org.apache.flink.runtime.operators.coordination.OperatorEvent; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FatalExitExceptionHandler; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.ThrowableCatchingRunnable; +import org.apache.flink.util.function.ThrowingRunnable; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link + * DataStatisticsOperator} every subtask and then merge them together. Once aggregation for all + * subtasks data statistics completes, DataStatisticsCoordinator will send the aggregated + * result(global data statistics) back to {@link DataStatisticsOperator}. In the end a custom + * partitioner will distribute traffic based on the global data statistics to improve data + * clustering. + */ +@Internal +class DataStatisticsCoordinator<D extends DataStatistics<D, S>, S> implements OperatorCoordinator { + private static final Logger LOG = LoggerFactory.getLogger(DataStatisticsCoordinator.class); + + private final String operatorName; + private final ExecutorService coordinatorExecutor; + private final OperatorCoordinator.Context operatorCoordinatorContext; + private final SubtaskGateways subtaskGateways; + private final CoordinatorExecutorThreadFactory coordinatorThreadFactory; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final transient GlobalStatisticsTracker<D, S> globalStatisticsTracker; + private volatile GlobalStatistics<D, S> completedStatistics; + private volatile boolean started; + + DataStatisticsCoordinator( + String operatorName, + OperatorCoordinator.Context context, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + this.operatorName = operatorName; + this.coordinatorThreadFactory = + new CoordinatorExecutorThreadFactory( + "DataStatisticsCoordinator-" + operatorName, context.getUserCodeClassloader()); + this.coordinatorExecutor = Executors.newSingleThreadExecutor(coordinatorThreadFactory); + this.operatorCoordinatorContext = context; + this.subtaskGateways = new SubtaskGateways(operatorName, parallelism()); + this.statisticsSerializer = statisticsSerializer; + this.globalStatisticsTracker = + new GlobalStatisticsTracker<>(operatorName, statisticsSerializer, parallelism()); + } + + @Override + public void start() throws Exception { + LOG.info("Starting data statistics coordinator: {}.", operatorName); + started = true; + } + + @Override + public void close() throws Exception { + coordinatorExecutor.shutdown(); + LOG.info("Closed data statistics coordinator: {}.", operatorName); + } + + void callInCoordinatorThread(Callable<Void> callable, String errorMessage) { + ensureStarted(); + // Ensure the task is done by the coordinator executor. + if (!coordinatorThreadFactory.isCurrentThreadCoordinatorThread()) { + try { + final Callable<Void> guardedCallable = + () -> { + try { + return callable.call(); + } catch (Throwable t) { + LOG.error( + "Uncaught Exception in data statistics coordinator: {} executor", + operatorName, + t); + ExceptionUtils.rethrowException(t); + return null; + } + }; + + coordinatorExecutor.submit(guardedCallable).get(); + } catch (InterruptedException | ExecutionException e) { + throw new FlinkRuntimeException(errorMessage, e); + } + } else { + try { + callable.call(); + } catch (Throwable t) { + LOG.error( + "Uncaught Exception in data statistics coordinator: {} executor", operatorName, t); + throw new FlinkRuntimeException(errorMessage, t); + } + } + } + + public void runInCoordinatorThread(Runnable runnable) { + this.coordinatorExecutor.execute( + new ThrowableCatchingRunnable( + throwable -> + this.coordinatorThreadFactory.uncaughtException(Thread.currentThread(), throwable), + runnable)); + } + + private void runInCoordinatorThread(ThrowingRunnable<Throwable> action, String actionString) { + ensureStarted(); + runInCoordinatorThread( + () -> { + try { + action.run(); + } catch (Throwable t) { + ExceptionUtils.rethrowIfFatalErrorOrOOM(t); + LOG.error( + "Uncaught exception in the data statistics coordinator: {} while {}. Triggering job failover.", + operatorName, + actionString, + t); + operatorCoordinatorContext.failJob(t); + } + }); + } + + private void ensureStarted() { + Preconditions.checkState(started, "The coordinator of %s has not started yet.", operatorName); + } + + private int parallelism() { + return operatorCoordinatorContext.currentParallelism(); + } + + private void handleDataStatisticRequest(int subtask, DataStatisticsEvent<D, S> event) { + GlobalStatistics<D, S> globalStatistics = + globalStatisticsTracker.receiveDataStatisticEventAndCheckCompletion(subtask, event); + + if (globalStatistics != null) { + completedStatistics = globalStatistics; + sendDataStatisticsToSubtasks( + completedStatistics.checkpointId(), completedStatistics.dataStatistics()); + } + } + + private void sendDataStatisticsToSubtasks( + long checkpointId, DataStatistics<D, S> globalDataStatistics) { + callInCoordinatorThread( + () -> { + DataStatisticsEvent<D, S> dataStatisticsEvent = + DataStatisticsEvent.create(checkpointId, globalDataStatistics, statisticsSerializer); + int parallelism = parallelism(); + for (int i = 0; i < parallelism; ++i) { + subtaskGateways.getSubtaskGateway(i).sendEvent(dataStatisticsEvent); + } + return null; + }, + String.format("Failed to send global data statistics for checkpoint %d", checkpointId)); + } + + @Override + @SuppressWarnings("unchecked") + public void handleEventFromOperator(int subtask, int attemptNumber, OperatorEvent event) { + runInCoordinatorThread( + () -> { + LOG.debug( + "Handling event from subtask {} (#{}) of {}: {}", + subtask, + attemptNumber, + operatorName, + event); + Preconditions.checkArgument(event instanceof DataStatisticsEvent); + handleDataStatisticRequest(subtask, ((DataStatisticsEvent<D, S>) event)); + }, + String.format( + "handling operator event %s from subtask %d (#%d)", + event.getClass(), subtask, attemptNumber)); + } + + @Override + public void checkpointCoordinator(long checkpointId, CompletableFuture<byte[]> resultFuture) { + runInCoordinatorThread( + () -> { + LOG.debug( + "Taking a state snapshot on data statistics coordinator {} for checkpoint {}", + operatorName, + checkpointId); + resultFuture.complete( + DataStatisticsUtil.serializeGlobalStatistics( + completedStatistics, statisticsSerializer)); + }, + String.format("taking checkpoint %d", checkpointId)); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) {} + + @Override + public void resetToCheckpoint(long checkpointId, @Nullable byte[] checkpointData) + throws Exception { + Preconditions.checkState( + !started, "The coordinator %s can only be reset if it was not yet started", operatorName); + + if (checkpointData == null) { + return; Review Comment: maybe add an INFO log ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java: ########## @@ -117,21 +134,36 @@ public void snapshotState(StateSnapshotContext context) throws Exception { long checkpointId = context.getCheckpointId(); int subTaskId = getRuntimeContext().getIndexOfThisSubtask(); LOG.info( - "Taking data statistics operator snapshot for checkpoint {} in subtask {}", + "Taking data statistics operator {} snapshot for checkpoint {} in subtask {}", Review Comment: nit: `Snapshotting data statistics operator {} for ...` ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java: ########## @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.operators.coordination.OperatorCoordinator; +import org.apache.flink.runtime.operators.coordination.OperatorEvent; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FatalExitExceptionHandler; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.ThrowableCatchingRunnable; +import org.apache.flink.util.function.ThrowingRunnable; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link + * DataStatisticsOperator} every subtask and then merge them together. Once aggregation for all + * subtasks data statistics completes, DataStatisticsCoordinator will send the aggregated + * result(global data statistics) back to {@link DataStatisticsOperator}. In the end a custom + * partitioner will distribute traffic based on the global data statistics to improve data + * clustering. + */ +@Internal +class DataStatisticsCoordinator<D extends DataStatistics<D, S>, S> implements OperatorCoordinator { + private static final Logger LOG = LoggerFactory.getLogger(DataStatisticsCoordinator.class); + + private final String operatorName; + private final ExecutorService coordinatorExecutor; + private final OperatorCoordinator.Context operatorCoordinatorContext; + private final SubtaskGateways subtaskGateways; + private final CoordinatorExecutorThreadFactory coordinatorThreadFactory; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final transient GlobalStatisticsTracker<D, S> globalStatisticsTracker; + private volatile GlobalStatistics<D, S> completedStatistics; + private volatile boolean started; + + DataStatisticsCoordinator( + String operatorName, + OperatorCoordinator.Context context, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + this.operatorName = operatorName; + this.coordinatorThreadFactory = + new CoordinatorExecutorThreadFactory( + "DataStatisticsCoordinator-" + operatorName, context.getUserCodeClassloader()); + this.coordinatorExecutor = Executors.newSingleThreadExecutor(coordinatorThreadFactory); + this.operatorCoordinatorContext = context; + this.subtaskGateways = new SubtaskGateways(operatorName, parallelism()); + this.statisticsSerializer = statisticsSerializer; + this.globalStatisticsTracker = + new GlobalStatisticsTracker<>(operatorName, statisticsSerializer, parallelism()); + } + + @Override + public void start() throws Exception { + LOG.info("Starting data statistics coordinator: {}.", operatorName); + started = true; + } + + @Override + public void close() throws Exception { + coordinatorExecutor.shutdown(); + LOG.info("Closed data statistics coordinator: {}.", operatorName); + } + + void callInCoordinatorThread(Callable<Void> callable, String errorMessage) { + ensureStarted(); + // Ensure the task is done by the coordinator executor. + if (!coordinatorThreadFactory.isCurrentThreadCoordinatorThread()) { + try { + final Callable<Void> guardedCallable = + () -> { + try { + return callable.call(); + } catch (Throwable t) { + LOG.error( + "Uncaught Exception in data statistics coordinator: {} executor", + operatorName, + t); + ExceptionUtils.rethrowException(t); + return null; + } + }; + + coordinatorExecutor.submit(guardedCallable).get(); + } catch (InterruptedException | ExecutionException e) { + throw new FlinkRuntimeException(errorMessage, e); + } + } else { + try { + callable.call(); + } catch (Throwable t) { + LOG.error( + "Uncaught Exception in data statistics coordinator: {} executor", operatorName, t); + throw new FlinkRuntimeException(errorMessage, t); + } + } + } + + public void runInCoordinatorThread(Runnable runnable) { + this.coordinatorExecutor.execute( + new ThrowableCatchingRunnable( + throwable -> + this.coordinatorThreadFactory.uncaughtException(Thread.currentThread(), throwable), + runnable)); + } + + private void runInCoordinatorThread(ThrowingRunnable<Throwable> action, String actionString) { + ensureStarted(); + runInCoordinatorThread( + () -> { + try { + action.run(); + } catch (Throwable t) { + ExceptionUtils.rethrowIfFatalErrorOrOOM(t); + LOG.error( + "Uncaught exception in the data statistics coordinator: {} while {}. Triggering job failover.", + operatorName, + actionString, + t); + operatorCoordinatorContext.failJob(t); + } + }); + } + + private void ensureStarted() { + Preconditions.checkState(started, "The coordinator of %s has not started yet.", operatorName); + } + + private int parallelism() { + return operatorCoordinatorContext.currentParallelism(); + } + + private void handleDataStatisticRequest(int subtask, DataStatisticsEvent<D, S> event) { + GlobalStatistics<D, S> globalStatistics = + globalStatisticsTracker.receiveDataStatisticEventAndCheckCompletion(subtask, event); + + if (globalStatistics != null) { + completedStatistics = globalStatistics; + sendDataStatisticsToSubtasks( + completedStatistics.checkpointId(), completedStatistics.dataStatistics()); + } + } + + private void sendDataStatisticsToSubtasks( + long checkpointId, DataStatistics<D, S> globalDataStatistics) { + callInCoordinatorThread( + () -> { + DataStatisticsEvent<D, S> dataStatisticsEvent = + DataStatisticsEvent.create(checkpointId, globalDataStatistics, statisticsSerializer); + int parallelism = parallelism(); + for (int i = 0; i < parallelism; ++i) { + subtaskGateways.getSubtaskGateway(i).sendEvent(dataStatisticsEvent); + } + return null; + }, + String.format("Failed to send global data statistics for checkpoint %d", checkpointId)); + } + + @Override + @SuppressWarnings("unchecked") + public void handleEventFromOperator(int subtask, int attemptNumber, OperatorEvent event) { + runInCoordinatorThread( + () -> { + LOG.debug( + "Handling event from subtask {} (#{}) of {}: {}", + subtask, + attemptNumber, + operatorName, + event); + Preconditions.checkArgument(event instanceof DataStatisticsEvent); + handleDataStatisticRequest(subtask, ((DataStatisticsEvent<D, S>) event)); + }, + String.format( + "handling operator event %s from subtask %d (#%d)", + event.getClass(), subtask, attemptNumber)); + } + + @Override + public void checkpointCoordinator(long checkpointId, CompletableFuture<byte[]> resultFuture) { + runInCoordinatorThread( + () -> { + LOG.debug( + "Taking a state snapshot on data statistics coordinator {} for checkpoint {}", Review Comment: nit: the error msg seems a little verbose. maybe can be simplified. e.g. `Snapshotting data statistics coordinator {} for checkpoint {}` ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java: ########## @@ -117,21 +134,36 @@ public void snapshotState(StateSnapshotContext context) throws Exception { long checkpointId = context.getCheckpointId(); int subTaskId = getRuntimeContext().getIndexOfThisSubtask(); LOG.info( - "Taking data statistics operator snapshot for checkpoint {} in subtask {}", + "Taking data statistics operator {} snapshot for checkpoint {} in subtask {}", + operatorName, checkpointId, subTaskId); + // Pass global statistics to partitioners so that all the operators refresh statistics + // at same checkpoint barrier + output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); Review Comment: do we need to check if `globalStatistics` is null or empty? ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java: ########## @@ -117,21 +134,36 @@ public void snapshotState(StateSnapshotContext context) throws Exception { long checkpointId = context.getCheckpointId(); int subTaskId = getRuntimeContext().getIndexOfThisSubtask(); LOG.info( - "Taking data statistics operator snapshot for checkpoint {} in subtask {}", + "Taking data statistics operator {} snapshot for checkpoint {} in subtask {}", + operatorName, checkpointId, subTaskId); + // Pass global statistics to partitioners so that all the operators refresh statistics + // at same checkpoint barrier + output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + // Only subtask 0 saves the state so that globalStatisticsState(UnionListState) stores // an exact copy of globalStatistics if (!globalStatistics.isEmpty() && getRuntimeContext().getIndexOfThisSubtask() == 0) { globalStatisticsState.clear(); - LOG.info("Saving global statistics {} to state in subtask {}", globalStatistics, subTaskId); + LOG.info( + "Saving operator {} global statistics {} to state in subtask {}", + operatorName, + globalStatistics, + subTaskId); globalStatisticsState.add(globalStatistics); } - // For now, we make it simple to send globalStatisticsState at checkpoint + // For now, we make it simple to send localStatistics at checkpoint Review Comment: nit: avoid nouns like `I` or `We`. maybe `For now, local statistics are sent to coordinator at checkpoint` ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java: ########## @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.Set; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; + +/** + * DataStatisticsUtil is the utility to serialize and deserialize {@link DataStatistics} and {@link + * GlobalStatistics} + */ +class DataStatisticsUtil { + + private DataStatisticsUtil() {} + + static <D extends DataStatistics<D, S>, S> byte[] serializeDataStatistics( + DataStatistics<D, S> dataStatistics, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + DataOutputSerializer out = new DataOutputSerializer(64); + try { + statisticsSerializer.serialize(dataStatistics, out); + return out.getCopyOfBuffer(); + } catch (IOException e) { + throw new IllegalStateException("Fail to serialize data statistics", e); + } + } + + @SuppressWarnings("unchecked") + static <D extends DataStatistics<D, S>, S> D deserializeDataStatistics( + byte[] bytes, TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + DataInputDeserializer input = new DataInputDeserializer(bytes, 0, bytes.length); + D dataStatistics; + + try { + dataStatistics = (D) statisticsSerializer.deserialize(input); Review Comment: can directly return here and avoid the need of local var `dataStatistics` ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatisticsTracker.java: ########## @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * GlobalStatisticsTracker is used by {@link DataStatisticsCoordinator} to track the in progress + * {@link GlobalStatistics} received from {@link DataStatisticsOperator} subtasks for specific + * checkpoint. + */ +@Internal +class GlobalStatisticsTracker<D extends DataStatistics<D, S>, S> { + private static final Logger LOG = LoggerFactory.getLogger(GlobalStatisticsTracker.class); + private static final double EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE = 90; + private final String operatorName; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final int parallelism; + private volatile GlobalStatistics<D, S> inProgressStatistics; + + GlobalStatisticsTracker( + String operatorName, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer, + int parallelism) { + this.operatorName = operatorName; + this.statisticsSerializer = statisticsSerializer; + this.parallelism = parallelism; + } + + GlobalStatistics<D, S> receiveDataStatisticEventAndCheckCompletion( + int subtask, DataStatisticsEvent<D, S> event) { + long checkpointId = event.checkpointId(); + + if (inProgressStatistics != null && inProgressStatistics.checkpointId() > checkpointId) { + LOG.debug( + "Expect data statistics for operator {} checkpoint {}, but receive event from older checkpoint {}. Ignore it.", + operatorName, + inProgressStatistics.checkpointId(), + checkpointId); + return null; + } + + GlobalStatistics<D, S> completedStatistics = null; + if (inProgressStatistics != null && inProgressStatistics.checkpointId() < checkpointId) { + if ((double) inProgressStatistics.aggregatedSubtasksCount() / parallelism * 100 + >= EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE) { + completedStatistics = inProgressStatistics; + LOG.info( + "Received data statistics from {} operators {} out of total {} for checkpoint {}. " + + "It's more than the expected percentage {}. Complete data statistics aggregation {}", Review Comment: nit: maybe simplify a little `Complete global aggregation as it is more than the threshold of {} percentage` ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatisticsTracker.java: ########## @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * GlobalStatisticsTracker is used by {@link DataStatisticsCoordinator} to track the in progress + * {@link GlobalStatistics} received from {@link DataStatisticsOperator} subtasks for specific + * checkpoint. + */ +@Internal +class GlobalStatisticsTracker<D extends DataStatistics<D, S>, S> { + private static final Logger LOG = LoggerFactory.getLogger(GlobalStatisticsTracker.class); + private static final double EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE = 90; + private final String operatorName; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final int parallelism; + private volatile GlobalStatistics<D, S> inProgressStatistics; + + GlobalStatisticsTracker( + String operatorName, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer, + int parallelism) { + this.operatorName = operatorName; + this.statisticsSerializer = statisticsSerializer; + this.parallelism = parallelism; + } + + GlobalStatistics<D, S> receiveDataStatisticEventAndCheckCompletion( + int subtask, DataStatisticsEvent<D, S> event) { + long checkpointId = event.checkpointId(); + + if (inProgressStatistics != null && inProgressStatistics.checkpointId() > checkpointId) { + LOG.debug( + "Expect data statistics for operator {} checkpoint {}, but receive event from older checkpoint {}. Ignore it.", + operatorName, + inProgressStatistics.checkpointId(), + checkpointId); + return null; + } + + GlobalStatistics<D, S> completedStatistics = null; + if (inProgressStatistics != null && inProgressStatistics.checkpointId() < checkpointId) { + if ((double) inProgressStatistics.aggregatedSubtasksCount() / parallelism * 100 + >= EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE) { + completedStatistics = inProgressStatistics; + LOG.info( + "Received data statistics from {} operators {} out of total {} for checkpoint {}. " Review Comment: nit: `Received data statistics from {} subtasks out of total {} for operator {} at checkpoint {}` ########## flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java: ########## @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.coordination.EventReceivingTasks; +import org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.binary.BinaryRowData; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.util.ExceptionUtils; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestDataStatisticsCoordinator { + private static final String OPERATOR_NAME = "TestCoordinator"; + private static final OperatorID TEST_OPERATOR_ID = new OperatorID(1234L, 5678L); + private static final int NUM_SUBTASKS = 2; + private TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> + statisticsSerializer; + + private EventReceivingTasks receivingTasks; + private DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> + dataStatisticsCoordinator; + + @Before + public void before() throws Exception { + receivingTasks = EventReceivingTasks.createForRunningTasks(); + statisticsSerializer = + MapDataStatisticsSerializer.fromKeySerializer( + new RowDataSerializer(RowType.of(new VarCharType()))); + + dataStatisticsCoordinator = + new DataStatisticsCoordinator<>( + OPERATOR_NAME, + new MockOperatorCoordinatorContext(TEST_OPERATOR_ID, NUM_SUBTASKS), + statisticsSerializer); + } + + private void tasksReady() throws Exception { + dataStatisticsCoordinator.start(); + setAllTasksReady(NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks); + } + + @Test + public void testThrowExceptionWhenNotStarted() { + String failureMessage = + "Call should fail when data statistics coordinator has not started yet."; + + Assert.assertThrows( Review Comment: use assertj instead ``` Assertions.assertThatThrownBy(() -> AwsClientFactories.from(properties)) .isInstanceOf(ValidationException.class) .hasMessage("S3 client access key ID and secret access key must be set at the same time"); ``` ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatisticsTracker.java: ########## @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * GlobalStatisticsTracker is used by {@link DataStatisticsCoordinator} to track the in progress + * {@link GlobalStatistics} received from {@link DataStatisticsOperator} subtasks for specific + * checkpoint. + */ +@Internal +class GlobalStatisticsTracker<D extends DataStatistics<D, S>, S> { + private static final Logger LOG = LoggerFactory.getLogger(GlobalStatisticsTracker.class); + private static final double EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE = 90; + private final String operatorName; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final int parallelism; + private volatile GlobalStatistics<D, S> inProgressStatistics; + + GlobalStatisticsTracker( + String operatorName, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer, + int parallelism) { + this.operatorName = operatorName; + this.statisticsSerializer = statisticsSerializer; + this.parallelism = parallelism; + } + + GlobalStatistics<D, S> receiveDataStatisticEventAndCheckCompletion( + int subtask, DataStatisticsEvent<D, S> event) { + long checkpointId = event.checkpointId(); + + if (inProgressStatistics != null && inProgressStatistics.checkpointId() > checkpointId) { + LOG.debug( + "Expect data statistics for operator {} checkpoint {}, but receive event from older checkpoint {}. Ignore it.", + operatorName, + inProgressStatistics.checkpointId(), + checkpointId); + return null; + } + + GlobalStatistics<D, S> completedStatistics = null; + if (inProgressStatistics != null && inProgressStatistics.checkpointId() < checkpointId) { + if ((double) inProgressStatistics.aggregatedSubtasksCount() / parallelism * 100 + >= EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE) { + completedStatistics = inProgressStatistics; + LOG.info( + "Received data statistics from {} operators {} out of total {} for checkpoint {}. " + + "It's more than the expected percentage {}. Complete data statistics aggregation {}", + inProgressStatistics.aggregatedSubtasksCount(), + operatorName, + parallelism, + inProgressStatistics.checkpointId(), + EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE, + completedStatistics); + } else { + LOG.info( + "Received data statistics from {} operators {} out of total {} for checkpoint {}. " + + "It's less than the expected percentage {}. Dropping the incomplete aggregate " + + "data statistics and starting collecting data statistics from new checkpoint {}", + inProgressStatistics.aggregatedSubtasksCount(), + operatorName, + parallelism, + inProgressStatistics.checkpointId(), + EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE, + checkpointId); + } + inProgressStatistics = null; Review Comment: nit: Iceberg coding style as an empty line after control block `}` ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java: ########## @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.Set; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; + +/** + * DataStatisticsUtil is the utility to serialize and deserialize {@link DataStatistics} and {@link + * GlobalStatistics} + */ +class DataStatisticsUtil { + + private DataStatisticsUtil() {} + + static <D extends DataStatistics<D, S>, S> byte[] serializeDataStatistics( + DataStatistics<D, S> dataStatistics, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + DataOutputSerializer out = new DataOutputSerializer(64); + try { + statisticsSerializer.serialize(dataStatistics, out); + return out.getCopyOfBuffer(); + } catch (IOException e) { + throw new IllegalStateException("Fail to serialize data statistics", e); + } + } + + @SuppressWarnings("unchecked") + static <D extends DataStatistics<D, S>, S> D deserializeDataStatistics( + byte[] bytes, TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + DataInputDeserializer input = new DataInputDeserializer(bytes, 0, bytes.length); + D dataStatistics; + + try { + dataStatistics = (D) statisticsSerializer.deserialize(input); + } catch (IOException e) { + throw new IllegalStateException("Fail to serialize data statistics", e); + } + + return dataStatistics; + } + + static <D extends DataStatistics<D, S>, S> byte[] serializeGlobalStatistics( + GlobalStatistics<D, S> globalStatistics, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer) + throws IOException { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(bytes); + + DataOutputSerializer outSerializer = new DataOutputSerializer(64); + out.writeLong(globalStatistics.checkpointId()); + statisticsSerializer.serialize(globalStatistics.dataStatistics(), outSerializer); + byte[] statisticsBytes = outSerializer.getCopyOfBuffer(); + out.writeInt(statisticsBytes.length); + out.write(statisticsBytes); + out.writeObject(globalStatistics.subtaskSet()); Review Comment: I know we need the subtask set to check completion. but do we need to serialize it for completed statistics? subtask set can be part of the `GlobalStatisticsTracker` state. but maybe don't need to be part of `GlobalStatistics`? ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java: ########## @@ -117,21 +134,36 @@ public void snapshotState(StateSnapshotContext context) throws Exception { long checkpointId = context.getCheckpointId(); int subTaskId = getRuntimeContext().getIndexOfThisSubtask(); LOG.info( - "Taking data statistics operator snapshot for checkpoint {} in subtask {}", + "Taking data statistics operator {} snapshot for checkpoint {} in subtask {}", + operatorName, checkpointId, subTaskId); + // Pass global statistics to partitioners so that all the operators refresh statistics + // at same checkpoint barrier + output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + // Only subtask 0 saves the state so that globalStatisticsState(UnionListState) stores // an exact copy of globalStatistics if (!globalStatistics.isEmpty() && getRuntimeContext().getIndexOfThisSubtask() == 0) { globalStatisticsState.clear(); - LOG.info("Saving global statistics {} to state in subtask {}", globalStatistics, subTaskId); + LOG.info( + "Saving operator {} global statistics {} to state in subtask {}", + operatorName, + globalStatistics, + subTaskId); globalStatisticsState.add(globalStatistics); } - // For now, we make it simple to send globalStatisticsState at checkpoint + // For now, we make it simple to send localStatistics at checkpoint operatorEventGateway.sendEventToCoordinator( - new DataStatisticsEvent<>(checkpointId, localStatistics)); + DataStatisticsEvent.create(checkpointId, localStatistics, statisticsSerializer)); + LOG.debug( + "Send operator {} local statistics {} from subtask {} at checkpoint {} to coordinator", Review Comment: maybe ``` "Subtask {} of operator {} sent local statistics to coordinator at checkpoint{}: {}", subTaskId, operatorName, checkpointId, localStatistics) ``` because `localStatistics` can be large, we want to keep it printed out in the end ########## flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java: ########## @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Map; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.coordination.EventReceivingTasks; +import org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext; +import org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.binary.BinaryRowData; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestDataStatisticsCoordinatorProvider { Review Comment: do we need to test `DataStatisticsCoordinatorProvider`? it seems like a trivial class. the test here seems almost identical as `Test DataStatisticsCoordinator` above. ########## flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java: ########## @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Map; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.coordination.EventReceivingTasks; +import org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext; +import org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.binary.BinaryRowData; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestDataStatisticsCoordinatorProvider { + private static final OperatorID OPERATOR_ID = new OperatorID(); + private static final int NUM_SUBTASKS = 1; + + private DataStatisticsCoordinatorProvider<MapDataStatistics, Map<RowData, Long>> provider; + private EventReceivingTasks receivingTasks; + private TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> + statisticsSerializer; + + @Before + public void before() { + statisticsSerializer = + MapDataStatisticsSerializer.fromKeySerializer( + new RowDataSerializer(RowType.of(new VarCharType()))); + provider = + new DataStatisticsCoordinatorProvider<>( + "DataStatisticsCoordinatorProviderTest", OPERATOR_ID, statisticsSerializer); Review Comment: nit: `DataStatisticsCoordinatorProviderTest`. maybe just use class name? ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatisticsTracker.java: ########## @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * GlobalStatisticsTracker is used by {@link DataStatisticsCoordinator} to track the in progress + * {@link GlobalStatistics} received from {@link DataStatisticsOperator} subtasks for specific + * checkpoint. + */ +@Internal +class GlobalStatisticsTracker<D extends DataStatistics<D, S>, S> { + private static final Logger LOG = LoggerFactory.getLogger(GlobalStatisticsTracker.class); + private static final double EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE = 90; + private final String operatorName; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final int parallelism; + private volatile GlobalStatistics<D, S> inProgressStatistics; + + GlobalStatisticsTracker( + String operatorName, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer, + int parallelism) { + this.operatorName = operatorName; + this.statisticsSerializer = statisticsSerializer; + this.parallelism = parallelism; + } + + GlobalStatistics<D, S> receiveDataStatisticEventAndCheckCompletion( + int subtask, DataStatisticsEvent<D, S> event) { + long checkpointId = event.checkpointId(); + + if (inProgressStatistics != null && inProgressStatistics.checkpointId() > checkpointId) { + LOG.debug( + "Expect data statistics for operator {} checkpoint {}, but receive event from older checkpoint {}. Ignore it.", + operatorName, + inProgressStatistics.checkpointId(), + checkpointId); + return null; + } + + GlobalStatistics<D, S> completedStatistics = null; + if (inProgressStatistics != null && inProgressStatistics.checkpointId() < checkpointId) { + if ((double) inProgressStatistics.aggregatedSubtasksCount() / parallelism * 100 + >= EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE) { + completedStatistics = inProgressStatistics; + LOG.info( + "Received data statistics from {} operators {} out of total {} for checkpoint {}. " + + "It's more than the expected percentage {}. Complete data statistics aggregation {}", + inProgressStatistics.aggregatedSubtasksCount(), + operatorName, + parallelism, + inProgressStatistics.checkpointId(), + EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE, + completedStatistics); + } else { + LOG.info( + "Received data statistics from {} operators {} out of total {} for checkpoint {}. " + + "It's less than the expected percentage {}. Dropping the incomplete aggregate " Review Comment: similar to above `Aborting the incomplete aggregation for checkpoint {} and starting a new aggregation for checkpoint {}` ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatisticsTracker.java: ########## @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * GlobalStatisticsTracker is used by {@link DataStatisticsCoordinator} to track the in progress + * {@link GlobalStatistics} received from {@link DataStatisticsOperator} subtasks for specific + * checkpoint. + */ +@Internal +class GlobalStatisticsTracker<D extends DataStatistics<D, S>, S> { + private static final Logger LOG = LoggerFactory.getLogger(GlobalStatisticsTracker.class); + private static final double EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE = 90; + private final String operatorName; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final int parallelism; + private volatile GlobalStatistics<D, S> inProgressStatistics; + + GlobalStatisticsTracker( + String operatorName, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer, + int parallelism) { + this.operatorName = operatorName; + this.statisticsSerializer = statisticsSerializer; + this.parallelism = parallelism; + } + + GlobalStatistics<D, S> receiveDataStatisticEventAndCheckCompletion( + int subtask, DataStatisticsEvent<D, S> event) { + long checkpointId = event.checkpointId(); + + if (inProgressStatistics != null && inProgressStatistics.checkpointId() > checkpointId) { + LOG.debug( Review Comment: I am debating with myself whether this should be debug or info ########## flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java: ########## @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.coordination.EventReceivingTasks; +import org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.binary.BinaryRowData; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.util.ExceptionUtils; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestDataStatisticsCoordinator { + private static final String OPERATOR_NAME = "TestCoordinator"; + private static final OperatorID TEST_OPERATOR_ID = new OperatorID(1234L, 5678L); + private static final int NUM_SUBTASKS = 2; + private TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> + statisticsSerializer; + + private EventReceivingTasks receivingTasks; + private DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> + dataStatisticsCoordinator; + + @Before + public void before() throws Exception { + receivingTasks = EventReceivingTasks.createForRunningTasks(); + statisticsSerializer = + MapDataStatisticsSerializer.fromKeySerializer( + new RowDataSerializer(RowType.of(new VarCharType()))); + + dataStatisticsCoordinator = + new DataStatisticsCoordinator<>( + OPERATOR_NAME, + new MockOperatorCoordinatorContext(TEST_OPERATOR_ID, NUM_SUBTASKS), + statisticsSerializer); + } + + private void tasksReady() throws Exception { + dataStatisticsCoordinator.start(); + setAllTasksReady(NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks); + } + + @Test + public void testThrowExceptionWhenNotStarted() { + String failureMessage = + "Call should fail when data statistics coordinator has not started yet."; + + Assert.assertThrows( + failureMessage, + IllegalStateException.class, + () -> + dataStatisticsCoordinator.handleEventFromOperator( + 0, + 0, + DataStatisticsEvent.create(0, new MapDataStatistics(), statisticsSerializer))); + Assert.assertThrows( + failureMessage, + IllegalStateException.class, + () -> dataStatisticsCoordinator.executionAttemptFailed(0, 0, null)); + Assert.assertThrows( + failureMessage, + IllegalStateException.class, + () -> dataStatisticsCoordinator.checkpointCoordinator(0, null)); + } + + @Test + public void testDataStatisticsEventHandling() throws Exception { + tasksReady(); + RowType rowType = RowType.of(new VarCharType()); + BinaryRowData binaryRowDataA = + new RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a"))); + BinaryRowData binaryRowDataB = + new RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b"))); + BinaryRowData binaryRowDataC = + new RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("c"))); + + MapDataStatistics checkpoint1Subtask0DataStatistic = new MapDataStatistics(); + checkpoint1Subtask0DataStatistic.add(binaryRowDataA); + checkpoint1Subtask0DataStatistic.add(binaryRowDataB); + checkpoint1Subtask0DataStatistic.add(binaryRowDataB); + checkpoint1Subtask0DataStatistic.add(binaryRowDataC); + checkpoint1Subtask0DataStatistic.add(binaryRowDataC); + checkpoint1Subtask0DataStatistic.add(binaryRowDataC); + DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>> + checkpoint1Subtask0DataStatisticEvent = + DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, statisticsSerializer); + MapDataStatistics checkpoint1Subtask1DataStatistic = new MapDataStatistics(); + checkpoint1Subtask1DataStatistic.add(binaryRowDataA); + checkpoint1Subtask1DataStatistic.add(binaryRowDataB); + checkpoint1Subtask1DataStatistic.add(binaryRowDataC); + checkpoint1Subtask1DataStatistic.add(binaryRowDataC); + DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>> + checkpoint1Subtask1DataStatisticEvent = + DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, statisticsSerializer); + // Handle events from operators for checkpoint 1 + dataStatisticsCoordinator.handleEventFromOperator(0, 0, checkpoint1Subtask0DataStatisticEvent); + dataStatisticsCoordinator.handleEventFromOperator(1, 0, checkpoint1Subtask1DataStatisticEvent); + + waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + // Verify global data statistics is the aggregation of all subtasks data statistics + MapDataStatistics globalDataStatistics = + (MapDataStatistics) dataStatisticsCoordinator.completedStatistics().dataStatistics(); + assertThat(globalDataStatistics.statistics()) + .containsExactlyInAnyOrderEntriesOf( + ImmutableMap.of( + binaryRowDataA, + checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataA) + + (long) checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataA), + binaryRowDataB, + checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataB) + + (long) checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataB), + binaryRowDataC, + checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataC) + + (long) checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataC))); + } + + static void setAllTasksReady( + int subtasks, + DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> dataStatisticsCoordinator, + EventReceivingTasks receivingTasks) { + for (int i = 0; i < subtasks; i++) { + dataStatisticsCoordinator.executionAttemptReady( + i, 0, receivingTasks.createGatewayForSubtask(i, 0)); + } + } + + static void waitForCoordinatorToProcessActions( + DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> coordinator) { + CompletableFuture<Void> future = new CompletableFuture<>(); + coordinator.callInCoordinatorThread( + () -> { + future.complete(null); + return null; + }, + "Coordinator fails to process action"); + + try { + future.get(); + } catch (InterruptedException e) { + throw new AssertionError("test interrupted"); + } catch (ExecutionException e) { + ExceptionUtils.rethrow(ExceptionUtils.stripExecutionException(e)); + } + } + + static byte[] waitForCheckpoint( Review Comment: it doesn't seem that this util method was used by this class? should it be moved to the other class where it is actually used ########## flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java: ########## @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.coordination.EventReceivingTasks; +import org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.binary.BinaryRowData; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.util.ExceptionUtils; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestDataStatisticsCoordinator { + private static final String OPERATOR_NAME = "TestCoordinator"; + private static final OperatorID TEST_OPERATOR_ID = new OperatorID(1234L, 5678L); + private static final int NUM_SUBTASKS = 2; + private TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> + statisticsSerializer; + + private EventReceivingTasks receivingTasks; + private DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> + dataStatisticsCoordinator; + + @Before + public void before() throws Exception { + receivingTasks = EventReceivingTasks.createForRunningTasks(); + statisticsSerializer = + MapDataStatisticsSerializer.fromKeySerializer( + new RowDataSerializer(RowType.of(new VarCharType()))); + + dataStatisticsCoordinator = + new DataStatisticsCoordinator<>( + OPERATOR_NAME, + new MockOperatorCoordinatorContext(TEST_OPERATOR_ID, NUM_SUBTASKS), + statisticsSerializer); + } + + private void tasksReady() throws Exception { + dataStatisticsCoordinator.start(); + setAllTasksReady(NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks); + } + + @Test + public void testThrowExceptionWhenNotStarted() { + String failureMessage = + "Call should fail when data statistics coordinator has not started yet."; + + Assert.assertThrows( + failureMessage, + IllegalStateException.class, + () -> + dataStatisticsCoordinator.handleEventFromOperator( + 0, + 0, + DataStatisticsEvent.create(0, new MapDataStatistics(), statisticsSerializer))); + Assert.assertThrows( + failureMessage, + IllegalStateException.class, + () -> dataStatisticsCoordinator.executionAttemptFailed(0, 0, null)); + Assert.assertThrows( + failureMessage, + IllegalStateException.class, + () -> dataStatisticsCoordinator.checkpointCoordinator(0, null)); + } + + @Test + public void testDataStatisticsEventHandling() throws Exception { + tasksReady(); + RowType rowType = RowType.of(new VarCharType()); + BinaryRowData binaryRowDataA = Review Comment: why do we need to convert `GenericRowData` to `BinaryRowData`? at least we should document the reason in comment. This question applies to all usage of `BinaryRowData` in other test classes. ########## flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestGlobalStatisticsTracker.java: ########## @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Map; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.binary.BinaryRowData; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestGlobalStatisticsTracker { + private static final int NUM_SUBTASKS = 2; + private final RowType rowType = RowType.of(new VarCharType()); + private final BinaryRowData binaryRowDataA = + new RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a"))); + private final BinaryRowData binaryRowDataB = + new RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b"))); + private final TypeSerializer<RowData> rowSerializer = new RowDataSerializer(rowType); + private final TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> + statisticsSerializer = MapDataStatisticsSerializer.fromKeySerializer(rowSerializer); + private GlobalStatisticsTracker<MapDataStatistics, Map<RowData, Long>> globalStatisticsTracker; + + @Before + public void before() throws Exception { + globalStatisticsTracker = + new GlobalStatisticsTracker<>("testOperator", statisticsSerializer, NUM_SUBTASKS); + } + + @Test + public void receiveDataStatisticEventAndCheckCompletionTest() { Review Comment: this test method should be broken into multiple methods: one for each scenario -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
