This is an automated email from the ASF dual-hosted git repository.

trohrmann pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 9d742bceef839bb3e9087946f434979d2d3f847d
Author: Till Rohrmann <[email protected]>
AuthorDate: Fri Aug 27 09:27:19 2021 +0200

    [FLINK-9925] Rework Client to encapsulate concurrency of single connection
    
    This commit replaces the Client.pendingConnections and 
.establishedConnections
    with a single data structure connections of type ServerConnection. The 
logic of
    a pending and established conneciton is moved into this class. This allows 
to get
    rid of concurrent access to the Client data structures.
    
    This closes #16915.
---
 .../flink/queryablestate/network/Client.java       | 477 ++------------------
 .../queryablestate/network/ServerConnection.java   | 490 +++++++++++++++++++++
 .../flink/queryablestate/network/ClientTest.java   |  16 +-
 3 files changed, 528 insertions(+), 455 deletions(-)

diff --git 
a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/Client.java
 
b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/Client.java
index 148952f..0781f39 100644
--- 
a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/Client.java
+++ 
b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/Client.java
@@ -28,10 +28,7 @@ import org.apache.flink.util.concurrent.FutureUtils;
 
 import 
org.apache.flink.shaded.guava30.com.google.common.util.concurrent.ThreadFactoryBuilder;
 import org.apache.flink.shaded.netty4.io.netty.bootstrap.Bootstrap;
-import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
 import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator;
-import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
-import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOption;
@@ -46,8 +43,6 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.net.InetSocketAddress;
-import java.nio.channels.ClosedChannelException;
-import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
@@ -55,7 +50,6 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 
 /**
@@ -82,12 +76,7 @@ public class Client<REQ extends MessageBody, RESP extends 
MessageBody> {
     /** Statistics tracker. */
     private final KvStateRequestStats stats;
 
-    /** Established connections. */
-    private final Map<InetSocketAddress, EstablishedConnection> 
establishedConnections =
-            new ConcurrentHashMap<>();
-
-    /** Pending connections. */
-    private final Map<InetSocketAddress, PendingConnection> pendingConnections 
=
+    private final Map<InetSocketAddress, ServerConnection<REQ, RESP>> 
connections =
             new ConcurrentHashMap<>();
 
     /** Atomic shut down future. */
@@ -132,8 +121,7 @@ public class Client<REQ extends MessageBody, RESP extends 
MessageBody> {
                         .handler(
                                 new ChannelInitializer<SocketChannel>() {
                                     @Override
-                                    protected void initChannel(SocketChannel 
channel)
-                                            throws Exception {
+                                    protected void initChannel(SocketChannel 
channel) {
                                         channel.pipeline()
                                                 .addLast(
                                                         new 
LengthFieldBasedFrameDecoder(
@@ -154,31 +142,30 @@ public class Client<REQ extends MessageBody, RESP extends 
MessageBody> {
                     new IllegalStateException(clientName + " is already shut 
down."));
         }
 
-        EstablishedConnection connection = 
establishedConnections.get(serverAddress);
-        if (connection != null) {
-            return connection.sendRequest(request);
-        } else {
-            PendingConnection pendingConnection = 
pendingConnections.get(serverAddress);
-            if (pendingConnection != null) {
-                // There was a race, use the existing pending connection.
-                return pendingConnection.sendRequest(request);
-            } else {
-                // We try to connect to the server.
-                PendingConnection pending = new 
PendingConnection(serverAddress, messageSerializer);
-                PendingConnection previous = 
pendingConnections.putIfAbsent(serverAddress, pending);
-
-                if (previous == null) {
-                    // OK, we are responsible to connect.
-                    bootstrap
-                            .connect(serverAddress.getAddress(), 
serverAddress.getPort())
-                            .addListener(pending);
-                    return pending.sendRequest(request);
-                } else {
-                    // There was a race, use the existing pending connection.
-                    return previous.sendRequest(request);
-                }
-            }
-        }
+        final ServerConnection<REQ, RESP> connection =
+                connections.computeIfAbsent(
+                        serverAddress,
+                        ignored -> {
+                            final ServerConnection<REQ, RESP> newConnection =
+                                    ServerConnection.createPendingConnection(
+                                            clientName, messageSerializer, 
stats);
+                            bootstrap
+                                    .connect(serverAddress.getAddress(), 
serverAddress.getPort())
+                                    .addListener(
+                                            (ChannelFutureListener)
+                                                    
newConnection::establishConnection);
+
+                            newConnection
+                                    .getCloseFuture()
+                                    .handle(
+                                            (ignoredA, ignoredB) ->
+                                                    connections.remove(
+                                                            serverAddress, 
newConnection));
+
+                            return newConnection;
+                        });
+
+        return connection.sendRequest(request);
     }
 
     /**
@@ -194,16 +181,9 @@ public class Client<REQ extends MessageBody, RESP extends 
MessageBody> {
 
             final List<CompletableFuture<Void>> connectionFutures = new 
ArrayList<>();
 
-            for (Map.Entry<InetSocketAddress, EstablishedConnection> conn :
-                    establishedConnections.entrySet()) {
-                if (establishedConnections.remove(conn.getKey(), 
conn.getValue())) {
-                    connectionFutures.add(conn.getValue().close());
-                }
-            }
-
-            for (Map.Entry<InetSocketAddress, PendingConnection> conn :
-                    pendingConnections.entrySet()) {
-                if (pendingConnections.remove(conn.getKey()) != null) {
+            for (Map.Entry<InetSocketAddress, ServerConnection<REQ, RESP>> 
conn :
+                    connections.entrySet()) {
+                if (connections.remove(conn.getKey(), conn.getValue())) {
                     connectionFutures.add(conn.getValue().close());
                 }
             }
@@ -247,405 +227,6 @@ public class Client<REQ extends MessageBody, RESP extends 
MessageBody> {
         return clientShutdownFuture.get();
     }
 
-    /** A pending connection that is in the process of connecting. */
-    private class PendingConnection implements ChannelFutureListener {
-
-        /** Lock to guard the connect call, channel hand in, etc. */
-        private final Object connectLock = new Object();
-
-        /** Address of the server we are connecting to. */
-        private final InetSocketAddress serverAddress;
-
-        private final MessageSerializer<REQ, RESP> serializer;
-
-        /** Queue of requests while connecting. */
-        private final ArrayDeque<PendingRequest> queuedRequests = new 
ArrayDeque<>();
-
-        /** The established connection after the connect succeeds. */
-        private EstablishedConnection established;
-
-        /** Atomic shut down future. */
-        private final AtomicReference<CompletableFuture<Void>> 
connectionShutdownFuture =
-                new AtomicReference<>(null);
-
-        /** Failure cause if something goes wrong. */
-        private Throwable failureCause;
-
-        /**
-         * Creates a pending connection to the given server.
-         *
-         * @param serverAddress Address of the server to connect to.
-         */
-        private PendingConnection(
-                final InetSocketAddress serverAddress,
-                final MessageSerializer<REQ, RESP> serializer) {
-            this.serverAddress = serverAddress;
-            this.serializer = serializer;
-        }
-
-        @Override
-        public void operationComplete(ChannelFuture future) throws Exception {
-            if (future.isSuccess()) {
-                handInChannel(future.channel());
-            } else {
-                close(future.cause());
-            }
-        }
-
-        /**
-         * Returns a future holding the serialized request result.
-         *
-         * <p>If the channel has been established, forward the call to the 
established channel,
-         * otherwise queue it for when the channel is handed in.
-         *
-         * @param request the request to be sent.
-         * @return Future holding the serialized result
-         */
-        CompletableFuture<RESP> sendRequest(REQ request) {
-            synchronized (connectLock) {
-                if (failureCause != null) {
-                    return FutureUtils.completedExceptionally(failureCause);
-                } else if (connectionShutdownFuture.get() != null) {
-                    return FutureUtils.completedExceptionally(new 
ClosedChannelException());
-                } else {
-                    if (established != null) {
-                        return established.sendRequest(request);
-                    } else {
-                        // Queue this and handle when connected
-                        final PendingRequest pending = new 
PendingRequest(request);
-                        queuedRequests.add(pending);
-                        return pending;
-                    }
-                }
-            }
-        }
-
-        /**
-         * Hands in a channel after a successful connection.
-         *
-         * @param channel Channel to hand in
-         */
-        private void handInChannel(Channel channel) {
-            synchronized (connectLock) {
-                if (connectionShutdownFuture.get() != null || failureCause != 
null) {
-                    // Close the channel and we are done. Any queued requests
-                    // are removed on the close/failure call and after that no
-                    // new ones can be enqueued.
-                    channel.close();
-                } else {
-                    established = new EstablishedConnection(serverAddress, 
serializer, channel);
-
-                    while (!queuedRequests.isEmpty()) {
-                        final PendingRequest pending = queuedRequests.poll();
-
-                        established
-                                .sendRequest(pending.request)
-                                .whenComplete(
-                                        (response, throwable) -> {
-                                            if (throwable != null) {
-                                                
pending.completeExceptionally(throwable);
-                                            } else {
-                                                pending.complete(response);
-                                            }
-                                        });
-                    }
-
-                    // Publish the channel for the general public
-                    establishedConnections.put(serverAddress, established);
-                    pendingConnections.remove(serverAddress);
-
-                    // Check shut down for possible race with shut down. We
-                    // don't want any lingering connections after shut down,
-                    // which can happen if we don't check this here.
-                    if (clientShutdownFuture.get() != null) {
-                        if (establishedConnections.remove(serverAddress, 
established)) {
-                            established.close();
-                        }
-                    }
-                }
-            }
-        }
-
-        /** Close the connecting channel with a ClosedChannelException. */
-        private CompletableFuture<Void> close() {
-            return close(new ClosedChannelException());
-        }
-
-        /**
-         * Close the connecting channel with an Exception (can be {@code 
null}) or forward to the
-         * established channel.
-         */
-        private CompletableFuture<Void> close(Throwable cause) {
-            CompletableFuture<Void> future = new CompletableFuture<>();
-            if (connectionShutdownFuture.compareAndSet(null, future)) {
-                synchronized (connectLock) {
-                    if (failureCause == null) {
-                        failureCause = cause;
-                    }
-
-                    if (established != null) {
-                        established
-                                .close()
-                                .whenComplete(
-                                        (result, throwable) -> {
-                                            if (throwable != null) {
-                                                
future.completeExceptionally(throwable);
-                                            } else {
-                                                future.complete(null);
-                                            }
-                                        });
-                    } else {
-                        PendingRequest pending;
-                        while ((pending = queuedRequests.poll()) != null) {
-                            pending.completeExceptionally(cause);
-                        }
-                        future.complete(null);
-                    }
-                }
-            }
-            return connectionShutdownFuture.get();
-        }
-
-        @Override
-        public String toString() {
-            synchronized (connectLock) {
-                return "PendingConnection{"
-                        + "serverAddress="
-                        + serverAddress
-                        + ", queuedRequests="
-                        + queuedRequests.size()
-                        + ", established="
-                        + (established != null)
-                        + ", closed="
-                        + (connectionShutdownFuture.get() != null)
-                        + '}';
-            }
-        }
-
-        /** A pending request queued while the channel is connecting. */
-        private final class PendingRequest extends CompletableFuture<RESP> {
-
-            private final REQ request;
-
-            private PendingRequest(REQ request) {
-                this.request = request;
-            }
-        }
-    }
-
-    /**
-     * An established connection that wraps the actual channel instance and is 
registered at the
-     * {@link ClientHandler} for callbacks.
-     */
-    private class EstablishedConnection implements ClientHandlerCallback<RESP> 
{
-
-        /** Address of the server we are connected to. */
-        private final InetSocketAddress serverAddress;
-
-        /** The actual TCP channel. */
-        private final Channel channel;
-
-        /** Pending requests keyed by request ID. */
-        private final ConcurrentHashMap<Long, TimestampedCompletableFuture> 
pendingRequests =
-                new ConcurrentHashMap<>();
-
-        /** Current request number used to assign unique request IDs. */
-        private final AtomicLong requestCount = new AtomicLong();
-
-        /** Atomic shut down future. */
-        private final AtomicReference<CompletableFuture<Void>> 
connectionShutdownFuture =
-                new AtomicReference<>(null);
-
-        /**
-         * Creates an established connection with the given channel.
-         *
-         * @param serverAddress Address of the server connected to
-         * @param channel The actual TCP channel
-         */
-        EstablishedConnection(
-                final InetSocketAddress serverAddress,
-                final MessageSerializer<REQ, RESP> serializer,
-                final Channel channel) {
-
-            this.serverAddress = Preconditions.checkNotNull(serverAddress);
-            this.channel = Preconditions.checkNotNull(channel);
-
-            // Add the client handler with the callback
-            channel.pipeline()
-                    .addLast(
-                            getClientName() + " Handler",
-                            new ClientHandler<>(clientName, serializer, this));
-
-            stats.reportActiveConnection();
-        }
-
-        /** Close the channel with a ClosedChannelException. */
-        CompletableFuture<Void> close() {
-            return close(new ClosedChannelException());
-        }
-
-        /**
-         * Close the channel with a cause.
-         *
-         * @param cause The cause to close the channel with.
-         * @return Channel close future
-         */
-        private CompletableFuture<Void> close(final Throwable cause) {
-            final CompletableFuture<Void> shutdownFuture = new 
CompletableFuture<>();
-
-            if (connectionShutdownFuture.compareAndSet(null, shutdownFuture)) {
-                channel.close()
-                        .addListener(
-                                finished -> {
-                                    stats.reportInactiveConnection();
-                                    for (long requestId : 
pendingRequests.keySet()) {
-                                        TimestampedCompletableFuture pending =
-                                                
pendingRequests.remove(requestId);
-                                        if (pending != null
-                                                && 
pending.completeExceptionally(cause)) {
-                                            stats.reportFailedRequest();
-                                        }
-                                    }
-
-                                    // when finishing, if netty successfully 
closes the channel,
-                                    // then the provided exception is used
-                                    // as the reason for the closing. If there 
was something wrong
-                                    // at the netty side, then that exception
-                                    // is prioritized over the provided one.
-                                    if (finished.isSuccess()) {
-                                        
shutdownFuture.completeExceptionally(cause);
-                                    } else {
-                                        LOG.warn(
-                                                "Something went wrong when 
trying to close connection due to : ",
-                                                cause);
-                                        
shutdownFuture.completeExceptionally(finished.cause());
-                                    }
-                                });
-            }
-
-            // in case we had a race condition, return the winner of the race.
-            return connectionShutdownFuture.get();
-        }
-
-        /**
-         * Returns a future holding the serialized request result.
-         *
-         * @param request the request to be sent.
-         * @return Future holding the serialized result
-         */
-        CompletableFuture<RESP> sendRequest(REQ request) {
-            TimestampedCompletableFuture requestPromiseTs =
-                    new TimestampedCompletableFuture(System.nanoTime());
-            try {
-                final long requestId = requestCount.getAndIncrement();
-                pendingRequests.put(requestId, requestPromiseTs);
-
-                stats.reportRequest();
-
-                ByteBuf buf =
-                        MessageSerializer.serializeRequest(channel.alloc(), 
requestId, request);
-
-                channel.writeAndFlush(buf)
-                        .addListener(
-                                (ChannelFutureListener)
-                                        future -> {
-                                            if (!future.isSuccess()) {
-                                                // Fail promise if not failed 
to write
-                                                TimestampedCompletableFuture 
pending =
-                                                        
pendingRequests.remove(requestId);
-                                                if (pending != null
-                                                        && 
pending.completeExceptionally(
-                                                                
future.cause())) {
-                                                    
stats.reportFailedRequest();
-                                                }
-                                            }
-                                        });
-
-                // Check for possible race. We don't want any lingering
-                // promises after a failure, which can happen if we don't check
-                // this here. Note that close is treated as a failure as well.
-                CompletableFuture<Void> clShutdownFuture = 
clientShutdownFuture.get();
-                if (clShutdownFuture != null) {
-                    TimestampedCompletableFuture pending = 
pendingRequests.remove(requestId);
-                    if (pending != null) {
-                        clShutdownFuture.whenComplete(
-                                (ignored, throwable) -> {
-                                    if (throwable != null
-                                            && 
pending.completeExceptionally(throwable)) {
-                                        stats.reportFailedRequest();
-                                    } else {
-                                        // the shutdown future is always 
completed exceptionally so
-                                        // we should not arrive here.
-                                        // but in any case, we complete the 
pending connection
-                                        // request exceptionally.
-                                        pending.completeExceptionally(new 
ClosedChannelException());
-                                    }
-                                });
-                    }
-                }
-            } catch (Throwable t) {
-                requestPromiseTs.completeExceptionally(t);
-            }
-
-            return requestPromiseTs;
-        }
-
-        @Override
-        public void onRequestResult(long requestId, RESP response) {
-            TimestampedCompletableFuture pending = 
pendingRequests.remove(requestId);
-            if (pending != null && !pending.isDone()) {
-                long durationMillis = (System.nanoTime() - 
pending.getTimestamp()) / 1_000_000L;
-                stats.reportSuccessfulRequest(durationMillis);
-                pending.complete(response);
-            }
-        }
-
-        @Override
-        public void onRequestFailure(long requestId, Throwable cause) {
-            TimestampedCompletableFuture pending = 
pendingRequests.remove(requestId);
-            if (pending != null && !pending.isDone()) {
-                stats.reportFailedRequest();
-                pending.completeExceptionally(cause);
-            }
-        }
-
-        @Override
-        public void onFailure(Throwable cause) {
-            close(cause)
-                    .handle(
-                            (cancelled, ignored) ->
-                                    
establishedConnections.remove(serverAddress, this));
-        }
-
-        @Override
-        public String toString() {
-            return "EstablishedConnection{"
-                    + "serverAddress="
-                    + serverAddress
-                    + ", channel="
-                    + channel
-                    + ", pendingRequests="
-                    + pendingRequests.size()
-                    + ", requestCount="
-                    + requestCount
-                    + '}';
-        }
-
-        /** Pair of promise and a timestamp. */
-        private class TimestampedCompletableFuture extends 
CompletableFuture<RESP> {
-
-            private final long timestampInNanos;
-
-            TimestampedCompletableFuture(long timestampInNanos) {
-                this.timestampInNanos = timestampInNanos;
-            }
-
-            public long getTimestamp() {
-                return timestampInNanos;
-            }
-        }
-    }
-
     @VisibleForTesting
     public boolean isEventGroupShutdown() {
         return bootstrap == null || bootstrap.group().isTerminated();
diff --git 
a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/ServerConnection.java
 
b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/ServerConnection.java
new file mode 100644
index 0000000..2fdad01
--- /dev/null
+++ 
b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/ServerConnection.java
@@ -0,0 +1,490 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.queryablestate.network;
+
+import org.apache.flink.queryablestate.network.messages.MessageBody;
+import org.apache.flink.queryablestate.network.messages.MessageSerializer;
+import org.apache.flink.queryablestate.network.stats.KvStateRequestStats;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.concurrent.FutureUtils;
+
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture;
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+
+import java.nio.channels.ClosedChannelException;
+import java.util.ArrayDeque;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Connection class used by the {@link Client}.
+ *
+ * @param <REQ> Request type
+ * @param <RESP> Response type
+ */
+final class ServerConnection<REQ extends MessageBody, RESP extends 
MessageBody> {
+    private static final Logger LOG = 
LoggerFactory.getLogger(ServerConnection.class);
+
+    private final Object connectionLock = new Object();
+
+    @GuardedBy("connectionLock")
+    private InternalConnection<REQ, RESP> internalConnection;
+
+    @GuardedBy("connectionLock")
+    private boolean running = true;
+
+    private final CompletableFuture<Void> closeFuture = new 
CompletableFuture<>();
+
+    private ServerConnection(InternalConnection<REQ, RESP> internalConnection) 
{
+        this.internalConnection = internalConnection;
+        forwardCloseFuture();
+    }
+
+    @GuardedBy("connectionLock")
+    private void forwardCloseFuture() {
+        final InternalConnection<REQ, RESP> currentConnection = 
this.internalConnection;
+        currentConnection
+                .getCloseFuture()
+                .whenComplete(
+                        (unused, throwable) -> {
+                            synchronized (connectionLock) {
+                                if (internalConnection == currentConnection) {
+                                    if (throwable != null) {
+                                        
closeFuture.completeExceptionally(throwable);
+                                    } else {
+                                        closeFuture.complete(null);
+                                    }
+                                }
+                            }
+                        });
+    }
+
+    CompletableFuture<RESP> sendRequest(REQ request) {
+        synchronized (connectionLock) {
+            Preconditions.checkState(running, "Connection has already been 
closed.");
+            return internalConnection.sendRequest(request);
+        }
+    }
+
+    void establishConnection(ChannelFuture future) {
+        synchronized (connectionLock) {
+            Preconditions.checkState(running, "Connection has already been 
closed.");
+            this.internalConnection = 
internalConnection.establishConnection(future);
+            forwardCloseFuture();
+        }
+    }
+
+    CompletableFuture<Void> close() {
+        synchronized (connectionLock) {
+            if (running) {
+                running = false;
+                internalConnection.close();
+            }
+
+            return closeFuture;
+        }
+    }
+
+    CompletableFuture<Void> getCloseFuture() {
+        return closeFuture;
+    }
+
+    static <REQ extends MessageBody, RESP extends MessageBody>
+            ServerConnection<REQ, RESP> createPendingConnection(
+                    final String clientName,
+                    final MessageSerializer<REQ, RESP> serializer,
+                    final KvStateRequestStats stats) {
+        return new ServerConnection<>(new PendingConnection<>(clientName, 
serializer, stats));
+    }
+
+    interface InternalConnection<REQ, RESP> {
+        CompletableFuture<RESP> sendRequest(REQ request);
+
+        InternalConnection<REQ, RESP> establishConnection(ChannelFuture 
future);
+
+        boolean isEstablished();
+
+        CompletableFuture<Void> getCloseFuture();
+
+        CompletableFuture<Void> close();
+    }
+
+    /** A pending connection that is in the process of connecting. */
+    private static final class PendingConnection<REQ extends MessageBody, RESP 
extends MessageBody>
+            implements InternalConnection<REQ, RESP> {
+
+        private final String clientName;
+
+        private final MessageSerializer<REQ, RESP> serializer;
+
+        private final KvStateRequestStats stats;
+
+        private final CompletableFuture<Void> closeFuture = new 
CompletableFuture<>();
+
+        /** Queue of requests while connecting. */
+        private final ArrayDeque<PendingConnection.PendingRequest<REQ, RESP>> 
queuedRequests =
+                new ArrayDeque<>();
+
+        /** Failure cause if something goes wrong. */
+        @Nullable private Throwable failureCause = null;
+
+        private boolean running = true;
+
+        /** Creates a pending connection to the given server. */
+        private PendingConnection(
+                final String clientName,
+                final MessageSerializer<REQ, RESP> serializer,
+                final KvStateRequestStats stats) {
+            this.clientName = clientName;
+            this.serializer = serializer;
+            this.stats = stats;
+        }
+
+        /**
+         * Returns a future holding the serialized request result.
+         *
+         * <p>Queues the request for when the channel is handed in.
+         *
+         * @param request the request to be sent.
+         * @return Future holding the serialized result
+         */
+        @Override
+        public CompletableFuture<RESP> sendRequest(REQ request) {
+            if (failureCause != null) {
+                return FutureUtils.completedExceptionally(failureCause);
+            } else if (!running) {
+                return FutureUtils.completedExceptionally(new 
ClosedChannelException());
+            } else {
+                // Queue this and handle when connected
+                final PendingConnection.PendingRequest<REQ, RESP> pending =
+                        new PendingConnection.PendingRequest<>(request);
+                queuedRequests.add(pending);
+                return pending;
+            }
+        }
+
+        @Override
+        public InternalConnection<REQ, RESP> establishConnection(ChannelFuture 
future) {
+            if (future.isSuccess()) {
+                return createEstablishedConnection(future.channel());
+            } else {
+                close(future.cause());
+                return this;
+            }
+        }
+
+        @Override
+        public boolean isEstablished() {
+            return false;
+        }
+
+        @Override
+        public CompletableFuture<Void> getCloseFuture() {
+            return closeFuture;
+        }
+
+        /**
+         * Creates an established connection from the given channel.
+         *
+         * @param channel Channel to create an established connection from
+         */
+        private InternalConnection<REQ, RESP> 
createEstablishedConnection(Channel channel) {
+            if (failureCause != null || !running) {
+                // Close the channel and we are done. Any queued requests
+                // are removed on the close/failure call and after that no
+                // new ones can be enqueued.
+                channel.close();
+                return this;
+            } else {
+                final EstablishedConnection<REQ, RESP> establishedConnection =
+                        new EstablishedConnection<>(clientName, serializer, 
channel, stats);
+
+                while (!queuedRequests.isEmpty()) {
+                    final PendingConnection.PendingRequest<REQ, RESP> pending =
+                            queuedRequests.poll();
+
+                    FutureUtils.forward(
+                            
establishedConnection.sendRequest(pending.getRequest()), pending);
+                }
+
+                return establishedConnection;
+            }
+        }
+
+        /** Close the connecting channel with a ClosedChannelException. */
+        @Override
+        public CompletableFuture<Void> close() {
+            return close(new ClosedChannelException());
+        }
+
+        /**
+         * Close the connecting channel with an Exception (can be {@code 
null}) or forward to the
+         * established channel.
+         */
+        private CompletableFuture<Void> close(Throwable cause) {
+            if (running) {
+                running = false;
+                failureCause = cause;
+
+                for (PendingConnection.PendingRequest<REQ, RESP> 
pendingRequest : queuedRequests) {
+                    pendingRequest.completeExceptionally(cause);
+                }
+                queuedRequests.clear();
+
+                closeFuture.completeExceptionally(cause);
+            }
+
+            return closeFuture;
+        }
+
+        /** A pending request queued while the channel is connecting. */
+        private static final class PendingRequest<REQ extends MessageBody, 
RESP extends MessageBody>
+                extends CompletableFuture<RESP> {
+
+            private final REQ request;
+
+            private PendingRequest(REQ request) {
+                this.request = request;
+            }
+
+            public REQ getRequest() {
+                return request;
+            }
+        }
+    }
+
+    /**
+     * An established connection that wraps the actual channel instance and is 
registered at the
+     * {@link ClientHandler} for callbacks.
+     */
+    private static class EstablishedConnection<REQ extends MessageBody, RESP 
extends MessageBody>
+            implements ClientHandlerCallback<RESP>, InternalConnection<REQ, 
RESP> {
+
+        private final Object lock = new Object();
+
+        /** The actual TCP channel. */
+        private final Channel channel;
+
+        private final KvStateRequestStats stats;
+
+        /** Pending requests keyed by request ID. */
+        private final ConcurrentHashMap<
+                        Long, 
EstablishedConnection.TimestampedCompletableFuture<RESP>>
+                pendingRequests = new ConcurrentHashMap<>();
+
+        private final CompletableFuture<Void> closeFuture = new 
CompletableFuture<>();
+
+        /** Current request number used to assign unique request IDs. */
+        @GuardedBy("lock")
+        private long requestCount = 0;
+
+        @GuardedBy("lock")
+        private boolean running = true;
+
+        /**
+         * Creates an established connection with the given channel.
+         *
+         * @param channel The actual TCP channel
+         */
+        EstablishedConnection(
+                final String clientName,
+                final MessageSerializer<REQ, RESP> serializer,
+                final Channel channel,
+                final KvStateRequestStats stats) {
+
+            this.channel = Preconditions.checkNotNull(channel);
+
+            // Add the client handler with the callback
+            channel.pipeline()
+                    .addLast(
+                            clientName + " Handler",
+                            new ClientHandler<>(clientName, serializer, this));
+
+            this.stats = stats;
+            stats.reportActiveConnection();
+        }
+
+        /** Close the channel with a ClosedChannelException. */
+        @Override
+        public CompletableFuture<Void> close() {
+            return close(new ClosedChannelException());
+        }
+
+        /**
+         * Close the channel with a cause.
+         *
+         * @param cause The cause to close the channel with.
+         * @return Channel close future
+         */
+        private CompletableFuture<Void> close(final Throwable cause) {
+            synchronized (lock) {
+                if (running) {
+                    running = false;
+                    channel.close()
+                            .addListener(
+                                    finished -> {
+                                        stats.reportInactiveConnection();
+                                        for (long requestId : 
pendingRequests.keySet()) {
+                                            
EstablishedConnection.TimestampedCompletableFuture<RESP>
+                                                    pending = 
pendingRequests.remove(requestId);
+                                            if (pending != null
+                                                    && 
pending.completeExceptionally(cause)) {
+                                                stats.reportFailedRequest();
+                                            }
+                                        }
+
+                                        // when finishing, if netty 
successfully closes the channel,
+                                        // then the provided exception is used
+                                        // as the reason for the closing. If 
there was something
+                                        // wrong
+                                        // at the netty side, then that 
exception
+                                        // is prioritized over the provided 
one.
+                                        if (finished.isSuccess()) {
+                                            
closeFuture.completeExceptionally(cause);
+                                        } else {
+                                            LOG.warn(
+                                                    "Something went wrong when 
trying to close connection due to : ",
+                                                    cause);
+                                            
closeFuture.completeExceptionally(finished.cause());
+                                        }
+                                    });
+                }
+            }
+
+            return closeFuture;
+        }
+
+        /**
+         * Returns a future holding the serialized request result.
+         *
+         * @param request the request to be sent.
+         * @return Future holding the serialized result
+         */
+        @Override
+        public CompletableFuture<RESP> sendRequest(REQ request) {
+            synchronized (lock) {
+                if (running) {
+                    EstablishedConnection.TimestampedCompletableFuture<RESP> 
requestPromiseTs =
+                            new 
EstablishedConnection.TimestampedCompletableFuture<>(
+                                    System.nanoTime());
+                    try {
+                        final long requestId = requestCount++;
+                        pendingRequests.put(requestId, requestPromiseTs);
+
+                        stats.reportRequest();
+
+                        ByteBuf buf =
+                                MessageSerializer.serializeRequest(
+                                        channel.alloc(), requestId, request);
+
+                        channel.writeAndFlush(buf)
+                                .addListener(
+                                        (ChannelFutureListener)
+                                                future -> {
+                                                    if (!future.isSuccess()) {
+                                                        // Fail promise if not 
failed to write
+                                                        EstablishedConnection
+                                                                               
 .TimestampedCompletableFuture<
+                                                                        RESP>
+                                                                pending =
+                                                                        
pendingRequests.remove(
+                                                                               
 requestId);
+                                                        if (pending != null
+                                                                && 
pending.completeExceptionally(
+                                                                        
future.cause())) {
+                                                            
stats.reportFailedRequest();
+                                                        }
+                                                    }
+                                                });
+                    } catch (Throwable t) {
+                        requestPromiseTs.completeExceptionally(t);
+                    }
+
+                    return requestPromiseTs;
+                } else {
+                    return FutureUtils.completedExceptionally(new 
ClosedChannelException());
+                }
+            }
+        }
+
+        @Override
+        public InternalConnection<REQ, RESP> establishConnection(ChannelFuture 
future) {
+            throw new IllegalStateException("The connection is already 
established.");
+        }
+
+        @Override
+        public boolean isEstablished() {
+            return true;
+        }
+
+        @Override
+        public CompletableFuture<Void> getCloseFuture() {
+            return closeFuture;
+        }
+
+        @Override
+        public void onRequestResult(long requestId, RESP response) {
+            EstablishedConnection.TimestampedCompletableFuture<RESP> pending =
+                    pendingRequests.remove(requestId);
+            if (pending != null && !pending.isDone()) {
+                long durationMillis = (System.nanoTime() - 
pending.getTimestamp()) / 1_000_000L;
+                stats.reportSuccessfulRequest(durationMillis);
+                pending.complete(response);
+            }
+        }
+
+        @Override
+        public void onRequestFailure(long requestId, Throwable cause) {
+            EstablishedConnection.TimestampedCompletableFuture<RESP> pending =
+                    pendingRequests.remove(requestId);
+            if (pending != null && !pending.isDone()) {
+                stats.reportFailedRequest();
+                pending.completeExceptionally(cause);
+            }
+        }
+
+        @Override
+        public void onFailure(Throwable cause) {
+            close(cause);
+        }
+
+        /** Pair of promise and a timestamp. */
+        private static final class TimestampedCompletableFuture<RESP extends 
MessageBody>
+                extends CompletableFuture<RESP> {
+
+            private final long timestampInNanos;
+
+            TimestampedCompletableFuture(long timestampInNanos) {
+                this.timestampInNanos = timestampInNanos;
+            }
+
+            public long getTimestamp() {
+                return timestampInNanos;
+            }
+        }
+    }
+}
diff --git 
a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java
 
b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java
index 7516d76..13bb73b 100644
--- 
a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java
+++ 
b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.core.testutils.FlinkMatchers;
 import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
 import org.apache.flink.queryablestate.KvStateID;
 import org.apache.flink.queryablestate.client.VoidNamespace;
@@ -59,6 +60,7 @@ import 
org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel;
 import 
org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioServerSocketChannel;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder;
 
+import org.hamcrest.core.CombinableMatcher;
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
@@ -84,9 +86,11 @@ import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.hamcrest.CoreMatchers.either;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThat;
 import static org.junit.Assert.fail;
 
 /** Tests for {@link Client}. */
@@ -727,18 +731,16 @@ public class ClientTest extends TestLogger {
             }
             Assert.assertTrue(client.isEventGroupShutdown());
 
+            final CombinableMatcher<Throwable> exceptionMatcher =
+                    
either(FlinkMatchers.containsCause(ClosedChannelException.class))
+                            
.or(FlinkMatchers.containsCause(IllegalStateException.class));
+
             for (Future<Void> future : taskFutures) {
                 try {
                     future.get();
                     fail("Did not throw expected Exception after shut down");
                 } catch (ExecutionException t) {
-                    if (t.getCause().getCause() instanceof 
ClosedChannelException
-                            || t.getCause().getCause() instanceof 
IllegalStateException) {
-                        // Expected
-                    } else {
-                        t.printStackTrace();
-                        fail("Failed with unexpected Exception type: " + 
t.getClass().getName());
-                    }
+                    assertThat(t, exceptionMatcher);
                 }
             }
 

Reply via email to