This is an automated email from the ASF dual-hosted git repository.

szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git

commit 904a75e9f587d39cffebf3f94f65173312790a0b
Author: Swaminathan Balachandran <47532440+swamiri...@users.noreply.github.com>
AuthorDate: Wed Feb 19 15:06:25 2025 -0800

    RATIS-2245. Ratis should wait for all apply transaction futures before 
taking snapshot and group remove (#1218)
---
 .../ratis/server/impl/StateMachineUpdater.java     |  38 +++---
 .../server/impl/StateMachineShutdownTests.java     | 128 ++++++++++++++++-----
 2 files changed, 119 insertions(+), 47 deletions(-)

diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java
 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java
index a919ca732..3dfe5e0aa 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java
@@ -37,8 +37,6 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
@@ -182,19 +180,20 @@ class StateMachineUpdater implements Runnable {
 
   @Override
   public void run() {
+    CompletableFuture<Void> applyLogFutures = 
CompletableFuture.completedFuture(null);
     for(; state != State.STOP; ) {
       try {
-        waitForCommit();
+        waitForCommit(applyLogFutures);
 
         if (state == State.RELOAD) {
           reload();
         }
 
-        final MemoizedSupplier<List<CompletableFuture<Message>>> futures = 
applyLog();
-        checkAndTakeSnapshot(futures);
+        applyLogFutures = applyLog(applyLogFutures);
+        checkAndTakeSnapshot(applyLogFutures);
 
         if (shouldStop()) {
-          checkAndTakeSnapshot(futures);
+          applyLogFutures.get();
           stop();
         }
       } catch (Throwable t) {
@@ -210,14 +209,14 @@ class StateMachineUpdater implements Runnable {
     }
   }
 
-  private void waitForCommit() throws InterruptedException {
+  private void waitForCommit(CompletableFuture<?> applyLogFutures) throws 
InterruptedException, ExecutionException {
     // When a peer starts, the committed is initialized to 0.
     // It will be updated only after the leader contacts other peers.
     // Thus it is possible to have applied > committed initially.
     final long applied = getLastAppliedIndex();
     for(; applied >= raftLog.getLastCommittedIndex() && state == State.RUNNING 
&& !shouldStop(); ) {
       if (server.getSnapshotRequestHandler().shouldTriggerTakingSnapshot()) {
-        takeSnapshot();
+        takeSnapshot(applyLogFutures);
       }
       if (awaitForSignal.await(100, TimeUnit.MILLISECONDS)) {
         return;
@@ -239,8 +238,7 @@ class StateMachineUpdater implements Runnable {
     state = State.RUNNING;
   }
 
-  private MemoizedSupplier<List<CompletableFuture<Message>>> applyLog() throws 
RaftLogIOException {
-    final MemoizedSupplier<List<CompletableFuture<Message>>> futures = 
MemoizedSupplier.valueOf(ArrayList::new);
+  private CompletableFuture<Void> applyLog(CompletableFuture<Void> 
applyLogFutures) throws RaftLogIOException {
     final long committed = raftLog.getLastCommittedIndex();
     for(long applied; (applied = getLastAppliedIndex()) < committed && state 
== State.RUNNING && !shouldStop(); ) {
       final long nextIndex = applied + 1;
@@ -256,7 +254,12 @@ class StateMachineUpdater implements Runnable {
         final long incremented = 
appliedIndex.incrementAndGet(debugIndexChange);
         Preconditions.assertTrue(incremented == nextIndex);
         if (f != null) {
-          futures.get().add(f);
+          CompletableFuture<Message> exceptionHandledFuture = 
f.exceptionally(ex -> {
+            LOG.error("Exception while {}: applying txn index={}, nextLog={}", 
this, nextIndex,
+                    LogProtoUtils.toLogEntryString(next), ex);
+            return null;
+          });
+          applyLogFutures = 
applyLogFutures.thenCombine(exceptionHandledFuture, (v, message) -> null);
           f.thenAccept(m -> notifyAppliedIndex(incremented));
         } else {
           notifyAppliedIndex(incremented);
@@ -267,23 +270,20 @@ class StateMachineUpdater implements Runnable {
         break;
       }
     }
-    return futures;
+    return applyLogFutures;
   }
 
-  private void 
checkAndTakeSnapshot(MemoizedSupplier<List<CompletableFuture<Message>>> futures)
+  private void checkAndTakeSnapshot(CompletableFuture<?> futures)
       throws ExecutionException, InterruptedException {
     // check if need to trigger a snapshot
     if (shouldTakeSnapshot()) {
-      if (futures.isInitialized()) {
-        JavaUtils.allOf(futures.get()).get();
-      }
-
-      takeSnapshot();
+      takeSnapshot(futures);
     }
   }
 
-  private void takeSnapshot() {
+  private void takeSnapshot(CompletableFuture<?> applyLogFutures) throws 
ExecutionException, InterruptedException {
     final long i;
+    applyLogFutures.get();
     try {
       try(UncheckedAutoCloseable ignored = 
Timekeeper.start(stateMachineMetrics.get().getTakeSnapshotTimer())) {
         i = stateMachine.takeSnapshot();
diff --git 
a/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java
 
b/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java
index 28f8e6ace..c70464a18 100644
--- 
a/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java
+++ 
b/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java
@@ -28,47 +28,106 @@ import org.apache.ratis.server.RaftServer;
 import org.apache.ratis.statemachine.impl.SimpleStateMachine4Testing;
 import org.apache.ratis.statemachine.StateMachine;
 import org.apache.ratis.statemachine.TransactionContext;
-import org.junit.Assert;
-import org.junit.Test;
-
-import java.util.concurrent.CompletableFuture;
+import org.junit.*;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
+import java.util.*;
+import java.util.concurrent.*;
+import java.util.concurrent.atomic.AtomicLong;
 
 public abstract class StateMachineShutdownTests<CLUSTER extends 
MiniRaftCluster>
     extends BaseTest
     implements MiniRaftCluster.Factory.Get<CLUSTER> {
-
+  public static Logger LOG = 
LoggerFactory.getLogger(StateMachineUpdater.class);
+  private static MockedStatic<CompletableFuture> mocked;
   protected static class StateMachineWithConditionalWait extends
       SimpleStateMachine4Testing {
+    boolean unblockAllTxns = false;
+    final Set<Long> blockTxns = ConcurrentHashMap.newKeySet();
+    private final ExecutorService executor = Executors.newFixedThreadPool(10);
+    public static Map<Long, Set<CompletableFuture<Message>>> futures = new 
ConcurrentHashMap<>();
+    public static Map<RaftPeerId, AtomicLong> numTxns = new 
ConcurrentHashMap<>();
+    private final Map<Long, Long> appliedTxns = new ConcurrentHashMap<>();
+
+    private synchronized void updateTxns() {
+      long appliedIndex = this.getLastAppliedTermIndex().getIndex() + 1;
+      Long appliedTerm = null;
+      while (appliedTxns.containsKey(appliedIndex)) {
+        appliedTerm = appliedTxns.remove(appliedIndex);
+        appliedIndex += 1;
+      }
+      if (appliedTerm != null) {
+        updateLastAppliedTermIndex(appliedTerm, appliedIndex - 1);
+      }
+    }
 
-    private final Long objectToWait = 0L;
-    volatile boolean blockOnApply = true;
+    @Override
+    public void notifyTermIndexUpdated(long term, long index) {
+      appliedTxns.put(index, term);
+      updateTxns();
+    }
 
     @Override
     public CompletableFuture<Message> applyTransaction(TransactionContext trx) 
{
-      if (blockOnApply) {
-        synchronized (objectToWait) {
-          try {
-            objectToWait.wait();
-          } catch (InterruptedException e) {
-            Thread.currentThread().interrupt();
-            throw new RuntimeException();
+      final RaftProtos.LogEntryProto entry = trx.getLogEntry();
+
+      CompletableFuture<Message> future = new CompletableFuture<>();
+      futures.computeIfAbsent(Thread.currentThread().getId(), k -> new 
HashSet<>()).add(future);
+      executor.submit(() -> {
+        synchronized (blockTxns) {
+          if (!unblockAllTxns) {
+            blockTxns.add(entry.getIndex());
+          }
+          while (!unblockAllTxns && blockTxns.contains(entry.getIndex())) {
+            try {
+              blockTxns.wait(10000);
+            } catch (InterruptedException e) {
+              throw new RuntimeException(e);
+            }
           }
         }
+        numTxns.computeIfAbsent(getId(), (k) -> new 
AtomicLong()).incrementAndGet();
+        appliedTxns.put(entry.getIndex(), entry.getTerm());
+        updateTxns();
+        future.complete(new RaftTestUtil.SimpleMessage("done"));
+      });
+      return future;
+    }
+
+    public void unBlockApplyTxn(long txnId) {
+      synchronized (blockTxns) {
+        blockTxns.remove(txnId);
+        blockTxns.notifyAll();
       }
-      RaftProtos.LogEntryProto entry = trx.getLogEntry();
-      updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
-      return CompletableFuture.completedFuture(new 
RaftTestUtil.SimpleMessage("done"));
     }
 
-    public void unBlockApplyTxn() {
-      blockOnApply = false;
-      synchronized (objectToWait) {
-        objectToWait.notifyAll();
+    public void unblockAllTxns() {
+      unblockAllTxns = true;
+      synchronized (blockTxns) {
+        for (Long txnId : blockTxns) {
+          blockTxns.remove(txnId);
+        }
+        blockTxns.notifyAll();
       }
     }
   }
 
+  @Before
+  public void setup() {
+    mocked = Mockito.mockStatic(CompletableFuture.class, 
Mockito.CALLS_REAL_METHODS);
+  }
+
+  @After
+  public void tearDownClass() {
+    if (mocked != null) {
+      mocked.close();
+    }
+
+  }
+
   @Test
   public void testStateMachineShutdownWaitsForApplyTxn() throws Exception {
     final RaftProperties prop = getProperties();
@@ -82,10 +141,9 @@ public abstract class StateMachineShutdownTests<CLUSTER 
extends MiniRaftCluster>
 
     //Unblock leader and one follower
     ((StateMachineWithConditionalWait)leader.getStateMachine())
-        .unBlockApplyTxn();
+            .unblockAllTxns();
     ((StateMachineWithConditionalWait)cluster.
-        getFollowers().get(0).getStateMachine()).unBlockApplyTxn();
-
+            getFollowers().get(0).getStateMachine()).unblockAllTxns();
     cluster.getLeaderAndSendFirstMessage(true);
 
     try (final RaftClient client = cluster.createClient(leaderId)) {
@@ -107,16 +165,30 @@ public abstract class StateMachineShutdownTests<CLUSTER 
extends MiniRaftCluster>
       final Thread t = new Thread(secondFollower::close);
       t.start();
 
-      // The second follower should still be blocked in apply transaction
-      Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < 
logIndex);
+
 
       // Now unblock the second follower
-      ((StateMachineWithConditionalWait) secondFollower.getStateMachine())
-              .unBlockApplyTxn();
+      long minIndex = ((StateMachineWithConditionalWait) 
secondFollower.getStateMachine()).blockTxns.stream()
+              .min(Comparator.naturalOrder()).get();
+      Assert.assertEquals(2, 
StateMachineWithConditionalWait.numTxns.values().stream()
+                      .filter(val -> val.get() == 3).count());
+      // The second follower should still be blocked in apply transaction
+      Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < 
minIndex);
+      for (long index : ((StateMachineWithConditionalWait) 
secondFollower.getStateMachine()).blockTxns) {
+        if (minIndex != index) {
+          ((StateMachineWithConditionalWait) 
secondFollower.getStateMachine()).unBlockApplyTxn(index);
+        }
+      }
+      Assert.assertEquals(2, 
StateMachineWithConditionalWait.numTxns.values().stream()
+              .filter(val -> val.get() == 3).count());
+      Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < 
minIndex);
+      ((StateMachineWithConditionalWait) 
secondFollower.getStateMachine()).unBlockApplyTxn(minIndex);
 
       // Now wait for the thread
       t.join(5000);
       Assert.assertEquals(logIndex, 
secondFollower.getInfo().getLastAppliedIndex());
+      Assert.assertEquals(3, 
StateMachineWithConditionalWait.numTxns.values().stream()
+              .filter(val -> val.get() == 3).count());
 
       cluster.shutdown();
     }

Reply via email to