This is an automated email from the ASF dual-hosted git repository.

schofielaj pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 800612e4a7a KAFKA-19015: Remove share session from cache on share 
consumer connection drop (#19329)
800612e4a7a is described below

commit 800612e4a7ae5511d20816de73ae6cea0596edfc
Author: Chirag Wadhwa <cwad...@confluent.io>
AuthorDate: Thu May 1 19:06:18 2025 +0530

    KAFKA-19015: Remove share session from cache on share consumer connection 
drop (#19329)
    
    Up till now, the share sessions in the broker were only attempted to
    evict when the share session cache was full and a new session was trying
    to get registered. With the changes in this PR, whenever a share
    consumer gets disconnected from the broker, the corresponding share
    session would be evicted from the cache.
    
    Note - `connectAndReceiveWithoutClosingSocket` has been introduced in
    `GroupCoordinatorBaseRequestTest`. This method creates a socket
    connection, sends the request, receives a response but does not close
    the connection. Instead, these sockets are stored in a ListBuffer
    `openSockets`, which are closed in tearDown method after each test is
    run. Also, all the `connectAndReceive` calls in
    `ShareFetchAcknowledgeRequestTest` have been replaced by
    `connectAndReceiveWithoutClosingSocket`, because these tests depends
    upon the persistence of the share sessions on the broker once
    registered. But, with the new code introduced, as soon as the socket
    connection is closed, a connection drop is assumed by the broker,
    leading to session eviction.
    
    Reviewers: Apoorv Mittal <apoorvmitta...@gmail.com>, Andrew Schofield 
<aschofi...@confluent.io>
---
 .../kafka/clients/consumer/ShareConsumerTest.java  |  52 ++++++
 .../kafka/server/share/SharePartitionManager.java  |  16 +-
 .../src/main/scala/kafka/server/BrokerServer.scala |  10 +-
 core/src/main/scala/kafka/server/KafkaApis.scala   |   2 +-
 .../server/share/SharePartitionManagerTest.java    |  93 +++++-----
 .../server/GroupCoordinatorBaseRequestTest.scala   |  29 +++
 .../scala/unit/kafka/server/KafkaApisTest.scala    |  62 ++++---
 .../server/ShareFetchAcknowledgeRequestTest.scala  | 200 +++++++++++++++------
 .../server/ShareGroupHeartbeatRequestTest.scala    |  74 ++++++++
 .../server/share/session/ShareSessionCache.java    |  35 +++-
 .../share/session/ShareSessionCacheTest.java       |  46 ++++-
 11 files changed, 467 insertions(+), 152 deletions(-)

diff --git 
a/clients/clients-integration-tests/src/test/java/org/apache/kafka/clients/consumer/ShareConsumerTest.java
 
b/clients/clients-integration-tests/src/test/java/org/apache/kafka/clients/consumer/ShareConsumerTest.java
index 4aa08032463..51ec647fd7c 100644
--- 
a/clients/clients-integration-tests/src/test/java/org/apache/kafka/clients/consumer/ShareConsumerTest.java
+++ 
b/clients/clients-integration-tests/src/test/java/org/apache/kafka/clients/consumer/ShareConsumerTest.java
@@ -36,6 +36,7 @@ import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.config.ConfigResource;
+import org.apache.kafka.common.errors.GroupMaxSizeReachedException;
 import org.apache.kafka.common.errors.InterruptException;
 import org.apache.kafka.common.errors.InvalidConfigurationException;
 import org.apache.kafka.common.errors.InvalidRecordStateException;
@@ -2057,6 +2058,57 @@ public class ShareConsumerTest {
         verifyShareGroupStateTopicRecordsProduced();
     }
 
+    @ClusterTest(
+        brokers = 1,
+        serverProperties = {
+            @ClusterConfigProperty(key = "auto.create.topics.enable", value = 
"false"),
+            @ClusterConfigProperty(key = 
"group.coordinator.rebalance.protocols", value = "classic,consumer,share"),
+            @ClusterConfigProperty(key = "group.share.enable", value = "true"),
+            @ClusterConfigProperty(key = 
"group.share.partition.max.record.locks", value = "10000"),
+            @ClusterConfigProperty(key = 
"group.share.record.lock.duration.ms", value = "15000"),
+            @ClusterConfigProperty(key = "offsets.topic.replication.factor", 
value = "1"),
+            @ClusterConfigProperty(key = 
"share.coordinator.state.topic.min.isr", value = "1"),
+            @ClusterConfigProperty(key = 
"share.coordinator.state.topic.num.partitions", value = "3"),
+            @ClusterConfigProperty(key = 
"share.coordinator.state.topic.replication.factor", value = "1"),
+            @ClusterConfigProperty(key = "transaction.state.log.min.isr", 
value = "1"),
+            @ClusterConfigProperty(key = 
"transaction.state.log.replication.factor", value = "1"),
+            @ClusterConfigProperty(key = "group.share.max.size", value = "3") 
// Setting max group size to 3
+        }
+    )
+    public void testShareGroupMaxSizeConfigExceeded() throws Exception {
+        // creating 3 consumers in the group1
+        ShareConsumer<byte[], byte[]> shareConsumer1 = 
createShareConsumer("group1");
+        ShareConsumer<byte[], byte[]> shareConsumer2 = 
createShareConsumer("group1");
+        ShareConsumer<byte[], byte[]> shareConsumer3 = 
createShareConsumer("group1");
+
+        shareConsumer1.subscribe(Set.of(tp.topic()));
+        shareConsumer2.subscribe(Set.of(tp.topic()));
+        shareConsumer3.subscribe(Set.of(tp.topic()));
+
+        shareConsumer1.poll(Duration.ofMillis(5000));
+        shareConsumer2.poll(Duration.ofMillis(5000));
+        shareConsumer3.poll(Duration.ofMillis(5000));
+
+        ShareConsumer<byte[], byte[]> shareConsumer4 = 
createShareConsumer("group1");
+        shareConsumer4.subscribe(Set.of(tp.topic()));
+
+        TestUtils.waitForCondition(() -> {
+            try {
+                shareConsumer4.poll(Duration.ofMillis(5000));
+            } catch (GroupMaxSizeReachedException e) {
+                return true;
+            } catch (Exception e) {
+                return false;
+            }
+            return false;
+        }, 30000, 200L, () -> "The 4th consumer was not kicked out of the 
group");
+
+        shareConsumer1.close();
+        shareConsumer2.close();
+        shareConsumer3.close();
+        shareConsumer4.close();
+    }
+
     @ClusterTest
     public void testReadCommittedIsolationLevel() {
         alterShareAutoOffsetReset("group1", "earliest");
diff --git a/core/src/main/java/kafka/server/share/SharePartitionManager.java 
b/core/src/main/java/kafka/server/share/SharePartitionManager.java
index a53f846a01c..44af40ec8f8 100644
--- a/core/src/main/java/kafka/server/share/SharePartitionManager.java
+++ b/core/src/main/java/kafka/server/share/SharePartitionManager.java
@@ -420,11 +420,18 @@ public class SharePartitionManager implements 
AutoCloseable {
      * @param shareFetchData The topic-partitions in the share fetch request.
      * @param toForget The topic-partitions to forget present in the share 
fetch request.
      * @param reqMetadata The metadata in the share fetch request.
-     * @param isAcknowledgeDataPresent This tells whether the fetch request 
received includes piggybacked acknowledgements or not
+     * @param isAcknowledgeDataPresent This tells whether the fetch request 
received includes piggybacked acknowledgements or not.
+     * @param clientConnectionId The client connection id.
      * @return The new share fetch context object
      */
-    public ShareFetchContext newContext(String groupId, List<TopicIdPartition> 
shareFetchData,
-                                        List<TopicIdPartition> toForget, 
ShareRequestMetadata reqMetadata, Boolean isAcknowledgeDataPresent) {
+    public ShareFetchContext newContext(
+        String groupId,
+        List<TopicIdPartition> shareFetchData,
+        List<TopicIdPartition> toForget,
+        ShareRequestMetadata reqMetadata,
+        Boolean isAcknowledgeDataPresent,
+        String clientConnectionId
+    ) {
         ShareFetchContext context;
         // If the request's epoch is FINAL_EPOCH or INITIAL_EPOCH, we should 
remove the existing sessions. Also, start a
         // new session in case it is INITIAL_EPOCH. Hence, we need to treat 
them as special cases.
@@ -448,7 +455,8 @@ public class SharePartitionManager implements AutoCloseable 
{
                         ImplicitLinkedHashCollection<>(shareFetchData.size());
                 shareFetchData.forEach(topicIdPartition ->
                     cachedSharePartitions.mustAdd(new 
CachedSharePartition(topicIdPartition, false)));
-                ShareSessionKey responseShareSessionKey = 
cache.maybeCreateSession(groupId, reqMetadata.memberId(), 
cachedSharePartitions);
+                ShareSessionKey responseShareSessionKey = 
cache.maybeCreateSession(groupId, reqMetadata.memberId(),
+                    cachedSharePartitions, clientConnectionId);
                 if (responseShareSessionKey == null) {
                     log.error("Could not create a share session for group {} 
member {}", groupId, reqMetadata.memberId());
                     throw Errors.SHARE_SESSION_NOT_FOUND.exception();
diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala 
b/core/src/main/scala/kafka/server/BrokerServer.scala
index 62d01d89b90..d219e6461cd 100644
--- a/core/src/main/scala/kafka/server/BrokerServer.scala
+++ b/core/src/main/scala/kafka/server/BrokerServer.scala
@@ -259,7 +259,13 @@ class BrokerServer(
         Optional.of(clientMetricsManager)
       )
 
-      val connectionDisconnectListeners = 
Seq(clientMetricsManager.connectionDisconnectListener())
+      val shareFetchSessionCache : ShareSessionCache = new 
ShareSessionCache(config.shareGroupConfig.shareGroupMaxShareSessions())
+
+      val connectionDisconnectListeners = Seq(
+        clientMetricsManager.connectionDisconnectListener(),
+        shareFetchSessionCache.connectionDisconnectListener()
+      )
+
       // Create and start the socket server acceptor threads so that the bound 
port is known.
       // Delay starting processors until the end of the initialization 
sequence to ensure
       // that credentials have been loaded before processing authentications.
@@ -426,8 +432,6 @@ class BrokerServer(
         ))
       val fetchManager = new FetchManager(Time.SYSTEM, new 
FetchSessionCache(fetchSessionCacheShards))
 
-      val shareFetchSessionCache : ShareSessionCache = new 
ShareSessionCache(config.shareGroupConfig.shareGroupMaxShareSessions())
-
       sharePartitionManager = new SharePartitionManager(
         replicaManager,
         time,
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala 
b/core/src/main/scala/kafka/server/KafkaApis.scala
index e6dda8626b8..1e5180ae8db 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -3041,7 +3041,7 @@ class KafkaApis(val requestChannel: RequestChannel,
 
     try {
       // Creating the shareFetchContext for Share Session Handling. if context 
creation fails, the request is failed directly here.
-      shareFetchContext = sharePartitionManager.newContext(groupId, 
shareFetchData, forgottenTopics, newReqMetadata, isAcknowledgeDataPresent)
+      shareFetchContext = sharePartitionManager.newContext(groupId, 
shareFetchData, forgottenTopics, newReqMetadata, isAcknowledgeDataPresent, 
request.context.connectionId)
     } catch {
       case e: Exception =>
         requestHelper.sendMaybeThrottle(request, 
shareFetchRequest.getErrorResponse(AbstractResponse.DEFAULT_THROTTLE_TIME, e))
diff --git 
a/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java 
b/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java
index a45724dd1e7..0b3f8e4828d 100644
--- a/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java
+++ b/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java
@@ -153,6 +153,7 @@ public class SharePartitionManagerTest {
         FetchRequest.ORDINARY_CONSUMER_ID, -1, DELAYED_SHARE_FETCH_MAX_WAIT_MS,
         1, 1024 * 1024, FetchIsolation.HIGH_WATERMARK, Optional.empty(), true);
     private static final String TIMER_NAME_PREFIX = "share-partition-manager";
+    private static final String CONNECTION_ID = "id-1";
 
     static final int DELAYED_SHARE_FETCH_PURGATORY_PURGE_INTERVAL = 1000;
 
@@ -200,12 +201,12 @@ public class SharePartitionManagerTest {
         List<TopicIdPartition> reqData1 = List.of(tp0, tp1);
 
         ShareRequestMetadata reqMetadata1 = new ShareRequestMetadata(memberId, 
ShareRequestMetadata.INITIAL_EPOCH);
-        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false);
+        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context1);
         assertFalse(((ShareSessionContext) context1).isSubsequent());
 
         ShareRequestMetadata reqMetadata2 = new ShareRequestMetadata(memberId, 
ShareRequestMetadata.FINAL_EPOCH);
-        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
List.of(), List.of(), reqMetadata2, true);
+        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
List.of(), List.of(), reqMetadata2, true, CONNECTION_ID);
         assertEquals(FinalContext.class, context2.getClass());
     }
 
@@ -217,7 +218,6 @@ public class SharePartitionManagerTest {
             .build();
 
         Uuid tpId0 = Uuid.randomUuid();
-        Uuid tpId1 = Uuid.randomUuid();
         TopicIdPartition tp0 = new TopicIdPartition(tpId0, new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(tpId0, new 
TopicPartition("foo", 1));
 
@@ -228,16 +228,15 @@ public class SharePartitionManagerTest {
         List<TopicIdPartition> reqData1 = List.of(tp0, tp1);
 
         ShareRequestMetadata reqMetadata1 = new ShareRequestMetadata(memberId, 
ShareRequestMetadata.INITIAL_EPOCH);
-        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false);
+        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context1);
         assertFalse(((ShareSessionContext) context1).isSubsequent());
 
         ShareRequestMetadata reqMetadata2 = new ShareRequestMetadata(memberId, 
ShareRequestMetadata.FINAL_EPOCH);
 
-        // shareFetch is not empty, but the maxBytes of topic partition is 0, 
which means this is added only for acknowledgements.
-        // New context should be created successfully
-        List<TopicIdPartition> reqData3 = List.of(new TopicIdPartition(tpId1, 
new TopicPartition("foo", 0)));
-        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
reqData3, List.of(), reqMetadata2, true);
+        // Sending a Request with FINAL_EPOCH. This should return a 
FinalContext.
+        List<TopicIdPartition> reqData2 = List.of(tp0, tp1);
+        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
reqData2, EMPTY_PART_LIST, reqMetadata2, true, CONNECTION_ID);
         assertEquals(FinalContext.class, context2.getClass());
     }
 
@@ -260,16 +259,16 @@ public class SharePartitionManagerTest {
         List<TopicIdPartition> reqData1 = List.of(tp0, tp1);
 
         ShareRequestMetadata reqMetadata1 = new ShareRequestMetadata(memberId, 
ShareRequestMetadata.INITIAL_EPOCH);
-        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false);
+        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context1);
         assertFalse(((ShareSessionContext) context1).isSubsequent());
 
         ShareRequestMetadata reqMetadata2 = new ShareRequestMetadata(memberId, 
ShareRequestMetadata.FINAL_EPOCH);
 
         // shareFetch is not empty, and it contains tpId1, which should return 
FinalContext instance since it is FINAL_EPOCH
-        List<TopicIdPartition> reqData3 = List.of(new TopicIdPartition(tpId1, 
new TopicPartition("foo", 0)));
+        List<TopicIdPartition> reqData2 = List.of(new TopicIdPartition(tpId1, 
new TopicPartition("foo", 0)));
         assertInstanceOf(FinalContext.class,
-            sharePartitionManager.newContext(groupId, reqData3, List.of(), 
reqMetadata2, true));
+            sharePartitionManager.newContext(groupId, reqData2, List.of(), 
reqMetadata2, true, CONNECTION_ID));
     }
 
     @Test
@@ -295,7 +294,7 @@ public class SharePartitionManagerTest {
         List<TopicIdPartition> reqData2 = List.of(tp0, tp1);
 
         ShareRequestMetadata reqMetadata2 = new 
ShareRequestMetadata(Uuid.randomUuid(), ShareRequestMetadata.INITIAL_EPOCH);
-        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
reqData2, EMPTY_PART_LIST, reqMetadata2, false);
+        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
reqData2, EMPTY_PART_LIST, reqMetadata2, false, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context2);
         assertFalse(((ShareSessionContext) context2).isSubsequent());
 
@@ -314,16 +313,16 @@ public class SharePartitionManagerTest {
 
         // Test trying to create a new session with an invalid epoch
         assertThrows(InvalidShareSessionEpochException.class, () -> 
sharePartitionManager.newContext(groupId, reqData2, EMPTY_PART_LIST,
-            new ShareRequestMetadata(shareSessionKey2.memberId(), 5), true));
+            new ShareRequestMetadata(shareSessionKey2.memberId(), 5), true, 
"id-2"));
 
         // Test trying to create a new session with a non-existent session key
         Uuid memberId4 = Uuid.randomUuid();
         assertThrows(ShareSessionNotFoundException.class, () -> 
sharePartitionManager.newContext(groupId, reqData2, EMPTY_PART_LIST,
-            new ShareRequestMetadata(memberId4, 1), true));
+            new ShareRequestMetadata(memberId4, 1), true, "id-3"));
 
         // Continue the first share session we created.
         ShareFetchContext context5 = sharePartitionManager.newContext(groupId, 
List.of(), EMPTY_PART_LIST,
-            new ShareRequestMetadata(shareSessionKey2.memberId(), 1), true);
+            new ShareRequestMetadata(shareSessionKey2.memberId(), 1), true, 
CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context5);
         assertTrue(((ShareSessionContext) context5).isSubsequent());
 
@@ -341,18 +340,18 @@ public class SharePartitionManagerTest {
 
         // Test setting an invalid share session epoch.
         assertThrows(InvalidShareSessionEpochException.class, () -> 
sharePartitionManager.newContext(groupId, reqData2, EMPTY_PART_LIST,
-            new ShareRequestMetadata(shareSessionKey2.memberId(), 5), true));
+            new ShareRequestMetadata(shareSessionKey2.memberId(), 5), true, 
CONNECTION_ID));
 
         // Test generating a throttled response for a subsequent share session
         ShareFetchContext context7 = sharePartitionManager.newContext(groupId, 
List.of(), EMPTY_PART_LIST,
-            new ShareRequestMetadata(shareSessionKey2.memberId(), 2), true);
+            new ShareRequestMetadata(shareSessionKey2.memberId(), 2), true, 
CONNECTION_ID);
         ShareFetchResponse resp7 = context7.throttleResponse(100);
         assertEquals(Errors.NONE, resp7.error());
         assertEquals(100, resp7.throttleTimeMs());
 
         // Get the final share session.
         ShareFetchContext context8 = sharePartitionManager.newContext(groupId, 
List.of(), EMPTY_PART_LIST,
-            new ShareRequestMetadata(reqMetadata2.memberId(), 
ShareRequestMetadata.FINAL_EPOCH), true);
+            new ShareRequestMetadata(reqMetadata2.memberId(), 
ShareRequestMetadata.FINAL_EPOCH), true, CONNECTION_ID);
         assertEquals(FinalContext.class, context8.getClass());
         assertEquals(1, cache.size());
 
@@ -389,7 +388,7 @@ public class SharePartitionManagerTest {
         String groupId = "grp";
         ShareRequestMetadata reqMetadata1 = new 
ShareRequestMetadata(Uuid.randomUuid(), ShareRequestMetadata.INITIAL_EPOCH);
 
-        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false);
+        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context1);
 
         LinkedHashMap<TopicIdPartition, ShareFetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
@@ -405,7 +404,7 @@ public class SharePartitionManagerTest {
         List<TopicIdPartition> removed2 = new ArrayList<>();
         removed2.add(tp0);
         ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
reqData2, removed2,
-                new ShareRequestMetadata(reqMetadata1.memberId(), 1), true);
+                new ShareRequestMetadata(reqMetadata1.memberId(), 1), true, 
CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context2);
 
         Set<TopicIdPartition> expectedTopicIdPartitions2 = new HashSet<>();
@@ -452,7 +451,7 @@ public class SharePartitionManagerTest {
         String groupId = "grp";
         ShareRequestMetadata reqMetadata1 = new 
ShareRequestMetadata(Uuid.randomUuid(), ShareRequestMetadata.INITIAL_EPOCH);
 
-        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false);
+        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context1);
 
         LinkedHashMap<TopicIdPartition, ShareFetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
@@ -469,7 +468,7 @@ public class SharePartitionManagerTest {
         removed2.add(foo0);
         removed2.add(foo1);
         ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
List.of(), removed2,
-                new ShareRequestMetadata(reqMetadata1.memberId(), 1), true);
+                new ShareRequestMetadata(reqMetadata1.memberId(), 1), true, 
CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context2);
 
         LinkedHashMap<TopicIdPartition, ShareFetchResponseData.PartitionData> 
respData2 = new LinkedHashMap<>();
@@ -495,14 +494,14 @@ public class SharePartitionManagerTest {
 
         List<TopicIdPartition> reqData1 = List.of(foo, bar);
 
-        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false);
+        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context1);
         assertPartitionsPresent((ShareSessionContext) context1, List.of(foo, 
bar));
 
         mockUpdateAndGenerateResponseData(context1, groupId, 
reqMetadata1.memberId());
 
         ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
List.of(), List.of(foo),
-                new ShareRequestMetadata(reqMetadata1.memberId(), 1), true);
+                new ShareRequestMetadata(reqMetadata1.memberId(), 1), true, 
CONNECTION_ID);
 
         // So foo is removed but not the others.
         assertPartitionsPresent((ShareSessionContext) context2, List.of(bar));
@@ -510,7 +509,7 @@ public class SharePartitionManagerTest {
         mockUpdateAndGenerateResponseData(context2, groupId, 
reqMetadata1.memberId());
 
         ShareFetchContext context3 = sharePartitionManager.newContext(groupId, 
List.of(), List.of(bar),
-                new ShareRequestMetadata(reqMetadata1.memberId(), 2), true);
+                new ShareRequestMetadata(reqMetadata1.memberId(), 2), true, 
CONNECTION_ID);
         assertPartitionsPresent((ShareSessionContext) context3, List.of());
     }
 
@@ -537,7 +536,7 @@ public class SharePartitionManagerTest {
         List<TopicIdPartition> reqData1 = List.of(foo, bar);
 
         ShareRequestMetadata reqMetadata1 = new 
ShareRequestMetadata(Uuid.randomUuid(), ShareRequestMetadata.INITIAL_EPOCH);
-        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false);
+        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false, CONNECTION_ID);
 
         assertInstanceOf(ShareSessionContext.class, context1);
         assertFalse(((ShareSessionContext) context1).isSubsequent());
@@ -553,7 +552,7 @@ public class SharePartitionManagerTest {
 
         // Create a subsequent share fetch request as though no topics changed.
         ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
List.of(), EMPTY_PART_LIST,
-                new ShareRequestMetadata(reqMetadata1.memberId(), 1), true);
+                new ShareRequestMetadata(reqMetadata1.memberId(), 1), true, 
CONNECTION_ID);
 
         assertInstanceOf(ShareSessionContext.class, context2);
         assertTrue(((ShareSessionContext) context2).isSubsequent());
@@ -587,7 +586,7 @@ public class SharePartitionManagerTest {
         List<TopicIdPartition> reqData2 = List.of(tp0, tp1, tpNull1);
 
         ShareRequestMetadata reqMetadata2 = new 
ShareRequestMetadata(Uuid.randomUuid(), ShareRequestMetadata.INITIAL_EPOCH);
-        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
reqData2, EMPTY_PART_LIST, reqMetadata2, false);
+        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
reqData2, EMPTY_PART_LIST, reqMetadata2, false, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context2);
         assertFalse(((ShareSessionContext) context2).isSubsequent());
         
assertErroneousAndValidTopicIdPartitions(context2.getErroneousAndValidTopicIdPartitions(),
 List.of(tpNull1), List.of(tp0, tp1));
@@ -609,15 +608,15 @@ public class SharePartitionManagerTest {
 
         // Test trying to create a new session with an invalid epoch
         assertThrows(InvalidShareSessionEpochException.class, () -> 
sharePartitionManager.newContext(groupId, reqData2, EMPTY_PART_LIST,
-                new ShareRequestMetadata(shareSessionKey2.memberId(), 5), 
true));
+                new ShareRequestMetadata(shareSessionKey2.memberId(), 5), 
true, CONNECTION_ID));
 
         // Test trying to create a new session with a non-existent session key
         assertThrows(ShareSessionNotFoundException.class, () -> 
sharePartitionManager.newContext(groupId, reqData2, EMPTY_PART_LIST,
-                new ShareRequestMetadata(Uuid.randomUuid(), 1), true));
+                new ShareRequestMetadata(Uuid.randomUuid(), 1), true, 
CONNECTION_ID));
 
         // Continue the first share session we created.
         ShareFetchContext context5 = sharePartitionManager.newContext(groupId, 
List.of(), EMPTY_PART_LIST,
-                new ShareRequestMetadata(shareSessionKey2.memberId(), 1), 
true);
+                new ShareRequestMetadata(shareSessionKey2.memberId(), 1), 
true, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context5);
         assertTrue(((ShareSessionContext) context5).isSubsequent());
 
@@ -628,12 +627,12 @@ public class SharePartitionManagerTest {
 
         // Test setting an invalid share session epoch.
         assertThrows(InvalidShareSessionEpochException.class, () -> 
sharePartitionManager.newContext(groupId, reqData2, EMPTY_PART_LIST,
-                new ShareRequestMetadata(shareSessionKey2.memberId(), 5), 
true));
+                new ShareRequestMetadata(shareSessionKey2.memberId(), 5), 
true, CONNECTION_ID));
 
         // Test generating a throttled response for a subsequent share session
         List<TopicIdPartition> reqData7 = List.of(tpNull2);
         ShareFetchContext context7 = sharePartitionManager.newContext(groupId, 
reqData7, EMPTY_PART_LIST,
-                new ShareRequestMetadata(shareSessionKey2.memberId(), 2), 
true);
+                new ShareRequestMetadata(shareSessionKey2.memberId(), 2), 
true, CONNECTION_ID);
         // Check for throttled response
         ShareFetchResponse resp7 = context7.throttleResponse(100);
         assertEquals(Errors.NONE, resp7.error());
@@ -643,7 +642,7 @@ public class SharePartitionManagerTest {
 
         // Get the final share session.
         ShareFetchContext context8 = sharePartitionManager.newContext(groupId, 
List.of(), EMPTY_PART_LIST,
-                new ShareRequestMetadata(reqMetadata2.memberId(), 
ShareRequestMetadata.FINAL_EPOCH), true);
+                new ShareRequestMetadata(reqMetadata2.memberId(), 
ShareRequestMetadata.FINAL_EPOCH), true, CONNECTION_ID);
         assertEquals(FinalContext.class, context8.getClass());
         assertEquals(1, cache.size());
 
@@ -688,7 +687,7 @@ public class SharePartitionManagerTest {
         short version = ApiKeys.SHARE_FETCH.latestVersion();
 
         ShareRequestMetadata reqMetadata2 = new 
ShareRequestMetadata(Uuid.randomUuid(), ShareRequestMetadata.INITIAL_EPOCH);
-        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
reqData2, EMPTY_PART_LIST, reqMetadata2, false);
+        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
reqData2, EMPTY_PART_LIST, reqMetadata2, false, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context2);
         assertFalse(((ShareSessionContext) context2).isSubsequent());
 
@@ -708,17 +707,17 @@ public class SharePartitionManagerTest {
 
         // Test trying to create a new session with an invalid epoch
         assertThrows(InvalidShareSessionEpochException.class, () -> 
sharePartitionManager.newContext(groupId, reqData2, EMPTY_PART_LIST,
-                new ShareRequestMetadata(shareSessionKey2.memberId(), 5), 
true));
+                new ShareRequestMetadata(shareSessionKey2.memberId(), 5), 
true, CONNECTION_ID));
 
         // Test trying to create a new session with a non-existent session key
         Uuid memberId4 = Uuid.randomUuid();
         assertThrows(ShareSessionNotFoundException.class, () -> 
sharePartitionManager.newContext(groupId, reqData2, EMPTY_PART_LIST,
-                new ShareRequestMetadata(memberId4, 1), true));
+                new ShareRequestMetadata(memberId4, 1), true, CONNECTION_ID));
 
         // Continue the first share session we created.
         List<TopicIdPartition> reqData5 = List.of(tp2);
         ShareFetchContext context5 = sharePartitionManager.newContext(groupId, 
reqData5, EMPTY_PART_LIST,
-                new ShareRequestMetadata(shareSessionKey2.memberId(), 1), 
true);
+                new ShareRequestMetadata(shareSessionKey2.memberId(), 1), 
true, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context5);
         assertTrue(((ShareSessionContext) context5).isSubsequent());
 
@@ -733,11 +732,11 @@ public class SharePartitionManagerTest {
 
         // Test setting an invalid share session epoch.
         assertThrows(InvalidShareSessionEpochException.class, () -> 
sharePartitionManager.newContext(groupId, reqData2, EMPTY_PART_LIST,
-                new ShareRequestMetadata(shareSessionKey2.memberId(), 5), 
true));
+                new ShareRequestMetadata(shareSessionKey2.memberId(), 5), 
true, CONNECTION_ID));
 
         // Test generating a throttled response for a subsequent share session
         ShareFetchContext context7 = sharePartitionManager.newContext(groupId, 
List.of(), EMPTY_PART_LIST,
-                new ShareRequestMetadata(shareSessionKey2.memberId(), 2), 
true);
+                new ShareRequestMetadata(shareSessionKey2.memberId(), 2), 
true, CONNECTION_ID);
 
         int respSize7 = context7.responseSize(respData2, version);
         ShareFetchResponse resp7 = context7.throttleResponse(100);
@@ -748,7 +747,7 @@ public class SharePartitionManagerTest {
 
         // Get the final share session.
         ShareFetchContext context8 = sharePartitionManager.newContext(groupId, 
List.of(), EMPTY_PART_LIST,
-                new ShareRequestMetadata(reqMetadata2.memberId(), 
ShareRequestMetadata.FINAL_EPOCH), true);
+                new ShareRequestMetadata(reqMetadata2.memberId(), 
ShareRequestMetadata.FINAL_EPOCH), true, CONNECTION_ID);
         assertEquals(FinalContext.class, context8.getClass());
         assertEquals(1, cache.size());
 
@@ -794,7 +793,7 @@ public class SharePartitionManagerTest {
         List<TopicIdPartition> reqData1 = List.of(tp0, tp1);
 
         ShareRequestMetadata reqMetadata1 = new 
ShareRequestMetadata(memberId1, ShareRequestMetadata.INITIAL_EPOCH);
-        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false);
+        ShareFetchContext context1 = sharePartitionManager.newContext(groupId, 
reqData1, EMPTY_PART_LIST, reqMetadata1, false, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context1);
         assertFalse(((ShareSessionContext) context1).isSubsequent());
 
@@ -815,7 +814,7 @@ public class SharePartitionManagerTest {
         List<TopicIdPartition> reqData2 = List.of(tp2);
 
         ShareRequestMetadata reqMetadata2 = new 
ShareRequestMetadata(memberId2, ShareRequestMetadata.INITIAL_EPOCH);
-        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
reqData2, EMPTY_PART_LIST, reqMetadata2, false);
+        ShareFetchContext context2 = sharePartitionManager.newContext(groupId, 
reqData2, EMPTY_PART_LIST, reqMetadata2, false, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context2);
         assertFalse(((ShareSessionContext) context2).isSubsequent());
 
@@ -833,7 +832,7 @@ public class SharePartitionManagerTest {
         // Continue the first share session we created.
         List<TopicIdPartition> reqData3 = List.of(tp2);
         ShareFetchContext context3 = sharePartitionManager.newContext(groupId, 
reqData3, EMPTY_PART_LIST,
-                new ShareRequestMetadata(shareSessionKey1.memberId(), 1), 
true);
+                new ShareRequestMetadata(shareSessionKey1.memberId(), 1), 
true, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context3);
         assertTrue(((ShareSessionContext) context3).isSubsequent());
 
@@ -848,7 +847,7 @@ public class SharePartitionManagerTest {
         // Continue the second session we created.
         List<TopicIdPartition> reqData4 = List.of(tp3);
         ShareFetchContext context4 = sharePartitionManager.newContext(groupId, 
reqData4, List.of(tp2),
-                new ShareRequestMetadata(shareSessionKey2.memberId(), 1), 
true);
+                new ShareRequestMetadata(shareSessionKey2.memberId(), 1), 
true, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context4);
         assertTrue(((ShareSessionContext) context4).isSubsequent());
 
@@ -861,7 +860,7 @@ public class SharePartitionManagerTest {
 
         // Get the final share session.
         ShareFetchContext context5 = sharePartitionManager.newContext(groupId, 
List.of(), EMPTY_PART_LIST,
-                new ShareRequestMetadata(reqMetadata1.memberId(), 
ShareRequestMetadata.FINAL_EPOCH), true);
+                new ShareRequestMetadata(reqMetadata1.memberId(), 
ShareRequestMetadata.FINAL_EPOCH), true, CONNECTION_ID);
         assertEquals(FinalContext.class, context5.getClass());
 
         LinkedHashMap<TopicIdPartition, ShareFetchResponseData.PartitionData> 
respData5 = new LinkedHashMap<>();
@@ -876,7 +875,7 @@ public class SharePartitionManagerTest {
 
         // Continue the second share session .
         ShareFetchContext context6 = sharePartitionManager.newContext(groupId, 
List.of(), List.of(tp3),
-                new ShareRequestMetadata(shareSessionKey2.memberId(), 2), 
true);
+                new ShareRequestMetadata(shareSessionKey2.memberId(), 2), 
true, CONNECTION_ID);
         assertInstanceOf(ShareSessionContext.class, context6);
         assertTrue(((ShareSessionContext) context6).isSubsequent());
 
diff --git 
a/core/src/test/scala/unit/kafka/server/GroupCoordinatorBaseRequestTest.scala 
b/core/src/test/scala/unit/kafka/server/GroupCoordinatorBaseRequestTest.scala
index df939c29ffb..e7926df3e36 100644
--- 
a/core/src/test/scala/unit/kafka/server/GroupCoordinatorBaseRequestTest.scala
+++ 
b/core/src/test/scala/unit/kafka/server/GroupCoordinatorBaseRequestTest.scala
@@ -33,9 +33,11 @@ import org.apache.kafka.common.utils.ProducerIdAndEpoch
 import 
org.apache.kafka.controller.ControllerRequestContextUtil.ANONYMOUS_CONTEXT
 import org.junit.jupiter.api.Assertions.{assertEquals, fail}
 
+import java.net.Socket
 import java.util.{Comparator, Properties}
 import java.util.stream.Collectors
 import scala.collection.Seq
+import scala.collection.mutable.ListBuffer
 import scala.jdk.CollectionConverters._
 import scala.reflect.ClassTag
 
@@ -46,6 +48,8 @@ class GroupCoordinatorBaseRequestTest(cluster: 
ClusterInstance) {
 
   protected var producer: KafkaProducer[String, String] = _
 
+  protected var openSockets: ListBuffer[Socket] = ListBuffer[Socket]()
+
   protected def createOffsetsTopic(): Unit = {
     val admin = cluster.admin()
     try {
@@ -140,6 +144,14 @@ class GroupCoordinatorBaseRequestTest(cluster: 
ClusterInstance) {
       keySerializer = new StringSerializer, valueSerializer = new 
StringSerializer)
   }
 
+  protected def closeSockets(): Unit = {
+    while (openSockets.nonEmpty) {
+      val socket = openSockets.head
+      socket.close()
+      openSockets.remove(0)
+    }
+  }
+
   protected def closeProducer(): Unit = {
     if(producer != null)
       producer.close()
@@ -922,6 +934,23 @@ class GroupCoordinatorBaseRequestTest(cluster: 
ClusterInstance) {
     )
   }
 
+  protected def connectAndReceiveWithoutClosingSocket[T <: AbstractResponse](
+    request: AbstractRequest,
+    destination: Int
+  )(implicit classTag: ClassTag[T]): T = {
+    val socket = IntegrationTestUtils.connect(brokerSocketServer(destination), 
cluster.clientListener())
+    openSockets += socket
+    IntegrationTestUtils.sendAndReceive[T](request, socket)
+  }
+
+  protected def connectAndReceiveWithoutClosingSocket[T <: AbstractResponse](
+    request: AbstractRequest
+  )(implicit classTag: ClassTag[T]): T = {
+    val socket = IntegrationTestUtils.connect(cluster.anyBrokerSocketServer(), 
cluster.clientListener())
+    openSockets += socket
+    IntegrationTestUtils.sendAndReceive[T](request, socket)
+  }
+
   private def brokerSocketServer(brokerId: Int): SocketServer = {
     getBrokers.find { broker =>
       broker.config.brokerId == brokerId
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala 
b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index 093ef4943f9..18c1d7b3b6c 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -4188,7 +4188,7 @@ class KafkaApisTest extends Logging {
       ).asJava)
     )
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), 
any())).thenReturn(
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
       new ShareSessionContext(new ShareRequestMetadata(memberId, 
shareSessionEpoch), util.List.of(
         new TopicIdPartition(topicId, partitionIndex, topicName)))
     )
@@ -4257,10 +4257,11 @@ class KafkaApisTest extends Logging {
     cachedSharePartitions.mustAdd(new CachedSharePartition(
       new TopicIdPartition(topicId, partitionIndex, topicName), false))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), 
any())).thenThrow(
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenThrow(
       Errors.INVALID_REQUEST.exception()
-    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 1),
-      new ShareSession(new ShareSessionKey(groupId, memberId), 
cachedSharePartitions, 2)))
+    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 
1), new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2
+    )))
 
     when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
@@ -4351,7 +4352,7 @@ class KafkaApisTest extends Logging {
       ).asJava)
     )
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), 
any())).thenReturn(
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
       new ShareSessionContext(new ShareRequestMetadata(memberId, 0), 
util.List.of(
         new TopicIdPartition(topicId, partitionIndex, topicName)
       ))
@@ -4436,7 +4437,7 @@ class KafkaApisTest extends Logging {
       FutureUtils.failedFuture[util.Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData]](Errors.UNKNOWN_SERVER_ERROR.exception())
     )
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), 
any())).thenReturn(
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
       new ShareSessionContext(new ShareRequestMetadata(memberId, 0), 
util.List.of(
         new TopicIdPartition(topicId, partitionIndex, topicName)
       ))
@@ -4501,9 +4502,10 @@ class KafkaApisTest extends Logging {
     cachedSharePartitions.mustAdd(new CachedSharePartition(
       new TopicIdPartition(topicId, partitionIndex, topicName), false))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), any()))
-      .thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 
1),
-        new ShareSession(new ShareSessionKey(groupId, memberId), 
cachedSharePartitions, 2)))
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any()))
+      .thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 
1), new ShareSession(
+        new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2))
+      )
 
     when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
@@ -4559,9 +4561,10 @@ class KafkaApisTest extends Logging {
     cachedSharePartitions.mustAdd(new CachedSharePartition(
       new TopicIdPartition(topicId, partitionIndex, topicName), false))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), any()))
-      .thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 
1),
-        new ShareSession(new ShareSessionKey(groupId, memberId), 
cachedSharePartitions, 2)))
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any()))
+      .thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 
1), new ShareSession(
+        new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2))
+      )
 
     when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
@@ -4615,7 +4618,7 @@ class KafkaApisTest extends Logging {
       ).asJava)
     )
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), 
any())).thenReturn(
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
       new ShareSessionContext(new ShareRequestMetadata(memberId, 0), 
util.List.of(
         new TopicIdPartition(topicId, partitionIndex, topicName)
       ))
@@ -4679,7 +4682,7 @@ class KafkaApisTest extends Logging {
       ).asJava)
     )
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), 
any())).thenReturn(
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
       new ShareSessionContext(new ShareRequestMetadata(memberId, 0), 
util.List.of(
         new TopicIdPartition(topicId, partitionIndex, topicName)
       ))
@@ -4765,7 +4768,7 @@ class KafkaApisTest extends Logging {
       ).asJava)
     )
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), 
any())).thenReturn(
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
       new ShareSessionContext(new ShareRequestMetadata(memberId, 0), 
util.List.of(
         new TopicIdPartition(topicId, partitionIndex, topicName)
       ))
@@ -4900,14 +4903,14 @@ class KafkaApisTest extends Logging {
       new TopicIdPartition(topicId, partitionIndex, topicName), false)
     )
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), 
any())).thenReturn(
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
       new ShareSessionContext(new ShareRequestMetadata(memberId, 0), 
util.List.of(
         new TopicIdPartition(topicId, partitionIndex, topicName)
       ))
-    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 1),
-      new ShareSession(new ShareSessionKey(groupId, memberId), 
cachedSharePartitions, 2))
-    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 2),
-      new ShareSession(new ShareSessionKey(groupId, memberId), 
cachedSharePartitions, 3))
+    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 
1), new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2))
+    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 
2), new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 3))
     )
 
     when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
@@ -5162,17 +5165,17 @@ class KafkaApisTest extends Logging {
       new TopicIdPartition(topicId4, 0, topicName4), false
     ))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), 
any())).thenReturn(
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
       new ShareSessionContext(new ShareRequestMetadata(memberId, 0), 
util.List.of(
         new TopicIdPartition(topicId1, new TopicPartition(topicName1, 0)),
         new TopicIdPartition(topicId1, new TopicPartition(topicName1, 1)),
         new TopicIdPartition(topicId2, new TopicPartition(topicName2, 0)),
         new TopicIdPartition(topicId2, new TopicPartition(topicName2, 1))
       ))
-    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 1),
-      new ShareSession(new ShareSessionKey(groupId, memberId), 
cachedSharePartitions1, 2))
-    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 2),
-      new ShareSession(new ShareSessionKey(groupId, memberId), 
cachedSharePartitions2, 3))
+    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 
1), new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions1, 2))
+    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 
2), new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions2, 3))
     ).thenReturn(new FinalContext())
 
     when(sharePartitionManager.releaseSession(any(), any())).thenReturn(
@@ -6127,12 +6130,13 @@ class KafkaApisTest extends Logging {
       new TopicIdPartition(topicId, 0, topicName), false
     ))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), 
any())).thenReturn(
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
       new ShareSessionContext(new ShareRequestMetadata(memberId, 0), 
util.List.of(
         new TopicIdPartition(topicId, partitionIndex, topicName)
       ))
-    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 1),
-      new ShareSession(new ShareSessionKey(groupId, memberId), 
cachedSharePartitions, 2)))
+    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 
1), new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2))
+    )
 
     when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
@@ -6329,7 +6333,7 @@ class KafkaApisTest extends Logging {
       ).asJava)
     )
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), 
any())).thenReturn(
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
       new FinalContext()
     )
 
diff --git 
a/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala 
b/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala
index ef02c576dd1..7aba491536f 100644
--- 
a/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala
+++ 
b/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala
@@ -43,6 +43,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
   @AfterEach
   def tearDown(): Unit = {
     closeProducer
+    closeSockets
   }
 
   @ClusterTest(
@@ -59,7 +60,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     )
 
     val shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    val shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     assertEquals(Errors.UNSUPPORTED_VERSION.code, 
shareFetchResponse.data.errorCode)
     assertEquals(0, shareFetchResponse.data.acquisitionLockTimeoutMs)
@@ -75,7 +76,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     val metadata: ShareRequestMetadata = new 
ShareRequestMetadata(Uuid.randomUuid(), ShareRequestMetadata.INITIAL_EPOCH)
 
     val shareAcknowledgeRequest = createShareAcknowledgeRequest(groupId, 
metadata, Map.empty)
-    val shareAcknowledgeResponse = 
connectAndReceive[ShareAcknowledgeResponse](shareAcknowledgeRequest)
+    val shareAcknowledgeResponse = 
connectAndReceiveWithoutClosingSocket[ShareAcknowledgeResponse](shareAcknowledgeRequest)
 
     assertEquals(Errors.UNSUPPORTED_VERSION.code, 
shareAcknowledgeResponse.data.errorCode)
   }
@@ -123,9 +124,8 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
 
     // Send the share fetch request to the non-replica and verify the error 
code
     val shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    val shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest, nonReplicaId)
+    val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest, 
nonReplicaId)
     assertEquals(30000, shareFetchResponse.data.acquisitionLockTimeoutMs)
-
     val partitionData = 
shareFetchResponse.responseData(topicNames).get(topicIdPartition)
     assertEquals(Errors.NOT_LEADER_OR_FOLLOWER.code, partitionData.errorCode)
     assertEquals(leader, partitionData.currentLeader().leaderId())
@@ -174,7 +174,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     val metadata = new ShareRequestMetadata(memberId, 
ShareRequestMetadata.nextEpoch(ShareRequestMetadata.INITIAL_EPOCH))
     val acknowledgementsMap: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     val shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMap)
-    val shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     val shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -245,7 +245,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     // as the share partitions might not be initialized yet. So, we retry 
until we get the response.
     var responses = Seq[ShareFetchResponseData.PartitionData]()
     TestUtils.waitUntilTrue(() => {
-      val shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+      val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
       val shareFetchResponseData = shareFetchResponse.data()
       assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
       assertEquals(30000, shareFetchResponseData.acquisitionLockTimeoutMs)
@@ -340,9 +340,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var shareFetchRequest2 = createShareFetchRequest(groupId, metadata, send2, 
Seq.empty, acknowledgementsMap)
     var shareFetchRequest3 = createShareFetchRequest(groupId, metadata, send3, 
Seq.empty, acknowledgementsMap)
 
-    var shareFetchResponse1 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest1, destination = leader1)
-    var shareFetchResponse2 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest2, destination = leader2)
-    var shareFetchResponse3 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest3, destination = leader3)
+    var shareFetchResponse1 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest1, 
destination = leader1)
+    var shareFetchResponse2 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest2, 
destination = leader2)
+    var shareFetchResponse3 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest3, 
destination = leader3)
 
     initProducer()
     // Producing 10 records to the topic partitions created above
@@ -356,9 +356,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     shareFetchRequest2 = createShareFetchRequest(groupId, metadata, send2, 
Seq.empty, acknowledgementsMap)
     shareFetchRequest3 = createShareFetchRequest(groupId, metadata, send3, 
Seq.empty, acknowledgementsMap)
 
-    shareFetchResponse1 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest1, destination = leader1)
-    shareFetchResponse2 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest2, destination = leader2)
-    shareFetchResponse3 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest3, destination = leader3)
+    shareFetchResponse1 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest1, 
destination = leader1)
+    shareFetchResponse2 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest2, 
destination = leader2)
+    shareFetchResponse3 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest3, 
destination = leader3)
 
     val shareFetchResponseData1 = shareFetchResponse1.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData1.errorCode)
@@ -451,7 +451,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     val acknowledgementsMapForFetch: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     var shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    var shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    var shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     var shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -478,7 +478,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
       .setLastOffset(9)
       .setAcknowledgeTypes(Collections.singletonList(1.toByte))).asJava) // 
Accept the records
     val shareAcknowledgeRequest = createShareAcknowledgeRequest(groupId, 
metadata, acknowledgementsMapForAcknowledge)
-    val shareAcknowledgeResponse = 
connectAndReceive[ShareAcknowledgeResponse](shareAcknowledgeRequest)
+    val shareAcknowledgeResponse = 
connectAndReceiveWithoutClosingSocket[ShareAcknowledgeResponse](shareAcknowledgeRequest)
 
     val shareAcknowledgeResponseData = shareAcknowledgeResponse.data()
     assertEquals(Errors.NONE.code, shareAcknowledgeResponseData.errorCode)
@@ -500,7 +500,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     shareSessionEpoch = ShareRequestMetadata.nextEpoch(shareSessionEpoch)
     metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -566,7 +566,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     var acknowledgementsMapForFetch: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     var shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    var shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    var shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     var shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -595,7 +595,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
         .setLastOffset(9)
         .setAcknowledgeTypes(Collections.singletonList(1.toByte))).asJava) // 
Accept the records
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -620,7 +620,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     shareSessionEpoch = ShareRequestMetadata.nextEpoch(shareSessionEpoch)
     metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -684,7 +684,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     val acknowledgementsMapForFetch: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     var shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    var shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    var shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     var shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -711,7 +711,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
         .setLastOffset(9)
         .setAcknowledgeTypes(Collections.singletonList(2.toByte))).asJava) // 
Release the records
     val shareAcknowledgeRequest = createShareAcknowledgeRequest(groupId, 
metadata, acknowledgementsMapForAcknowledge)
-    val shareAcknowledgeResponse = 
connectAndReceive[ShareAcknowledgeResponse](shareAcknowledgeRequest)
+    val shareAcknowledgeResponse = 
connectAndReceiveWithoutClosingSocket[ShareAcknowledgeResponse](shareAcknowledgeRequest)
 
     val shareAcknowledgeResponseData = shareAcknowledgeResponse.data()
     assertEquals(Errors.NONE.code, shareAcknowledgeResponseData.errorCode)
@@ -730,7 +730,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     shareSessionEpoch = ShareRequestMetadata.nextEpoch(shareSessionEpoch)
     metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -794,7 +794,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     var acknowledgementsMapForFetch: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     var shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    var shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    var shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     var shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -839,7 +839,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
         releaseAcknowledgementSent = true
       }
       shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-      shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+      shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
       shareFetchResponseData = shareFetchResponse.data()
       assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -908,7 +908,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     val acknowledgementsMapForFetch: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     var shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    var shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    var shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     var shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -935,7 +935,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
         .setLastOffset(9)
         .setAcknowledgeTypes(Collections.singletonList(3.toByte))).asJava) // 
Reject the records
     val shareAcknowledgeRequest = createShareAcknowledgeRequest(groupId, 
metadata, acknowledgementsMapForAcknowledge)
-    val shareAcknowledgeResponse = 
connectAndReceive[ShareAcknowledgeResponse](shareAcknowledgeRequest)
+    val shareAcknowledgeResponse = 
connectAndReceiveWithoutClosingSocket[ShareAcknowledgeResponse](shareAcknowledgeRequest)
 
     val shareAcknowledgeResponseData = shareAcknowledgeResponse.data()
     assertEquals(Errors.NONE.code, shareAcknowledgeResponseData.errorCode)
@@ -957,7 +957,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     shareSessionEpoch = ShareRequestMetadata.nextEpoch(shareSessionEpoch)
     metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1021,7 +1021,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     var acknowledgementsMapForFetch: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     var shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    var shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    var shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     var shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1050,7 +1050,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
       .setLastOffset(9)
       .setAcknowledgeTypes(Collections.singletonList(3.toByte))).asJava) // 
Reject the records
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1075,7 +1075,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     shareSessionEpoch = ShareRequestMetadata.nextEpoch(shareSessionEpoch)
     metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1141,7 +1141,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     val acknowledgementsMapForFetch: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     var shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    var shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    var shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     var shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1168,7 +1168,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
         .setLastOffset(9)
         .setAcknowledgeTypes(Collections.singletonList(2.toByte))).asJava) // 
Release the records
     var shareAcknowledgeRequest = createShareAcknowledgeRequest(groupId, 
metadata, acknowledgementsMapForAcknowledge)
-    var shareAcknowledgeResponse = 
connectAndReceive[ShareAcknowledgeResponse](shareAcknowledgeRequest)
+    var shareAcknowledgeResponse = 
connectAndReceiveWithoutClosingSocket[ShareAcknowledgeResponse](shareAcknowledgeRequest)
 
     var shareAcknowledgeResponseData = shareAcknowledgeResponse.data()
     assertEquals(Errors.NONE.code, shareAcknowledgeResponseData.errorCode)
@@ -1187,7 +1187,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     shareSessionEpoch = ShareRequestMetadata.nextEpoch(shareSessionEpoch)
     metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1213,7 +1213,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
         .setLastOffset(9)
         .setAcknowledgeTypes(Collections.singletonList(2.toByte))).asJava) // 
Release the records again
     shareAcknowledgeRequest = createShareAcknowledgeRequest(groupId, metadata, 
acknowledgementsMapForAcknowledge)
-    shareAcknowledgeResponse = 
connectAndReceive[ShareAcknowledgeResponse](shareAcknowledgeRequest)
+    shareAcknowledgeResponse = 
connectAndReceiveWithoutClosingSocket[ShareAcknowledgeResponse](shareAcknowledgeRequest)
 
     shareAcknowledgeResponseData = shareAcknowledgeResponse.data()
     assertEquals(Errors.NONE.code, shareAcknowledgeResponseData.errorCode)
@@ -1235,7 +1235,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     shareSessionEpoch = ShareRequestMetadata.nextEpoch(shareSessionEpoch)
     metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1312,9 +1312,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     val acknowledgementsMap3: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     val shareFetchRequest3 = createShareFetchRequest(groupId, metadata3, send, 
Seq.empty, acknowledgementsMap3, minBytes = 100, maxBytes = 1500)
 
-    val shareFetchResponse1 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest1)
-    val shareFetchResponse2 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest2)
-    val shareFetchResponse3 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest3)
+    val shareFetchResponse1 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest1)
+    val shareFetchResponse2 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest2)
+    val shareFetchResponse3 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest3)
 
 
     val shareFetchResponseData1 = shareFetchResponse1.data()
@@ -1407,9 +1407,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     val acknowledgementsMap3: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     val shareFetchRequest3 = createShareFetchRequest(groupId3, metadata3, 
send, Seq.empty, acknowledgementsMap3)
 
-    val shareFetchResponse1 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest1)
-    val shareFetchResponse2 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest2)
-    val shareFetchResponse3 = 
connectAndReceive[ShareFetchResponse](shareFetchRequest3)
+    val shareFetchResponse1 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest1)
+    val shareFetchResponse2 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest2)
+    val shareFetchResponse3 = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest3)
 
 
     val shareFetchResponseData1 = shareFetchResponse1.data()
@@ -1487,7 +1487,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     var acknowledgementsMapForFetch: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     var shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    var shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    var shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     var shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1516,7 +1516,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
       .setLastOffset(9)
       .setAcknowledgeTypes(Collections.singletonList(1.toByte))).asJava) // 
Accept the records
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1542,7 +1542,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
       .setLastOffset(19)
       .setAcknowledgeTypes(Collections.singletonList(1.toByte))).asJava) // 
Accept the records
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1595,7 +1595,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     var acknowledgementsMapForFetch: Map[TopicIdPartition, 
util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
     var shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    var shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    var shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     var shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1624,7 +1624,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
       .setLastOffset(9)
       .setAcknowledgeTypes(Collections.singletonList(1.toByte))).asJava) // 
Accept the records
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMapForFetch)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1651,7 +1651,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
       .setLastOffset(19)
       .setAcknowledgeTypes(Collections.singletonList(1.toByte))).asJava) // 
Accept the records
     val shareAcknowledgeRequest = createShareAcknowledgeRequest(groupId, 
metadata, acknowledgementsMapForAcknowledge)
-    val shareAcknowledgeResponse = 
connectAndReceive[ShareAcknowledgeResponse](shareAcknowledgeRequest)
+    val shareAcknowledgeResponse = 
connectAndReceiveWithoutClosingSocket[ShareAcknowledgeResponse](shareAcknowledgeRequest)
 
     val shareAcknowledgeResponseData = shareAcknowledgeResponse.data()
     assertEquals(Errors.NONE.code, shareAcknowledgeResponseData.errorCode)
@@ -1711,7 +1711,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
       .setLastOffset(9)
       .setAcknowledgeTypes(Collections.singletonList(1.toByte))).asJava) // 
Acknowledgements in the Initial Fetch Request
     val shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, acknowledgementsMap)
-    val shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     val shareFetchResponseData = shareFetchResponse.data()
     // The response will have a top level error code because this is an 
Initial Fetch request with acknowledgement data present
@@ -1759,7 +1759,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
         .setLastOffset(9)
         setAcknowledgeTypes(Collections.singletonList(1.toByte))).asJava)
     val shareAcknowledgeRequest = createShareAcknowledgeRequest(groupId, 
metadata, acknowledgementsMap)
-    val shareAcknowledgeResponse = 
connectAndReceive[ShareAcknowledgeResponse](shareAcknowledgeRequest)
+    val shareAcknowledgeResponse = 
connectAndReceiveWithoutClosingSocket[ShareAcknowledgeResponse](shareAcknowledgeRequest)
 
     val shareAcknowledgeResponseData = shareAcknowledgeResponse.data()
     assertEquals(Errors.INVALID_SHARE_SESSION_EPOCH.code, 
shareAcknowledgeResponseData.errorCode)
@@ -1809,7 +1809,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var shareSessionEpoch = 
ShareRequestMetadata.nextEpoch(ShareRequestMetadata.INITIAL_EPOCH)
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     var shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    var shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    var shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     var shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1831,7 +1831,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     shareSessionEpoch = 
ShareRequestMetadata.nextEpoch(ShareRequestMetadata.nextEpoch(shareSessionEpoch))
     metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.INVALID_SHARE_SESSION_EPOCH.code, 
shareFetchResponseData.errorCode)
@@ -1881,7 +1881,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var shareSessionEpoch = 
ShareRequestMetadata.nextEpoch(ShareRequestMetadata.INITIAL_EPOCH)
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     val shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    val shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     val shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1908,7 +1908,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
       .setLastOffset(9)
       .setAcknowledgeTypes(Collections.singletonList(1.toByte))).asJava)
     val shareAcknowledgeRequest = createShareAcknowledgeRequest(groupId, 
metadata, acknowledgementsMap)
-    val shareAcknowledgeResponse = 
connectAndReceive[ShareAcknowledgeResponse](shareAcknowledgeRequest)
+    val shareAcknowledgeResponse = 
connectAndReceiveWithoutClosingSocket[ShareAcknowledgeResponse](shareAcknowledgeRequest)
 
     val shareAcknowledgeResponseData = shareAcknowledgeResponse.data()
     assertEquals(Errors.INVALID_SHARE_SESSION_EPOCH.code, 
shareAcknowledgeResponseData.errorCode)
@@ -1959,7 +1959,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var shareSessionEpoch = 
ShareRequestMetadata.nextEpoch(ShareRequestMetadata.INITIAL_EPOCH)
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     var shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    var shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    var shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     var shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -1981,12 +1981,94 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     shareSessionEpoch = ShareRequestMetadata.nextEpoch(shareSessionEpoch)
     metadata = new ShareRequestMetadata(wrongMemberId, shareSessionEpoch)
     shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.SHARE_SESSION_NOT_FOUND.code, 
shareFetchResponseData.errorCode)
   }
 
+  @ClusterTests(
+    Array(
+      new ClusterTest(
+        serverProperties = Array(
+          new ClusterConfigProperty(key = "offsets.topic.num.partitions", 
value = "1"),
+          new ClusterConfigProperty(key = "offsets.topic.replication.factor", 
value = "1"),
+          new ClusterConfigProperty(key = "group.share.max.share.sessions", 
value="2"),
+          new ClusterConfigProperty(key = "group.share.max.size", value="2")
+        )
+      ),
+      new ClusterTest(
+        serverProperties = Array(
+          new ClusterConfigProperty(key = "offsets.topic.num.partitions", 
value = "1"),
+          new ClusterConfigProperty(key = "offsets.topic.replication.factor", 
value = "1"),
+          new ClusterConfigProperty(key = "group.share.persister.class.name", 
value = "org.apache.kafka.server.share.persister.DefaultStatePersister"),
+          new ClusterConfigProperty(key = 
"share.coordinator.state.topic.replication.factor", value = "1"),
+          new ClusterConfigProperty(key = 
"share.coordinator.state.topic.num.partitions", value = "1"),
+          new ClusterConfigProperty(key = "group.share.max.share.sessions", 
value="2"),
+          new ClusterConfigProperty(key = "group.share.max.size", value="2")
+        )
+      ),
+    )
+  )
+  def testShareSessionEvictedOnConnectionDrop(): Unit = {
+    val groupId: String = "group"
+    val memberId1 = Uuid.randomUuid()
+    val memberId2 = Uuid.randomUuid()
+    val memberId3 = Uuid.randomUuid()
+
+    val topic = "topic"
+    val partition = 0
+
+    createTopicAndReturnLeaders(topic, numPartitions = 3)
+    val topicIds = getTopicIds.asJava
+    val topicId = topicIds.get(topic)
+    val topicIdPartition = new TopicIdPartition(topicId, new 
TopicPartition(topic, partition))
+
+    val send: Seq[TopicIdPartition] = Seq(topicIdPartition)
+
+    // member1 sends share fetch request to register it's share session. Note 
it does not close the socket connection after.
+    TestUtils.waitUntilTrue(() => {
+      val metadata = new ShareRequestMetadata(memberId1, 
ShareRequestMetadata.INITIAL_EPOCH)
+      val shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
+      val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
+      val shareFetchResponseData = shareFetchResponse.data()
+      shareFetchResponseData.errorCode == Errors.NONE.code
+    }, "Share fetch request failed", 5000)
+
+    // member2 sends share fetch request to register it's share session. Note 
it does not close the socket connection after.
+    TestUtils.waitUntilTrue(() => {
+      val metadata = new ShareRequestMetadata(memberId2, 
ShareRequestMetadata.INITIAL_EPOCH)
+      val shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
+      val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
+      val shareFetchResponseData = shareFetchResponse.data()
+      shareFetchResponseData.errorCode == Errors.NONE.code
+    }, "Share fetch request failed", 5000)
+
+    // member3 sends share fetch request to register it's share session. Since 
the maximum number of share sessions that could
+    // exist in the share session cache is 2 (group.share.max.share.sessions), 
the attempt to register a third
+    // share session with the ShareSessionCache would throw 
SHARE_SESSION_LIMIT_REACHED
+    TestUtils.waitUntilTrue(() => {
+      val metadata = new ShareRequestMetadata(memberId3, 
ShareRequestMetadata.INITIAL_EPOCH)
+      val shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
+      val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
+      val shareFetchResponseData = shareFetchResponse.data()
+      shareFetchResponseData.errorCode == Errors.SHARE_SESSION_NOT_FOUND.code
+    }, "Share fetch request failed", 5000)
+
+    // Now we will close the socket connections for the above three members, 
mimicking a client disconnection
+    closeSockets()
+
+    // Since the socket connections were closed before, the corresponding 
share sessions were dropped from the ShareSessionCache
+    // on the broker. Now, since the cache is empty, new share sessions can be 
registered
+    TestUtils.waitUntilTrue(() => {
+      val metadata = new ShareRequestMetadata(memberId3, 
ShareRequestMetadata.INITIAL_EPOCH)
+      val shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
+      val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
+      val shareFetchResponseData = shareFetchResponse.data()
+      shareFetchResponseData.errorCode == Errors.NONE.code
+    }, "Share fetch request failed", 5000)
+  }
+
   @ClusterTests(
     Array(
       new ClusterTest(
@@ -2032,7 +2114,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     var shareSessionEpoch = 
ShareRequestMetadata.nextEpoch(ShareRequestMetadata.INITIAL_EPOCH)
     var metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     val shareFetchRequest = createShareFetchRequest(groupId, metadata, send, 
Seq.empty, Map.empty)
-    val shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     val shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -2059,7 +2141,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
         .setLastOffset(9)
         .setAcknowledgeTypes(Collections.singletonList(1.toByte))).asJava)
     val shareAcknowledgeRequest = createShareAcknowledgeRequest(groupId, 
metadata, acknowledgementsMap)
-    val shareAcknowledgeResponse = 
connectAndReceive[ShareAcknowledgeResponse](shareAcknowledgeRequest)
+    val shareAcknowledgeResponse = 
connectAndReceiveWithoutClosingSocket[ShareAcknowledgeResponse](shareAcknowledgeRequest)
 
     val shareAcknowledgeResponseData = shareAcknowledgeResponse.data()
     assertEquals(Errors.SHARE_SESSION_NOT_FOUND.code, 
shareAcknowledgeResponseData.errorCode)
@@ -2118,7 +2200,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     // as the share partitions might not be initialized yet. So, we retry 
until we get the response.
     var responses = Seq[ShareFetchResponseData.PartitionData]()
     TestUtils.waitUntilTrue(() => {
-      val shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+      val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
       val shareFetchResponseData = shareFetchResponse.data()
       assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
       assertEquals(30000, shareFetchResponseData.acquisitionLockTimeoutMs)
@@ -2144,7 +2226,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     metadata = new ShareRequestMetadata(memberId, shareSessionEpoch)
     val forget: Seq[TopicIdPartition] = Seq(topicIdPartition1)
     shareFetchRequest = createShareFetchRequest(groupId, metadata, Seq.empty, 
forget, acknowledgementsMap)
-    val shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+    val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
 
     val shareFetchResponseData = shareFetchResponse.data()
     assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
@@ -2294,7 +2376,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: 
ClusterInstance) extends GroupCo
     TestUtils.waitUntilTrue(() => {
       val metadata = new ShareRequestMetadata(memberId, 
ShareRequestMetadata.INITIAL_EPOCH)
       val shareFetchRequest = createShareFetchRequest(groupId, metadata, 
topicIdPartitions, Seq.empty, Map.empty)
-      val shareFetchResponse = 
connectAndReceive[ShareFetchResponse](shareFetchRequest)
+      val shareFetchResponse = 
connectAndReceiveWithoutClosingSocket[ShareFetchResponse](shareFetchRequest)
       val shareFetchResponseData = shareFetchResponse.data()
 
       assertEquals(Errors.NONE.code, shareFetchResponseData.errorCode)
diff --git 
a/core/src/test/scala/unit/kafka/server/ShareGroupHeartbeatRequestTest.scala 
b/core/src/test/scala/unit/kafka/server/ShareGroupHeartbeatRequestTest.scala
index a20cdfe9057..865870eef3b 100644
--- a/core/src/test/scala/unit/kafka/server/ShareGroupHeartbeatRequestTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ShareGroupHeartbeatRequestTest.scala
@@ -587,6 +587,80 @@ class ShareGroupHeartbeatRequestTest(cluster: 
ClusterInstance) {
   }
 
   @ClusterTest(
+    serverProperties = Array(
+      new ClusterConfigProperty(key = "offsets.topic.num.partitions", value = 
"1"),
+      new ClusterConfigProperty(key = "offsets.topic.replication.factor", 
value = "1"),
+      new ClusterConfigProperty(key = "group.share.max.size", value = "2")
+    ))
+  def testShareGroupMaxSizeConfigExceeded(): Unit = {
+    val groupId: String = "group"
+    val memberId1 = Uuid.randomUuid()
+    val memberId2 = Uuid.randomUuid()
+    val memberId3 = Uuid.randomUuid()
+
+    val admin = cluster.admin()
+
+    // Creates the __consumer_offsets topics because it won't be created 
automatically
+    // in this test because it does not use FindCoordinator API.
+    try {
+      TestUtils.createOffsetsTopicWithAdmin(
+        admin = admin,
+        brokers = cluster.brokers.values().asScala.toSeq,
+        controllers = cluster.controllers().values().asScala.toSeq
+      )
+
+      // Heartbeat request to join the group by the first member (memberId1).
+      var shareGroupHeartbeatRequest = new ShareGroupHeartbeatRequest.Builder(
+        new ShareGroupHeartbeatRequestData()
+          .setGroupId(groupId)
+          .setMemberId(memberId1.toString)
+          .setMemberEpoch(0)
+          .setSubscribedTopicNames(List("foo").asJava)
+      ).build()
+
+      // Send the request until receiving a successful response. There is a 
delay
+      // here because the group coordinator is loaded in the background.
+      var shareGroupHeartbeatResponse: ShareGroupHeartbeatResponse = null
+      TestUtils.waitUntilTrue(() => {
+        shareGroupHeartbeatResponse = 
connectAndReceive(shareGroupHeartbeatRequest)
+        shareGroupHeartbeatResponse.data.errorCode == Errors.NONE.code
+      }, msg = s"Could not join the group successfully. Last response 
$shareGroupHeartbeatResponse.")
+
+      // Heartbeat request to join the group by the second member (memberId2).
+      shareGroupHeartbeatRequest = new ShareGroupHeartbeatRequest.Builder(
+        new ShareGroupHeartbeatRequestData()
+          .setGroupId(groupId)
+          .setMemberId(memberId2.toString)
+          .setMemberEpoch(0)
+          .setSubscribedTopicNames(List("foo").asJava)
+      ).build()
+
+      // Send the request until receiving a successful response
+      TestUtils.waitUntilTrue(() => {
+        shareGroupHeartbeatResponse = 
connectAndReceive(shareGroupHeartbeatRequest)
+        shareGroupHeartbeatResponse.data.errorCode == Errors.NONE.code
+      }, msg = s"Could not join the group successfully. Last response 
$shareGroupHeartbeatResponse.")
+
+      // Heartbeat request to join the group by the third member (memberId3).
+      shareGroupHeartbeatRequest = new ShareGroupHeartbeatRequest.Builder(
+        new ShareGroupHeartbeatRequestData()
+          .setGroupId(groupId)
+          .setMemberId(memberId3.toString)
+          .setMemberEpoch(0)
+          .setSubscribedTopicNames(List("foo").asJava)
+      ).build()
+
+      shareGroupHeartbeatResponse = 
connectAndReceive(shareGroupHeartbeatRequest)
+      // Since the group.share.max.size config is set to 2, a third member 
cannot join the same group.
+      assertEquals(shareGroupHeartbeatResponse.data.errorCode, 
Errors.GROUP_MAX_SIZE_REACHED.code)
+
+    } finally {
+      admin.close()
+    }
+  }
+
+  @ClusterTest(
+    types = Array(Type.KRAFT),
     serverProperties = Array(
       new ClusterConfigProperty(key = "offsets.topic.num.partitions", value = 
"1"),
       new ClusterConfigProperty(key = "offsets.topic.replication.factor", 
value = "1"),
diff --git 
a/server/src/main/java/org/apache/kafka/server/share/session/ShareSessionCache.java
 
b/server/src/main/java/org/apache/kafka/server/share/session/ShareSessionCache.java
index 0b06ea535be..f0f37d9ec7d 100644
--- 
a/server/src/main/java/org/apache/kafka/server/share/session/ShareSessionCache.java
+++ 
b/server/src/main/java/org/apache/kafka/server/share/session/ShareSessionCache.java
@@ -21,6 +21,7 @@ import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.requests.ShareRequestMetadata;
 import org.apache.kafka.common.utils.ImplicitLinkedHashCollection;
 import org.apache.kafka.server.metrics.KafkaMetricsGroup;
+import org.apache.kafka.server.network.ConnectionDisconnectListener;
 import org.apache.kafka.server.share.CachedSharePartition;
 
 import com.yammer.metrics.core.Meter;
@@ -53,10 +54,13 @@ public class ShareSessionCache {
 
     private final int maxEntries;
     private long numPartitions = 0;
+    private final ConnectionDisconnectListener connectionDisconnectListener;
 
     // A map of session key to ShareSession.
     private final Map<ShareSessionKey, ShareSession> sessions = new 
HashMap<>();
 
+    private final Map<String, ShareSessionKey> connectionIdToSessionMap;
+
     @SuppressWarnings("this-escape")
     public ShareSessionCache(int maxEntries) {
         this.maxEntries = maxEntries;
@@ -64,6 +68,8 @@ public class ShareSessionCache {
         KafkaMetricsGroup metricsGroup = new KafkaMetricsGroup("kafka.server", 
"ShareSessionCache");
         metricsGroup.newGauge(SHARE_SESSIONS_COUNT, this::size);
         metricsGroup.newGauge(SHARE_PARTITIONS_COUNT, this::totalPartitions);
+        this.connectionIdToSessionMap = new HashMap<>();
+        this.connectionDisconnectListener = new 
ClientConnectionDisconnectListener();
         this.evictionsMeter = 
metricsGroup.newMeter(SHARE_SESSION_EVICTIONS_PER_SEC, "evictions", 
TimeUnit.SECONDS);
     }
 
@@ -123,21 +129,48 @@ public class ShareSessionCache {
      * @param groupId - The group id in the share fetch request.
      * @param memberId - The member id in the share fetch request.
      * @param partitionMap - The topic partitions to be added to the session.
+     * @param clientConnectionId - The client connection id.
      * @return - The session key if the session was created, or null if the 
session was not created.
      */
-    public synchronized ShareSessionKey maybeCreateSession(String groupId, 
Uuid memberId, ImplicitLinkedHashCollection<CachedSharePartition> partitionMap) 
{
+    public synchronized ShareSessionKey maybeCreateSession(
+        String groupId,
+        Uuid memberId,
+        ImplicitLinkedHashCollection<CachedSharePartition> partitionMap,
+        String clientConnectionId
+    ) {
         if (sessions.size() < maxEntries) {
             ShareSession session = new ShareSession(new 
ShareSessionKey(groupId, memberId), partitionMap,
                 
ShareRequestMetadata.nextEpoch(ShareRequestMetadata.INITIAL_EPOCH));
             sessions.put(session.key(), session);
             updateNumPartitions(session);
+            connectionIdToSessionMap.put(clientConnectionId, session.key());
             return session.key();
         }
         return null;
     }
 
+    public ConnectionDisconnectListener connectionDisconnectListener() {
+        return connectionDisconnectListener;
+    }
+
     // Visible for testing.
     Meter evictionsMeter() {
         return evictionsMeter;
     }
+
+    private final class ClientConnectionDisconnectListener implements 
ConnectionDisconnectListener {
+
+        // When the client disconnects, the corresponding session should be 
removed from the cache.
+        @Override
+        public void onDisconnect(String connectionId) {
+            ShareSessionKey shareSessionKey = 
connectionIdToSessionMap.remove(connectionId);
+            if (shareSessionKey != null) {
+                // Remove the session from the cache.
+                ShareSession removedSession = remove(shareSessionKey);
+                if (removedSession != null) {
+                    evictionsMeter.mark();
+                }
+            }
+        }
+    }
 }
diff --git 
a/server/src/test/java/org/apache/kafka/server/share/session/ShareSessionCacheTest.java
 
b/server/src/test/java/org/apache/kafka/server/share/session/ShareSessionCacheTest.java
index ca18de5b65c..c9692063b5c 100644
--- 
a/server/src/test/java/org/apache/kafka/server/share/session/ShareSessionCacheTest.java
+++ 
b/server/src/test/java/org/apache/kafka/server/share/session/ShareSessionCacheTest.java
@@ -45,11 +45,11 @@ public class ShareSessionCacheTest {
     public void testShareSessionCache() throws InterruptedException {
         ShareSessionCache cache = new ShareSessionCache(3);
         assertEquals(0, cache.size());
-        ShareSessionKey key1 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(10));
-        ShareSessionKey key2 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(20));
-        ShareSessionKey key3 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(30));
-        assertNull(cache.maybeCreateSession("grp", Uuid.randomUuid(), 
mockedSharePartitionMap(40)));
-        assertNull(cache.maybeCreateSession("grp", Uuid.randomUuid(), 
mockedSharePartitionMap(5)));
+        ShareSessionKey key1 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(10), "conn-1");
+        ShareSessionKey key2 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(20), "conn-2");
+        ShareSessionKey key3 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(30), "conn-3");
+        assertNull(cache.maybeCreateSession("grp", Uuid.randomUuid(), 
mockedSharePartitionMap(40), "conn-4"));
+        assertNull(cache.maybeCreateSession("grp", Uuid.randomUuid(), 
mockedSharePartitionMap(5), "conn-5"));
         assertShareCacheContains(cache, List.of(key1, key2, key3));
 
         assertMetricsValues(3, 60, 0, cache);
@@ -60,7 +60,7 @@ public class ShareSessionCacheTest {
         ShareSessionCache cache = new ShareSessionCache(2);
         assertEquals(0, cache.size());
         assertEquals(0, cache.totalPartitions());
-        ShareSessionKey key1 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(2));
+        ShareSessionKey key1 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(2), "conn-1");
         assertNotNull(key1);
         assertShareCacheContains(cache, List.of(key1));
         ShareSession session1 = cache.get(key1);
@@ -70,7 +70,7 @@ public class ShareSessionCacheTest {
 
         assertMetricsValues(1, 2, 0, cache);
 
-        ShareSessionKey key2 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(4));
+        ShareSessionKey key2 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(4), "conn-2");
         assertNotNull(key2);
         assertShareCacheContains(cache, List.of(key1, key2));
         ShareSession session2 = cache.get(key2);
@@ -81,7 +81,7 @@ public class ShareSessionCacheTest {
 
         assertMetricsValues(2, 6, 0, cache);
 
-        ShareSessionKey key3 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(5));
+        ShareSessionKey key3 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(5), "conn-3");
         assertNull(key3);
         assertShareCacheContains(cache, List.of(key1, key2));
         assertEquals(6, cache.totalPartitions());
@@ -109,6 +109,36 @@ public class ShareSessionCacheTest {
         assertMetricsValues(1, 3, 0, cache);
     }
 
+    @Test
+    public void testRemoveConnection() throws InterruptedException {
+        ShareSessionCache cache = new ShareSessionCache(3);
+        assertEquals(0, cache.size());
+        ShareSessionKey key1 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(1), "conn-1");
+        ShareSessionKey key2 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(2), "conn-2");
+        ShareSessionKey key3 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(3), "conn-3");
+
+        assertMetricsValues(3, 6, 0, cache);
+
+        // Since cache size is now equal to max entries allowed(3), no new 
session can be created.
+        assertNull(cache.maybeCreateSession("grp", Uuid.randomUuid(), 
mockedSharePartitionMap(40), "conn-4"));
+        assertNull(cache.maybeCreateSession("grp", Uuid.randomUuid(), 
mockedSharePartitionMap(5), "conn-5"));
+        assertShareCacheContains(cache, List.of(key1, key2, key3));
+
+        assertMetricsValues(3, 6, 0, cache);
+
+        // Simulating the disconnection of client with connection id conn-1
+        cache.connectionDisconnectListener().onDisconnect("conn-1");
+        assertShareCacheContains(cache, List.of(key2, key3));
+
+        assertMetricsValues(2, 5, 1, cache);
+
+        // Since one client got disconnected, we can add another one now
+        ShareSessionKey key4 = cache.maybeCreateSession("grp", 
Uuid.randomUuid(), mockedSharePartitionMap(4), "conn-6");
+        assertShareCacheContains(cache, List.of(key2, key3, key4));
+
+        assertMetricsValues(3, 9, 1, cache);
+    }
+
     private ImplicitLinkedHashCollection<CachedSharePartition> 
mockedSharePartitionMap(int size) {
         ImplicitLinkedHashCollection<CachedSharePartition> cacheMap = new
                 ImplicitLinkedHashCollection<>(size);

Reply via email to