This is an automated email from the ASF dual-hosted git repository.

ptupitsyn pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/ignite-3.git


The following commit(s) were added to refs/heads/main by this push:
     new 2eb4d5ffdac IGNITE-28376 Fix potential invalid responses in 
ClientInboundMessageHandler (#7894)
2eb4d5ffdac is described below

commit 2eb4d5ffdac4a9815a8f1d69c863e392b147c6da
Author: Pavel Tupitsyn <[email protected]>
AuthorDate: Mon Mar 30 13:33:07 2026 +0200

    IGNITE-28376 Fix potential invalid responses in ClientInboundMessageHandler 
(#7894)
---
 .../handler/ClientInboundMessageHandler.java       | 108 +++++++++++++++------
 .../ignite/client/handler/ResponseWriteGuard.java  |  45 +++++++++
 .../client/handler/ResponseWriteGuardTest.java     |  52 ++++++++++
 3 files changed, 174 insertions(+), 31 deletions(-)

diff --git 
a/modules/client-handler/src/main/java/org/apache/ignite/client/handler/ClientInboundMessageHandler.java
 
b/modules/client-handler/src/main/java/org/apache/ignite/client/handler/ClientInboundMessageHandler.java
index 36348e5c90f..895872994ea 100644
--- 
a/modules/client-handler/src/main/java/org/apache/ignite/client/handler/ClientInboundMessageHandler.java
+++ 
b/modules/client-handler/src/main/java/org/apache/ignite/client/handler/ClientInboundMessageHandler.java
@@ -474,6 +474,8 @@ public class ClientInboundMessageHandler
     }
 
     private void handshake(ChannelHandlerContext ctx, ClientMessageUnpacker 
unpacker) {
+        var guard = new ResponseWriteGuard();
+
         try (unpacker) {
             var clientVer = ProtocolVersion.unpack(unpacker);
 
@@ -497,13 +499,13 @@ public class ClientInboundMessageHandler
 
                     LOG.debug(msg);
 
-                    handshakeError(ctx, new IgniteException(PROTOCOL_ERR, 
msg));
+                    handshakeError(ctx, new IgniteException(PROTOCOL_ERR, 
msg), guard);
                 } else {
                     LOG.debug("Compute executor connected [connectionId=" + 
connectionId
                             + ", remoteAddress=" + 
ctx.channel().remoteAddress() + ", executorId=" + computeExecutorId + "]");
 
                     // Bypass authentication for compute executor connections.
-                    handshakeSuccess(ctx, UserDetails.UNKNOWN, clientFeatures, 
clientVer, clientCode);
+                    handshakeSuccess(ctx, UserDetails.UNKNOWN, clientFeatures, 
clientVer, clientCode, guard);
 
                     // Ready to handle compute requests now.
                     computeConnFut.complete(new ComputeConnection());
@@ -523,13 +525,13 @@ public class ClientInboundMessageHandler
                                         connectionId, 
ctx.channel().remoteAddress(), authReq.getIdentity(), err.getMessage());
                             }
 
-                            handshakeError(ctx, err);
+                            handshakeError(ctx, err, guard);
                         } else {
-                            handshakeSuccess(ctx, user, clientFeatures, 
clientVer, clientCode);
+                            handshakeSuccess(ctx, user, clientFeatures, 
clientVer, clientCode, guard);
                         }
                     }, ctx.executor());
         } catch (Throwable t) {
-            handshakeError(ctx, t);
+            handshakeError(ctx, t, guard);
         }
     }
 
@@ -538,8 +540,9 @@ public class ClientInboundMessageHandler
             UserDetails user,
             BitSet clientFeatures,
             ProtocolVersion clientVer,
-            int clientCode) {
-        // Disable direct mapping if not all required features are supported 
alltogether.
+            int clientCode,
+            ResponseWriteGuard guard) {
+        // Disable direct mapping if not all required features are supported 
altogether.
         boolean supportsDirectMapping = 
features.get(TX_DIRECT_MAPPING.featureId()) && 
clientFeatures.get(TX_DIRECT_MAPPING.featureId())
                 && features.get(TX_DELAYED_ACKS.featureId()) && 
clientFeatures.get(TX_DELAYED_ACKS.featureId())
                 && features.get(TX_PIGGYBACK.featureId()) && 
clientFeatures.get(TX_PIGGYBACK.featureId())
@@ -565,10 +568,10 @@ public class ClientInboundMessageHandler
 
         logConnectionEstablished(ctx);
 
-        sendHandshakeResponse(ctx, actualFeatures);
+        sendHandshakeResponse(ctx, actualFeatures, guard);
     }
 
-    private void handshakeError(ChannelHandlerContext ctx, Throwable t) {
+    private void handshakeError(ChannelHandlerContext ctx, Throwable t, 
ResponseWriteGuard guard) {
         // Authentication failures are already logged by the caller with more 
details (e.g. username).
         if (!isAuthenticationException(t)) {
             LOG.warn("Handshake failed [connectionId={}, remoteAddress={}]", 
t, connectionId, ctx.channel().remoteAddress());
@@ -581,7 +584,7 @@ public class ClientInboundMessageHandler
 
             writeErrorCore(t, errPacker);
 
-            writeAndFlushWithMagic(errPacker, ctx); // Releases packer.
+            writeAndFlushWithMagic(errPacker, ctx, guard); // Releases packer.
         } catch (Throwable t2) {
             LOG.warn("Handshake failed [connectionId=" + connectionId + ", 
remoteAddress=" + ctx.channel().remoteAddress() + "]: "
                     + t2.getMessage(), t2);
@@ -593,14 +596,14 @@ public class ClientInboundMessageHandler
         metrics.sessionsRejectedIncrement();
     }
 
-    private void sendHandshakeResponse(ChannelHandlerContext ctx, BitSet 
mutuallySupportedFeatures) {
+    private void sendHandshakeResponse(ChannelHandlerContext ctx, BitSet 
mutuallySupportedFeatures, ResponseWriteGuard guard) {
         state = STATE_HANDSHAKE_RESPONSE_SENT;
 
         ClientMessagePacker packer = getPacker(ctx.alloc());
 
         try {
             writeHandshakeResponse(mutuallySupportedFeatures, packer);
-            writeAndFlushWithMagic(packer, ctx); // Releases packer.
+            writeAndFlushWithMagic(packer, ctx, guard); // Releases packer.
         } catch (Throwable t) {
             packer.close();
             throw t;
@@ -667,24 +670,33 @@ public class ClientInboundMessageHandler
         throw new UnsupportedAuthenticationTypeException("Unsupported 
authentication type: " + authnType);
     }
 
-    private void writeAndFlush(ClientMessagePacker packer, 
ChannelHandlerContext ctx) {
+    private boolean writeAndFlush(ClientMessagePacker packer, 
ChannelHandlerContext ctx, ResponseWriteGuard guard) {
         var buf = packer.getBuffer();
         int bytes = buf.readableBytes();
 
         try {
-            // writeAndFlush releases pooled buffer.
-            ctx.writeAndFlush(buf);
+            // write releases pooled buffer.
+            if (!guard.write(ctx, buf)) {
+                // Response for this request has already been sent.
+                // Example: exception after response write, catch block in 
processOperation tries to write an error
+                //          => duplicate response. Guard prevents this.
+                packer.close();
+                return false;
+            }
+
+            ctx.flush();
         } catch (Throwable t) {
             packer.close();
             throw t;
         }
 
         metrics.bytesSentAdd(bytes);
+        return true;
     }
 
-    private void writeAndFlushWithMagic(ClientMessagePacker packer, 
ChannelHandlerContext ctx) {
+    private void writeAndFlushWithMagic(ClientMessagePacker packer, 
ChannelHandlerContext ctx, ResponseWriteGuard guard) {
         ctx.write(Unpooled.wrappedBuffer(ClientMessageCommon.MAGIC_BYTES));
-        writeAndFlush(packer, ctx);
+        writeAndFlush(packer, ctx, guard);
         metrics.bytesSentAdd(ClientMessageCommon.MAGIC_BYTES.length);
     }
 
@@ -704,7 +716,13 @@ public class ClientInboundMessageHandler
         packer.packLong(Math.max(clockService.currentLong(), timestamp));
     }
 
-    private void writeError(long requestId, int opCode, Throwable err, 
ChannelHandlerContext ctx, boolean isNotification) {
+    private void writeError(
+            long requestId,
+            int opCode,
+            Throwable err,
+            ChannelHandlerContext ctx,
+            boolean isNotification,
+            ResponseWriteGuard guard) {
         if (LOG.isDebugEnabled() && shouldLogError(err)) {
             if (isNotification) {
                 LOG.debug("Error processing client notification 
[connectionId={}, id={}, remoteAddress={}]",
@@ -723,7 +741,10 @@ public class ClientInboundMessageHandler
             writeResponseHeader(packer, requestId, ctx, isNotification, true, 
NULL_HYBRID_TIMESTAMP);
             writeErrorCore(err, packer);
 
-            writeAndFlush(packer, ctx);
+            if (!writeAndFlush(packer, ctx, guard)) {
+                LOG.warn("Failed to write error [connectionId={}, id={}, 
op={}, remoteAddress={}]: response already sent",
+                        connectionId, requestId, opCode, 
ctx.channel().remoteAddress());
+            }
         } catch (Throwable t) {
             packer.close();
             exceptionCaught(ctx, t);
@@ -840,6 +861,7 @@ public class ClientInboundMessageHandler
         long requestId = -1;
         int opCode = -1;
 
+        var guard = new ResponseWriteGuard();
         metrics.requestsActiveIncrement();
 
         try {
@@ -853,6 +875,7 @@ public class ClientInboundMessageHandler
 
             if (opCode == ClientOp.SERVER_OP_RESPONSE) {
                 processServerOpResponse(requestId, in);
+                metrics.requestsActiveDecrement();
                 return;
             }
 
@@ -862,24 +885,40 @@ public class ClientInboundMessageHandler
 
                 partitionOperationsExecutor.execute(() -> {
                     try {
-                        processOperationInternal(ctx, in, requestId0, opCode0);
+                        processOperationInternal(ctx, in, requestId0, opCode0, 
guard);
                     } catch (Throwable t) {
                         in.close();
 
-                        writeError(requestId0, opCode0, t, ctx, false);
+                        writeError(requestId0, opCode0, t, ctx, false, guard);
 
                         metrics.requestsFailedIncrement();
+                        metrics.requestsActiveDecrement();
                     }
                 });
             } else {
-                processOperationInternal(ctx, in, requestId, opCode);
+                processOperationInternal(ctx, in, requestId, opCode, guard);
             }
         } catch (Throwable t) {
             in.close();
 
-            writeError(requestId, opCode, t, ctx, false);
+            // If we failed to read the request ID, we cannot send a valid 
error response.
+            // Close the connection instead.
+            if (requestId == -1) {
+                LOG.warn("Failed to read client request, closing connection 
[connectionId={}, remoteAddress={}]",
+                        t, connectionId, ctx.channel().remoteAddress());
+
+                metrics.requestsFailedIncrement();
+                metrics.requestsActiveDecrement();
+
+                ctx.close();
+
+                return;
+            }
+
+            writeError(requestId, opCode, t, ctx, false, guard);
 
             metrics.requestsFailedIncrement();
+            metrics.requestsActiveDecrement();
         }
     }
 
@@ -1159,7 +1198,8 @@ public class ClientInboundMessageHandler
             ChannelHandlerContext ctx,
             ClientMessageUnpacker in,
             long requestId,
-            int opCode
+            int opCode,
+            ResponseWriteGuard guard
     ) {
         CompletableFuture<ResponseWriter> fut;
         HybridTimestampTracker tsTracker = 
HybridTimestampTracker.atomicTracker(null);
@@ -1176,7 +1216,7 @@ public class ClientInboundMessageHandler
             metrics.requestsActiveDecrement();
 
             if (err != null) {
-                writeError(requestId, opCode, (Throwable) err, ctx, false);
+                writeError(requestId, opCode, (Throwable) err, ctx, false, 
guard);
                 metrics.requestsFailedIncrement();
                 return;
             }
@@ -1194,7 +1234,7 @@ public class ClientInboundMessageHandler
 
                 out.setLong(observableTsIdx, tsTracker.getLong());
 
-                writeAndFlush(out, ctx);
+                writeAndFlush(out, ctx, guard);
 
                 metrics.requestsProcessedIncrement();
 
@@ -1204,7 +1244,7 @@ public class ClientInboundMessageHandler
                 }
             } catch (Throwable e) {
                 out.close();
-                writeError(requestId, opCode, e, ctx, false);
+                writeError(requestId, opCode, e, ctx, false, guard);
                 metrics.requestsFailedIncrement();
             }
         });
@@ -1278,9 +1318,15 @@ public class ClientInboundMessageHandler
         return null;
     }
 
-    private void sendNotification(long requestId, @Nullable 
Consumer<ClientMessagePacker> writer, @Nullable Throwable err, long timestamp) {
+    private void sendNotification(
+            long requestId,
+            @Nullable Consumer<ClientMessagePacker> writer,
+            @Nullable Throwable err,
+            long timestamp) {
+        var guard = new ResponseWriteGuard();
+
         if (err != null) {
-            writeError(requestId, -1, err, channelHandlerContext, true);
+            writeError(requestId, -1, err, channelHandlerContext, true, guard);
             return;
         }
 
@@ -1293,7 +1339,7 @@ public class ClientInboundMessageHandler
                 writer.accept(packer);
             }
 
-            writeAndFlush(packer, channelHandlerContext);
+            writeAndFlush(packer, channelHandlerContext, guard);
         } catch (Throwable t) {
             packer.close();
             exceptionCaught(channelHandlerContext, t);
@@ -1405,7 +1451,7 @@ public class ClientInboundMessageHandler
             var fut = new CompletableFuture<ClientMessageUnpacker>();
             serverToClientRequests.put(requestId, fut);
 
-            writeAndFlush(packer, channelHandlerContext);
+            writeAndFlush(packer, channelHandlerContext, new 
ResponseWriteGuard());
 
             return fut;
         } catch (Throwable t) {
diff --git 
a/modules/client-handler/src/main/java/org/apache/ignite/client/handler/ResponseWriteGuard.java
 
b/modules/client-handler/src/main/java/org/apache/ignite/client/handler/ResponseWriteGuard.java
new file mode 100644
index 00000000000..e0b391cca74
--- /dev/null
+++ 
b/modules/client-handler/src/main/java/org/apache/ignite/client/handler/ResponseWriteGuard.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.client.handler;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelOutboundInvoker;
+import java.util.concurrent.locks.ReentrantLock;
+
+class ResponseWriteGuard {
+    private final ReentrantLock lock = new ReentrantLock();
+
+    private boolean responseWritten = false;
+
+    boolean write(ChannelOutboundInvoker ctx, ByteBuf buf) {
+        // No double-check, contention or false case is extremely rare.
+        lock.lock();
+        try {
+            if (responseWritten) {
+                return false;
+            }
+
+            ctx.write(buf);
+            responseWritten = true;
+
+            return true;
+        } finally {
+            lock.unlock();
+        }
+    }
+}
diff --git 
a/modules/client-handler/src/test/java/org/apache/ignite/client/handler/ResponseWriteGuardTest.java
 
b/modules/client-handler/src/test/java/org/apache/ignite/client/handler/ResponseWriteGuardTest.java
new file mode 100644
index 00000000000..4a16f099315
--- /dev/null
+++ 
b/modules/client-handler/src/test/java/org/apache/ignite/client/handler/ResponseWriteGuardTest.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.client.handler;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelOutboundInvoker;
+import org.apache.ignite.internal.testframework.BaseIgniteAbstractTest;
+import org.junit.jupiter.api.Test;
+
+/**
+ * Tests for {@link ResponseWriteGuard}.
+ */
+class ResponseWriteGuardTest extends BaseIgniteAbstractTest {
+    @Test
+    void testWriteHappensOnce() {
+        var guard = new ResponseWriteGuard();
+        var ctx = mock(ChannelOutboundInvoker.class);
+        var buf1 = mock(ByteBuf.class);
+        var buf2 = mock(ByteBuf.class);
+
+        boolean first = guard.write(ctx, buf1);
+        boolean second = guard.write(ctx, buf2);
+
+        assertTrue(first);
+        assertFalse(second);
+
+        // Second write should not reach ctx
+        verify(ctx, times(1)).write(buf1);
+        verify(ctx, times(0)).write(buf2);
+    }
+}

Reply via email to