This is an automated email from the ASF dual-hosted git repository.

fanrui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 9f20bc41db85b0928cd98c172a85cb43076e0ec0
Author: Rui Fan <[email protected]>
AuthorDate: Sun Dec 17 17:12:16 2023 +0800

    [FLINK-33565][Scheduler] ConcurrentExceptions works with exception merging
---
 .../failover/ExecutionFailureHandler.java          | 11 ++--
 .../failover/FailureHandlingResult.java            | 34 +++++++++--
 .../flink/runtime/scheduler/SchedulerBase.java     | 29 +++++++--
 .../FailureHandlingResultSnapshot.java             | 20 ++++++-
 .../RootExceptionHistoryEntry.java                 | 36 ++++++++----
 .../failover/ExecutionFailureHandlerTest.java      | 58 +++++++++++++++++-
 .../failover/FailureHandlingResultTest.java        | 28 ++++++---
 .../failover/TestRestartBackoffTimeStrategy.java   | 16 ++++-
 .../runtime/scheduler/DefaultSchedulerTest.java    | 21 +++----
 .../FailureHandlingResultSnapshotTest.java         | 16 +++--
 .../RootExceptionHistoryEntryTest.java             | 68 +++++++++++++---------
 11 files changed, 251 insertions(+), 86 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/ExecutionFailureHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/ExecutionFailureHandler.java
index af1829de41f..0ff7f673a55 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/ExecutionFailureHandler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/ExecutionFailureHandler.java
@@ -157,10 +157,11 @@ public class ExecutionFailureHandler {
                     new JobException("The failure is not recoverable", cause),
                     timestamp,
                     failureLabels,
-                    globalFailure);
+                    globalFailure,
+                    true);
         }
 
-        restartBackoffTimeStrategy.notifyFailure(cause);
+        boolean isNewAttempt = restartBackoffTimeStrategy.notifyFailure(cause);
         if (restartBackoffTimeStrategy.canRestart()) {
             numberOfRestarts++;
 
@@ -171,7 +172,8 @@ public class ExecutionFailureHandler {
                     failureLabels,
                     verticesToRestart,
                     restartBackoffTimeStrategy.getBackoffTime(),
-                    globalFailure);
+                    globalFailure,
+                    isNewAttempt);
         } else {
             return FailureHandlingResult.unrecoverable(
                     failedExecution,
@@ -179,7 +181,8 @@ public class ExecutionFailureHandler {
                             "Recovery is suppressed by " + 
restartBackoffTimeStrategy, cause),
                     timestamp,
                     failureLabels,
-                    globalFailure);
+                    globalFailure,
+                    isNewAttempt);
         }
     }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/FailureHandlingResult.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/FailureHandlingResult.java
index f3a43c4f471..68c1983b0fe 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/FailureHandlingResult.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/FailureHandlingResult.java
@@ -66,6 +66,9 @@ public class FailureHandlingResult {
     /** True if the original failure was a global failure. */
     private final boolean globalFailure;
 
+    /** True if current failure is the root cause instead of concurrent 
exceptions. */
+    private final boolean isRootCause;
+
     /**
      * Creates a result of a set of tasks to restart to recover from the 
failure.
      *
@@ -77,6 +80,7 @@ public class FailureHandlingResult {
      * @param verticesToRestart containing task vertices to restart to recover 
from the failure.
      *     {@code null} indicates that the failure is not restartable.
      * @param restartDelayMS indicate a delay before conducting the restart
+     * @param isRootCause indicate whether current failure is a new attempt.
      */
     private FailureHandlingResult(
             @Nullable Execution failedExecution,
@@ -85,7 +89,8 @@ public class FailureHandlingResult {
             CompletableFuture<Map<String, String>> failureLabels,
             @Nullable Set<ExecutionVertexID> verticesToRestart,
             long restartDelayMS,
-            boolean globalFailure) {
+            boolean globalFailure,
+            boolean isRootCause) {
         checkState(restartDelayMS >= 0);
 
         this.verticesToRestart = 
Collections.unmodifiableSet(checkNotNull(verticesToRestart));
@@ -95,6 +100,7 @@ public class FailureHandlingResult {
         this.failureLabels = failureLabels;
         this.timestamp = timestamp;
         this.globalFailure = globalFailure;
+        this.isRootCause = isRootCause;
     }
 
     /**
@@ -106,13 +112,16 @@ public class FailureHandlingResult {
      * @param timestamp the time the failure was handled.
      * @param failureLabels collection of tags characterizing the failure as 
produced by the
      *     FailureEnrichers
+     * @param isRootCause indicate whether current failure is a new attempt.
      */
     private FailureHandlingResult(
             @Nullable Execution failedExecution,
             @Nonnull Throwable error,
             long timestamp,
             CompletableFuture<Map<String, String>> failureLabels,
-            boolean globalFailure) {
+            boolean globalFailure,
+            boolean isRootCause) {
+        this.isRootCause = isRootCause;
         this.verticesToRestart = null;
         this.restartDelayMS = -1;
         this.failedExecution = failedExecution;
@@ -206,6 +215,16 @@ public class FailureHandlingResult {
         return globalFailure;
     }
 
+    /**
+     * @return True means that the current failure is a new attempt, false 
means that there has been
+     *     a failure before and has not been tried yet, and the current 
failure will be merged into
+     *     the previous attempt, and these merged exceptions will be 
considered as the concurrent
+     *     exceptions.
+     */
+    public boolean isRootCause() {
+        return isRootCause;
+    }
+
     /**
      * Creates a result of a set of tasks to restart to recover from the 
failure.
      *
@@ -230,7 +249,8 @@ public class FailureHandlingResult {
             CompletableFuture<Map<String, String>> failureLabels,
             @Nullable Set<ExecutionVertexID> verticesToRestart,
             long restartDelayMS,
-            boolean globalFailure) {
+            boolean globalFailure,
+            boolean isRootCause) {
         return new FailureHandlingResult(
                 failedExecution,
                 cause,
@@ -238,7 +258,8 @@ public class FailureHandlingResult {
                 failureLabels,
                 verticesToRestart,
                 restartDelayMS,
-                globalFailure);
+                globalFailure,
+                isRootCause);
     }
 
     /**
@@ -260,8 +281,9 @@ public class FailureHandlingResult {
             @Nonnull Throwable error,
             long timestamp,
             CompletableFuture<Map<String, String>> failureLabels,
-            boolean globalFailure) {
+            boolean globalFailure,
+            boolean isRootCause) {
         return new FailureHandlingResult(
-                failedExecution, error, timestamp, failureLabels, 
globalFailure);
+                failedExecution, error, timestamp, failureLabels, 
globalFailure, isRootCause);
     }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java
index 547daccca23..3c9241c6b41 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java
@@ -111,6 +111,7 @@ import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -128,6 +129,7 @@ import java.util.stream.Collectors;
 import java.util.stream.StreamSupport;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
 /** Base class which can be used to implement {@link SchedulerNG}. */
 public abstract class SchedulerBase implements SchedulerNG, 
CheckpointScheduling {
@@ -166,6 +168,8 @@ public abstract class SchedulerBase implements SchedulerNG, 
CheckpointScheduling
 
     private final BoundedFIFOQueue<RootExceptionHistoryEntry> exceptionHistory;
 
+    private RootExceptionHistoryEntry latestRootExceptionEntry;
+
     private final ExecutionGraphFactory executionGraphFactory;
 
     private final MetricOptions.JobStatusMetricsSettings 
jobStatusMetricsSettings;
@@ -707,27 +711,40 @@ public abstract class SchedulerBase implements 
SchedulerNG, CheckpointScheduling
             long timestamp,
             CompletableFuture<Map<String, String>> failureLabels,
             Iterable<Execution> executions) {
-        exceptionHistory.add(
+        latestRootExceptionEntry =
                 RootExceptionHistoryEntry.fromGlobalFailure(
-                        failure, timestamp, failureLabels, executions));
+                        failure, timestamp, failureLabels, executions);
+        exceptionHistory.add(latestRootExceptionEntry);
         log.debug("Archive global failure.", failure);
     }
 
     protected final void archiveFromFailureHandlingResult(
             FailureHandlingResultSnapshot failureHandlingResult) {
-        if (failureHandlingResult.getRootCauseExecution().isPresent()) {
+        if (!failureHandlingResult.isRootCause()) {
+            // Handle all subsequent exceptions as the concurrent exceptions 
when it's not a new
+            // attempt.
+            checkState(
+                    latestRootExceptionEntry != null,
+                    "A root exception entry should exist if 
failureHandlingResult wasn't "
+                            + "generated as part of a new error handling 
cycle.");
+            List<Execution> concurrentlyExecutions = new ArrayList<>();
+            
failureHandlingResult.getRootCauseExecution().ifPresent(concurrentlyExecutions::add);
+            
concurrentlyExecutions.addAll(failureHandlingResult.getConcurrentlyFailedExecution());
+
+            
latestRootExceptionEntry.addConcurrentExceptions(concurrentlyExecutions);
+        } else if (failureHandlingResult.getRootCauseExecution().isPresent()) {
             final Execution rootCauseExecution =
                     failureHandlingResult.getRootCauseExecution().get();
 
-            final RootExceptionHistoryEntry rootEntry =
+            latestRootExceptionEntry =
                     
RootExceptionHistoryEntry.fromFailureHandlingResultSnapshot(
                             failureHandlingResult);
-            exceptionHistory.add(rootEntry);
+            exceptionHistory.add(latestRootExceptionEntry);
 
             log.debug(
                     "Archive local failure causing attempt {} to fail: {}",
                     rootCauseExecution.getAttemptId(),
-                    rootEntry.getExceptionAsString());
+                    latestRootExceptionEntry.getExceptionAsString());
         } else {
             archiveGlobalFailure(
                     failureHandlingResult.getRootCause(),
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/exceptionhistory/FailureHandlingResultSnapshot.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/exceptionhistory/FailureHandlingResultSnapshot.java
index f402fbb86df..e68d09236d1 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/exceptionhistory/FailureHandlingResultSnapshot.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/exceptionhistory/FailureHandlingResultSnapshot.java
@@ -48,6 +48,7 @@ public class FailureHandlingResultSnapshot {
     private final CompletableFuture<Map<String, String>> failureLabels;
     private final long timestamp;
     private final Set<Execution> concurrentlyFailedExecutions;
+    private final boolean isRootCause;
 
     /**
      * Creates a {@code FailureHandlingResultSnapshot} based on the passed 
{@link
@@ -84,7 +85,8 @@ public class FailureHandlingResultSnapshot {
                 
ErrorInfo.handleMissingThrowable(failureHandlingResult.getError()),
                 failureHandlingResult.getTimestamp(),
                 failureHandlingResult.getFailureLabels(),
-                concurrentlyFailedExecutions);
+                concurrentlyFailedExecutions,
+                failureHandlingResult.isRootCause());
     }
 
     @VisibleForTesting
@@ -93,7 +95,8 @@ public class FailureHandlingResultSnapshot {
             Throwable rootCause,
             long timestamp,
             CompletableFuture<Map<String, String>> failureLabels,
-            Set<Execution> concurrentlyFailedExecutions) {
+            Set<Execution> concurrentlyFailedExecutions,
+            boolean isRootCause) {
         Preconditions.checkArgument(
                 rootCauseExecution == null
                         || 
!concurrentlyFailedExecutions.contains(rootCauseExecution),
@@ -105,6 +108,7 @@ public class FailureHandlingResultSnapshot {
         this.timestamp = timestamp;
         this.concurrentlyFailedExecutions =
                 Preconditions.checkNotNull(concurrentlyFailedExecutions);
+        this.isRootCause = isRootCause;
     }
 
     /**
@@ -150,7 +154,17 @@ public class FailureHandlingResultSnapshot {
      *
      * @return The concurrently failed {@code Executions}.
      */
-    public Iterable<Execution> getConcurrentlyFailedExecution() {
+    public Set<Execution> getConcurrentlyFailedExecution() {
         return Collections.unmodifiableSet(concurrentlyFailedExecutions);
     }
+
+    /**
+     * @return True means that the current failure is a new attempt, false 
means that there has been
+     *     a failure before and has not been tried yet, and the current 
failure will be merged into
+     *     the previous attempt, and these merged exceptions will be 
considered as the concurrent
+     *     exceptions.
+     */
+    public boolean isRootCause() {
+        return isRootCause;
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/exceptionhistory/RootExceptionHistoryEntry.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/exceptionhistory/RootExceptionHistoryEntry.java
index cfbad29716a..bd4d62f0400 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/exceptionhistory/RootExceptionHistoryEntry.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/exceptionhistory/RootExceptionHistoryEntry.java
@@ -26,7 +26,9 @@ import 
org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.util.Preconditions;
 
 import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
@@ -37,11 +39,12 @@ import java.util.stream.StreamSupport;
  * {@code RootExceptionHistoryEntry} extending {@link ExceptionHistoryEntry} 
by providing a list of
  * {@code ExceptionHistoryEntry} instances to store concurrently caught 
failures.
  */
+@NotThreadSafe
 public class RootExceptionHistoryEntry extends ExceptionHistoryEntry {
 
     private static final long serialVersionUID = -7647332765867297434L;
 
-    private final Iterable<ExceptionHistoryEntry> concurrentExceptions;
+    private final Collection<ExceptionHistoryEntry> concurrentExceptions;
 
     /**
      * Creates a {@code RootExceptionHistoryEntry} based on the passed {@link
@@ -96,7 +99,7 @@ public class RootExceptionHistoryEntry extends 
ExceptionHistoryEntry {
     }
 
     public static RootExceptionHistoryEntry fromExceptionHistoryEntry(
-            ExceptionHistoryEntry entry, Iterable<ExceptionHistoryEntry> 
entries) {
+            ExceptionHistoryEntry entry, Collection<ExceptionHistoryEntry> 
entries) {
         return new RootExceptionHistoryEntry(
                 entry.getException(),
                 entry.getTimestamp(),
@@ -140,15 +143,20 @@ public class RootExceptionHistoryEntry extends 
ExceptionHistoryEntry {
                 failureLabels,
                 failingTaskName,
                 taskManagerLocation,
-                StreamSupport.stream(executions.spliterator(), false)
-                        .filter(execution -> 
execution.getFailureInfo().isPresent())
-                        .map(
-                                execution ->
-                                        ExceptionHistoryEntry.create(
-                                                execution,
-                                                
execution.getVertexWithAttempt(),
-                                                
FailureEnricherUtils.EMPTY_FAILURE_LABELS))
-                        .collect(Collectors.toList()));
+                createExceptionHistoryEntries(executions));
+    }
+
+    private static Collection<ExceptionHistoryEntry> 
createExceptionHistoryEntries(
+            Iterable<Execution> executions) {
+        return StreamSupport.stream(executions.spliterator(), false)
+                .filter(execution -> execution.getFailureInfo().isPresent())
+                .map(
+                        execution ->
+                                ExceptionHistoryEntry.create(
+                                        execution,
+                                        execution.getVertexWithAttempt(),
+                                        
FailureEnricherUtils.EMPTY_FAILURE_LABELS))
+                .collect(Collectors.toList());
     }
 
     /**
@@ -170,11 +178,15 @@ public class RootExceptionHistoryEntry extends 
ExceptionHistoryEntry {
             CompletableFuture<Map<String, String>> failureLabels,
             @Nullable String failingTaskName,
             @Nullable TaskManagerLocation taskManagerLocation,
-            Iterable<ExceptionHistoryEntry> concurrentExceptions) {
+            Collection<ExceptionHistoryEntry> concurrentExceptions) {
         super(cause, timestamp, failureLabels, failingTaskName, 
taskManagerLocation);
         this.concurrentExceptions = concurrentExceptions;
     }
 
+    public void addConcurrentExceptions(Iterable<Execution> 
concurrentlyExecutions) {
+        
this.concurrentExceptions.addAll(createExceptionHistoryEntries(concurrentlyExecutions));
+    }
+
     public Iterable<ExceptionHistoryEntry> getConcurrentExceptions() {
         return concurrentExceptions;
     }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/ExecutionFailureHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/ExecutionFailureHandlerTest.java
index 2b250af3e71..3102694be28 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/ExecutionFailureHandlerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/ExecutionFailureHandlerTest.java
@@ -39,6 +39,7 @@ import java.util.Collections;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.stream.Collectors;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -57,6 +58,8 @@ class ExecutionFailureHandlerTest {
 
     private TestFailoverStrategy failoverStrategy;
 
+    private AtomicBoolean isNewAttempt;
+
     private TestRestartBackoffTimeStrategy backoffTimeStrategy;
 
     private ExecutionFailureHandler executionFailureHandler;
@@ -71,7 +74,9 @@ class ExecutionFailureHandlerTest {
 
         failoverStrategy = new TestFailoverStrategy();
         testingFailureEnricher = new TestingFailureEnricher();
-        backoffTimeStrategy = new TestRestartBackoffTimeStrategy(true, 
RESTART_DELAY_MS);
+        isNewAttempt = new AtomicBoolean(true);
+        backoffTimeStrategy =
+                new TestRestartBackoffTimeStrategy(true, RESTART_DELAY_MS, 
isNewAttempt::get);
         executionFailureHandler =
                 new ExecutionFailureHandler(
                         schedulingTopology,
@@ -158,6 +163,8 @@ class ExecutionFailureHandlerTest {
         final Throwable error =
                 new Exception(new SuppressRestartsException(new 
Exception("test failure")));
         final long timestamp = System.currentTimeMillis();
+
+        isNewAttempt.set(false);
         final FailureHandlingResult result =
                 executionFailureHandler.getFailureHandlingResult(execution, 
error, timestamp);
 
@@ -171,6 +178,10 @@ class ExecutionFailureHandlerTest {
         assertThat(result.getFailureLabels().get())
                 .isEqualTo(testingFailureEnricher.getFailureLabels());
         assertThat(result.getTimestamp()).isEqualTo(timestamp);
+        assertThat(result.isRootCause())
+                .as(
+                        "A NonRecoverableFailure should be new attempt even if 
RestartBackoffTimeStrategy consider it's not new attempt.")
+                .isTrue();
 
         assertThatThrownBy(result::getVerticesToRestart)
                 .as("getVerticesToRestart is not allowed when restarting is 
suppressed")
@@ -183,6 +194,51 @@ class ExecutionFailureHandlerTest {
         assertThat(executionFailureHandler.getNumberOfRestarts()).isZero();
     }
 
+    @Test
+    void testNewAttempt() throws Exception {
+        final Set<ExecutionVertexID> tasksToRestart =
+                Collections.singleton(new ExecutionVertexID(new JobVertexID(), 
0));
+        failoverStrategy.setTasksToRestart(tasksToRestart);
+
+        Execution execution =
+                
FailureHandlingResultTest.createExecution(EXECUTOR_RESOURCE.getExecutor());
+        final Throwable error = new Exception("expected test failure");
+
+        testHandlingRootException(execution, error);
+
+        isNewAttempt.set(false);
+        testHandlingConcurrentException(execution, error);
+        testHandlingConcurrentException(execution, error);
+
+        isNewAttempt.set(true);
+        testHandlingRootException(execution, error);
+        testHandlingRootException(execution, error);
+
+        isNewAttempt.set(false);
+        testHandlingConcurrentException(execution, error);
+        testHandlingConcurrentException(execution, error);
+    }
+
+    private void testHandlingRootException(Execution execution, Throwable 
error) {
+        FailureHandlingResult result =
+                executionFailureHandler.getFailureHandlingResult(
+                        execution, error, System.currentTimeMillis());
+        assertThat(result.isRootCause())
+                .as(
+                        "The FailureHandlingResult should be the root cause if 
exception is new attempt.")
+                .isTrue();
+    }
+
+    private void testHandlingConcurrentException(Execution execution, 
Throwable error) {
+        FailureHandlingResult result =
+                executionFailureHandler.getFailureHandlingResult(
+                        execution, error, System.currentTimeMillis());
+        assertThat(result.isRootCause())
+                .as(
+                        "The FailureHandlingResult shouldn't be the root cause 
if exception isn't new attempt.")
+                .isFalse();
+    }
+
     /** Tests the check for unrecoverable error. */
     @Test
     void testUnrecoverableErrorCheck() {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/FailureHandlingResultTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/FailureHandlingResultTest.java
index 55f20890e6a..a590c99809f 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/FailureHandlingResultTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/FailureHandlingResultTest.java
@@ -24,8 +24,9 @@ import 
org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.testutils.TestingUtils;
 import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.RegisterExtension;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
 
 import java.util.Collections;
 import java.util.HashSet;
@@ -47,8 +48,9 @@ class FailureHandlingResultTest {
             TestingUtils.defaultExecutorExtension();
 
     /** Tests normal FailureHandlingResult. */
-    @Test
-    void testNormalFailureHandlingResult() throws Exception {
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void testNormalFailureHandlingResult(boolean isNewAttempt) throws 
Exception {
 
         // create a normal FailureHandlingResult
         Execution execution = createExecution(EXECUTOR_RESOURCE.getExecutor());
@@ -63,7 +65,14 @@ class FailureHandlingResultTest {
                 
CompletableFuture.completedFuture(Collections.singletonMap("key", "value"));
         FailureHandlingResult result =
                 FailureHandlingResult.restartable(
-                        execution, error, timestamp, failureLabels, tasks, 
delay, false);
+                        execution,
+                        error,
+                        timestamp,
+                        failureLabels,
+                        tasks,
+                        delay,
+                        false,
+                        isNewAttempt);
 
         assertThat(result.canRestart()).isTrue();
         assertThat(delay).isEqualTo(result.getRestartDelayMS());
@@ -73,24 +82,29 @@ class FailureHandlingResultTest {
         assertThat(result.getTimestamp()).isEqualTo(timestamp);
         assertThat(result.getFailedExecution()).isPresent();
         assertThat(result.getFailedExecution().get()).isSameAs(execution);
+        assertThat(result.isRootCause()).isEqualTo(isNewAttempt);
     }
 
     /** Tests FailureHandlingResult which suppresses restarts. */
-    @Test
-    void 
testRestartingSuppressedFailureHandlingResultWithNoCausingExecutionVertexId() {
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void 
testRestartingSuppressedFailureHandlingResultWithNoCausingExecutionVertexId(
+            boolean isNewAttempt) {
         // create a FailureHandlingResult with error
         final Throwable error = new Exception("test error");
         final long timestamp = System.currentTimeMillis();
         final CompletableFuture<Map<String, String>> failureLabels =
                 
CompletableFuture.completedFuture(Collections.singletonMap("key", "value"));
         FailureHandlingResult result =
-                FailureHandlingResult.unrecoverable(null, error, timestamp, 
failureLabels, false);
+                FailureHandlingResult.unrecoverable(
+                        null, error, timestamp, failureLabels, false, 
isNewAttempt);
 
         assertThat(result.canRestart()).isFalse();
         assertThat(result.getError()).isSameAs(error);
         assertThat(result.getTimestamp()).isEqualTo(timestamp);
         assertThat(result.getFailureLabels()).isEqualTo(failureLabels);
         assertThat(result.getFailedExecution()).isNotPresent();
+        assertThat(result.isRootCause()).isEqualTo(isNewAttempt);
 
         assertThatThrownBy(result::getVerticesToRestart)
                 .as("getVerticesToRestart is not allowed when restarting is 
suppressed")
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/TestRestartBackoffTimeStrategy.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/TestRestartBackoffTimeStrategy.java
index 44830a6423f..1dd725f8e19 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/TestRestartBackoffTimeStrategy.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/TestRestartBackoffTimeStrategy.java
@@ -19,6 +19,8 @@
 
 package org.apache.flink.runtime.executiongraph.failover;
 
+import java.util.function.Supplier;
+
 /** A RestartBackoffTimeStrategy implementation for tests. */
 public class TestRestartBackoffTimeStrategy implements 
RestartBackoffTimeStrategy {
 
@@ -26,9 +28,17 @@ public class TestRestartBackoffTimeStrategy implements 
RestartBackoffTimeStrateg
 
     private long backoffTime;
 
+    private Supplier<Boolean> isNewAttempt;
+
     public TestRestartBackoffTimeStrategy(boolean canRestart, long 
backoffTime) {
+        this(canRestart, backoffTime, () -> true);
+    }
+
+    public TestRestartBackoffTimeStrategy(
+            boolean canRestart, long backoffTime, Supplier<Boolean> 
isNewAttempt) {
         this.canRestart = canRestart;
         this.backoffTime = backoffTime;
+        this.isNewAttempt = isNewAttempt;
     }
 
     @Override
@@ -44,7 +54,7 @@ public class TestRestartBackoffTimeStrategy implements 
RestartBackoffTimeStrateg
     @Override
     public boolean notifyFailure(Throwable cause) {
         // ignore
-        return true;
+        return isNewAttempt.get();
     }
 
     public void setCanRestart(final boolean canRestart) {
@@ -54,4 +64,8 @@ public class TestRestartBackoffTimeStrategy implements 
RestartBackoffTimeStrateg
     public void setBackoffTime(final long backoffTime) {
         this.backoffTime = backoffTime;
     }
+
+    public void setIsNewAttempt(Supplier<Boolean> isNewAttempt) {
+        this.isNewAttempt = isNewAttempt;
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java
index 46c56281492..24c530c53c6 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java
@@ -128,6 +128,7 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import java.util.stream.Collectors;
@@ -1440,6 +1441,9 @@ public class DefaultSchedulerTest {
 
     @Test
     void testExceptionHistoryConcurrentRestart() throws Exception {
+        AtomicBoolean isNewAttempt = new AtomicBoolean(true);
+        testRestartBackoffTimeStrategy.setIsNewAttempt(isNewAttempt::get);
+
         final JobGraph jobGraph = singleJobVertexJobGraph(2);
 
         final TaskManagerLocation taskManagerLocation = new 
LocalTaskManagerLocation();
@@ -1476,9 +1480,9 @@ public class DefaultSchedulerTest {
                         exception0);
 
         // multi-ExecutionVertex failure
+        isNewAttempt.set(false);
         final RuntimeException exception1 = new RuntimeException("failure #1");
-        failoverStrategyFactory.setTasksToRestart(
-                executionVertex1.getID(), executionVertex0.getID());
+        failoverStrategyFactory.setTasksToRestart(executionVertex1.getID());
         final long updateStateTriggeringRestartTimestamp1 =
                 initiateFailure(
                         scheduler,
@@ -1491,7 +1495,7 @@ public class DefaultSchedulerTest {
 
         delayExecutor.triggerNonPeriodicScheduledTasks();
 
-        assertThat(scheduler.getExceptionHistory()).hasSize(2);
+        assertThat(scheduler.getExceptionHistory()).hasSize(1);
         final Iterator<RootExceptionHistoryEntry> actualExceptionHistory =
                 scheduler.getExceptionHistory().iterator();
 
@@ -1513,17 +1517,6 @@ public class DefaultSchedulerTest {
                                         updateStateTriggeringRestartTimestamp1,
                                         
executionVertex1.getTaskNameWithSubtaskIndex(),
                                         
executionVertex1.getCurrentAssignedResourceLocation()));
-
-        final RootExceptionHistoryEntry entry1 = actualExceptionHistory.next();
-        assertThat(
-                        ExceptionHistoryEntryTestingUtils.matchesFailure(
-                                entry1,
-                                exception1,
-                                updateStateTriggeringRestartTimestamp1,
-                                executionVertex1.getTaskNameWithSubtaskIndex(),
-                                
executionVertex1.getCurrentAssignedResourceLocation()))
-                .isTrue();
-        assertThat(entry1.getConcurrentExceptions()).isEmpty();
     }
 
     @Test
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/exceptionhistory/FailureHandlingResultSnapshotTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/exceptionhistory/FailureHandlingResultSnapshotTest.java
index 69ad4095ef1..21ed10c0357 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/exceptionhistory/FailureHandlingResultSnapshotTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/exceptionhistory/FailureHandlingResultSnapshotTest.java
@@ -96,7 +96,8 @@ class FailureHandlingResultSnapshotTest {
                                 .map(ExecutionVertex::getID)
                                 .collect(Collectors.toSet()),
                         0L,
-                        false);
+                        false,
+                        true);
 
         assertThatThrownBy(
                         () ->
@@ -124,7 +125,8 @@ class FailureHandlingResultSnapshotTest {
                                 .map(ExecutionVertex::getID)
                                 .collect(Collectors.toSet()),
                         0L,
-                        false);
+                        false,
+                        true);
 
         // FailedExecution with failure labels
         assertThat(failureHandlingResult.getFailureLabels().get())
@@ -145,6 +147,7 @@ class FailureHandlingResultSnapshotTest {
         assertThat(testInstance.getRootCauseExecution()).isPresent();
         assertThat(testInstance.getRootCauseExecution().get())
                 
.isSameAs(rootCauseExecutionVertex.getCurrentExecutionAttempt());
+        assertThat(testInstance.isRootCause()).isTrue();
     }
 
     @Test
@@ -170,7 +173,8 @@ class FailureHandlingResultSnapshotTest {
                                 .map(ExecutionVertex::getID)
                                 .collect(Collectors.toSet()),
                         0L,
-                        false);
+                        false,
+                        true);
 
         final FailureHandlingResultSnapshot testInstance =
                 FailureHandlingResultSnapshot.create(
@@ -183,6 +187,7 @@ class FailureHandlingResultSnapshotTest {
                 
.isSameAs(rootCauseExecutionVertex.getCurrentExecutionAttempt());
         assertThat(testInstance.getConcurrentlyFailedExecution())
                 
.containsExactly(otherFailedExecutionVertex.getCurrentExecutionAttempt());
+        assertThat(testInstance.isRootCause()).isTrue();
     }
 
     @Test
@@ -196,7 +201,8 @@ class FailureHandlingResultSnapshotTest {
                                         new RuntimeException("Expected 
exception"),
                                         System.currentTimeMillis(),
                                         
FailureEnricherUtils.EMPTY_FAILURE_LABELS,
-                                        
Collections.singleton(rootCauseExecution)))
+                                        
Collections.singleton(rootCauseExecution),
+                                        true))
                 .isInstanceOf(IllegalArgumentException.class);
     }
 
@@ -227,6 +233,7 @@ class FailureHandlingResultSnapshotTest {
                                 .map(ExecutionVertex::getID)
                                 .collect(Collectors.toSet()),
                         0L,
+                        true,
                         true);
 
         // FailedExecution with failure labels
@@ -244,6 +251,7 @@ class FailureHandlingResultSnapshotTest {
                 .containsExactlyInAnyOrder(
                         failedExecutionVertex0.getCurrentExecutionAttempt(),
                         failedExecutionVertex1.getCurrentExecutionAttempt());
+        assertThat(testInstance.isRootCause()).isTrue();
     }
 
     private Collection<Execution> getCurrentExecutions(ExecutionVertexID 
executionVertexId) {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/exceptionhistory/RootExceptionHistoryEntryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/exceptionhistory/RootExceptionHistoryEntryTest.java
index 76539d21143..61c1a156cf2 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/exceptionhistory/RootExceptionHistoryEntryTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/exceptionhistory/RootExceptionHistoryEntryTest.java
@@ -81,10 +81,11 @@ class RootExceptionHistoryEntryTest {
         final CompletableFuture<Map<String, String>> rootFailureLabels =
                 
CompletableFuture.completedFuture(Collections.singletonMap("key", "value"));
 
-        final Throwable concurrentException = new 
IllegalStateException("Expected other failure");
-        final ExecutionVertex concurrentlyFailedExecutionVertex = 
extractExecutionVertex(1);
-        final long concurrentExceptionTimestamp =
-                triggerFailure(concurrentlyFailedExecutionVertex, 
concurrentException);
+        final Throwable concurrentException1 = new 
IllegalStateException("Expected other failure1");
+        final ExecutionVertex concurrentlyFailedExecutionVertex1 = 
extractExecutionVertex(1);
+        Predicate<ExceptionHistoryEntry> exception1Predicate =
+                triggerFailureAndCreateEntryMatcher(
+                        concurrentException1, 
concurrentlyFailedExecutionVertex1);
 
         final FailureHandlingResultSnapshot snapshot =
                 new FailureHandlingResultSnapshot(
@@ -93,7 +94,8 @@ class RootExceptionHistoryEntryTest {
                         rootTimestamp,
                         rootFailureLabels,
                         Collections.singleton(
-                                
concurrentlyFailedExecutionVertex.getCurrentExecutionAttempt()));
+                                
concurrentlyFailedExecutionVertex1.getCurrentExecutionAttempt()),
+                        true);
         final RootExceptionHistoryEntry actualEntry =
                 
RootExceptionHistoryEntry.fromFailureHandlingResultSnapshot(snapshot);
 
@@ -105,15 +107,24 @@ class RootExceptionHistoryEntryTest {
                                 rootFailureLabels.get(),
                                 
rootExecutionVertex.getTaskNameWithSubtaskIndex(),
                                 
rootExecutionVertex.getCurrentAssignedResourceLocation()));
+
+        
assertThat(actualEntry.getConcurrentExceptions()).hasSize(1).allMatch(exception1Predicate);
+
+        // Test for addConcurrentExceptions
+        final Throwable concurrentException2 = new 
IllegalStateException("Expected other failure2");
+        final ExecutionVertex concurrentlyFailedExecutionVertex2 = 
extractExecutionVertex(2);
+        Predicate<ExceptionHistoryEntry> exception2Predicate =
+                triggerFailureAndCreateEntryMatcher(
+                        concurrentException2, 
concurrentlyFailedExecutionVertex2);
+
+        actualEntry.addConcurrentExceptions(
+                concurrentlyFailedExecutionVertex2.getCurrentExecutions());
         assertThat(actualEntry.getConcurrentExceptions())
-                .hasSize(1)
+                .hasSize(2)
                 .allMatch(
-                        ExceptionHistoryEntryMatcher.matchesFailure(
-                                concurrentException,
-                                concurrentExceptionTimestamp,
-                                
concurrentlyFailedExecutionVertex.getTaskNameWithSubtaskIndex(),
-                                concurrentlyFailedExecutionVertex
-                                        
.getCurrentAssignedResourceLocation()));
+                        exceptionHistoryEntry ->
+                                exception1Predicate.test(exceptionHistoryEntry)
+                                        || 
exception2Predicate.test(exceptionHistoryEntry));
     }
 
     @Test
@@ -121,14 +132,16 @@ class RootExceptionHistoryEntryTest {
         final Throwable concurrentException0 =
                 new RuntimeException("Expected concurrent failure #0");
         final ExecutionVertex concurrentlyFailedExecutionVertex0 = 
extractExecutionVertex(0);
-        final long concurrentExceptionTimestamp0 =
-                triggerFailure(concurrentlyFailedExecutionVertex0, 
concurrentException0);
+        final Predicate<ExceptionHistoryEntry> exception0Predicate =
+                triggerFailureAndCreateEntryMatcher(
+                        concurrentException0, 
concurrentlyFailedExecutionVertex0);
 
         final Throwable concurrentException1 =
                 new IllegalStateException("Expected concurrent failure #1");
         final ExecutionVertex concurrentlyFailedExecutionVertex1 = 
extractExecutionVertex(1);
-        final long concurrentExceptionTimestamp1 =
-                triggerFailure(concurrentlyFailedExecutionVertex1, 
concurrentException1);
+        final Predicate<ExceptionHistoryEntry> exception1Predicate =
+                triggerFailureAndCreateEntryMatcher(
+                        concurrentException1, 
concurrentlyFailedExecutionVertex1);
 
         final Throwable rootCause = new Exception("Expected root failure");
         final long rootTimestamp = System.currentTimeMillis();
@@ -150,18 +163,6 @@ class RootExceptionHistoryEntryTest {
                         ExceptionHistoryEntryMatcher.matchesGlobalFailure(
                                 rootCause, rootTimestamp, 
rootFailureLabels.get()));
 
-        final Predicate<ExceptionHistoryEntry> exception0Predicate =
-                ExceptionHistoryEntryMatcher.matchesFailure(
-                        concurrentException0,
-                        concurrentExceptionTimestamp0,
-                        
concurrentlyFailedExecutionVertex0.getTaskNameWithSubtaskIndex(),
-                        
concurrentlyFailedExecutionVertex0.getCurrentAssignedResourceLocation());
-        final Predicate<ExceptionHistoryEntry> exception1Predicate =
-                ExceptionHistoryEntryMatcher.matchesFailure(
-                        concurrentException1,
-                        concurrentExceptionTimestamp1,
-                        
concurrentlyFailedExecutionVertex1.getTaskNameWithSubtaskIndex(),
-                        
concurrentlyFailedExecutionVertex1.getCurrentAssignedResourceLocation());
         assertThat(actualEntry.getConcurrentExceptions())
                 .allMatch(
                         exceptionHistoryEntry ->
@@ -169,6 +170,17 @@ class RootExceptionHistoryEntryTest {
                                         || 
exception1Predicate.test(exceptionHistoryEntry));
     }
 
+    private Predicate<ExceptionHistoryEntry> 
triggerFailureAndCreateEntryMatcher(
+            Throwable concurrentException0, ExecutionVertex 
concurrentlyFailedExecutionVertex0) {
+        final long concurrentExceptionTimestamp0 =
+                triggerFailure(concurrentlyFailedExecutionVertex0, 
concurrentException0);
+        return ExceptionHistoryEntryMatcher.matchesFailure(
+                concurrentException0,
+                concurrentExceptionTimestamp0,
+                
concurrentlyFailedExecutionVertex0.getTaskNameWithSubtaskIndex(),
+                
concurrentlyFailedExecutionVertex0.getCurrentAssignedResourceLocation());
+    }
+
     private long triggerFailure(ExecutionVertex executionVertex, Throwable 
throwable) {
         executionGraph.updateState(
                 new TaskExecutionStateTransition(


Reply via email to