This is an automated email from the ASF dual-hosted git repository.

szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 81c714dde RATIS-2350. Fix readAfterWrite bugs. (#1311)
81c714dde is described below

commit 81c714dde6632d82fb2f10cc5118de309d77c92a
Author: Tsz-Wo Nicholas Sze <[email protected]>
AuthorDate: Fri Nov 14 11:12:51 2025 -0800

    RATIS-2350. Fix readAfterWrite bugs. (#1311)
---
 .../apache/ratis/server/impl/LeaderStateImpl.java  |  8 ++--
 .../org/apache/ratis/server/impl/ReadRequests.java | 50 ++++++++++++++--------
 .../org/apache/ratis/server/impl/ServerState.java  |  2 +-
 .../ratis/server/impl/StateMachineUpdater.java     |  5 ++-
 .../apache/ratis/server/impl/WriteIndexCache.java  |  5 ++-
 .../org/apache/ratis/LinearizableReadTests.java    | 23 +++++-----
 6 files changed, 56 insertions(+), 37 deletions(-)

diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
index 0835802bd..836b15bcd 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
@@ -1140,13 +1140,15 @@ class LeaderStateImpl implements LeaderState {
    * @return current readIndex.
    */
   CompletableFuture<Long> getReadIndex(Long readAfterWriteConsistentIndex) {
+    final long commitIndex = server.getRaftLog().getLastCommittedIndex();
     final long readIndex;
-    if (readAfterWriteConsistentIndex != null) {
+    if (readAfterWriteConsistentIndex != null && readAfterWriteConsistentIndex 
> commitIndex) {
       readIndex = readAfterWriteConsistentIndex;
     } else {
-      readIndex = server.getRaftLog().getLastCommittedIndex();
+      readIndex = commitIndex;
     }
-    LOG.debug("readIndex={}, readAfterWriteConsistentIndex={}", readIndex, 
readAfterWriteConsistentIndex);
+    LOG.debug("readIndex={} (commitIndex={}, 
readAfterWriteConsistentIndex={})",
+        readIndex, commitIndex, readAfterWriteConsistentIndex);
 
     // if group contains only one member, fast path
     if (server.getRaftConf().isSingleton()) {
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadRequests.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadRequests.java
index e63a23a0b..6112a4600 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadRequests.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadRequests.java
@@ -20,7 +20,7 @@ package org.apache.ratis.server.impl;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.protocol.exceptions.ReadException;
 import org.apache.ratis.server.RaftServerConfigKeys;
-import org.apache.ratis.statemachine.StateMachine;
+import org.apache.ratis.util.Preconditions;
 import org.apache.ratis.util.TimeDuration;
 import org.apache.ratis.util.TimeoutExecutor;
 import org.slf4j.Logger;
@@ -29,7 +29,7 @@ import org.slf4j.LoggerFactory;
 import java.util.NavigableMap;
 import java.util.TreeMap;
 import java.util.concurrent.CompletableFuture;
-import java.util.function.Consumer;
+import java.util.function.LongConsumer;
 
 /** For supporting linearizable read. */
 class ReadRequests {
@@ -37,10 +37,18 @@ class ReadRequests {
 
   static class ReadIndexQueue {
     private final TimeoutExecutor scheduler = TimeoutExecutor.getInstance();
+    /** The log index known to be applied. */
+    private long lastAppliedIndex;
+    /**
+     * Map      : readIndex -> appliedIndexFuture (when completes, readIndex 
<= appliedIndex).
+     * Invariant: all keys > lastAppliedIndex.
+     */
     private final NavigableMap<Long, CompletableFuture<Long>> sorted = new 
TreeMap<>();
+
     private final TimeDuration readTimeout;
 
-    ReadIndexQueue(TimeDuration readTimeout) {
+    ReadIndexQueue(long lastAppliedIndex, TimeDuration readTimeout) {
+      this.lastAppliedIndex = lastAppliedIndex;
       this.readTimeout = readTimeout;
     }
 
@@ -48,6 +56,9 @@ class ReadRequests {
       final CompletableFuture<Long> returned;
       final boolean create;
       synchronized (this) {
+        if (readIndex <= lastAppliedIndex) {
+          return CompletableFuture.completedFuture(lastAppliedIndex);
+        }
         // The same as computeIfAbsent except that it also tells if a new 
value is created.
         final CompletableFuture<Long> existing = sorted.get(readIndex);
         create = existing == null;
@@ -79,7 +90,19 @@ class ReadRequests {
 
 
     /** Complete all the entries less than or equal to the given applied 
index. */
-    synchronized void complete(Long appliedIndex) {
+    synchronized void complete(long appliedIndex) {
+      if (appliedIndex > lastAppliedIndex) {
+        lastAppliedIndex = appliedIndex;
+      } else {
+        // appliedIndex <= lastAppliedIndex: nothing to do
+        if (!sorted.isEmpty()) {
+          // Assert: all keys > lastAppliedIndex.
+          final long first = sorted.firstKey();
+          Preconditions.assertTrue(first > lastAppliedIndex,
+              () -> "first = " + first + " <= lastAppliedIndex = " + 
lastAppliedIndex);
+        }
+        return;
+      }
       final NavigableMap<Long, CompletableFuture<Long>> headMap = 
sorted.headMap(appliedIndex, true);
       headMap.values().forEach(f -> f.complete(appliedIndex));
       headMap.clear();
@@ -87,27 +110,16 @@ class ReadRequests {
   }
 
   private final ReadIndexQueue readIndexQueue;
-  private final StateMachine stateMachine;
 
-  ReadRequests(RaftProperties properties, StateMachine stateMachine) {
-    this.readIndexQueue = new 
ReadIndexQueue(RaftServerConfigKeys.Read.timeout(properties));
-    this.stateMachine = stateMachine;
+  ReadRequests(long appliedIndex, RaftProperties properties) {
+    this.readIndexQueue = new ReadIndexQueue(appliedIndex, 
RaftServerConfigKeys.Read.timeout(properties));
   }
 
-  Consumer<Long> getAppliedIndexConsumer() {
+  LongConsumer getAppliedIndexConsumer() {
     return readIndexQueue::complete;
   }
 
   CompletableFuture<Long> waitToAdvance(long readIndex) {
-    final long lastApplied = stateMachine.getLastAppliedTermIndex().getIndex();
-    if (lastApplied >= readIndex) {
-      return CompletableFuture.completedFuture(lastApplied);
-    }
-    final CompletableFuture<Long> f = readIndexQueue.add(readIndex);
-    final long current = stateMachine.getLastAppliedTermIndex().getIndex();
-    if (current > lastApplied) {
-      readIndexQueue.complete(current);
-    }
-    return f;
+    return readIndexQueue.add(readIndex);
   }
 }
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java
index 05afc0975..ee1b7d37b 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java
@@ -125,7 +125,7 @@ class ServerState {
     this.lastNoLeaderTime = new AtomicReference<>(Timestamp.currentTime());
     this.noLeaderTimeout = 
RaftServerConfigKeys.Notification.noLeaderTimeout(prop);
     this.log = JavaUtils.memoize(() -> initRaftLog(() -> 
getSnapshotIndexFromStateMachine(stateMachine), prop));
-    this.readRequests = new ReadRequests(prop, stateMachine);
+    this.readRequests = new 
ReadRequests(stateMachine.getLastAppliedTermIndex().getIndex(), prop);
     this.stateMachineUpdater = JavaUtils.memoize(() -> new StateMachineUpdater(
         stateMachine, server, this, getLog().getSnapshotIndex(), prop,
         this.readRequests.getAppliedIndexConsumer()));
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java
 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java
index bd7f26a8a..041693195 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java
@@ -45,6 +45,7 @@ import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
+import java.util.function.LongConsumer;
 import java.util.stream.LongStream;
 
 /**
@@ -90,12 +91,12 @@ class StateMachineUpdater implements Runnable {
 
   private final MemoizedSupplier<StateMachineMetrics> stateMachineMetrics;
 
-  private final Consumer<Long> appliedIndexConsumer;
+  private final LongConsumer appliedIndexConsumer;
 
   private volatile boolean isRemoving;
 
   StateMachineUpdater(StateMachine stateMachine, RaftServerImpl server,
-      ServerState serverState, long lastAppliedIndex, RaftProperties 
properties, Consumer<Long> appliedIndexConsumer) {
+      ServerState serverState, long lastAppliedIndex, RaftProperties 
properties, LongConsumer appliedIndexConsumer) {
     this.name = 
ServerStringUtils.generateUnifiedName(serverState.getMemberId(), getClass());
     this.appliedIndexConsumer = appliedIndexConsumer;
     this.infoIndexChange = s -> LOG.info("{}: {}", name, s);
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/WriteIndexCache.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/WriteIndexCache.java
index df4448622..98250ca22 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/WriteIndexCache.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/WriteIndexCache.java
@@ -46,14 +46,15 @@ class WriteIndexCache {
         .build();
   }
 
-  void add(ClientId key, CompletableFuture<Long> future) {
+  void add(ClientId key, CompletableFuture<Long> current) {
     final AtomicReference<CompletableFuture<Long>> ref;
     try {
       ref = cache.get(key, AtomicReference::new);
     } catch (ExecutionException e) {
       throw new IllegalStateException(e);
     }
-    ref.set(future);
+    ref.updateAndGet(previous -> previous == null ? current
+        : previous.thenCombine(current, Math::max));
   }
 
   CompletableFuture<Long> getWriteIndexFuture(RaftClientRequest request) {
diff --git 
a/ratis-server/src/test/java/org/apache/ratis/LinearizableReadTests.java 
b/ratis-server/src/test/java/org/apache/ratis/LinearizableReadTests.java
index 49176b18a..91bd2f28d 100644
--- a/ratis-server/src/test/java/org/apache/ratis/LinearizableReadTests.java
+++ b/ratis-server/src/test/java/org/apache/ratis/LinearizableReadTests.java
@@ -47,7 +47,6 @@ import static org.apache.ratis.ReadOnlyRequestTests.QUERY;
 import static org.apache.ratis.ReadOnlyRequestTests.WAIT_AND_INCREMENT;
 import static org.apache.ratis.ReadOnlyRequestTests.assertReplyAtLeast;
 import static org.apache.ratis.ReadOnlyRequestTests.assertReplyExact;
-import static org.apache.ratis.ReadOnlyRequestTests.retrieve;
 import static 
org.apache.ratis.server.RaftServerConfigKeys.Read.Option.LINEARIZABLE;
 
 /** Test for the {@link RaftServerConfigKeys.Read.Option#LINEARIZABLE} 
feature. */
@@ -233,20 +232,24 @@ public abstract class LinearizableReadTests<CLUSTER 
extends MiniRaftCluster>
       assertReplyExact(1, client.io().sendReadAfterWrite(QUERY));
 
       // test asynchronous read-after-write
-      client.async().send(INCREMENT);
+      final CompletableFuture<RaftClientReply> writeReply = 
client.async().send(INCREMENT);
       final CompletableFuture<RaftClientReply> asyncReply = 
client.async().sendReadAfterWrite(QUERY);
 
-      for (int i = 0; i < 20; i++) {
-        client.async().send(INCREMENT);
+      final int n = 100;
+      final List<Reply> writeReplies = new ArrayList<>(n);
+      final List<Reply> readAfterWriteReplies = new ArrayList<>(n);
+      for (int i = 0; i < n; i++) {
+        final int count = i + 3;
+        writeReplies.add(new Reply(count, client.async().send(INCREMENT)));
+        readAfterWriteReplies.add(new Reply(count, 
client.async().sendReadAfterWrite(QUERY)));
       }
 
-      // read-after-write is more consistent than linearizable read
-      final CompletableFuture<RaftClientReply> linearizable = 
client.async().sendReadOnly(QUERY);
-      final CompletableFuture<RaftClientReply> readAfterWrite = 
client.async().sendReadAfterWrite(QUERY);
-      final int r = retrieve(readAfterWrite.get());
-      final int l = retrieve(linearizable.get());
-      Assertions.assertTrue(r >= l, () -> "readAfterWrite = " + r + " < 
linearizable = " + l);
+      for (int i = 0; i < n; i++) {
+        writeReplies.get(i).assertExact();
+        readAfterWriteReplies.get(i).assertAtLeast();
+      }
 
+      assertReplyAtLeast(2, writeReply.join());
       assertReplyAtLeast(2, asyncReply.join());
     }
   }

Reply via email to