zentol commented on a change in pull request #9038: 
[FLINK-13169][tests][coordination] IT test for fine-grained recovery (task 
executor failures)
URL: https://github.com/apache/flink/pull/9038#discussion_r304869641
 
 

 ##########
 File path: 
flink-tests/src/test/java/org/apache/flink/test/recovery/BatchFineGrainedRecoveryITCase.java
 ##########
 @@ -101,101 +166,276 @@ public void teardown() throws Exception {
        public void testProgram() throws Exception {
                ExecutionEnvironment env = createExecutionEnvironment();
 
-               StaticFailureCounter.reset();
-               StaticMapFailureTracker.reset();
-
-               FailureStrategy failureStrategy = new 
RandomExceptionFailureStrategy(1, EMITTED_RECORD_NUMBER);
+               FailureStrategy failureStrategy = createFailureStrategy();
 
                DataSet<Long> input = env.generateSequence(0, 
EMITTED_RECORD_NUMBER - 1);
-               for (int i = 0; i < MAP_NUMBER; i++) {
+               for (int trackingIndex = 0; trackingIndex < MAP_NUMBER; 
trackingIndex++) {
                        input = input
                                .mapPartition(new 
TestPartitionMapper(StaticMapFailureTracker.addNewMap(), failureStrategy))
-                               .name("Test partition mapper " + i);
+                               .name(TASK_NAME_PREFIX + trackingIndex);
                }
-               assertThat(input.collect(), is(EXPECTED_JOB_OUTPUT));
 
+               assertThat(input.collect(), is(EXPECTED_JOB_OUTPUT));
                StaticMapFailureTracker.verify();
        }
 
-       private ExecutionEnvironment createExecutionEnvironment() {
+       private static FailureStrategy createFailureStrategy() {
+               int failWithExceptionAfterNumberOfProcessedRecords = 
rnd.nextInt(EMITTED_RECORD_NUMBER) + 1;
+               int failTaskExecutorAfterNumberOfProcessedRecords = 
rnd.nextInt(EMITTED_RECORD_NUMBER) + 1;
+               // it has to fail only once during one mapper run so that 
different failure strategies do not mess up each other stats
+               FailureStrategy failureStrategy = new OneTimeFailureStrategy(
+                       new JoinedFailureStrategy(
+                               new GloballyTrackingFailureStrategy(
+                                       new 
ExceptionFailureStrategy(failWithExceptionAfterNumberOfProcessedRecords)),
+                               new GloballyTrackingFailureStrategy(
+                                       new 
TaskExecutorFailureStrategy(failTaskExecutorAfterNumberOfProcessedRecords))));
+               LOG.info("FailureStrategy: {}", failureStrategy);
+               return failureStrategy;
+       }
+
+       private static ExecutionEnvironment createExecutionEnvironment() {
                @SuppressWarnings("StaticVariableUsedBeforeInitialization")
                ExecutionEnvironment env = new TestEnvironment(miniCluster, 1, 
true);
-               
env.setRestartStrategy(RestartStrategies.fixedDelayRestart(MAX_FAILURE_NUMBER, 
Time.milliseconds(10)));
+               
env.setRestartStrategy(RestartStrategies.fixedDelayRestart(MAX_JOB_RESTART_ATTEMPTS,
 Time.milliseconds(10)));
                env.getConfig().setExecutionMode(ExecutionMode.BATCH_FORCED); 
// forces all partitions to be blocking
                return env;
        }
 
-       private enum StaticMapFailureTracker {
-               ;
+       @SuppressWarnings({"StaticVariableUsedBeforeInitialization", 
"OverlyBroadThrowsClause"})
+       private static void restartTaskManager() throws Exception {
+               int tmi = lastTaskManagerIndexInMiniCluster.getAndIncrement();
+               try {
+                       miniCluster.terminateTaskExecutor(tmi).get();
+               } finally {
+                       miniCluster.startTaskExecutor();
+               }
+       }
 
-               private static final List<AtomicInteger> mapRestarts = new 
ArrayList<>(10);
-               private static final List<AtomicInteger> expectedMapRestarts = 
new ArrayList<>(10);
+       @FunctionalInterface
+       private interface FailureStrategy extends Serializable {
+               /**
+                * Decides whether to fail and fails the task implicitly or by 
throwing an exception.
+                *
+                * @param trackingIndex index of the mapper task in the sequence
+                * @return {@code true} if task is failed implicitly or {@code 
false} if task is not failed
+                * @throws Exception To fail the task explicitly
+                */
+               boolean failOrNot(int trackingIndex) throws Exception;
+       }
 
-               private static void reset() {
-                       mapRestarts.clear();
-                       expectedMapRestarts.clear();
+       private static class OneTimeFailureStrategy implements FailureStrategy {
+               private static final long serialVersionUID = 1L;
+
+               private final FailureStrategy wrappedFailureStrategy;
+               private transient boolean failed;
+
+               private OneTimeFailureStrategy(FailureStrategy 
wrappedFailureStrategy) {
+                       this.wrappedFailureStrategy = wrappedFailureStrategy;
                }
 
-               private static int addNewMap() {
-                       mapRestarts.add(new AtomicInteger(0));
-                       expectedMapRestarts.add(new AtomicInteger(1));
-                       return mapRestarts.size() - 1;
+               @Override
+               public boolean failOrNot(int trackingIndex) throws Exception {
+                       if (!failed) {
+                               try {
+                                       boolean failedNow = 
wrappedFailureStrategy.failOrNot(trackingIndex);
+                                       failed = failedNow;
+                                       return failedNow;
+                               } catch (Exception e) {
+                                       failed = true;
+                                       throw e;
+                               }
+                       }
+                       return false;
                }
 
-               private static void mapRestart(int index) {
-                       mapRestarts.get(index).incrementAndGet();
+               @Override
+               public String toString() {
+                       return "FailingOnce{" + wrappedFailureStrategy + '}';
                }
+       }
 
-               private static void mapFailure(int index) {
-                       expectedMapRestarts.get(index).incrementAndGet();
+       private static class JoinedFailureStrategy implements FailureStrategy {
+               private static final long serialVersionUID = 1L;
+
+               private final FailureStrategy[] failureStrategies;
+
+               private JoinedFailureStrategy(FailureStrategy ... 
failureStrategies) {
+                       this.failureStrategies = failureStrategies;
                }
 
-               private static void verify() {
-                       assertThat(collect(mapRestarts), 
is(collect(expectedMapRestarts)));
+               @Override
+               public boolean failOrNot(int trackingIndex) throws Exception {
+                       for (FailureStrategy failureStrategy : 
failureStrategies) {
+                               if (failureStrategy.failOrNot(trackingIndex)) {
+                                       return true;
+                               }
+                       }
+                       return false;
                }
 
-               private static int[] collect(Collection<AtomicInteger> list) {
-                       return 
list.stream().mapToInt(AtomicInteger::get).toArray();
+               @Override
+               public String toString() {
+                       return String.join(
+                               " or ",
+                               (Iterable<String>) () -> 
Arrays.stream(failureStrategies).map(Object::toString).iterator());
                }
        }
 
-       @FunctionalInterface
-       private interface FailureStrategy extends Serializable {
-               void failOrNot();
+       private static class GloballyTrackingFailureStrategy implements 
FailureStrategy {
+               private static final long serialVersionUID = 1L;
+
+               private final FailureStrategy wrappedFailureStrategy;
+
+               private GloballyTrackingFailureStrategy(FailureStrategy 
wrappedFailureStrategy) {
+                       this.wrappedFailureStrategy = wrappedFailureStrategy;
+               }
+
+               @Override
+               public boolean failOrNot(int trackingIndex) throws Exception {
+                       return StaticMapFailureTracker.failOrNot(
+                               trackingIndex,
+                               wrappedFailureStrategy);
+               }
+
+               @Override
+               public String toString() {
+                       return "Tracked{" + wrappedFailureStrategy + '}';
+               }
        }
 
-       private static class RandomExceptionFailureStrategy implements 
FailureStrategy {
+       private static class ExceptionFailureStrategy extends 
AbstractOnceAfterCallNumberFailureStrategy {
                private static final long serialVersionUID = 1L;
 
-               private final CoinToss coin;
+               private ExceptionFailureStrategy(int failAfterCallNumber) {
+                       super(failAfterCallNumber);
+               }
 
-               private RandomExceptionFailureStrategy(int probFraction, int 
probBase) {
-                       this.coin = new CoinToss(probFraction, probBase);
+               @Override
+               void fail(int trackingIndex) throws FlinkException {
+                       throw new FlinkException("BAGA-BOOM!!! The user 
function generated test failure.");
+               }
+       }
+
+       private static class TaskExecutorFailureStrategy extends 
AbstractOnceAfterCallNumberFailureStrategy {
+               private static final long serialVersionUID = 1L;
+
+               private TaskExecutorFailureStrategy(int failAfterCallNumber) {
+                       super(failAfterCallNumber);
                }
 
                @Override
-               public void failOrNot() {
-                       if (coin.toss() && StaticFailureCounter.failOrNot()) {
-                               throw new FlinkRuntimeException("BAGA-BOOM!!! 
The user function generated test failure.");
+               void fail(int trackingIndex) throws Exception {
+                       //noinspection OverlyBroadCatchBlock
+                       try {
+                               restartTaskManager();
+                       } catch (InterruptedException e) {
+                               // ignore the exception, task should have been 
failed while stopping TM
+                               Thread.currentThread().interrupt();
+                       } catch (Throwable t) {
+                               StaticMapFailureTracker.unrelatedFailure(t);
+                               throw t;
                        }
                }
        }
 
-       private static class CoinToss implements Serializable {
+       private abstract static class 
AbstractOnceAfterCallNumberFailureStrategy implements FailureStrategy {
                private static final long serialVersionUID = 1L;
-               private static final Random rnd = new Random();
 
-               private final int probFraction;
-               private final int probBase;
+               private final UUID id;
+               private final int failAfterCallNumber;
+               private transient int callCounter;
+
+               private AbstractOnceAfterCallNumberFailureStrategy(int 
failAfterCallNumber) {
+                       this.failAfterCallNumber = failAfterCallNumber;
+                       id = UUID.randomUUID();
+               }
+
+               @Override
+               public boolean failOrNot(int trackingIndex) throws Exception {
+                       callCounter++;
+                       boolean generateFailure = callCounter == 
failAfterCallNumber;
+                       if (generateFailure) {
+                               fail(trackingIndex);
+                       }
+                       return generateFailure;
+               }
+
+               abstract void fail(int trackingIndex) throws Exception;
+
+               @Override
+               public String toString() {
+                       return this.getClass().getSimpleName() + " (fail after 
" + failAfterCallNumber + " calls)";
+               }
+
+               @Override
+               public boolean equals(Object o) {
+                       if (this == o) {
+                               return true;
+                       }
+                       if (o == null || getClass() != o.getClass()) {
+                               return false;
+                       }
+                       return Objects.equals(id, 
((AbstractOnceAfterCallNumberFailureStrategy) o).id);
+               }
+
+               @Override
+               public int hashCode() {
+                       return id.hashCode();
+               }
+       }
+
+       @SuppressWarnings("SynchronizationOnStaticField")
+       private enum StaticMapFailureTracker {
+               ;
+
+               private static final List<AtomicInteger> mapRestarts = new 
ArrayList<>(MAP_NUMBER);
+               private static final List<Map<FailureStrategy, Boolean>> 
mapFailures = new ArrayList<>(MAP_NUMBER);
+
+               private static final Object classLock = new Object();
+               @GuardedBy("classLock")
+               private static Throwable unrelatedFailure;
+
+               private static void reset() {
+                       mapRestarts.clear();
+                       mapFailures.clear();
+               }
+
+               private static int addNewMap() {
+                       mapRestarts.add(new AtomicInteger(0));
+                       mapFailures.add(new HashMap<>(2));
+                       return mapRestarts.size() - 1;
+               }
+
+               private static boolean failOrNot(int index, FailureStrategy 
failureStrategy) throws Exception {
+                       Boolean prevFailed = 
mapFailures.get(index).get(failureStrategy);
+                       boolean alreadyFailed = prevFailed != null && 
prevFailed;
 
 Review comment:
   could be simplified using `Map#getOrDefault`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to