This is an automated email from the ASF dual-hosted git repository. szetszwo pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/ratis.git
commit 904a75e9f587d39cffebf3f94f65173312790a0b Author: Swaminathan Balachandran <47532440+swamiri...@users.noreply.github.com> AuthorDate: Wed Feb 19 15:06:25 2025 -0800 RATIS-2245. Ratis should wait for all apply transaction futures before taking snapshot and group remove (#1218) --- .../ratis/server/impl/StateMachineUpdater.java | 38 +++--- .../server/impl/StateMachineShutdownTests.java | 128 ++++++++++++++++----- 2 files changed, 119 insertions(+), 47 deletions(-) diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java index a919ca732..3dfe5e0aa 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java @@ -37,8 +37,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -182,19 +180,20 @@ class StateMachineUpdater implements Runnable { @Override public void run() { + CompletableFuture<Void> applyLogFutures = CompletableFuture.completedFuture(null); for(; state != State.STOP; ) { try { - waitForCommit(); + waitForCommit(applyLogFutures); if (state == State.RELOAD) { reload(); } - final MemoizedSupplier<List<CompletableFuture<Message>>> futures = applyLog(); - checkAndTakeSnapshot(futures); + applyLogFutures = applyLog(applyLogFutures); + checkAndTakeSnapshot(applyLogFutures); if (shouldStop()) { - checkAndTakeSnapshot(futures); + applyLogFutures.get(); stop(); } } catch (Throwable t) { @@ -210,14 +209,14 @@ class StateMachineUpdater implements Runnable { } } - private void waitForCommit() throws InterruptedException { + private void waitForCommit(CompletableFuture<?> applyLogFutures) throws InterruptedException, ExecutionException { // When a peer starts, the committed is initialized to 0. // It will be updated only after the leader contacts other peers. // Thus it is possible to have applied > committed initially. final long applied = getLastAppliedIndex(); for(; applied >= raftLog.getLastCommittedIndex() && state == State.RUNNING && !shouldStop(); ) { if (server.getSnapshotRequestHandler().shouldTriggerTakingSnapshot()) { - takeSnapshot(); + takeSnapshot(applyLogFutures); } if (awaitForSignal.await(100, TimeUnit.MILLISECONDS)) { return; @@ -239,8 +238,7 @@ class StateMachineUpdater implements Runnable { state = State.RUNNING; } - private MemoizedSupplier<List<CompletableFuture<Message>>> applyLog() throws RaftLogIOException { - final MemoizedSupplier<List<CompletableFuture<Message>>> futures = MemoizedSupplier.valueOf(ArrayList::new); + private CompletableFuture<Void> applyLog(CompletableFuture<Void> applyLogFutures) throws RaftLogIOException { final long committed = raftLog.getLastCommittedIndex(); for(long applied; (applied = getLastAppliedIndex()) < committed && state == State.RUNNING && !shouldStop(); ) { final long nextIndex = applied + 1; @@ -256,7 +254,12 @@ class StateMachineUpdater implements Runnable { final long incremented = appliedIndex.incrementAndGet(debugIndexChange); Preconditions.assertTrue(incremented == nextIndex); if (f != null) { - futures.get().add(f); + CompletableFuture<Message> exceptionHandledFuture = f.exceptionally(ex -> { + LOG.error("Exception while {}: applying txn index={}, nextLog={}", this, nextIndex, + LogProtoUtils.toLogEntryString(next), ex); + return null; + }); + applyLogFutures = applyLogFutures.thenCombine(exceptionHandledFuture, (v, message) -> null); f.thenAccept(m -> notifyAppliedIndex(incremented)); } else { notifyAppliedIndex(incremented); @@ -267,23 +270,20 @@ class StateMachineUpdater implements Runnable { break; } } - return futures; + return applyLogFutures; } - private void checkAndTakeSnapshot(MemoizedSupplier<List<CompletableFuture<Message>>> futures) + private void checkAndTakeSnapshot(CompletableFuture<?> futures) throws ExecutionException, InterruptedException { // check if need to trigger a snapshot if (shouldTakeSnapshot()) { - if (futures.isInitialized()) { - JavaUtils.allOf(futures.get()).get(); - } - - takeSnapshot(); + takeSnapshot(futures); } } - private void takeSnapshot() { + private void takeSnapshot(CompletableFuture<?> applyLogFutures) throws ExecutionException, InterruptedException { final long i; + applyLogFutures.get(); try { try(UncheckedAutoCloseable ignored = Timekeeper.start(stateMachineMetrics.get().getTakeSnapshotTimer())) { i = stateMachine.takeSnapshot(); diff --git a/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java b/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java index 28f8e6ace..c70464a18 100644 --- a/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java +++ b/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java @@ -28,47 +28,106 @@ import org.apache.ratis.server.RaftServer; import org.apache.ratis.statemachine.impl.SimpleStateMachine4Testing; import org.apache.ratis.statemachine.StateMachine; import org.apache.ratis.statemachine.TransactionContext; -import org.junit.Assert; -import org.junit.Test; - -import java.util.concurrent.CompletableFuture; +import org.junit.*; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.util.*; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicLong; public abstract class StateMachineShutdownTests<CLUSTER extends MiniRaftCluster> extends BaseTest implements MiniRaftCluster.Factory.Get<CLUSTER> { - + public static Logger LOG = LoggerFactory.getLogger(StateMachineUpdater.class); + private static MockedStatic<CompletableFuture> mocked; protected static class StateMachineWithConditionalWait extends SimpleStateMachine4Testing { + boolean unblockAllTxns = false; + final Set<Long> blockTxns = ConcurrentHashMap.newKeySet(); + private final ExecutorService executor = Executors.newFixedThreadPool(10); + public static Map<Long, Set<CompletableFuture<Message>>> futures = new ConcurrentHashMap<>(); + public static Map<RaftPeerId, AtomicLong> numTxns = new ConcurrentHashMap<>(); + private final Map<Long, Long> appliedTxns = new ConcurrentHashMap<>(); + + private synchronized void updateTxns() { + long appliedIndex = this.getLastAppliedTermIndex().getIndex() + 1; + Long appliedTerm = null; + while (appliedTxns.containsKey(appliedIndex)) { + appliedTerm = appliedTxns.remove(appliedIndex); + appliedIndex += 1; + } + if (appliedTerm != null) { + updateLastAppliedTermIndex(appliedTerm, appliedIndex - 1); + } + } - private final Long objectToWait = 0L; - volatile boolean blockOnApply = true; + @Override + public void notifyTermIndexUpdated(long term, long index) { + appliedTxns.put(index, term); + updateTxns(); + } @Override public CompletableFuture<Message> applyTransaction(TransactionContext trx) { - if (blockOnApply) { - synchronized (objectToWait) { - try { - objectToWait.wait(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(); + final RaftProtos.LogEntryProto entry = trx.getLogEntry(); + + CompletableFuture<Message> future = new CompletableFuture<>(); + futures.computeIfAbsent(Thread.currentThread().getId(), k -> new HashSet<>()).add(future); + executor.submit(() -> { + synchronized (blockTxns) { + if (!unblockAllTxns) { + blockTxns.add(entry.getIndex()); + } + while (!unblockAllTxns && blockTxns.contains(entry.getIndex())) { + try { + blockTxns.wait(10000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } } } + numTxns.computeIfAbsent(getId(), (k) -> new AtomicLong()).incrementAndGet(); + appliedTxns.put(entry.getIndex(), entry.getTerm()); + updateTxns(); + future.complete(new RaftTestUtil.SimpleMessage("done")); + }); + return future; + } + + public void unBlockApplyTxn(long txnId) { + synchronized (blockTxns) { + blockTxns.remove(txnId); + blockTxns.notifyAll(); } - RaftProtos.LogEntryProto entry = trx.getLogEntry(); - updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex()); - return CompletableFuture.completedFuture(new RaftTestUtil.SimpleMessage("done")); } - public void unBlockApplyTxn() { - blockOnApply = false; - synchronized (objectToWait) { - objectToWait.notifyAll(); + public void unblockAllTxns() { + unblockAllTxns = true; + synchronized (blockTxns) { + for (Long txnId : blockTxns) { + blockTxns.remove(txnId); + } + blockTxns.notifyAll(); } } } + @Before + public void setup() { + mocked = Mockito.mockStatic(CompletableFuture.class, Mockito.CALLS_REAL_METHODS); + } + + @After + public void tearDownClass() { + if (mocked != null) { + mocked.close(); + } + + } + @Test public void testStateMachineShutdownWaitsForApplyTxn() throws Exception { final RaftProperties prop = getProperties(); @@ -82,10 +141,9 @@ public abstract class StateMachineShutdownTests<CLUSTER extends MiniRaftCluster> //Unblock leader and one follower ((StateMachineWithConditionalWait)leader.getStateMachine()) - .unBlockApplyTxn(); + .unblockAllTxns(); ((StateMachineWithConditionalWait)cluster. - getFollowers().get(0).getStateMachine()).unBlockApplyTxn(); - + getFollowers().get(0).getStateMachine()).unblockAllTxns(); cluster.getLeaderAndSendFirstMessage(true); try (final RaftClient client = cluster.createClient(leaderId)) { @@ -107,16 +165,30 @@ public abstract class StateMachineShutdownTests<CLUSTER extends MiniRaftCluster> final Thread t = new Thread(secondFollower::close); t.start(); - // The second follower should still be blocked in apply transaction - Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < logIndex); + // Now unblock the second follower - ((StateMachineWithConditionalWait) secondFollower.getStateMachine()) - .unBlockApplyTxn(); + long minIndex = ((StateMachineWithConditionalWait) secondFollower.getStateMachine()).blockTxns.stream() + .min(Comparator.naturalOrder()).get(); + Assert.assertEquals(2, StateMachineWithConditionalWait.numTxns.values().stream() + .filter(val -> val.get() == 3).count()); + // The second follower should still be blocked in apply transaction + Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < minIndex); + for (long index : ((StateMachineWithConditionalWait) secondFollower.getStateMachine()).blockTxns) { + if (minIndex != index) { + ((StateMachineWithConditionalWait) secondFollower.getStateMachine()).unBlockApplyTxn(index); + } + } + Assert.assertEquals(2, StateMachineWithConditionalWait.numTxns.values().stream() + .filter(val -> val.get() == 3).count()); + Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < minIndex); + ((StateMachineWithConditionalWait) secondFollower.getStateMachine()).unBlockApplyTxn(minIndex); // Now wait for the thread t.join(5000); Assert.assertEquals(logIndex, secondFollower.getInfo().getLastAppliedIndex()); + Assert.assertEquals(3, StateMachineWithConditionalWait.numTxns.values().stream() + .filter(val -> val.get() == 3).count()); cluster.shutdown(); }