This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 6e445064f6a20caed16b99b3d06398a51c404e84
Author: ifndef-SleePy <[email protected]>
AuthorDate: Sun Jan 29 01:35:10 2023 +0800

    [FLINK-30799][runtime] Make SinkFunction support speculative execution 
through implementing SupportsConcurrentExecutionAttempts interface
    
    This closes #21773.
---
 .../transformations/LegacySinkTransformation.java  | 20 ++++++
 .../api/graph/StreamingJobGraphGeneratorTest.java  | 51 ++++++++++++++
 .../scheduling/SpeculativeSchedulerITCase.java     | 78 ++++++++++++++++++++++
 3 files changed, 149 insertions(+)

diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java
index 99e0124ece1..523bec12c3c 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java
@@ -20,13 +20,17 @@ package org.apache.flink.streaming.api.transformations;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.SupportsConcurrentExecutionAttempts;
+import org.apache.flink.api.common.functions.Function;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.dag.Transformation;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamSink;
+import org.apache.flink.streaming.api.operators.UserFunctionProvider;
 
 import org.apache.flink.shaded.guava30.com.google.common.collect.Lists;
 
@@ -131,6 +135,22 @@ public class LegacySinkTransformation<T> extends 
PhysicalTransformation<T> {
 
     @Override
     public boolean isSupportsConcurrentExecutionAttempts() {
+        // first, check if the feature is disabled in physical transformation
+        if (!super.isSupportsConcurrentExecutionAttempts()) {
+            return false;
+        }
+        // second, check if the feature can be supported
+        if (operatorFactory instanceof SimpleOperatorFactory) {
+            final StreamOperator<Object> operator =
+                    ((SimpleOperatorFactory<Object>) 
operatorFactory).getOperator();
+            if (operator instanceof UserFunctionProvider) {
+                final Function userFunction =
+                        ((UserFunctionProvider<?>) operator).getUserFunction();
+                if (userFunction instanceof 
SupportsConcurrentExecutionAttempts) {
+                    return true;
+                }
+            }
+        }
         return false;
     }
 }
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
index 0046d9055a5..ed8160bb7e3 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
@@ -1828,6 +1828,57 @@ class StreamingJobGraphGeneratorTest {
         }
     }
 
+    @Test
+    void testSinkFunctionNotSupportConcurrentExecutionAttempts() {
+        testWhetherSinkFunctionSupportsConcurrentExecutionAttempts(
+                new 
TestingSinkFunctionNotSupportConcurrentExecutionAttempts<>(), false);
+    }
+
+    @Test
+    void testSinkFunctionSupportConcurrentExecutionAttempts() {
+        testWhetherSinkFunctionSupportsConcurrentExecutionAttempts(
+                new TestingSinkFunctionSupportConcurrentExecutionAttempts<>(), 
true);
+    }
+
+    private static void 
testWhetherSinkFunctionSupportsConcurrentExecutionAttempts(
+            SinkFunction<Integer> function, boolean isSupported) {
+        final StreamExecutionEnvironment env =
+                StreamExecutionEnvironment.getExecutionEnvironment(new 
Configuration());
+        env.setRuntimeMode(RuntimeExecutionMode.BATCH);
+
+        final DataStream<Integer> source = env.fromElements(1, 2, 
3).name("source");
+        source.rebalance().addSink(function).name("sink");
+
+        final StreamGraph streamGraph = env.getStreamGraph();
+        final JobGraph jobGraph = 
StreamingJobGraphGenerator.createJobGraph(streamGraph);
+        assertThat(jobGraph.getNumberOfVertices()).isEqualTo(2);
+        for (JobVertex jobVertex : jobGraph.getVertices()) {
+            if (jobVertex.getName().contains("source")) {
+                
assertThat(jobVertex.isSupportsConcurrentExecutionAttempts()).isTrue();
+            } else if (jobVertex.getName().contains("sink")) {
+                if (isSupported) {
+                    
assertThat(jobVertex.isSupportsConcurrentExecutionAttempts()).isTrue();
+                } else {
+                    
assertThat(jobVertex.isSupportsConcurrentExecutionAttempts()).isFalse();
+                }
+            } else {
+                Assertions.fail("Unexpected job vertex " + 
jobVertex.getName());
+            }
+        }
+    }
+
+    private static class 
TestingSinkFunctionNotSupportConcurrentExecutionAttempts<T>
+            implements SinkFunction<T> {
+        @Override
+        public void invoke(T value, Context context) throws Exception {}
+    }
+
+    private static class 
TestingSinkFunctionSupportConcurrentExecutionAttempts<T>
+            implements SinkFunction<T>, SupportsConcurrentExecutionAttempts {
+        @Override
+        public void invoke(T value, Context context) throws Exception {}
+    }
+
     private static class TestSinkWithSupportsConcurrentExecutionAttempts
             implements SupportsConcurrentExecutionAttempts,
                     TwoPhaseCommittingSink<Integer, Void>,
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/SpeculativeSchedulerITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/SpeculativeSchedulerITCase.java
index 6efae31ecf3..3d316b3ecad 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/SpeculativeSchedulerITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/SpeculativeSchedulerITCase.java
@@ -184,6 +184,22 @@ class SpeculativeSchedulerITCase {
         assertThat(DummyCommitter.foundSpeculativeWriter).isTrue();
     }
 
+    @Test
+    public void testNonSpeculativeSlowSinkFunction() throws Exception {
+        executeJob(this::setupNonSpeculativeSlowSinkFunction);
+        waitUntilJobArchived();
+
+        checkResults();
+    }
+
+    @Test
+    public void testSpeculativeSlowSinkFunction() throws Exception {
+        executeJob(this::setupSpeculativeSlowSinkFunction);
+        waitUntilJobArchived();
+
+        checkResults();
+    }
+
     private void checkResults() {
         final Map<Long, Long> numberCountResultMap =
                 numberCountResults.values().stream()
@@ -326,6 +342,30 @@ class SpeculativeSchedulerITCase {
                 .slotSharingGroup("sinkGroup");
     }
 
+    private void 
setupNonSpeculativeSlowSinkFunction(StreamExecutionEnvironment env) {
+        final DataStream<Long> source =
+                env.fromSequence(0, NUMBERS_TO_PRODUCE - 1)
+                        .setParallelism(parallelism)
+                        .name("source")
+                        .slotSharingGroup("sourceGroup");
+        source.addSink(new NonSpeculativeSinkFunction())
+                .setParallelism(parallelism)
+                .name("sink")
+                .slotSharingGroup("sinkGroup");
+    }
+
+    private void setupSpeculativeSlowSinkFunction(StreamExecutionEnvironment 
env) {
+        final DataStream<Long> source =
+                env.fromSequence(0, NUMBERS_TO_PRODUCE - 1)
+                        .setParallelism(parallelism)
+                        .name("source")
+                        .slotSharingGroup("sourceGroup");
+        source.addSink(new SpeculativeSinkFunction())
+                .setParallelism(parallelism)
+                .name("sink")
+                .slotSharingGroup("sinkGroup");
+    }
+
     private void addSink(DataStream<Long> dataStream) {
         dataStream
                 .rebalance()
@@ -568,6 +608,44 @@ class SpeculativeSchedulerITCase {
         public void close() throws Exception {}
     }
 
+    private static class NonSpeculativeSinkFunction extends 
RichSinkFunction<Long> {
+
+        private final Map<Long, Long> numberCountResult = new HashMap<>();
+
+        @Override
+        public void invoke(Long value, Context context) throws Exception {
+            if (slowTaskCounter.getAndDecrement() > 0) {
+                Thread.sleep(5000);
+            }
+            numberCountResult.merge(value, 1L, Long::sum);
+        }
+
+        @Override
+        public void finish() {
+            if (getRuntimeContext().getAttemptNumber() == 0) {
+                numberCountResults.put(
+                        getRuntimeContext().getIndexOfThisSubtask(), 
numberCountResult);
+            }
+        }
+    }
+
+    private static class SpeculativeSinkFunction extends RichSinkFunction<Long>
+            implements SupportsConcurrentExecutionAttempts {
+
+        private final Map<Long, Long> numberCountResult = new HashMap<>();
+
+        @Override
+        public void invoke(Long value, Context context) throws Exception {
+            numberCountResult.merge(value, 1L, Long::sum);
+            maybeSleep();
+        }
+
+        @Override
+        public void finish() {
+            
numberCountResults.put(getRuntimeContext().getIndexOfThisSubtask(), 
numberCountResult);
+        }
+    }
+
     private static void maybeSleep() {
         if (slowTaskCounter.getAndDecrement() > 0) {
             try {

Reply via email to