This is an automated email from the ASF dual-hosted git repository.

szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 88405b037 RATIS-1882. Support read-after-write consistency (#913)
88405b037 is described below

commit 88405b037ac9db4471603dfdd8ddc7c542f0229c
Author: William Song <[email protected]>
AuthorDate: Tue Sep 5 22:32:38 2023 +0800

    RATIS-1882. Support read-after-write consistency (#913)
---
 .../java/org/apache/ratis/client/api/AsyncApi.java |   9 ++
 .../org/apache/ratis/client/api/BlockingApi.java   |   8 ++
 .../org/apache/ratis/client/impl/AsyncImpl.java    |   5 +
 .../org/apache/ratis/client/impl/BlockingImpl.java |   5 +
 .../apache/ratis/protocol/RaftClientRequest.java   |  19 ++-
 .../ratis/server/impl/TestReadAfterWrite.java      | 160 +++++++++++++++++++++
 ratis-proto/src/main/proto/Raft.proto              |   2 +
 .../apache/ratis/server/RaftServerConfigKeys.java  |  17 +++
 .../apache/ratis/server/impl/LeaderStateImpl.java  |  13 +-
 .../apache/ratis/server/impl/RaftServerImpl.java   |  43 ++++--
 .../ratis/server/impl/ReadIndexHeartbeats.java     |   1 +
 .../apache/ratis/server/impl/ServerProtoUtils.java |   4 +-
 .../apache/ratis/server/impl/WriteIndexCache.java  |  68 +++++++++
 .../statemachine/impl/TransactionContextImpl.java  |  10 ++
 .../org/apache/ratis/ReadOnlyRequestTests.java     |  32 +++++
 15 files changed, 373 insertions(+), 23 deletions(-)

diff --git 
a/ratis-client/src/main/java/org/apache/ratis/client/api/AsyncApi.java 
b/ratis-client/src/main/java/org/apache/ratis/client/api/AsyncApi.java
index c6f5e4181..483a22205 100644
--- a/ratis-client/src/main/java/org/apache/ratis/client/api/AsyncApi.java
+++ b/ratis-client/src/main/java/org/apache/ratis/client/api/AsyncApi.java
@@ -55,6 +55,15 @@ public interface AsyncApi {
    */
   CompletableFuture<RaftClientReply> sendReadOnly(Message message, RaftPeerId 
server);
 
+  /**
+   * Send the given readonly message asynchronously to the raft service.
+   * The result will be read-after-write consistent, i.e. reflecting the 
latest successful write by the same client.
+   * @param message The request message.
+   * @return the reply.
+   */
+  CompletableFuture<RaftClientReply> sendReadAfterWrite(Message message);
+
+
   /**
    * Send the given readonly message asynchronously to the raft service using 
non-linearizable read.
    * This method is useful when linearizable read is enabled
diff --git 
a/ratis-client/src/main/java/org/apache/ratis/client/api/BlockingApi.java 
b/ratis-client/src/main/java/org/apache/ratis/client/api/BlockingApi.java
index dc03e1b8d..64d63ff29 100644
--- a/ratis-client/src/main/java/org/apache/ratis/client/api/BlockingApi.java
+++ b/ratis-client/src/main/java/org/apache/ratis/client/api/BlockingApi.java
@@ -65,6 +65,14 @@ public interface BlockingApi {
    */
   RaftClientReply sendReadOnlyNonLinearizable(Message message) throws 
IOException;
 
+  /**
+   * Send the given readonly message to the raft service.
+   * The result will be read-after-write consistent, i.e. reflecting the 
latest successful write by the same client.
+   * @param message The request message.
+   * @return the reply.
+   */
+  RaftClientReply sendReadAfterWrite(Message message) throws IOException;
+
   /**
    * Send the given stale-read message to the given server (not the raft 
service).
    * If the server commit index is larger than or equal to the given 
min-index, the request will be processed.
diff --git 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/AsyncImpl.java 
b/ratis-client/src/main/java/org/apache/ratis/client/impl/AsyncImpl.java
index 9bdc9d50a..2f7069f39 100644
--- a/ratis-client/src/main/java/org/apache/ratis/client/impl/AsyncImpl.java
+++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/AsyncImpl.java
@@ -51,6 +51,11 @@ class AsyncImpl implements AsyncRpcApi {
     return send(RaftClientRequest.readRequestType(), message, server);
   }
 
+  @Override
+  public CompletableFuture<RaftClientReply> sendReadAfterWrite(Message 
message) {
+    return send(RaftClientRequest.readAfterWriteConsistentRequestType(), 
message, null);
+  }
+
   @Override
   public CompletableFuture<RaftClientReply> 
sendReadOnlyNonLinearizable(Message message) {
     return send(RaftClientRequest.readRequestType(true), message, null);
diff --git 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/BlockingImpl.java 
b/ratis-client/src/main/java/org/apache/ratis/client/impl/BlockingImpl.java
index 7e81baf8d..742ee2901 100644
--- a/ratis-client/src/main/java/org/apache/ratis/client/impl/BlockingImpl.java
+++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/BlockingImpl.java
@@ -69,6 +69,11 @@ class BlockingImpl implements BlockingApi {
     return send(RaftClientRequest.readRequestType(true), message, null);
   }
 
+  @Override
+  public RaftClientReply sendReadAfterWrite(Message message) throws 
IOException {
+    return send(RaftClientRequest.readAfterWriteConsistentRequestType(), 
message, null);
+  }
+
   @Override
   public RaftClientReply sendStaleRead(Message message, long minIndex, 
RaftPeerId server)
       throws IOException {
diff --git 
a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java 
b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
index ae76607ff..220694ce0 100644
--- 
a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
+++ 
b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
@@ -35,9 +35,11 @@ public class RaftClientRequest extends RaftClientMessage {
   private static final Type WATCH_DEFAULT = new Type(
       
WatchRequestTypeProto.newBuilder().setIndex(0L).setReplication(ReplicationLevel.MAJORITY).build());
 
+  private static final Type READ_AFTER_WRITE_CONSISTENT_DEFAULT
+      = new 
Type(ReadRequestTypeProto.newBuilder().setReadAfterWriteConsistent(true).build());
   private static final Type READ_DEFAULT = new 
Type(ReadRequestTypeProto.getDefaultInstance());
-  private static final Type
-      READ_NONLINEARIZABLE_DEFAULT = new 
Type(ReadRequestTypeProto.newBuilder().setPreferNonLinearizable(true).build());
+  private static final Type READ_NONLINEARIZABLE_DEFAULT
+      = new 
Type(ReadRequestTypeProto.newBuilder().setPreferNonLinearizable(true).build());
   private static final Type STALE_READ_DEFAULT = new 
Type(StaleReadRequestTypeProto.getDefaultInstance());
 
   public static Type writeRequestType() {
@@ -60,6 +62,10 @@ public class RaftClientRequest extends RaftClientMessage {
         .build());
   }
 
+  public static Type readAfterWriteConsistentRequestType() {
+    return READ_AFTER_WRITE_CONSISTENT_DEFAULT;
+  }
+
   public static Type readRequestType() {
     return READ_DEFAULT;
   }
@@ -95,7 +101,9 @@ public class RaftClientRequest extends RaftClientMessage {
     }
 
     public static Type valueOf(ReadRequestTypeProto read) {
-      return read.getPreferNonLinearizable()? READ_NONLINEARIZABLE_DEFAULT: 
READ_DEFAULT;
+      return read.getPreferNonLinearizable()? READ_NONLINEARIZABLE_DEFAULT
+          : read.getReadAfterWriteConsistent()? 
READ_AFTER_WRITE_CONSISTENT_DEFAULT
+          : READ_DEFAULT;
     }
 
     public static Type valueOf(StaleReadRequestTypeProto staleRead) {
@@ -219,7 +227,10 @@ public class RaftClientRequest extends RaftClientMessage {
         case MESSAGESTREAM:
           return toString(getMessageStream());
         case READ:
-          return "RO";
+          final ReadRequestTypeProto read = getRead();
+          return read.getReadAfterWriteConsistent()? "RaW"
+              : read.getPreferNonLinearizable()? "RO(pNL)"
+              : "RO";
         case STALEREAD:
           return "StaleRead(" + getStaleRead().getMinIndex() + ")";
         case WATCH:
diff --git 
a/ratis-examples/src/test/java/org/apache/ratis/server/impl/TestReadAfterWrite.java
 
b/ratis-examples/src/test/java/org/apache/ratis/server/impl/TestReadAfterWrite.java
new file mode 100644
index 000000000..f515628c9
--- /dev/null
+++ 
b/ratis-examples/src/test/java/org/apache/ratis/server/impl/TestReadAfterWrite.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.server.impl;
+
+import org.apache.ratis.BaseTest;
+import org.apache.ratis.client.RaftClient;
+import org.apache.ratis.client.api.AsyncApi;
+import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.examples.arithmetic.ArithmeticStateMachine;
+import org.apache.ratis.examples.arithmetic.expression.DoubleValue;
+import org.apache.ratis.examples.arithmetic.expression.Expression;
+import org.apache.ratis.examples.arithmetic.expression.Variable;
+import org.apache.ratis.grpc.MiniRaftClusterWithGrpc;
+import org.apache.ratis.protocol.Message;
+import org.apache.ratis.protocol.RaftClientReply;
+import org.apache.ratis.server.RaftServer;
+import org.apache.ratis.server.RaftServerConfigKeys;
+import org.apache.ratis.statemachine.StateMachine;
+import org.apache.ratis.util.CodeInjectionForTesting;
+import org.apache.ratis.util.Slf4jUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.event.Level;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import static 
org.apache.ratis.examples.arithmetic.expression.BinaryExpression.Op.ADD;
+
+public class TestReadAfterWrite
+    extends BaseTest
+    implements MiniRaftClusterWithGrpc.FactoryGet {
+
+  @Before
+  public void setup() {
+    Slf4jUtils.setLogLevel(ArithmeticStateMachine.LOG, Level.DEBUG);
+    Slf4jUtils.setLogLevel(CodeInjectionForTesting.LOG, Level.DEBUG);
+    Slf4jUtils.setLogLevel(RaftServer.Division.LOG, Level.DEBUG);
+    RaftServerTestUtil.setStateMachineUpdaterLogLevel(Level.DEBUG);
+
+    final RaftProperties p = getProperties();
+    p.setClass(MiniRaftCluster.STATEMACHINE_CLASS_KEY,
+        ArithmeticStateMachine.class, StateMachine.class);
+    RaftServerConfigKeys.Read.setOption(p, 
RaftServerConfigKeys.Read.Option.LINEARIZABLE);
+
+  }
+
+  static class BlockingCode implements CodeInjectionForTesting.Code {
+    private final CompletableFuture<Void> future = new CompletableFuture<>();
+
+    void complete() {
+      future.complete(null);
+    }
+
+    @Override
+    public boolean execute(Object localId, Object remoteId, Object... args) {
+      final boolean blocked = !future.isDone();
+      if (blocked) {
+        LOG.info("Server {} blocks client {}: {}", localId, remoteId, args[0]);
+      }
+      future.join();
+      if (blocked) {
+        LOG.info("Server {} unblocks client {}", localId, remoteId);
+      }
+      return true;
+    }
+  }
+
+  @Test
+  public void testReadAfterWriteSingleServer() throws Exception {
+    runWithNewCluster(1, cluster -> {
+      try (final RaftClient client = cluster.createClient()) {
+        runTestReadAfterWrite(client);
+      }
+    });
+  }
+
+  @Test
+  public void testReadAfterWrite() throws Exception {
+    runWithNewCluster(3, cluster -> {
+      try (final RaftClient client = cluster.createClient()) {
+        runTestReadAfterWrite(client);
+      }
+    });
+  }
+
+  void runTestReadAfterWrite(RaftClient client) throws Exception {
+    final Variable a = new Variable("a");
+    final Expression a_plus_2 = ADD.apply(a, new DoubleValue(2));
+
+    final AsyncApi async = client.async();
+    final int initialValue = 10;
+    final RaftClientReply assign = async.send(a.assign(new 
DoubleValue(initialValue))).join();
+    Assert.assertTrue(assign.isSuccess());
+
+    final Message query = Expression.Utils.toMessage(a);
+    assertReply(async.sendReadOnly(query), initialValue);
+
+    //block state machine
+    final BlockingCode blockingCode = new BlockingCode();
+    CodeInjectionForTesting.put(RaftServerImpl.APPEND_TRANSACTION, 
blockingCode);
+    final CompletableFuture<RaftClientReply> plus2 = 
async.send(a.assign(a_plus_2));
+
+    final CompletableFuture<RaftClientReply> readOnlyUnordered = 
async.sendReadOnlyUnordered(query);
+    final CompletableFuture<RaftClientReply> readAfterWrite = 
async.sendReadAfterWrite(query);
+
+    Thread.sleep(1000);
+    // readOnlyUnordered should get 10
+    assertReply(readOnlyUnordered, initialValue);
+
+    LOG.info("readAfterWrite.get");
+    try {
+      // readAfterWrite should time out
+      final RaftClientReply reply = readAfterWrite.get(100, 
TimeUnit.MILLISECONDS);
+      final DoubleValue result = (DoubleValue) 
Expression.Utils.bytes2Expression(
+          reply.getMessage().getContent().toByteArray(), 0);
+      Assert.fail("result=" + result + ", reply=" + reply);
+    } catch (TimeoutException e) {
+      LOG.info("Good", e);
+    }
+
+    // plus2 should still be blocked.
+    Assert.assertFalse(plus2.isDone());
+    // readAfterWrite should still be blocked.
+    Assert.assertFalse(readAfterWrite.isDone());
+
+    // unblock plus2
+    blockingCode.complete();
+
+    // readAfterWrite should get 12.
+    assertReply(readAfterWrite, initialValue + 2);
+  }
+
+  void assertReply(CompletableFuture<RaftClientReply> future, int expected) {
+    LOG.info("assertReply, expected {}", expected);
+    final RaftClientReply reply = future.join();
+    Assert.assertTrue(reply.isSuccess());
+    LOG.info("reply {}", reply);
+    final DoubleValue result = (DoubleValue) Expression.Utils.bytes2Expression(
+        reply.getMessage().getContent().toByteArray(), 0);
+    Assert.assertEquals(expected, (int) (double) result.evaluate(null));
+  }
+}
diff --git a/ratis-proto/src/main/proto/Raft.proto 
b/ratis-proto/src/main/proto/Raft.proto
index b8680051f..49a107c45 100644
--- a/ratis-proto/src/main/proto/Raft.proto
+++ b/ratis-proto/src/main/proto/Raft.proto
@@ -241,6 +241,7 @@ message InstallSnapshotReplyProto {
 
 message ReadIndexRequestProto {
   RaftRpcRequestProto serverRequest = 1;
+  RaftClientRequestProto clientRequest = 2; // clientRequest is used to 
support read-after-write consistency
 }
 
 message ReadIndexReplyProto {
@@ -295,6 +296,7 @@ message ForwardRequestTypeProto {
 
 message ReadRequestTypeProto {
   bool preferNonLinearizable = 1;
+  bool readAfterWriteConsistent = 2;
 }
 
 message StaleReadRequestTypeProto {
diff --git 
a/ratis-server-api/src/main/java/org/apache/ratis/server/RaftServerConfigKeys.java
 
b/ratis-server-api/src/main/java/org/apache/ratis/server/RaftServerConfigKeys.java
index 211edd796..e561a3cd9 100644
--- 
a/ratis-server-api/src/main/java/org/apache/ratis/server/RaftServerConfigKeys.java
+++ 
b/ratis-server-api/src/main/java/org/apache/ratis/server/RaftServerConfigKeys.java
@@ -191,6 +191,23 @@ public interface RaftServerConfigKeys {
     static void setOption(RaftProperties properties, Option option) {
       set(properties::setEnum, OPTION_KEY, option);
     }
+
+    interface ReadAfterWriteConsistent {
+      String PREFIX = RaftServerConfigKeys.PREFIX + 
".read-after-write-consistent";
+
+      String WRITE_INDEX_CACHE_EXPIRY_TIME_KEY = PREFIX + 
"write-index-cache.expiry-time";
+      /** Must be larger than {@link Read#TIMEOUT_DEFAULT}. */
+      TimeDuration WRITE_INDEX_CACHE_EXPIRY_TIME_DEFAULT = 
TimeDuration.valueOf(60, TimeUnit.SECONDS);
+
+      static TimeDuration writeIndexCacheExpiryTime(RaftProperties properties) 
{
+        return 
getTimeDuration(properties.getTimeDuration(WRITE_INDEX_CACHE_EXPIRY_TIME_DEFAULT.getUnit()),
+            WRITE_INDEX_CACHE_EXPIRY_TIME_KEY, 
WRITE_INDEX_CACHE_EXPIRY_TIME_DEFAULT, getDefaultLog());
+      }
+
+      static void setWriteIndexCacheExpiryTime(RaftProperties properties, 
TimeDuration expiryTime) {
+        setTimeDuration(properties::setTimeDuration, 
WRITE_INDEX_CACHE_EXPIRY_TIME_KEY, expiryTime);
+      }
+    }
   }
 
   interface Write {
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
index ac8c3599f..5156585f8 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
@@ -1089,8 +1089,14 @@ class LeaderStateImpl implements LeaderState {
    * 4. If majority respond success, returns readIndex.
    * @return current readIndex.
    */
-  CompletableFuture<Long> getReadIndex() {
-    final long readIndex = server.getRaftLog().getLastCommittedIndex();
+  CompletableFuture<Long> getReadIndex(Long readAfterWriteConsistentIndex) {
+    final long readIndex;
+    if (readAfterWriteConsistentIndex != null) {
+      readIndex = readAfterWriteConsistentIndex;
+    } else {
+      readIndex = server.getRaftLog().getLastCommittedIndex();
+    }
+    LOG.debug("readIndex={}, readAfterWriteConsistentIndex={}", readIndex, 
readAfterWriteConsistentIndex);
 
     // if group contains only one member, fast path
     if (server.getRaftConf().isSingleton()) {
@@ -1098,7 +1104,8 @@ class LeaderStateImpl implements LeaderState {
     }
 
     // leader has not committed any entries in this term, reject
-    if (server.getRaftLog().getTermIndex(readIndex).getTerm() != 
getCurrentTerm()) {
+    // TODO: wait for leader to become ready instead of failing the request.
+    if (!isReady()) {
       return JavaUtils.completeExceptionally(new ReadIndexException(
           "Failed to getReadIndex " + readIndex + " since the term is not yet 
committed.",
           new LeaderNotReadyException(server.getMemberId())));
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
index 3b225dc3c..c11a1a4c4 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
@@ -125,6 +125,7 @@ import org.apache.ratis.server.util.ServerStringUtils;
 import org.apache.ratis.statemachine.SnapshotInfo;
 import org.apache.ratis.statemachine.StateMachine;
 import org.apache.ratis.statemachine.TransactionContext;
+import org.apache.ratis.statemachine.impl.TransactionContextImpl;
 import 
org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException;
 import org.apache.ratis.util.CodeInjectionForTesting;
 import org.apache.ratis.util.CollectionUtils;
@@ -157,6 +158,7 @@ class RaftServerImpl implements RaftServer.Division,
   static final String REQUEST_VOTE = CLASS_NAME + ".requestVote";
   static final String APPEND_ENTRIES = CLASS_NAME + ".appendEntries";
   static final String INSTALL_SNAPSHOT = CLASS_NAME + ".installSnapshot";
+  static final String APPEND_TRANSACTION = CLASS_NAME + ".appendTransaction";
   static final String LOG_SYNC = APPEND_ENTRIES + ".logComplete";
   static final String START_LEADER_ELECTION = CLASS_NAME + 
".startLeaderElection";
 
@@ -222,6 +224,7 @@ class RaftServerImpl implements RaftServer.Division,
 
   private final RetryCacheImpl retryCache;
   private final CommitInfoCache commitInfoCache = new CommitInfoCache();
+  private final WriteIndexCache writeIndexCache;
 
   private final RaftServerJmxAdapter jmxAdapter;
   private final LeaderElectionMetrics leaderElectionMetrics;
@@ -262,6 +265,7 @@ class RaftServerImpl implements RaftServer.Division,
     this.retryCache = new RetryCacheImpl(properties);
     this.dataStreamMap = new DataStreamMapImpl(id);
     this.readOption = RaftServerConfigKeys.Read.option(properties);
+    this.writeIndexCache = new WriteIndexCache(properties);
 
     this.jmxAdapter = new RaftServerJmxAdapter();
     this.leaderElectionMetrics = 
LeaderElectionMetrics.getLeaderElectionMetrics(
@@ -803,7 +807,10 @@ class RaftServerImpl implements RaftServer.Division,
    * Handle a normal update request from client.
    */
   private CompletableFuture<RaftClientReply> appendTransaction(
-      RaftClientRequest request, TransactionContext context, CacheEntry 
cacheEntry) throws IOException {
+      RaftClientRequest request, TransactionContextImpl context, CacheEntry 
cacheEntry) throws IOException {
+    CodeInjectionForTesting.execute(APPEND_TRANSACTION, getId(),
+        request.getClientId(), request, context, cacheEntry);
+
     assertLifeCycleState(LifeCycle.States.RUNNING);
     CompletableFuture<RaftClientReply> reply;
 
@@ -816,6 +823,8 @@ class RaftServerImpl implements RaftServer.Division,
 
       // append the message to its local log
       final LeaderStateImpl leaderState = role.getLeaderStateNonNull();
+      writeIndexCache.add(request.getClientId(), context.getLogIndexFuture());
+
       final PendingRequests.Permit permit = 
leaderState.tryAcquirePendingRequest(request.getMessage());
       if (permit == null) {
         cacheEntry.failWithException(new ResourceUnavailableException(
@@ -923,7 +932,8 @@ class RaftServerImpl implements RaftServer.Division,
           // TODO: this client request will not be added to pending requests 
until
           // later which means that any failure in between will leave partial 
state in
           // the state machine. We should call cancelTransaction() for failed 
requests
-          TransactionContext context = 
stateMachine.startTransaction(filterDataStreamRaftClientRequest(request));
+          final TransactionContextImpl context = (TransactionContextImpl) 
stateMachine.startTransaction(
+              filterDataStreamRaftClientRequest(request));
           if (context.getException() != null) {
             final StateMachineException e = new 
StateMachineException(getMemberId(), context.getException());
             final RaftClientReply exceptionReply = newExceptionReply(request, 
e);
@@ -970,12 +980,13 @@ class RaftServerImpl implements RaftServer.Division,
     return getState().getReadRequests();
   }
 
-  private CompletableFuture<ReadIndexReplyProto> sendReadIndexAsync() {
+  private CompletableFuture<ReadIndexReplyProto> 
sendReadIndexAsync(RaftClientRequest clientRequest) {
     final RaftPeerId leaderId = getInfo().getLeaderId();
     if (leaderId == null) {
       return JavaUtils.completeExceptionally(new 
ReadIndexException(getMemberId() + ": Leader is unknown."));
     }
-    final ReadIndexRequestProto request = 
ServerProtoUtils.toReadIndexRequestProto(getMemberId(), leaderId);
+    final ReadIndexRequestProto request =
+        ServerProtoUtils.toReadIndexRequestProto(clientRequest, getMemberId(), 
leaderId);
     try {
       return getServerRpc().async().readIndexAsync(request);
     } catch (IOException e) {
@@ -983,6 +994,10 @@ class RaftServerImpl implements RaftServer.Division,
     }
   }
 
+  private CompletableFuture<Long> getReadIndex(RaftClientRequest request, 
LeaderStateImpl leader) {
+    return 
writeIndexCache.getWriteIndexFuture(request).thenCompose(leader::getReadIndex);
+  }
+
   private CompletableFuture<RaftClientReply> readAsync(RaftClientRequest 
request) {
     if (readOption == RaftServerConfigKeys.Read.Option.LINEARIZABLE
         && !request.getType().getRead().getPreferNonLinearizable()) {
@@ -996,9 +1011,9 @@ class RaftServerImpl implements RaftServer.Division,
 
       final CompletableFuture<Long> replyFuture;
       if (leader != null) {
-        replyFuture = leader.getReadIndex();
+        replyFuture = getReadIndex(request, leader);
       } else {
-        replyFuture = sendReadIndexAsync().thenApply(reply   -> {
+        replyFuture = sendReadIndexAsync(request).thenApply(reply   -> {
           if (reply.getServerReply().getSuccess()) {
             return reply.getReadIndex();
           } else {
@@ -1454,7 +1469,7 @@ class RaftServerImpl implements RaftServer.Division,
           ServerProtoUtils.toReadIndexReplyProto(peerId, getMemberId(), false, 
INVALID_LOG_INDEX));
     }
 
-    return leader.getReadIndex()
+    return 
getReadIndex(ClientProtoUtils.toRaftClientRequest(request.getClientRequest()), 
leader)
         .thenApply(index -> ServerProtoUtils.toReadIndexReplyProto(peerId, 
getMemberId(), true, index))
         .exceptionally(throwable ->
             ServerProtoUtils.toReadIndexReplyProto(peerId, getMemberId(), 
false, INVALID_LOG_INDEX));
@@ -1754,14 +1769,11 @@ class RaftServerImpl implements RaftServer.Division,
   /**
    * The log has been submitted to the state machine. Use the future to update
    * the pending requests and retry cache.
-   * @param logEntry the log entry that has been submitted to the state machine
    * @param stateMachineFuture the future returned by the state machine
    *                           from which we will get transaction result later
    */
   private CompletableFuture<Message> replyPendingRequest(
-      LogEntryProto logEntry, CompletableFuture<Message> stateMachineFuture) {
-    Preconditions.assertTrue(logEntry.hasStateMachineLogEntry());
-    final ClientInvocationId invocationId = 
ClientInvocationId.valueOf(logEntry.getStateMachineLogEntry());
+      ClientInvocationId invocationId, long logIndex, 
CompletableFuture<Message> stateMachineFuture) {
     // update the retry cache
     final CacheEntry cacheEntry = retryCache.getOrCreateEntry(invocationId);
     Preconditions.assertTrue(cacheEntry != null);
@@ -1772,7 +1784,6 @@ class RaftServerImpl implements RaftServer.Division,
       retryCache.refreshEntry(new CacheEntry(cacheEntry.getKey()));
     }
 
-    final long logIndex = logEntry.getIndex();
     return stateMachineFuture.whenComplete((reply, exception) -> {
       final RaftClientReply.Builder b = newReplyBuilder(invocationId, 
logIndex);
       final RaftClientReply r;
@@ -1805,19 +1816,21 @@ class RaftServerImpl implements RaftServer.Division,
     } else if (next.hasStateMachineLogEntry()) {
       // check whether there is a TransactionContext because we are the leader.
       TransactionContext trx = role.getLeaderState()
-          .map(leader -> 
leader.getTransactionContext(next.getIndex())).orElseGet(
-              () -> TransactionContext.newBuilder()
+          .map(leader -> leader.getTransactionContext(next.getIndex()))
+          .orElseGet(() -> TransactionContext.newBuilder()
                   .setServerRole(role.getCurrentRole())
                   .setStateMachine(stateMachine)
                   .setLogEntry(next)
                   .build());
+      final ClientInvocationId invocationId = 
ClientInvocationId.valueOf(next.getStateMachineLogEntry());
+      writeIndexCache.add(invocationId.getClientId(), 
((TransactionContextImpl) trx).getLogIndexFuture());
 
       try {
         // Let the StateMachine inject logic for committed transactions in 
sequential order.
         trx = stateMachine.applyTransactionSerial(trx);
 
         final CompletableFuture<Message> stateMachineFuture = 
stateMachine.applyTransaction(trx);
-        return replyPendingRequest(next, stateMachineFuture);
+        return replyPendingRequest(invocationId, next.getIndex(), 
stateMachineFuture);
       } catch (Exception e) {
         throw new RaftLogIOException(e);
       }
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadIndexHeartbeats.java
 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadIndexHeartbeats.java
index 7e252f7ad..3f31a2530 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadIndexHeartbeats.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadIndexHeartbeats.java
@@ -167,6 +167,7 @@ class ReadIndexHeartbeats {
     if (commitIndex <= ackedCommitIndex.get()) {
       return null;
     }
+    LOG.debug("listen commitIndex {}", commitIndex);
     return appendEntriesListeners.add(commitIndex, constructor);
   }
 
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerProtoUtils.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerProtoUtils.java
index 108b6c939..c2ec88a32 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerProtoUtils.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerProtoUtils.java
@@ -20,6 +20,7 @@ package org.apache.ratis.server.impl;
 import org.apache.ratis.client.impl.ClientProtoUtils;
 import org.apache.ratis.proto.RaftProtos.*;
 import org.apache.ratis.proto.RaftProtos.AppendEntriesReplyProto.AppendResult;
+import org.apache.ratis.protocol.RaftClientRequest;
 import org.apache.ratis.protocol.RaftGroupMemberId;
 import org.apache.ratis.protocol.RaftPeer;
 import org.apache.ratis.protocol.RaftPeerId;
@@ -110,9 +111,10 @@ final class ServerProtoUtils {
   }
 
   static ReadIndexRequestProto toReadIndexRequestProto(
-      RaftGroupMemberId requestorId, RaftPeerId replyId) {
+      RaftClientRequest clientRequest, RaftGroupMemberId requestorId, 
RaftPeerId replyId) {
     return ReadIndexRequestProto.newBuilder()
         
.setServerRequest(ClientProtoUtils.toRaftRpcRequestProtoBuilder(requestorId, 
replyId))
+        
.setClientRequest(ClientProtoUtils.toRaftClientRequestProto(clientRequest))
         .build();
   }
 
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/WriteIndexCache.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/WriteIndexCache.java
new file mode 100644
index 000000000..df4448622
--- /dev/null
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/WriteIndexCache.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.server.impl;
+
+import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.protocol.ClientId;
+import org.apache.ratis.protocol.RaftClientRequest;
+import org.apache.ratis.server.RaftServerConfigKeys;
+import org.apache.ratis.thirdparty.com.google.common.cache.Cache;
+import org.apache.ratis.thirdparty.com.google.common.cache.CacheBuilder;
+import org.apache.ratis.util.TimeDuration;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicReference;
+
+/** Caching the per client write index in order to support read-after-write 
consistency. */
+class WriteIndexCache {
+  private final Cache<ClientId, AtomicReference<CompletableFuture<Long>>> 
cache;
+
+  WriteIndexCache(RaftProperties properties) {
+    
this(RaftServerConfigKeys.Read.ReadAfterWriteConsistent.writeIndexCacheExpiryTime(properties));
+  }
+
+  /**
+   * @param cacheExpiryTime time for a cache entry to expire.
+   */
+  WriteIndexCache(TimeDuration cacheExpiryTime) {
+    this.cache = CacheBuilder.newBuilder()
+        .expireAfterAccess(cacheExpiryTime.getDuration(), 
cacheExpiryTime.getUnit())
+        .build();
+  }
+
+  void add(ClientId key, CompletableFuture<Long> future) {
+    final AtomicReference<CompletableFuture<Long>> ref;
+    try {
+      ref = cache.get(key, AtomicReference::new);
+    } catch (ExecutionException e) {
+      throw new IllegalStateException(e);
+    }
+    ref.set(future);
+  }
+
+  CompletableFuture<Long> getWriteIndexFuture(RaftClientRequest request) {
+    if (request != null && 
request.getType().getRead().getReadAfterWriteConsistent()) {
+      final AtomicReference<CompletableFuture<Long>> ref = 
cache.getIfPresent(request.getClientId());
+      if (ref != null) {
+        return ref.get();
+      }
+    }
+    return CompletableFuture.completedFuture(null);
+  }
+}
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java
 
b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java
index 8cedb4a7c..a1a878e7d 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java
@@ -29,6 +29,7 @@ import org.apache.ratis.util.Preconditions;
 
 import java.io.IOException;
 import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
 
 /**
  * Implementation of {@link TransactionContext}
@@ -68,6 +69,8 @@ public class TransactionContextImpl implements 
TransactionContext {
   /** Committed LogEntry. */
   private LogEntryProto logEntry;
 
+  private final CompletableFuture<Long> logIndexFuture = new 
CompletableFuture<>();
+
   private TransactionContextImpl(RaftPeerRole serverRole, RaftClientRequest 
clientRequest, StateMachine stateMachine,
       StateMachineLogEntryProto stateMachineLogEntry) {
     this.serverRole = serverRole;
@@ -107,6 +110,7 @@ public class TransactionContextImpl implements 
TransactionContext {
   TransactionContextImpl(RaftPeerRole serverRole, StateMachine stateMachine, 
LogEntryProto logEntry) {
     this(serverRole, null, stateMachine, logEntry.getStateMachineLogEntry());
     this.logEntry = logEntry;
+    this.logIndexFuture.complete(logEntry.getIndex());
   }
 
   @Override
@@ -145,9 +149,15 @@ public class TransactionContextImpl implements 
TransactionContext {
     Preconditions.assertTrue(serverRole == RaftPeerRole.LEADER);
     Preconditions.assertNull(logEntry, "logEntry");
     Objects.requireNonNull(stateMachineLogEntry, "stateMachineLogEntry == 
null");
+
+    logIndexFuture.complete(index);
     return logEntry = LogProtoUtils.toLogEntryProto(stateMachineLogEntry, 
term, index);
   }
 
+  public CompletableFuture<Long> getLogIndexFuture() {
+    return logIndexFuture;
+  }
+
   @Override
   public LogEntryProto getLogEntry() {
     return logEntry;
diff --git 
a/ratis-server/src/test/java/org/apache/ratis/ReadOnlyRequestTests.java 
b/ratis-server/src/test/java/org/apache/ratis/ReadOnlyRequestTests.java
index c4c31cd22..a919a9292 100644
--- a/ratis-server/src/test/java/org/apache/ratis/ReadOnlyRequestTests.java
+++ b/ratis-server/src/test/java/org/apache/ratis/ReadOnlyRequestTests.java
@@ -43,6 +43,7 @@ import org.slf4j.event.Level;
 import java.nio.charset.StandardCharsets;
 import java.util.List;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
 
@@ -244,6 +245,37 @@ public abstract class ReadOnlyRequestTests<CLUSTER extends 
MiniRaftCluster>
     }
   }
 
+  @Test
+  public void testReadAfterWrite() throws Exception {
+    runWithNewCluster(NUM_SERVERS, this::testReadAfterWriteImpl);
+  }
+
+  private void testReadAfterWriteImpl(CLUSTER cluster) throws Exception {
+    RaftTestUtil.waitForLeader(cluster);
+    try (RaftClient client = cluster.createClient()) {
+      // test blocking read-after-write
+      client.io().send(incrementMessage);
+      final RaftClientReply blockReply = 
client.io().sendReadAfterWrite(queryMessage);
+      Assert.assertEquals(1, retrieve(blockReply));
+
+      // test asynchronous read-after-write
+      client.async().send(incrementMessage);
+      client.async().sendReadAfterWrite(queryMessage).thenAccept(reply -> {
+        Assert.assertEquals(2, retrieve(reply));
+      });
+
+      for (int i = 0; i < 20; i++) {
+        client.async().send(incrementMessage);
+      }
+      final CompletableFuture<RaftClientReply> linearizable = 
client.async().sendReadOnly(queryMessage);
+      final CompletableFuture<RaftClientReply> readAfterWrite = 
client.async().sendReadAfterWrite(queryMessage);
+
+      CompletableFuture.allOf(linearizable, readAfterWrite).get();
+      // read-after-write is more consistent than linearizable read
+      Assert.assertTrue(retrieve(readAfterWrite.get()) >= 
retrieve(linearizable.get()));
+    }
+  }
+
   static int retrieve(RaftClientReply reply) {
     return 
Integer.parseInt(reply.getMessage().getContent().toString(StandardCharsets.UTF_8));
   }


Reply via email to