This is an automated email from the ASF dual-hosted git repository.

szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 903995983 RATIS-1550. Rewrite stream client reply queue. (#740)
903995983 is described below

commit 90399598396af414c0de8252231fc171debb6bb2
Author: hao guo <[email protected]>
AuthorDate: Thu Aug 3 05:22:38 2023 +0800

    RATIS-1550. Rewrite stream client reply queue. (#740)
---
 .../ratis/client/impl/OrderedStreamAsync.java      |  21 ---
 .../ratis/netty/client/NettyClientReplies.java     | 179 +++++++++++++++++++++
 .../ratis/netty/client/NettyClientStreamRpc.java   | 144 +++++++----------
 3 files changed, 236 insertions(+), 108 deletions(-)

diff --git 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java
 
b/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java
index fe51359b2..3847adf03 100644
--- 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java
+++ 
b/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java
@@ -24,16 +24,12 @@ import 
org.apache.ratis.datastream.impl.DataStreamPacketByteBuffer;
 import org.apache.ratis.datastream.impl.DataStreamRequestByteBuffer;
 import org.apache.ratis.datastream.impl.DataStreamRequestFilePositionCount;
 import org.apache.ratis.io.FilePositionCount;
-import org.apache.ratis.io.StandardWriteOption;
 import org.apache.ratis.protocol.DataStreamReply;
 import org.apache.ratis.protocol.DataStreamRequest;
 import org.apache.ratis.protocol.DataStreamRequestHeader;
-import org.apache.ratis.protocol.exceptions.TimeoutIOException;
 import org.apache.ratis.util.IOUtils;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.SlidingWindow;
-import org.apache.ratis.util.TimeDuration;
-import org.apache.ratis.util.TimeoutExecutor;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -105,15 +101,10 @@ public class OrderedStreamAsync {
   private final DataStreamClientRpc dataStreamClientRpc;
 
   private final Semaphore requestSemaphore;
-  private final TimeDuration requestTimeout;
-  private final TimeDuration closeTimeout;
-  private final TimeoutExecutor scheduler = TimeoutExecutor.getInstance();
 
   OrderedStreamAsync(DataStreamClientRpc dataStreamClientRpc, RaftProperties 
properties){
     this.dataStreamClientRpc = dataStreamClientRpc;
     this.requestSemaphore = new 
Semaphore(RaftClientConfigKeys.DataStream.outstandingRequestsMax(properties));
-    this.requestTimeout = 
RaftClientConfigKeys.DataStream.requestTimeout(properties);
-    this.closeTimeout = requestTimeout.multiply(2);
   }
 
   CompletableFuture<DataStreamReply> sendRequest(DataStreamRequestHeader 
header, Object data,
@@ -149,9 +140,6 @@ public class OrderedStreamAsync {
         request.getDataStreamRequest());
     long seqNum = request.getSeqNum();
 
-    final boolean isClose = 
request.getDataStreamRequest().getWriteOptionList().contains(StandardWriteOption.CLOSE);
-    scheduleWithTimeout(request, isClose? closeTimeout: requestTimeout);
-
     requestFuture.thenApply(reply -> {
       slidingWindow.receiveReply(
           seqNum, reply, r -> sendRequestToNetwork(r, slidingWindow));
@@ -166,13 +154,4 @@ public class OrderedStreamAsync {
       return null;
     });
   }
-
-  private void scheduleWithTimeout(DataStreamWindowRequest request, 
TimeDuration timeout) {
-    scheduler.onTimeout(timeout, () -> {
-      if (!request.getReplyFuture().isDone()) {
-        request.getReplyFuture().completeExceptionally(
-            new TimeoutIOException("Timeout " + timeout + ": Failed to send " 
+ request));
-      }
-    }, LOG, () -> "Failed to completeExceptionally for " + request);
-  }
 }
diff --git 
a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientReplies.java
 
b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientReplies.java
new file mode 100644
index 000000000..fc97b6fe3
--- /dev/null
+++ 
b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientReplies.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ratis.netty.client;
+
+import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type;
+import org.apache.ratis.protocol.ClientInvocationId;
+import org.apache.ratis.protocol.DataStreamPacket;
+import org.apache.ratis.protocol.DataStreamReply;
+import org.apache.ratis.thirdparty.io.netty.util.concurrent.ScheduledFuture;
+import org.apache.ratis.util.MemoizedSupplier;
+import org.apache.ratis.util.Preconditions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicReference;
+
+public class NettyClientReplies {
+  public static final Logger LOG = 
LoggerFactory.getLogger(NettyClientReplies.class);
+
+  private final ConcurrentMap<ClientInvocationId, ReplyMap> replies = new 
ConcurrentHashMap<>();
+
+  ReplyMap getReplyMap(ClientInvocationId clientInvocationId) {
+    final MemoizedSupplier<ReplyMap> q = MemoizedSupplier.valueOf(() -> new 
ReplyMap(clientInvocationId));
+    return replies.computeIfAbsent(clientInvocationId, key -> q.get());
+  }
+
+  class ReplyMap {
+    private final ClientInvocationId clientInvocationId;
+    private final Map<RequestEntry, ReplyEntry> map = new 
ConcurrentHashMap<>();
+
+    ReplyMap(ClientInvocationId clientInvocationId) {
+      this.clientInvocationId = clientInvocationId;
+    }
+
+    ReplyEntry submitRequest(RequestEntry requestEntry, boolean isClose, 
CompletableFuture<DataStreamReply> f) {
+      LOG.debug("put {} to the map for {}", requestEntry, clientInvocationId);
+      final MemoizedSupplier<ReplyEntry> replySupplier = 
MemoizedSupplier.valueOf(() -> new ReplyEntry(isClose, f));
+      return map.computeIfAbsent(requestEntry, r -> replySupplier.get());
+    }
+
+    void receiveReply(DataStreamReply reply) {
+      final RequestEntry requestEntry = new RequestEntry(reply);
+      final ReplyEntry replyEntry = map.remove(requestEntry);
+      LOG.debug("remove: {}; replyEntry: {}; reply: {}", requestEntry, 
replyEntry, reply);
+      if (replyEntry == null) {
+        LOG.debug("Request not found: {}", this);
+        return;
+      }
+      replyEntry.complete(reply);
+      if (!reply.isSuccess()) {
+        failAll("a request failed with " + reply);
+      } else if (replyEntry.isClosed()) {  // stream closed clean up reply map
+        removeThisMap();
+      }
+    }
+
+    private void removeThisMap() {
+      final ReplyMap removed = replies.remove(clientInvocationId);
+      Preconditions.assertSame(removed, this, "removed");
+    }
+
+    void completeExceptionally(Throwable e) {
+      removeThisMap();
+      for (ReplyEntry entry : map.values()) {
+        entry.completeExceptionally(e);
+      }
+      map.clear();
+    }
+
+    private void failAll(String message) {
+      completeExceptionally(new IllegalStateException(this + ": " + message));
+    }
+
+    void fail(RequestEntry requestEntry) {
+      map.remove(requestEntry);
+      failAll(requestEntry + " failed ");
+    }
+
+    @Override
+    public String toString() {
+      final StringBuilder builder = new StringBuilder();
+      for (RequestEntry requestEntry : map.keySet()) {
+        builder.append(requestEntry).append(", ");
+      }
+      return builder.toString();
+    }
+  }
+
+  static class RequestEntry {
+    private final long streamOffset;
+    private final Type type;
+
+    RequestEntry(DataStreamPacket packet) {
+      this.streamOffset = packet.getStreamOffset();
+      this.type = packet.getType();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      final RequestEntry that = (RequestEntry) o;
+      return streamOffset == that.streamOffset
+          && type == that.type;
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(type, streamOffset);
+    }
+
+    @Override
+    public String toString() {
+      return "Request{" +
+          "streamOffset=" + streamOffset +
+          ", type=" + type +
+          '}';
+    }
+  }
+
+  static class ReplyEntry {
+    private final boolean isClosed;
+    private final CompletableFuture<DataStreamReply> replyFuture;
+    private final AtomicReference<ScheduledFuture<?>> timeoutFuture = new 
AtomicReference<>();
+
+    ReplyEntry(boolean isClosed, CompletableFuture<DataStreamReply> 
replyFuture) {
+      this.isClosed = isClosed;
+      this.replyFuture = replyFuture;
+    }
+
+    boolean isClosed() {
+      return isClosed;
+    }
+
+    void complete(DataStreamReply reply) {
+      cancelTimeoutFuture();
+      replyFuture.complete(reply);
+    }
+
+    void completeExceptionally(Throwable t) {
+      cancelTimeoutFuture();
+      replyFuture.completeExceptionally(t);
+    }
+
+    private void cancelTimeoutFuture() {
+      Optional.ofNullable(timeoutFuture.get()).ifPresent(f -> f.cancel(false));
+    }
+
+    void setTimeoutFuture(ScheduledFuture<?> timeoutFuture) {
+      this.timeoutFuture.compareAndSet(null, timeoutFuture);
+    }
+  }
+}
diff --git 
a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
 
b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
index 51326d13e..f815bcffe 100644
--- 
a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
+++ 
b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
@@ -33,6 +33,7 @@ import org.apache.ratis.protocol.DataStreamReply;
 import org.apache.ratis.protocol.DataStreamRequest;
 import org.apache.ratis.protocol.RaftPeer;
 import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
+import org.apache.ratis.protocol.exceptions.TimeoutIOException;
 import org.apache.ratis.security.TlsConf;
 import org.apache.ratis.thirdparty.io.netty.bootstrap.Bootstrap;
 import org.apache.ratis.thirdparty.io.netty.buffer.ByteBuf;
@@ -51,26 +52,21 @@ import 
org.apache.ratis.thirdparty.io.netty.channel.socket.nio.NioSocketChannel;
 import org.apache.ratis.thirdparty.io.netty.handler.codec.ByteToMessageDecoder;
 import 
org.apache.ratis.thirdparty.io.netty.handler.codec.MessageToMessageEncoder;
 import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
+import org.apache.ratis.thirdparty.io.netty.util.concurrent.ScheduledFuture;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.MemoizedSupplier;
 import org.apache.ratis.util.NetUtils;
 import org.apache.ratis.util.SizeInBytes;
 import org.apache.ratis.util.TimeDuration;
-import org.apache.ratis.util.TimeoutExecutor;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Optional;
-import java.util.Queue;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentLinkedQueue;
-import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
@@ -115,39 +111,6 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
     }
   }
 
-  static class ReplyQueue implements 
Iterable<CompletableFuture<DataStreamReply>> {
-    static final ReplyQueue EMPTY = new ReplyQueue();
-
-    private final Queue<CompletableFuture<DataStreamReply>> queue = new 
ConcurrentLinkedQueue<>();
-    private int emptyId;
-
-    /** @return an empty ID if the queue is empty; otherwise, the queue is 
non-empty, return null. */
-    synchronized Integer getEmptyId() {
-      return queue.isEmpty()? emptyId: null;
-    }
-
-    synchronized boolean offer(CompletableFuture<DataStreamReply> f) {
-      if (queue.offer(f)) {
-        emptyId++;
-        return true;
-      }
-      return false;
-    }
-
-    CompletableFuture<DataStreamReply> poll() {
-      return queue.poll();
-    }
-
-    int size() {
-      return queue.size();
-    }
-
-    @Override
-    public Iterator<CompletableFuture<DataStreamReply>> iterator() {
-      return queue.iterator();
-    }
-  }
-
   static class Connection {
     static final TimeDuration RECONNECT = TimeDuration.valueOf(100, 
TimeUnit.MILLISECONDS);
 
@@ -275,17 +238,19 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
   private final String name;
   private final Connection connection;
 
+  private final NettyClientReplies replies = new NettyClientReplies();
+  private final TimeDuration requestTimeout;
+  private final TimeDuration closeTimeout;
+
   private final int flushRequestCountMin;
   private final SizeInBytes flushRequestBytesMin;
   private final OutstandingRequests outstandingRequests = new 
OutstandingRequests();
 
-  private final ConcurrentMap<ClientInvocationId, ReplyQueue> replies = new 
ConcurrentHashMap<>();
-  private final TimeDuration replyQueueGracePeriod;
-  private final TimeoutExecutor timeoutScheduler = 
TimeoutExecutor.getInstance();
-
   public NettyClientStreamRpc(RaftPeer server, TlsConf tlsConf, RaftProperties 
properties) {
     this.name = JavaUtils.getClassSimpleName(getClass()) + "->" + server;
-    this.replyQueueGracePeriod = 
NettyConfigKeys.DataStream.Client.replyQueueGracePeriod(properties);
+    this.requestTimeout = 
RaftClientConfigKeys.DataStream.requestTimeout(properties);
+    this.closeTimeout = requestTimeout.multiply(2);
+
     this.flushRequestCountMin = 
RaftClientConfigKeys.DataStream.flushRequestCountMin(properties);
     this.flushRequestBytesMin = 
RaftClientConfigKeys.DataStream.flushRequestBytesMin(properties);
 
@@ -299,8 +264,6 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
   private ChannelInboundHandler getClientHandler(){
     return new ChannelInboundHandlerAdapter(){
 
-      private ClientInvocationId clientInvocationId;
-
       @Override
       public void channelRead(ChannelHandlerContext ctx, Object msg) {
         if (!(msg instanceof DataStreamReply)) {
@@ -309,29 +272,19 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
         }
         final DataStreamReply reply = (DataStreamReply) msg;
         LOG.debug("{}: read {}", this, reply);
-        clientInvocationId = ClientInvocationId.valueOf(reply.getClientId(), 
reply.getStreamId());
-        final ReplyQueue queue = reply.isSuccess() ? 
replies.get(clientInvocationId) :
-                replies.remove(clientInvocationId);
-        if (queue != null) {
-          final CompletableFuture<DataStreamReply> f = queue.poll();
-          if (f != null) {
-            f.complete(reply);
-
-            if (!reply.isSuccess() && queue.size() > 0) {
-              final IllegalStateException e = new IllegalStateException(
-                  this + ": an earlier request failed with " + reply);
-              queue.forEach(future -> future.completeExceptionally(e));
-            }
+        final ClientInvocationId clientInvocationId = 
ClientInvocationId.valueOf(
+            reply.getClientId(), reply.getStreamId());
+        final NettyClientReplies.ReplyMap replyMap = 
replies.getReplyMap(clientInvocationId);
+        if (replyMap == null) {
+          LOG.error("{}: {} replyMap not found for reply: {}", this, 
clientInvocationId, reply);
+          return;
+        }
 
-            final Integer emptyId = queue.getEmptyId();
-            if (emptyId != null) {
-              timeoutScheduler.onTimeout(replyQueueGracePeriod,
-                  // remove the queue if the same queue has been empty for the 
entire grace period.
-                  () -> replies.computeIfPresent(clientInvocationId,
-                      (key, q) -> q == queue && 
emptyId.equals(q.getEmptyId())? null: q),
-                  LOG, () -> "Timeout check failed, clientInvocationId=" + 
clientInvocationId);
-            }
-          }
+        try {
+          replyMap.receiveReply(reply);
+        } catch (Throwable cause) {
+          LOG.warn(name + ": channelRead error:", cause);
+          replyMap.completeExceptionally(cause);
         }
       }
 
@@ -339,10 +292,6 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
       public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
         LOG.warn(name + ": exceptionCaught", cause);
 
-        Optional.ofNullable(clientInvocationId)
-            .map(replies::remove)
-            .orElse(ReplyQueue.EMPTY)
-            .forEach(f -> f.completeExceptionally(cause));
         ctx.close();
       }
 
@@ -417,24 +366,45 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
   public CompletableFuture<DataStreamReply> streamAsync(DataStreamRequest 
request) {
     final CompletableFuture<DataStreamReply> f = new CompletableFuture<>();
     ClientInvocationId clientInvocationId = 
ClientInvocationId.valueOf(request.getClientId(), request.getStreamId());
-    final ReplyQueue q = replies.computeIfAbsent(clientInvocationId, key -> 
new ReplyQueue());
-    if (!q.offer(f)) {
-      f.completeExceptionally(new IllegalStateException(this + ": Failed to 
offer a future for " + request));
-      return f;
-    }
-    final Channel channel = connection.getChannelUninterruptibly();
-    if (channel == null) {
-      f.completeExceptionally(new AlreadyClosedException(this + ": Failed to 
send " + request));
-      return f;
+    final boolean isClose = 
request.getWriteOptionList().contains(StandardWriteOption.CLOSE);
+
+    final NettyClientReplies.ReplyMap replyMap = 
replies.getReplyMap(clientInvocationId);
+    final ChannelFuture channelFuture;
+    final Channel channel;
+    final NettyClientReplies.RequestEntry requestEntry = new 
NettyClientReplies.RequestEntry(request);
+    final NettyClientReplies.ReplyEntry replyEntry;
+    LOG.debug("{}: write begin {}", this, request);
+    synchronized (replyMap) {
+      channel = connection.getChannelUninterruptibly();
+      if (channel == null) {
+        f.completeExceptionally(new AlreadyClosedException(this + ": Failed to 
send " + request));
+        return f;
+      }
+      replyEntry = replyMap.submitRequest(requestEntry, isClose, f);
+      final Function<DataStreamRequest, ChannelFuture> writeMethod = 
outstandingRequests.write(request)?
+          channel::writeAndFlush: channel::write;
+      channelFuture = writeMethod.apply(request);
     }
-    LOG.debug("{}: write {}", this, request);
-    final Function<DataStreamRequest, ChannelFuture> writeMethod = 
outstandingRequests.write(request)?
-        channel::writeAndFlush: channel::write;
-    writeMethod.apply(request).addListener(future -> {
+    channelFuture.addListener(future -> {
       if (!future.isSuccess()) {
-        final IOException e = new IOException(this + ": Failed to send " + 
request, future.cause());
-        LOG.error("Channel write failed", e);
+        final IOException e = new IOException(this + ": Failed to send " + 
request + " to " + channel.remoteAddress(),
+            future.cause());
         f.completeExceptionally(e);
+        replyMap.fail(requestEntry);
+        LOG.error("Channel write failed", e);
+      } else {
+        LOG.debug("{}: write after {}", this, request);
+
+        final TimeDuration timeout = isClose ? closeTimeout : requestTimeout;
+        // if reply success cancel this future
+        final ScheduledFuture<?> timeoutFuture = 
channel.eventLoop().schedule(() -> {
+          if (!f.isDone()) {
+            f.completeExceptionally(new TimeoutIOException(
+                "Timeout " + timeout + ": Failed to send " + request + " 
channel: " + channel));
+            replyMap.fail(requestEntry);
+          }
+        }, timeout.toLong(timeout.getUnit()), timeout.getUnit());
+        replyEntry.setTimeoutFuture(timeoutFuture);
       }
     });
     return f;

Reply via email to