This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit b3589afff6477676c4fe2201aad9d37fab59ae29
Author: ifndef-SleePy <[email protected]>
AuthorDate: Mon Jan 23 01:12:35 2023 +0800

    [FLINK-30755][runtime] Support SupportsConcurrentExecutionAttempts property 
of SinkV2
    
    This closes #21765.
---
 .../translators/SinkTransformationTranslator.java  |  45 ++++---
 .../api/graph/StreamingJobGraphGeneratorTest.java  | 129 ++++++++++++++++++
 .../scheduling/SpeculativeSchedulerITCase.java     | 147 +++++++++++++++++++++
 3 files changed, 304 insertions(+), 17 deletions(-)

diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/translators/SinkTransformationTranslator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/translators/SinkTransformationTranslator.java
index e4ba29b4512..fb456efda3a 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/translators/SinkTransformationTranslator.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/translators/SinkTransformationTranslator.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.translators;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.SupportsConcurrentExecutionAttempts;
 import org.apache.flink.api.common.operators.SlotSharingGroup;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.connector.sink2.Sink;
@@ -48,10 +49,8 @@ import javax.annotation.Nullable;
 
 import java.util.Collection;
 import java.util.Collections;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Optional;
-import java.util.Set;
 import java.util.function.BiConsumer;
 import java.util.function.Function;
 
@@ -141,7 +140,8 @@ public class SinkTransformationTranslator<Input, Output>
                         adjustTransformations(
                                 prewritten,
                                 ((WithPreWriteTopology<T>) 
sink)::addPreWriteTopology,
-                                true);
+                                true,
+                                sink instanceof 
SupportsConcurrentExecutionAttempts);
             }
 
             if (sink instanceof TwoPhaseCommittingSink) {
@@ -154,16 +154,15 @@ public class SinkTransformationTranslator<Input, Output>
                                         WRITER_NAME,
                                         CommittableMessageTypeInfo.noOutput(),
                                         new SinkWriterOperatorFactory<>(sink)),
-                        false);
+                        false,
+                        sink instanceof SupportsConcurrentExecutionAttempts);
             }
 
-            final Set<Integer> expandedSinks = new HashSet<>();
             final List<Transformation<?>> sinkTransformations =
                     executionEnvironment
                             .getTransformations()
                             .subList(sizeBefore, 
executionEnvironment.getTransformations().size());
-            sinkTransformations.forEach(t -> 
expandedSinks.addAll(context.transform(t)));
-            context.getStreamGraph().registerExpandedSinks(expandedSinks);
+            sinkTransformations.forEach(context::transform);
 
             // Remove all added sink subtransformations to avoid duplications 
and allow additional
             // expansions
@@ -188,7 +187,8 @@ public class SinkTransformationTranslator<Input, Output>
                                             WRITER_NAME,
                                             typeInformation,
                                             new 
SinkWriterOperatorFactory<>(sink)),
-                            false);
+                            false,
+                            sink instanceof 
SupportsConcurrentExecutionAttempts);
 
             DataStream<CommittableMessage<CommT>> precommitted = 
addFailOverRegion(written);
 
@@ -197,7 +197,8 @@ public class SinkTransformationTranslator<Input, Output>
                         adjustTransformations(
                                 precommitted,
                                 ((WithPreCommitTopology<T, CommT>) 
sink)::addPreCommitTopology,
-                                true);
+                                true,
+                                false);
             }
 
             DataStream<CommittableMessage<CommT>> committed =
@@ -211,6 +212,7 @@ public class SinkTransformationTranslator<Input, Output>
                                                     committingSink,
                                                     isBatchMode,
                                                     isCheckpointingEnabled)),
+                            false,
                             false);
 
             if (sink instanceof WithPostCommitTopology) {
@@ -221,7 +223,8 @@ public class SinkTransformationTranslator<Input, Output>
                             ((WithPostCommitTopology<T, CommT>) 
sink).addPostCommitTopology(pc);
                             return null;
                         },
-                        true);
+                        true,
+                        false);
             }
         }
 
@@ -254,7 +257,8 @@ public class SinkTransformationTranslator<Input, Output>
         private <I, R> R adjustTransformations(
                 DataStream<I> inputStream,
                 Function<DataStream<I>, R> action,
-                boolean isExpandedTopology) {
+                boolean isExpandedTopology,
+                boolean supportsConcurrentExecutionAttempts) {
 
             // Reset the environment parallelism temporarily before adjusting 
transformations,
             // we can therefore be aware of any customized parallelism of the 
sub topology
@@ -333,13 +337,20 @@ public class SinkTransformationTranslator<Input, Output>
                     
subTransformation.setMaxParallelism(transformation.getMaxParallelism());
                 }
 
-                if (transformation.getChainingStrategy() == null
-                        || !(subTransformation instanceof 
PhysicalTransformation)) {
-                    continue;
-                }
+                if (subTransformation instanceof PhysicalTransformation) {
+                    PhysicalTransformation<?> physicalSubTransformation =
+                            (PhysicalTransformation<?>) subTransformation;
+
+                    if (transformation.getChainingStrategy() != null) {
+                        physicalSubTransformation.setChainingStrategy(
+                                transformation.getChainingStrategy());
+                    }
 
-                ((PhysicalTransformation<?>) subTransformation)
-                        
.setChainingStrategy(transformation.getChainingStrategy());
+                    // overrides the supportsConcurrentExecutionAttempts of 
transformation because
+                    // it's not allowed to specify fine-grained concurrent 
execution attempts yet
+                    
physicalSubTransformation.setSupportsConcurrentExecutionAttempts(
+                            supportsConcurrentExecutionAttempts);
+                }
             }
 
             // Restore the previous parallelism of the environment before 
adjusting transformations
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
index 16f16084803..a46b200b31d 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.api.graph;
 import org.apache.flink.api.common.BatchShuffleMode;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.RuntimeExecutionMode;
+import org.apache.flink.api.common.SupportsConcurrentExecutionAttempts;
 import org.apache.flink.api.common.eventtime.WatermarkStrategy;
 import org.apache.flink.api.common.functions.FilterFunction;
 import org.apache.flink.api.common.functions.FlatMapFunction;
@@ -33,6 +34,8 @@ import 
org.apache.flink.api.common.operators.util.UserCodeWrapper;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.connector.sink2.Committer;
+import org.apache.flink.api.connector.sink2.TwoPhaseCommittingSink;
 import org.apache.flink.api.connector.source.Boundedness;
 import org.apache.flink.api.connector.source.lib.NumberSequenceSource;
 import org.apache.flink.api.connector.source.mocks.MockSource;
@@ -44,6 +47,7 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.ExecutionOptions;
 import org.apache.flink.configuration.PipelineOptions;
 import org.apache.flink.configuration.TaskManagerOptions;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
 import org.apache.flink.core.memory.ManagedMemoryUseCase;
 import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
@@ -63,6 +67,11 @@ import 
org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
 import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
 import org.apache.flink.runtime.operators.util.TaskConfig;
 import org.apache.flink.streaming.api.CheckpointingMode;
+import org.apache.flink.streaming.api.connector.sink2.CommittableMessage;
+import 
org.apache.flink.streaming.api.connector.sink2.CommittableMessageTypeInfo;
+import org.apache.flink.streaming.api.connector.sink2.WithPostCommitTopology;
+import org.apache.flink.streaming.api.connector.sink2.WithPreCommitTopology;
+import org.apache.flink.streaming.api.connector.sink2.WithPreWriteTopology;
 import org.apache.flink.streaming.api.datastream.CachedDataStream;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.DataStreamSink;
@@ -118,6 +127,7 @@ import java.io.ObjectOutputStream;
 import java.lang.reflect.Method;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -1786,6 +1796,125 @@ class StreamingJobGraphGeneratorTest {
         }
     }
 
+    @Test
+    void testSinkSupportConcurrentExecutionAttempts() {
+        final StreamExecutionEnvironment env =
+                StreamExecutionEnvironment.getExecutionEnvironment(new 
Configuration());
+        env.setRuntimeMode(RuntimeExecutionMode.BATCH);
+
+        final DataStream<Integer> source = env.fromElements(1, 2, 
3).name("source");
+        // source -> (map1 -> map2) -> sink
+        source.rebalance()
+                .sinkTo(new TestSinkWithSupportsConcurrentExecutionAttempts())
+                .name("sink");
+
+        final StreamGraph streamGraph = env.getStreamGraph();
+        final JobGraph jobGraph = 
StreamingJobGraphGenerator.createJobGraph(streamGraph);
+        assertThat(jobGraph.getNumberOfVertices()).isEqualTo(6);
+        for (JobVertex jobVertex : jobGraph.getVertices()) {
+            if (jobVertex.getName().contains("source")) {
+                
assertThat(jobVertex.isSupportsConcurrentExecutionAttempts()).isTrue();
+            } else if (jobVertex.getName().contains("pre-writer")) {
+                
assertThat(jobVertex.isSupportsConcurrentExecutionAttempts()).isTrue();
+            } else if (jobVertex.getName().contains("Writer")) {
+                
assertThat(jobVertex.isSupportsConcurrentExecutionAttempts()).isTrue();
+            } else if (jobVertex.getName().contains("pre-committer")) {
+                
assertThat(jobVertex.isSupportsConcurrentExecutionAttempts()).isFalse();
+            } else if (jobVertex.getName().contains("post-committer")) {
+                
assertThat(jobVertex.isSupportsConcurrentExecutionAttempts()).isFalse();
+            } else if (jobVertex.getName().contains("Committer")) {
+                
assertThat(jobVertex.isSupportsConcurrentExecutionAttempts()).isFalse();
+            } else {
+                Assertions.fail("Unexpected job vertex " + 
jobVertex.getName());
+            }
+        }
+    }
+
+    private static class TestSinkWithSupportsConcurrentExecutionAttempts
+            implements SupportsConcurrentExecutionAttempts,
+                    TwoPhaseCommittingSink<Integer, Void>,
+                    WithPreWriteTopology<Integer>,
+                    WithPreCommitTopology<Integer, Void>,
+                    WithPostCommitTopology<Integer, Void> {
+
+        @Override
+        public PrecommittingSinkWriter<Integer, Void> createWriter(InitContext 
context)
+                throws IOException {
+            return new PrecommittingSinkWriter<Integer, Void>() {
+                @Override
+                public Collection<Void> prepareCommit() throws IOException, 
InterruptedException {
+                    return null;
+                }
+
+                @Override
+                public void write(Integer element, Context context)
+                        throws IOException, InterruptedException {}
+
+                @Override
+                public void flush(boolean endOfInput) throws IOException, 
InterruptedException {}
+
+                @Override
+                public void close() throws Exception {}
+            };
+        }
+
+        @Override
+        public Committer<Void> createCommitter() throws IOException {
+            return new Committer<Void>() {
+                @Override
+                public void commit(Collection<CommitRequest<Void>> 
committables)
+                        throws IOException, InterruptedException {}
+
+                @Override
+                public void close() throws Exception {}
+            };
+        }
+
+        @Override
+        public SimpleVersionedSerializer<Void> getCommittableSerializer() {
+            return new SimpleVersionedSerializer<Void>() {
+                @Override
+                public int getVersion() {
+                    return 0;
+                }
+
+                @Override
+                public byte[] serialize(Void obj) throws IOException {
+                    return new byte[0];
+                }
+
+                @Override
+                public Void deserialize(int version, byte[] serialized) throws 
IOException {
+                    return null;
+                }
+            };
+        }
+
+        @Override
+        public void addPostCommitTopology(DataStream<CommittableMessage<Void>> 
committables) {
+            committables
+                    .map(v -> v)
+                    .name("post-committer")
+                    .returns(CommittableMessageTypeInfo.noOutput())
+                    .rebalance();
+        }
+
+        @Override
+        public DataStream<CommittableMessage<Void>> addPreCommitTopology(
+                DataStream<CommittableMessage<Void>> committables) {
+            return committables
+                    .map(v -> v)
+                    .name("pre-committer")
+                    .returns(CommittableMessageTypeInfo.noOutput())
+                    .rebalance();
+        }
+
+        @Override
+        public DataStream<Integer> addPreWriteTopology(DataStream<Integer> 
inputDataStream) {
+            return inputDataStream.map(v -> v).name("pre-writer").rebalance();
+        }
+    }
+
     private static class SerializationTestOperatorFactory
             extends AbstractStreamOperatorFactory<Integer>
             implements CoordinatedOperatorFactory<Integer> {
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/SpeculativeSchedulerITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/SpeculativeSchedulerITCase.java
index b81d83eaeeb..6efae31ecf3 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/SpeculativeSchedulerITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/SpeculativeSchedulerITCase.java
@@ -19,11 +19,15 @@
 package org.apache.flink.test.scheduling;
 
 import org.apache.flink.api.common.RuntimeExecutionMode;
+import org.apache.flink.api.common.SupportsConcurrentExecutionAttempts;
 import org.apache.flink.api.common.eventtime.WatermarkStrategy;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.io.GenericInputFormat;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.sink2.Committer;
+import org.apache.flink.api.connector.sink2.TwoPhaseCommittingSink;
+import 
org.apache.flink.api.connector.sink2.TwoPhaseCommittingSink.PrecommittingSinkWriter;
 import org.apache.flink.api.connector.source.Boundedness;
 import org.apache.flink.api.connector.source.ReaderOutput;
 import org.apache.flink.api.connector.source.SourceReader;
@@ -31,6 +35,7 @@ import 
org.apache.flink.api.connector.source.SourceReaderContext;
 import org.apache.flink.api.connector.source.lib.NumberSequenceSource;
 import org.apache.flink.api.connector.source.lib.util.IteratorSourceReader;
 import org.apache.flink.api.connector.source.lib.util.IteratorSourceSplit;
+import org.apache.flink.api.java.tuple.Tuple3;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.JobManagerOptions;
 import org.apache.flink.configuration.MemorySize;
@@ -41,6 +46,7 @@ import org.apache.flink.configuration.TaskManagerOptions;
 import org.apache.flink.core.execution.JobClient;
 import org.apache.flink.core.io.GenericInputSplit;
 import org.apache.flink.core.io.InputStatus;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
 import org.apache.flink.runtime.scheduler.adaptivebatch.SpeculativeScheduler;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.DataStreamSource;
@@ -49,6 +55,7 @@ import 
org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
 import 
org.apache.flink.streaming.api.functions.source.InputFormatSourceFunction;
 import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.util.InstantiationUtil;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -57,6 +64,8 @@ import org.junit.jupiter.api.io.TempDir;
 import java.io.IOException;
 import java.nio.file.Path;
 import java.time.Duration;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
@@ -64,6 +73,7 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Consumer;
 import java.util.function.Function;
@@ -161,6 +171,19 @@ class SpeculativeSchedulerITCase {
         checkResults();
     }
 
+    @Test
+    public void testSpeculativeSlowSink() throws Exception {
+        executeJob(this::setupSpeculativeSlowSink);
+        waitUntilJobArchived();
+
+        checkResults();
+
+        // no speculative executions for committer
+        assertThat(DummyCommitter.attempts.get()).isEqualTo(parallelism);
+        // there is a speculative execution for writer
+        assertThat(DummyCommitter.foundSpeculativeWriter).isTrue();
+    }
+
     private void checkResults() {
         final Map<Long, Long> numberCountResultMap =
                 numberCountResults.values().stream()
@@ -291,6 +314,18 @@ class SpeculativeSchedulerITCase {
         addSink(source);
     }
 
+    private void setupSpeculativeSlowSink(StreamExecutionEnvironment env) {
+        final DataStream<Long> source =
+                env.fromSequence(0, NUMBERS_TO_PRODUCE - 1)
+                        .setParallelism(parallelism)
+                        .name("source")
+                        .slotSharingGroup("sourceGroup");
+        source.sinkTo(new SpeculativeSink())
+                .setParallelism(parallelism)
+                .name("sink")
+                .slotSharingGroup("sinkGroup");
+    }
+
     private void addSink(DataStream<Long> dataStream) {
         dataStream
                 .rebalance()
@@ -421,6 +456,118 @@ class SpeculativeSchedulerITCase {
         }
     }
 
+    private static class SpeculativeSink
+            implements TwoPhaseCommittingSink<Long, Tuple3<Integer, Integer, 
Map<Long, Long>>>,
+                    SupportsConcurrentExecutionAttempts {
+
+        @Override
+        public PrecommittingSinkWriter<Long, Tuple3<Integer, Integer, 
Map<Long, Long>>>
+                createWriter(InitContext context) {
+            return new DummyPrecommittingSinkWriter(
+                    context.getSubtaskId(), context.getAttemptNumber());
+        }
+
+        @Override
+        public Committer<Tuple3<Integer, Integer, Map<Long, Long>>> 
createCommitter() {
+            return new DummyCommitter();
+        }
+
+        @Override
+        public SimpleVersionedSerializer<Tuple3<Integer, Integer, Map<Long, 
Long>>>
+                getCommittableSerializer() {
+            return new SimpleVersionedSerializer<Tuple3<Integer, Integer, 
Map<Long, Long>>>() {
+                @Override
+                public int getVersion() {
+                    return 0;
+                }
+
+                @Override
+                public byte[] serialize(Tuple3<Integer, Integer, Map<Long, 
Long>> obj)
+                        throws IOException {
+                    return InstantiationUtil.serializeObject(obj);
+                }
+
+                @Override
+                public Tuple3<Integer, Integer, Map<Long, Long>> deserialize(
+                        int version, byte[] serialized) throws IOException {
+                    try {
+                        return InstantiationUtil.deserializeObject(
+                                serialized, 
Thread.currentThread().getContextClassLoader());
+                    } catch (ClassNotFoundException e) {
+                        throw new RuntimeException(e);
+                    }
+                }
+            };
+        }
+    }
+
+    private static class DummyPrecommittingSinkWriter
+            implements PrecommittingSinkWriter<Long, Tuple3<Integer, Integer, 
Map<Long, Long>>> {
+
+        private final int subTaskIndex;
+
+        private final int attemptNumber;
+
+        public DummyPrecommittingSinkWriter(int subTaskIndex, int 
attemptNumber) {
+            this.subTaskIndex = subTaskIndex;
+            this.attemptNumber = attemptNumber;
+        }
+
+        private final Map<Long, Long> numberCountResult = new HashMap<>();
+
+        @Override
+        public void write(Long value, Context context) throws IOException, 
InterruptedException {
+            numberCountResult.merge(value, 1L, Long::sum);
+            maybeSleep();
+        }
+
+        @Override
+        public void flush(boolean endOfInput) {}
+
+        @Override
+        public Collection<Tuple3<Integer, Integer, Map<Long, Long>>> 
prepareCommit() {
+            return Collections.singleton(Tuple3.of(subTaskIndex, 
attemptNumber, numberCountResult));
+        }
+
+        @Override
+        public void close() throws Exception {}
+    }
+
+    private static class DummyCommitter
+            implements Committer<Tuple3<Integer, Integer, Map<Long, Long>>> {
+
+        private static AtomicBoolean blocked = new AtomicBoolean(false);
+        private static AtomicInteger attempts = new AtomicInteger(0);
+
+        private static volatile boolean foundSpeculativeWriter;
+
+        public DummyCommitter() {
+            attempts.incrementAndGet();
+        }
+
+        @Override
+        public void commit(
+                Collection<CommitRequest<Tuple3<Integer, Integer, Map<Long, 
Long>>>> committables)
+                throws InterruptedException {
+
+            for (CommitRequest<Tuple3<Integer, Integer, Map<Long, Long>>> 
request : committables) {
+                Tuple3<Integer, Integer, Map<Long, Long>> committable = 
request.getCommittable();
+                numberCountResults.put(committable.f0, committable.f2);
+                // attempt number larger than 0
+                if (committable.f1 > 0) {
+                    foundSpeculativeWriter = true;
+                }
+            }
+
+            if (!blocked.getAndSet(true)) {
+                Thread.sleep(5000);
+            }
+        }
+
+        @Override
+        public void close() throws Exception {}
+    }
+
     private static void maybeSleep() {
         if (slowTaskCounter.getAndDecrement() > 0) {
             try {

Reply via email to