This is an automated email from the ASF dual-hosted git repository.

williamsong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 22cbefa2c RATIS-872. Invalidate replied calls in retry cache. (#942)
22cbefa2c is described below

commit 22cbefa2c11c3471d2f763ccb4251806ed3529f5
Author: Tsz-Wo Nicholas Sze <[email protected]>
AuthorDate: Mon Oct 30 19:55:28 2023 -0700

    RATIS-872. Invalidate replied calls in retry cache. (#942)
---
 .../apache/ratis/client/impl/ClientProtoUtils.java |   2 +
 .../apache/ratis/client/impl/RaftClientImpl.java   |  59 +++++++++-
 .../apache/ratis/protocol/RaftClientRequest.java   |  17 ++-
 .../src/main/java/org/apache/ratis/rpc/CallId.java |   4 +-
 ratis-proto/src/main/proto/Raft.proto              |   1 +
 .../apache/ratis/server/impl/RaftServerImpl.java   |   6 +-
 .../apache/ratis/server/impl/RetryCacheImpl.java   |  49 +++++----
 .../java/org/apache/ratis/RetryCacheTests.java     |   3 +-
 .../ratis/client/impl/RaftClientTestUtil.java      |   5 +
 .../impl/SimpleStateMachine4Testing.java           |  12 ++-
 .../apache/ratis/grpc/TestRaftServerWithGrpc.java  |  26 +++--
 .../apache/ratis/grpc/TestRetryCacheWithGrpc.java  | 120 +++++++++++++++++++++
 12 files changed, 267 insertions(+), 37 deletions(-)

diff --git 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java 
b/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java
index 3c0f14fb7..db1983195 100644
--- 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java
+++ 
b/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java
@@ -131,6 +131,7 @@ public interface ClientProtoUtils {
 
     return b.setCallId(request.getCallId())
         .setToLeader(request.isToLeader())
+        .addAllRepliedCallIds(request.getRepliedCallIds())
         .setTimeoutMs(request.getTimeoutMs());
   }
 
@@ -192,6 +193,7 @@ public interface ClientProtoUtils {
         .setCallId(request.getCallId())
         .setMessage(toMessage(p.getMessage()))
         .setType(type)
+        .setRepliedCallIds(request.getRepliedCallIdsList())
         .setRoutingTable(getRoutingTable(request))
         .setTimeoutMs(request.getTimeoutMs())
         .build();
diff --git 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java 
b/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java
index 9847beed7..ec16763c2 100644
--- 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java
+++ 
b/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java
@@ -59,6 +59,8 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.Set;
+import java.util.TreeSet;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
@@ -127,6 +129,47 @@ public final class RaftClientImpl implements RaftClient {
     }
   }
 
+  static class RepliedCallIds {
+    private final Object name;
+    /** The replied callIds. */
+    private Set<Long> replied = new TreeSet<>();
+    /**
+     * Map: callId to-be-sent -> replied callIds to-be-included.
+     * When retrying the same callId, the request will include the same set of 
replied callIds.
+     *
+     * @see RaftClientRequest#getRepliedCallIds()
+     */
+    private final ConcurrentMap<Long, Set<Long>> sent = new 
ConcurrentHashMap<>();
+
+    RepliedCallIds(Object name) {
+      this.name = name;
+    }
+
+    /** The given callId is replied. */
+    void add(long repliedCallId) {
+      LOG.debug("{}: add replied callId {}", name, repliedCallId);
+      synchronized (this) {
+        // synchronized to avoid adding to a previous set.
+        replied.add(repliedCallId);
+      }
+      sent.remove(repliedCallId);
+    }
+
+    /** @return the replied callIds for the given callId. */
+    Iterable<Long> get(long callId) {
+      final Supplier<Set<Long>> supplier = 
MemoizedSupplier.valueOf(this::getAndReset);
+      final Set<Long> set = 
Collections.unmodifiableSet(sent.computeIfAbsent(callId, cid -> 
supplier.get()));
+      LOG.debug("{}: get {} returns {}", name, callId, set);
+      return set;
+    }
+
+    private synchronized Set<Long> getAndReset() {
+      final Set<Long> previous = replied;
+      replied = new TreeSet<>();
+      return previous;
+    }
+  }
+
   private final ClientId clientId;
   private final RaftClientRpc clientRpc;
   private final RaftPeerList peers = new RaftPeerList();
@@ -134,6 +177,8 @@ public final class RaftClientImpl implements RaftClient {
   private final RetryPolicy retryPolicy;
 
   private volatile RaftPeerId leaderId;
+  /** The callIds of the replied requests. */
+  private final RepliedCallIds repliedCallIds;
 
   private final TimeoutExecutor scheduler = TimeoutExecutor.getInstance();
 
@@ -158,6 +203,7 @@ public final class RaftClientImpl implements RaftClient {
 
     this.leaderId = Objects.requireNonNull(computeLeaderId(leaderId, group),
         () -> "this.leaderId is set to null, leaderId=" + leaderId + ", 
group=" + group);
+    this.repliedCallIds = new RepliedCallIds(clientId);
     this.retryPolicy = Objects.requireNonNull(retryPolicy, "retry policy can't 
be null");
 
     clientRpc.addRaftPeers(group.getPeers());
@@ -241,7 +287,8 @@ public final class RaftClientImpl implements RaftClient {
     if (server != null) {
       b.setServerId(server);
     } else {
-      b.setLeaderId(getLeaderId());
+      b.setLeaderId(getLeaderId())
+       .setRepliedCallIds(repliedCallIds.get(callId));
     }
     return b.setClientId(clientId)
         .setGroupId(groupId)
@@ -307,8 +354,14 @@ public final class RaftClientImpl implements RaftClient {
   }
 
   RaftClientReply handleReply(RaftClientRequest request, RaftClientReply 
reply) {
-    if (request.isToLeader() && reply != null && reply.getException() == null) 
{
-      LEADER_CACHE.put(reply.getRaftGroupId(), reply.getServerId());
+    if (request.isToLeader() && reply != null) {
+      if (!request.getType().isReadOnly()) {
+        repliedCallIds.add(reply.getCallId());
+      }
+
+      if (reply.getException() == null) {
+        LEADER_CACHE.put(reply.getRaftGroupId(), reply.getServerId());
+      }
     }
     return reply;
   }
diff --git 
a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java 
b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
index 7c55a1822..9d853b48b 100644
--- 
a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
+++ 
b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
@@ -21,7 +21,9 @@ import org.apache.ratis.proto.RaftProtos.*;
 import org.apache.ratis.util.Preconditions;
 import org.apache.ratis.util.ProtoUtils;
 
+import java.util.Collections;
 import java.util.Objects;
+import java.util.Optional;
 
 import static 
org.apache.ratis.proto.RaftProtos.RaftClientRequestProto.TypeCase.*;
 
@@ -266,6 +268,7 @@ public class RaftClientRequest extends RaftClientMessage {
     private RaftGroupId groupId;
     private long callId;
     private boolean toLeader;
+    private Iterable<Long> repliedCallIds = Collections.emptyList();
 
     private Message message;
     private Type type;
@@ -304,6 +307,11 @@ public class RaftClientRequest extends RaftClientMessage {
       return this;
     }
 
+    public Builder setRepliedCallIds(Iterable<Long> repliedCallIds) {
+      this.repliedCallIds = repliedCallIds;
+      return this;
+    }
+
     public Builder setMessage(Message message) {
       this.message = message;
       return this;
@@ -350,6 +358,7 @@ public class RaftClientRequest extends RaftClientMessage {
   private final Message message;
   private final Type type;
 
+  private final Iterable<Long> repliedCallIds;
   private final SlidingWindowEntry slidingWindowEntry;
 
   private final RoutingTable routingTable;
@@ -386,8 +395,8 @@ public class RaftClientRequest extends RaftClientMessage {
 
     this.message = b.message;
     this.type = b.type;
-    this.slidingWindowEntry = b.slidingWindowEntry != null ? 
b.slidingWindowEntry
-        : SlidingWindowEntry.getDefaultInstance();
+    this.repliedCallIds = 
Optional.ofNullable(b.repliedCallIds).orElseGet(Collections::emptyList);
+    this.slidingWindowEntry = b.slidingWindowEntry;
     this.routingTable = b.routingTable;
     this.timeoutMs = b.timeoutMs;
   }
@@ -401,6 +410,10 @@ public class RaftClientRequest extends RaftClientMessage {
     return toLeader;
   }
 
+  public Iterable<Long> getRepliedCallIds() {
+    return repliedCallIds;
+  }
+
   public SlidingWindowEntry getSlidingWindowEntry() {
     return slidingWindowEntry;
   }
diff --git a/ratis-common/src/main/java/org/apache/ratis/rpc/CallId.java 
b/ratis-common/src/main/java/org/apache/ratis/rpc/CallId.java
index 85e6ef06b..abc24cc09 100644
--- a/ratis-common/src/main/java/org/apache/ratis/rpc/CallId.java
+++ b/ratis-common/src/main/java/org/apache/ratis/rpc/CallId.java
@@ -22,11 +22,11 @@ import java.util.concurrent.atomic.AtomicLong;
 
 /**
  * A long ID for RPC calls.
- *
+ * <p>
  * This class is threadsafe.
  */
 public final class CallId {
-  private static final AtomicLong CALL_ID_COUNTER = new AtomicLong();
+  private static final AtomicLong CALL_ID_COUNTER = new AtomicLong(1);
 
   private static final Comparator<Long> COMPARATOR = (left, right) -> {
     final long diff = left - right;
diff --git a/ratis-proto/src/main/proto/Raft.proto 
b/ratis-proto/src/main/proto/Raft.proto
index 49a107c45..d8a1b626a 100644
--- a/ratis-proto/src/main/proto/Raft.proto
+++ b/ratis-proto/src/main/proto/Raft.proto
@@ -117,6 +117,7 @@ message RaftRpcRequestProto {
   uint64 callId = 4;
   bool toLeader = 5;
 
+  repeated uint64 repliedCallIds = 12; // The call ids of the replied requests
   uint64 timeoutMs = 13;
   RoutingTableProto routingTable = 14;
   SlidingWindowEntry slidingWindowEntry = 15;
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
index 667e611b4..8005be894 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
@@ -900,6 +900,8 @@ class RaftServerImpl implements RaftServer.Division,
   }
 
   private CompletableFuture<RaftClientReply> replyFuture(RaftClientRequest 
request) throws IOException {
+    retryCache.invalidateRepliedRequests(request);
+
     final TypeCase type = request.getType().getTypeCase();
     switch (type) {
       case STALEREAD:
@@ -925,7 +927,7 @@ class RaftServerImpl implements RaftServer.Division,
     }
 
     // query the retry cache
-    final RetryCacheImpl.CacheQueryResult queryResult = 
retryCache.queryCache(ClientInvocationId.valueOf(request));
+    final RetryCacheImpl.CacheQueryResult queryResult = 
retryCache.queryCache(request);
     final CacheEntry cacheEntry = queryResult.getEntry();
     if (queryResult.isRetry()) {
       // return the cached future.
@@ -1784,7 +1786,7 @@ class RaftServerImpl implements RaftServer.Division,
       ClientInvocationId invocationId, long logIndex, 
CompletableFuture<Message> stateMachineFuture) {
     // update the retry cache
     final CacheEntry cacheEntry = retryCache.getOrCreateEntry(invocationId);
-    Preconditions.assertTrue(cacheEntry != null);
+    Objects.requireNonNull(cacheEntry , "cacheEntry == null");
     if (getInfo().isLeader() && cacheEntry.isCompletedNormally()) {
       LOG.warn("{} retry cache entry of leader should be pending: {}", this, 
cacheEntry);
     }
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java
index 438315ed7..a8bac4e5e 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java
@@ -18,14 +18,18 @@
 package org.apache.ratis.server.impl;
 
 import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.protocol.ClientId;
 import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.protocol.RaftClientReply;
+import org.apache.ratis.protocol.RaftClientRequest;
 import org.apache.ratis.server.RaftServerConfigKeys;
 import org.apache.ratis.server.RetryCache;
 import org.apache.ratis.thirdparty.com.google.common.cache.Cache;
 import org.apache.ratis.thirdparty.com.google.common.cache.CacheBuilder;
 import org.apache.ratis.thirdparty.com.google.common.cache.CacheStats;
+import org.apache.ratis.util.CollectionUtils;
 import org.apache.ratis.util.JavaUtils;
+import org.apache.ratis.util.MemoizedSupplier;
 import org.apache.ratis.util.TimeDuration;
 import org.apache.ratis.util.Timestamp;
 
@@ -33,6 +37,7 @@ import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
 
 class RetryCacheImpl implements RetryCache {
   static class CacheEntry implements Entry {
@@ -181,13 +186,15 @@ class RetryCacheImpl implements RetryCache {
   }
 
   CacheEntry getOrCreateEntry(ClientInvocationId key) {
-    final CacheEntry entry;
+    return getOrCreateEntry(key, () -> new CacheEntry(key));
+  }
+
+  private CacheEntry getOrCreateEntry(ClientInvocationId key, 
Supplier<CacheEntry> constructor) {
     try {
-      entry = cache.get(key, () -> new CacheEntry(key));
+      return cache.get(key, constructor::get);
     } catch (ExecutionException e) {
-      throw new IllegalStateException(e);
+      throw new IllegalStateException("Failed to get " + key, e);
     }
-    return entry;
   }
 
   CacheEntry refreshEntry(CacheEntry newEntry) {
@@ -195,16 +202,11 @@ class RetryCacheImpl implements RetryCache {
     return newEntry;
   }
 
-  CacheQueryResult queryCache(ClientInvocationId key) {
-    final CacheEntry newEntry = new CacheEntry(key);
-    final CacheEntry cacheEntry;
-    try {
-      cacheEntry = cache.get(key, () -> newEntry);
-    } catch (ExecutionException e) {
-      throw new IllegalStateException(e);
-    }
-
-    if (cacheEntry == newEntry) {
+  CacheQueryResult queryCache(RaftClientRequest request) {
+    final ClientInvocationId key = ClientInvocationId.valueOf(request);
+    final MemoizedSupplier<CacheEntry> newEntry = MemoizedSupplier.valueOf(() 
-> new CacheEntry(key));
+    final CacheEntry cacheEntry = getOrCreateEntry(key, newEntry);
+    if (newEntry.isInitialized()) {
       // this is the entry we just newly created
       return new CacheQueryResult(cacheEntry, false);
     } else if (!cacheEntry.isDone() || !cacheEntry.isFailed()){
@@ -221,13 +223,24 @@ class RetryCacheImpl implements RetryCache {
       if (currentEntry == cacheEntry || currentEntry == null) {
         // if the failed entry has not got replaced by another retry, or the
         // failed entry got invalidated, we add a new cache entry
-        return new CacheQueryResult(refreshEntry(newEntry), false);
+        return new CacheQueryResult(refreshEntry(newEntry.get()), false);
       } else {
         return new CacheQueryResult(currentEntry, true);
       }
     }
   }
 
+  void invalidateRepliedRequests(RaftClientRequest request) {
+    final ClientId clientId = request.getClientId();
+    final Iterable<Long> callIds = request.getRepliedCallIds();
+    if (!callIds.iterator().hasNext()) {
+      return;
+    }
+
+    LOG.debug("invalidateRepliedRequests callIds {} for {}", callIds, 
clientId);
+    cache.invalidateAll(CollectionUtils.as(callIds, callId -> 
ClientInvocationId.valueOf(clientId, callId)));
+  }
+
   @Override
   public Statistics getStatistics() {
     return statistics.updateAndGet(old -> old == null || old.isExpired()? new 
StatisticsImpl(cache): old);
@@ -240,10 +253,8 @@ class RetryCacheImpl implements RetryCache {
 
   @Override
   public synchronized void close() {
-    if (cache != null) {
-      cache.invalidateAll();
-      statistics.set(null);
-    }
+    cache.invalidateAll();
+    statistics.set(null);
   }
 
   static CompletableFuture<RaftClientReply> failWithReply(
diff --git a/ratis-server/src/test/java/org/apache/ratis/RetryCacheTests.java 
b/ratis-server/src/test/java/org/apache/ratis/RetryCacheTests.java
index f729dcd2d..288aa71a9 100644
--- a/ratis-server/src/test/java/org/apache/ratis/RetryCacheTests.java
+++ b/ratis-server/src/test/java/org/apache/ratis/RetryCacheTests.java
@@ -82,11 +82,10 @@ public abstract class RetryCacheTests<CLUSTER extends 
MiniRaftCluster>
     }
   }
 
-  public static RaftClient assertReply(RaftClientReply reply, RaftClient 
client, long callId) {
+  public static void assertReply(RaftClientReply reply, RaftClient client, 
long callId) {
     Assert.assertEquals(client.getId(), reply.getClientId());
     Assert.assertEquals(callId, reply.getCallId());
     Assert.assertTrue(reply.isSuccess());
-    return client;
   }
 
   public void assertServer(MiniRaftCluster cluster, ClientId clientId, long 
callId, long oldLastApplied) throws Exception {
diff --git 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientTestUtil.java
 
b/ratis-server/src/test/java/org/apache/ratis/client/impl/RaftClientTestUtil.java
similarity index 90%
rename from 
ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientTestUtil.java
rename to 
ratis-server/src/test/java/org/apache/ratis/client/impl/RaftClientTestUtil.java
index ba00b8f00..d90b0cc53 100644
--- 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientTestUtil.java
+++ 
b/ratis-server/src/test/java/org/apache/ratis/client/impl/RaftClientTestUtil.java
@@ -21,6 +21,7 @@ import org.apache.ratis.client.RaftClient;
 import org.apache.ratis.proto.RaftProtos.SlidingWindowEntry;
 import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.protocol.Message;
+import org.apache.ratis.protocol.RaftClientReply;
 import org.apache.ratis.protocol.RaftClientRequest;
 import org.apache.ratis.protocol.RaftPeerId;
 import org.apache.ratis.rpc.CallId;
@@ -39,4 +40,8 @@ public interface RaftClientTestUtil {
       long callId, Message message, RaftClientRequest.Type type, 
SlidingWindowEntry slidingWindowEntry) {
     return ((RaftClientImpl)client).newRaftClientRequest(server, callId, 
message, type, slidingWindowEntry);
   }
+
+  static void handleReply(RaftClientRequest request, RaftClientReply reply, 
RaftClient client) {
+    ((RaftClientImpl)client).handleReply(request, reply);
+  }
 }
diff --git 
a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
 
b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
index 2ce5643cc..312c9508d 100644
--- 
a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
+++ 
b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
@@ -126,7 +126,7 @@ public class SimpleStateMachine4Testing extends 
BaseStateMachine {
 
   static class Blocking {
     enum Type {
-      START_TRANSACTION, READ_STATE_MACHINE_DATA, WRITE_STATE_MACHINE_DATA, 
FLUSH_STATE_MACHINE_DATA
+      START_TRANSACTION, APPLY_TRANSACTION, READ_STATE_MACHINE_DATA, 
WRITE_STATE_MACHINE_DATA, FLUSH_STATE_MACHINE_DATA
     }
 
     private final EnumMap<Type, CompletableFuture<Void>> maps = new 
EnumMap<>(Type.class);
@@ -243,7 +243,10 @@ public class SimpleStateMachine4Testing extends 
BaseStateMachine {
 
   @Override
   public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
+    blocking.await(Blocking.Type.APPLY_TRANSACTION);
     LogEntryProto entry = Objects.requireNonNull(trx.getLogEntry());
+    LOG.info("applyTransaction for log index {}", entry.getIndex());
+
     put(entry);
     updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
 
@@ -386,6 +389,13 @@ public class SimpleStateMachine4Testing extends 
BaseStateMachine {
     blocking.unblock(Blocking.Type.START_TRANSACTION);
   }
 
+  public void blockApplyTransaction() {
+    blocking.block(Blocking.Type.APPLY_TRANSACTION);
+  }
+  public void unblockApplyTransaction() {
+    blocking.unblock(Blocking.Type.APPLY_TRANSACTION);
+  }
+
   public void blockWriteStateMachineData() {
     blocking.block(Blocking.Type.WRITE_STATE_MACHINE_DATA);
   }
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java 
b/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java
index 7de1c4042..0af1d87cc 100644
--- a/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java
+++ b/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java
@@ -23,6 +23,8 @@ import static 
org.apache.ratis.server.metrics.RaftServerMetricsImpl.RAFT_CLIENT_
 import static 
org.apache.ratis.server.metrics.RaftServerMetricsImpl.RAFT_CLIENT_WRITE_REQUEST;
 
 import org.apache.ratis.metrics.RatisMetricRegistry;
+import org.apache.ratis.protocol.ClientInvocationId;
+import org.apache.ratis.server.RetryCache;
 import org.apache.ratis.util.JavaUtils;
 import org.slf4j.event.Level;
 import org.apache.ratis.conf.Parameters;
@@ -183,19 +185,31 @@ public class TestRaftServerWithGrpc extends BaseTest 
implements MiniRaftClusterW
       final RaftClientRpc rpc = client.getClientRpc();
 
       final AtomicLong seqNum = new AtomicLong();
+      final ClientInvocationId invocationId;
       {
         // send a request using rpc directly
-        final RaftClientRequest request = newRaftClientRequest(client, 
leader.getId(), seqNum.incrementAndGet());
+        final RaftClientRequest request = newRaftClientRequest(client, 
seqNum.incrementAndGet());
+        Assert.assertEquals(client.getId(), request.getClientId());
         final CompletableFuture<RaftClientReply> f = 
rpc.sendRequestAsync(request);
-        Assert.assertTrue(f.get().isSuccess());
+        final RaftClientReply reply = f.get();
+        Assert.assertTrue(reply.isSuccess());
+        RaftClientTestUtil.handleReply(request, reply, client);
+        invocationId = ClientInvocationId.valueOf(request.getClientId(), 
request.getCallId());
+        final RetryCache.Entry entry = 
leader.getRetryCache().getIfPresent(invocationId);
+        Assert.assertNotNull(entry);
+        LOG.info("cache entry {}", entry);
       }
 
       // send another request which will be blocked
       final SimpleStateMachine4Testing stateMachine = 
SimpleStateMachine4Testing.get(leader);
       stateMachine.blockStartTransaction();
-      final RaftClientRequest requestBlocked = newRaftClientRequest(client, 
leader.getId(), seqNum.incrementAndGet());
+      final RaftClientRequest requestBlocked = newRaftClientRequest(client, 
seqNum.incrementAndGet());
       final CompletableFuture<RaftClientReply> futureBlocked = 
rpc.sendRequestAsync(requestBlocked);
 
+      JavaUtils.attempt(() -> 
Assert.assertNull(leader.getRetryCache().getIfPresent(invocationId)),
+          10, HUNDRED_MILLIS, "invalidate cache entry", LOG);
+      LOG.info("cache entry not found for {}", invocationId);
+
       // change leader
       RaftTestUtil.changeLeader(cluster, leader.getId());
       Assert.assertNotEquals(RaftPeerRole.LEADER, 
leader.getInfo().getCurrentRole());
@@ -206,7 +220,7 @@ public class TestRaftServerWithGrpc extends BaseTest 
implements MiniRaftClusterW
       stateMachine.unblockStartTransaction();
 
       // send one more request which should timeout.
-      final RaftClientRequest requestTimeout = newRaftClientRequest(client, 
leader.getId(), seqNum.incrementAndGet());
+      final RaftClientRequest requestTimeout = newRaftClientRequest(client, 
seqNum.incrementAndGet());
       rpc.handleException(leader.getId(), new Exception(), true);
       final CompletableFuture<RaftClientReply> f = 
rpc.sendRequestAsync(requestTimeout);
       testFailureCase("request should timeout", f::get,
@@ -346,9 +360,9 @@ public class TestRaftServerWithGrpc extends BaseTest 
implements MiniRaftClusterW
     }
   }
 
-  static RaftClientRequest newRaftClientRequest(RaftClient client, RaftPeerId 
serverId, long seqNum) {
+  static RaftClientRequest newRaftClientRequest(RaftClient client, long 
seqNum) {
     final SimpleMessage m = new SimpleMessage("m" + seqNum);
-    return RaftClientTestUtil.newRaftClientRequest(client, serverId, seqNum, m,
+    return RaftClientTestUtil.newRaftClientRequest(client, null, seqNum, m,
         RaftClientRequest.writeRequestType(), 
ProtoUtils.toSlidingWindowEntry(seqNum, seqNum == 1L));
   }
 
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/grpc/TestRetryCacheWithGrpc.java 
b/ratis-test/src/test/java/org/apache/ratis/grpc/TestRetryCacheWithGrpc.java
index 400e6e5a6..8a1878cd8 100644
--- a/ratis-test/src/test/java/org/apache/ratis/grpc/TestRetryCacheWithGrpc.java
+++ b/ratis-test/src/test/java/org/apache/ratis/grpc/TestRetryCacheWithGrpc.java
@@ -17,8 +17,12 @@
  */
 package org.apache.ratis.grpc;
 
+import org.apache.ratis.client.RaftClient;
+import org.apache.ratis.proto.RaftProtos;
+import org.apache.ratis.server.RetryCache;
 import org.apache.ratis.server.impl.MiniRaftCluster;
 import org.apache.ratis.RaftTestUtil;
+import org.apache.ratis.RaftTestUtil.SimpleMessage;
 import org.apache.ratis.RetryCacheTests;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.protocol.ClientId;
@@ -30,15 +34,131 @@ import org.apache.ratis.server.RaftServerConfigKeys;
 import org.apache.ratis.server.impl.RetryCacheTestUtil;
 import org.apache.ratis.statemachine.impl.SimpleStateMachine4Testing;
 import org.apache.ratis.statemachine.StateMachine;
+import org.apache.ratis.util.Slf4jUtils;
+import org.junit.Assert;
 import org.junit.Test;
+import org.slf4j.event.Level;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
 
 public class TestRetryCacheWithGrpc
     extends RetryCacheTests<MiniRaftClusterWithGrpc>
     implements MiniRaftClusterWithGrpc.FactoryGet {
+  {
+    Slf4jUtils.setLogLevel(RetryCache.LOG, Level.TRACE);
+
+    getProperties().setClass(MiniRaftCluster.STATEMACHINE_CLASS_KEY,
+        SimpleStateMachine4Testing.class, StateMachine.class);
+  }
+
+  @Test
+  public void testInvalidateRepliedCalls() throws Exception {
+    runWithNewCluster(3, cluster -> new 
InvalidateRepliedCallsTest(cluster).run());
+  }
+
+  static long assertReply(RaftClientReply reply) {
+    Assert.assertTrue(reply.isSuccess());
+    return reply.getCallId();
+  }
+
+  class InvalidateRepliedCallsTest {
+    private final MiniRaftCluster cluster;
+    private final RaftServer.Division leader;
+    private final AtomicInteger count = new AtomicInteger();
+
+    InvalidateRepliedCallsTest(MiniRaftCluster cluster) throws Exception {
+      this.cluster = cluster;
+      this.leader = RaftTestUtil.waitForLeader(cluster);
+    }
+
+    SimpleMessage nextMessage() {
+      return new SimpleMessage("m" + count.incrementAndGet());
+    }
+
+    void assertRetryCacheEntry(RaftClient client, long callId, boolean exist) {
+      final RetryCache.Entry e = RetryCacheTestUtil.get(leader, 
client.getId(), callId);
+      if (exist) {
+        Assert.assertNotNull(e);
+      } else {
+        Assert.assertNull(e);
+      }
+    }
+
+    long send(RaftClient client, Long previousCallId) throws Exception {
+      final RaftClientReply reply = client.io().send(nextMessage());
+      final long callId = assertReply(reply);
+      if (previousCallId != null) {
+        // the previous should be invalidated.
+        assertRetryCacheEntry(client, previousCallId, false);
+      }
+      // the current should exist.
+      assertRetryCacheEntry(client, callId, true);
+      return callId;
+    }
+
+    CompletableFuture<Long> sendAsync(RaftClient client) {
+      return client.async().send(nextMessage())
+              .thenApply(TestRetryCacheWithGrpc::assertReply);
+    }
+
+    CompletableFuture<Long> watch(long logIndex, RaftClient client) {
+      return client.async().watch(logIndex, 
RaftProtos.ReplicationLevel.MAJORITY)
+          .thenApply(TestRetryCacheWithGrpc::assertReply);
+    }
+
+    void run() throws Exception {
+      try (RaftClient client = cluster.createClient()) {
+        // test blocking io
+        Long lastBlockingCall = null;
+        for (int i = 0; i < 5; i++) {
+          lastBlockingCall = send(client, lastBlockingCall);
+        }
+        final long lastBlockingCallId = lastBlockingCall;
+
+        // test async
+        final SimpleStateMachine4Testing stateMachine = 
SimpleStateMachine4Testing.get(leader);
+        stateMachine.blockApplyTransaction();
+        final List<CompletableFuture<Long>> asyncCalls = new ArrayList<>();
+        for (int i = 0; i < 5; i++) {
+          // Since applyTransaction is blocked, the replied call id remains 
the same.
+          asyncCalls.add(sendAsync(client));
+        }
+        // async call will invalidate blocking calls even if applyTransaction 
is blocked.
+        assertRetryCacheEntry(client, lastBlockingCallId, false);
+
+        ONE_SECOND.sleep();
+        // No calls can be completed.
+        for (CompletableFuture<Long> f : asyncCalls) {
+          Assert.assertFalse(f.isDone());
+        }
+        stateMachine.unblockApplyTransaction();
+        // No calls can be invalidated.
+        for (CompletableFuture<Long> f : asyncCalls) {
+          assertRetryCacheEntry(client, f.join(), true);
+        }
+
+        // one more blocking call will invalidate all async calls
+        final long oneMoreBlockingCall = send(client, null);
+        LOG.info("oneMoreBlockingCall callId={}", oneMoreBlockingCall);
+        assertRetryCacheEntry(client, oneMoreBlockingCall, true);
+        for (CompletableFuture<Long> f : asyncCalls) {
+          assertRetryCacheEntry(client, f.join(), false);
+        }
+
+        // watch call will invalidate blocking calls
+        final long watchAsyncCall = watch(1, client).get();
+        LOG.info("watchAsyncCall callId={}", watchAsyncCall);
+        assertRetryCacheEntry(client, oneMoreBlockingCall, false);
+        // retry cache should not contain watch calls
+        assertRetryCacheEntry(client, watchAsyncCall, false);
+      }
+    }
+  }
 
   @Test(timeout = 10000)
   public void testRetryOnResourceUnavailableException()

Reply via email to