This is an automated email from the ASF dual-hosted git repository.

showuon pushed a commit to branch 3.2
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.2 by this push:
     new d8541b20a1 KAFKA-14024: Consumer keeps Commit offset in onJoinPrepare 
in Cooperative rebalance (#12349)
d8541b20a1 is described below

commit d8541b20a106f22736947e0c2f293833f3c3873b
Author: Shawn <wangxiaofan....@bytedance.com>
AuthorDate: Wed Jul 20 10:03:43 2022 +0800

    KAFKA-14024: Consumer keeps Commit offset in onJoinPrepare in Cooperative 
rebalance (#12349)
    
    In KAFKA-13310, we tried to fix a issue that consumer#poll(duration) will 
be returned after the provided duration. It's because if rebalance needed, 
we'll try to commit current offset first before rebalance synchronously. And if 
the offset committing takes too long, the consumer#poll will spend more time 
than provided duration. To fix that, we change commit sync with commit async 
before rebalance (i.e. onPrepareJoin).
    
    However, in this ticket, we found the async commit will keep sending a new 
commit request during each Consumer#poll, because the offset commit never 
completes in time. The impact is that the existing consumer will be kicked out 
of the group after rebalance timeout without joining the group. That is, 
suppose we have consumer A in group G, and now consumer B joined the group, 
after the rebalance, only consumer B in the group.
    
    Besides, there's also another bug found during fixing this bug. Before 
KAFKA-13310, we commitOffset sync with rebalanceTimeout, which will retry when 
retriable error until timeout. After KAFKA-13310, we thought we have retry, but 
we'll retry after partitions revoking. That is, even though the retried offset 
commit successfully, it still causes some partitions offsets un-committed, and 
after rebalance, other consumers will consume overlapping records.
    
    Reviewers: RivenSun <riven....@zoom.us>, Luke Chen <show...@gmail.com>
---
 .../consumer/internals/AbstractCoordinator.java    |   5 +-
 .../consumer/internals/ConsumerCoordinator.java    |  73 ++++++++++---
 .../internals/AbstractCoordinatorTest.java         |   2 +-
 .../internals/ConsumerCoordinatorTest.java         | 120 ++++++++++++++++-----
 .../runtime/distributed/WorkerCoordinator.java     |   2 +-
 .../kafka/api/AbstractConsumerTest.scala           |  11 +-
 .../kafka/api/PlaintextConsumerTest.scala          |  87 +++++++++++++++
 7 files changed, 251 insertions(+), 49 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
index 5fe8a6a0e1..4d71482562 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
@@ -187,11 +187,12 @@ public abstract class AbstractCoordinator implements 
Closeable {
     /**
      * Invoked prior to each group join or rejoin. This is typically used to 
perform any
      * cleanup from the previous generation (such as committing offsets for 
the consumer)
+     * @param timer Timer bounding how long this method can block
      * @param generation The previous generation or -1 if there was none
      * @param memberId The identifier of this member in the previous group or 
"" if there was none
      * @return true If onJoinPrepare async commit succeeded, false otherwise
      */
-    protected abstract boolean onJoinPrepare(int generation, String memberId);
+    protected abstract boolean onJoinPrepare(Timer timer, int generation, 
String memberId);
 
     /**
      * Invoked when the leader is elected. This is used by the leader to 
perform the assignment
@@ -426,7 +427,7 @@ public abstract class AbstractCoordinator implements 
Closeable {
                 // exception, in which case upon retry we should not retry 
onJoinPrepare either.
                 needsJoinPrepare = false;
                 // return false when onJoinPrepare is waiting for committing 
offset
-                if (!onJoinPrepare(generation.generationId, 
generation.memberId)) {
+                if (!onJoinPrepare(timer, generation.generationId, 
generation.memberId)) {
                     needsJoinPrepare = true;
                     //should not initiateJoinGroup if needsJoinPrepare still 
is true
                     return false;
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index b853ff99e8..9838e7dc8f 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -141,6 +141,12 @@ public final class ConsumerCoordinator extends 
AbstractCoordinator {
     }
 
     private final RebalanceProtocol protocol;
+    // pending commit offset request in onJoinPrepare
+    private RequestFuture<Void> autoCommitOffsetRequestFuture = null;
+    // a timer for join prepare to know when to stop.
+    // it'll set to rebalance timeout so that the member can join the group 
successfully
+    // even though offset commit failed.
+    private Timer joinPrepareTimer = null;
 
     /**
      * Initialize the coordination manager.
@@ -740,24 +746,58 @@ public final class ConsumerCoordinator extends 
AbstractCoordinator {
     }
 
     @Override
-    protected boolean onJoinPrepare(int generation, String memberId) {
+    protected boolean onJoinPrepare(Timer timer, int generation, String 
memberId) {
         log.debug("Executing onJoinPrepare with generation {} and memberId 
{}", generation, memberId);
-        boolean onJoinPrepareAsyncCommitCompleted = false;
+        if (joinPrepareTimer == null) {
+            // We should complete onJoinPrepare before rebalanceTimeout,
+            // and continue to join group to avoid member got kicked out from 
group
+            joinPrepareTimer = time.timer(rebalanceConfig.rebalanceTimeoutMs);
+        } else {
+            joinPrepareTimer.update();
+        }
+
         // async commit offsets prior to rebalance if auto-commit enabled
-        RequestFuture<Void> future = maybeAutoCommitOffsetsAsync();
-        // return true when
-        // 1. future is null, which means no commit request sent, so it is 
still considered completed
-        // 2. offset commit completed
-        // 3. offset commit failed with non-retriable exception
-        if (future == null)
-            onJoinPrepareAsyncCommitCompleted = true;
-        else if (future.succeeded())
-            onJoinPrepareAsyncCommitCompleted = true;
-        else if (future.failed() && !future.isRetriable()) {
-            log.error("Asynchronous auto-commit of offsets failed: {}", 
future.exception().getMessage());
-            onJoinPrepareAsyncCommitCompleted = true;
+        // and there is no in-flight offset commit request
+        if (autoCommitEnabled && autoCommitOffsetRequestFuture == null) {
+            autoCommitOffsetRequestFuture = maybeAutoCommitOffsetsAsync();
         }
 
+        // wait for commit offset response before timer expired
+        if (autoCommitOffsetRequestFuture != null) {
+            Timer pollTimer = timer.remainingMs() < 
joinPrepareTimer.remainingMs() ?
+                    timer : joinPrepareTimer;
+            client.poll(autoCommitOffsetRequestFuture, pollTimer);
+            joinPrepareTimer.update();
+
+            // Keep retrying/waiting the offset commit when:
+            // 1. offset commit haven't done (and joinPrepareTimer not expired)
+            // 2. failed with retryable exception (and joinPrepareTimer not 
expired)
+            // Otherwise, continue to revoke partitions, ex:
+            // 1. if joinPrepareTime has expired
+            // 2. if offset commit failed with no-retryable exception
+            // 3. if offset commit success
+            boolean onJoinPrepareAsyncCommitCompleted = true;
+            if (joinPrepareTimer.isExpired()) {
+                log.error("Asynchronous auto-commit of offsets failed: 
joinPrepare timeout. Will continue to join group");
+            } else if (!autoCommitOffsetRequestFuture.isDone()) {
+                onJoinPrepareAsyncCommitCompleted = false;
+            } else if (autoCommitOffsetRequestFuture.failed() && 
autoCommitOffsetRequestFuture.isRetriable()) {
+                log.debug("Asynchronous auto-commit of offsets failed with 
retryable error: {}. Will retry it.",
+                        
autoCommitOffsetRequestFuture.exception().getMessage());
+                onJoinPrepareAsyncCommitCompleted = false;
+            } else if (autoCommitOffsetRequestFuture.failed() && 
!autoCommitOffsetRequestFuture.isRetriable()) {
+                log.error("Asynchronous auto-commit of offsets failed: {}. 
Will continue to join group.",
+                        
autoCommitOffsetRequestFuture.exception().getMessage());
+            }
+            if (autoCommitOffsetRequestFuture.isDone()) {
+                autoCommitOffsetRequestFuture = null;
+            }
+            if (!onJoinPrepareAsyncCommitCompleted) {
+                pollTimer.sleep(Math.min(pollTimer.remainingMs(), 
rebalanceConfig.retryBackoffMs));
+                timer.update();
+                return false;
+            }
+        }
 
         // the generation / member-id can possibly be reset by the heartbeat 
thread
         // upon getting errors or heartbeat timeouts; in this case whatever is 
previously
@@ -809,11 +849,14 @@ public final class ConsumerCoordinator extends 
AbstractCoordinator {
 
         isLeader = false;
         subscriptions.resetGroupSubscription();
+        joinPrepareTimer = null;
+        autoCommitOffsetRequestFuture = null;
+        timer.update();
 
         if (exception != null) {
             throw new KafkaException("User rebalance callback throws an 
error", exception);
         }
-        return onJoinPrepareAsyncCommitCompleted;
+        return true;
     }
 
     @Override
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
index 7cf6ee0e66..69dd893e12 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
@@ -1661,7 +1661,7 @@ public class AbstractCoordinatorTest {
         }
 
         @Override
-        protected boolean onJoinPrepare(int generation, String memberId) {
+        protected boolean onJoinPrepare(Timer timer, int generation, String 
memberId) {
             onJoinPrepareInvokes++;
             return true;
         }
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
index c65d33176f..df88d84f08 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
@@ -77,6 +77,7 @@ import org.apache.kafka.common.requests.SyncGroupResponse;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.SystemTime;
+import org.apache.kafka.common.utils.Timer;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.test.TestUtils;
 import org.junit.jupiter.api.AfterEach;
@@ -1299,9 +1300,71 @@ public abstract class ConsumerCoordinatorTest {
         }
     }
 
+    @Test
+    public void testOnJoinPrepareWithOffsetCommitShouldSuccessAfterRetry() {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.empty(), false)) {
+            int generationId = 42;
+            String memberId = "consumer-42";
+
+            Timer pollTimer = time.timer(100L);
+            client.prepareResponse(offsetCommitResponse(singletonMap(t1p, 
Errors.UNKNOWN_TOPIC_OR_PARTITION)));
+            boolean res = coordinator.onJoinPrepare(pollTimer, generationId, 
memberId);
+            assertFalse(res);
+
+            pollTimer = time.timer(100L);
+            client.prepareResponse(offsetCommitResponse(singletonMap(t1p, 
Errors.NONE)));
+            res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertTrue(res);
+
+            assertFalse(client.hasPendingResponses());
+            assertFalse(client.hasInFlightRequests());
+            assertFalse(coordinator.coordinatorUnknown());
+        }
+    }
+
+    @Test
+    public void 
testOnJoinPrepareWithOffsetCommitShouldKeepJoinAfterNonRetryableException() {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.empty(), false)) {
+            int generationId = 42;
+            String memberId = "consumer-42";
+
+            Timer pollTimer = time.timer(100L);
+            client.prepareResponse(offsetCommitResponse(singletonMap(t1p, 
Errors.UNKNOWN_MEMBER_ID)));
+            boolean res = coordinator.onJoinPrepare(pollTimer, generationId, 
memberId);
+            assertTrue(res);
+
+            assertFalse(client.hasPendingResponses());
+            assertFalse(client.hasInFlightRequests());
+            assertFalse(coordinator.coordinatorUnknown());
+        }
+    }
+
+    @Test
+    public void 
testOnJoinPrepareWithOffsetCommitShouldKeepJoinAfterRebalanceTimeout() {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.empty(), false)) {
+            int generationId = 42;
+            String memberId = "consumer-42";
+
+            Timer pollTimer = time.timer(100L);
+            time.sleep(150);
+            boolean res = coordinator.onJoinPrepare(pollTimer, generationId, 
memberId);
+            assertFalse(res);
+
+            pollTimer = time.timer(100L);
+            time.sleep(rebalanceTimeoutMs);
+            client.respond(offsetCommitResponse(singletonMap(t1p, 
Errors.UNKNOWN_TOPIC_OR_PARTITION)));
+            res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertTrue(res);
+
+            assertFalse(client.hasPendingResponses());
+            assertFalse(client.hasInFlightRequests());
+            assertFalse(coordinator.coordinatorUnknown());
+        }
+    }
+
     @Test
     public void testJoinPrepareWithDisableAutoCommit() {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
@@ -1309,7 +1372,7 @@ public abstract class ConsumerCoordinatorTest {
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), 
generationId, memberId);
 
             assertTrue(res);
             assertTrue(client.hasPendingResponses());
@@ -1320,14 +1383,14 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testJoinPrepareAndCommitCompleted() {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), 
generationId, memberId);
             coordinator.invokeCompletedOffsetCommitCallbacks();
 
             assertTrue(res);
@@ -1339,7 +1402,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testJoinPrepareAndCommitWithCoordinatorNotAvailable() {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), 
Errors.COORDINATOR_NOT_AVAILABLE);
@@ -1347,7 +1410,7 @@ public abstract class ConsumerCoordinatorTest {
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), 
generationId, memberId);
             coordinator.invokeCompletedOffsetCommitCallbacks();
 
             assertFalse(res);
@@ -1359,7 +1422,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testJoinPrepareAndCommitWithUnknownMemberId() {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), 
Errors.UNKNOWN_MEMBER_ID);
@@ -1367,7 +1430,7 @@ public abstract class ConsumerCoordinatorTest {
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), 
generationId, memberId);
             coordinator.invokeCompletedOffsetCommitCallbacks();
 
             assertTrue(res);
@@ -3078,21 +3141,21 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseDynamicAssignment() {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.empty(), true)) {
             gracefulCloseTest(coordinator, true);
         }
     }
 
     @Test
     public void testCloseManualAssignment() {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(false, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(false, true, Optional.empty(), true)) {
             gracefulCloseTest(coordinator, false);
         }
     }
 
     @Test
     public void testCloseCoordinatorNotKnownManualAssignment() throws 
Exception {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(false, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(false, true, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 1000, 1000, 1000);
@@ -3101,7 +3164,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseCoordinatorNotKnownNoCommits() throws Exception {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR);
             closeVerifyTimeout(coordinator, 1000, 0, 0);
         }
@@ -3109,7 +3172,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseCoordinatorNotKnownWithCommits() throws Exception {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 1000, 1000, 1000);
@@ -3118,7 +3181,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseCoordinatorUnavailableNoCommits() throws Exception {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, 
Errors.COORDINATOR_NOT_AVAILABLE);
             closeVerifyTimeout(coordinator, 1000, 0, 0);
         }
@@ -3126,7 +3189,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseTimeoutCoordinatorUnavailableForCommit() throws 
Exception {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             makeCoordinatorUnknown(coordinator, 
Errors.COORDINATOR_NOT_AVAILABLE);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 1000, 1000, 1000);
@@ -3135,7 +3198,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseMaxWaitCoordinatorUnavailableForCommit() throws 
Exception {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             makeCoordinatorUnknown(coordinator, 
Errors.COORDINATOR_NOT_AVAILABLE);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, 
requestTimeoutMs);
@@ -3144,7 +3207,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseNoResponseForCommit() throws Exception {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, 
requestTimeoutMs);
         }
@@ -3152,14 +3215,14 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseNoResponseForLeaveGroup() throws Exception {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.empty(), true)) {
             closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, 
requestTimeoutMs);
         }
     }
 
     @Test
     public void testCloseNoWait() throws Exception {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 0, 0, 0);
         }
@@ -3167,7 +3230,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testHeartbeatThreadClose() throws Exception {
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             coordinator.ensureActiveGroup();
             time.sleep(heartbeatIntervalMs + 100);
             Thread.yield(); // Give heartbeat thread a chance to attempt 
heartbeat
@@ -3234,7 +3297,7 @@ public abstract class ConsumerCoordinatorTest {
         assertEquals(JoinGroupRequest.UNKNOWN_MEMBER_ID, 
groupMetadata.memberId());
         assertFalse(groupMetadata.groupInstanceId().isPresent());
 
-        try (final ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (final ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             coordinator.ensureActiveGroup();
 
             final ConsumerGroupMetadata joinedGroupMetadata = 
coordinator.groupMetadata();
@@ -3270,7 +3333,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testPrepareJoinAndRejoinAfterFailedRebalance() {
         final List<TopicPartition> partitions = singletonList(t1p);
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), 
Errors.REBALANCE_IN_PROGRESS);
@@ -3290,7 +3353,7 @@ public abstract class ConsumerCoordinatorTest {
             MockTime time = new MockTime(1);
 
             // onJoinPrepare will be executed and onJoinComplete will not.
-            boolean res = coordinator.joinGroupIfNeeded(time.timer(2));
+            boolean res = coordinator.joinGroupIfNeeded(time.timer(100));
 
             assertFalse(res);
             assertFalse(client.hasPendingResponses());
@@ -3335,7 +3398,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void 
shouldLoseAllOwnedPartitionsBeforeRejoiningAfterDroppingOutOfTheGroup() {
         final List<TopicPartition> partitions = singletonList(t1p);
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             final SystemTime realTime = new SystemTime();
             coordinator.ensureActiveGroup();
 
@@ -3368,7 +3431,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void 
shouldLoseAllOwnedPartitionsBeforeRejoiningAfterResettingGenerationId() {
         final List<TopicPartition> partitions = singletonList(t1p);
-        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = 
prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             final SystemTime realTime = new SystemTime();
             coordinator.ensureActiveGroup();
 
@@ -3462,7 +3525,8 @@ public abstract class ConsumerCoordinatorTest {
 
     private ConsumerCoordinator prepareCoordinatorForCloseTest(final boolean 
useGroupManagement,
                                                                final boolean 
autoCommit,
-                                                               final 
Optional<String> groupInstanceId) {
+                                                               final 
Optional<String> groupInstanceId,
+                                                               final boolean 
shouldPoll) {
         rebalanceConfig = buildRebalanceConfig(groupInstanceId);
         ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig,
                                                            new Metrics(),
@@ -3481,7 +3545,9 @@ public abstract class ConsumerCoordinatorTest {
         }
 
         subscriptions.seek(t1p, 100);
-        coordinator.poll(time.timer(Long.MAX_VALUE));
+        if (shouldPoll) {
+            coordinator.poll(time.timer(Long.MAX_VALUE));
+        }
 
         return coordinator;
     }
diff --git 
a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
 
b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
index 65720e2a78..ce1b82a9d5 100644
--- 
a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
+++ 
b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
@@ -224,7 +224,7 @@ public class WorkerCoordinator extends AbstractCoordinator 
implements Closeable
     }
 
     @Override
-    protected boolean onJoinPrepare(int generation, String memberId) {
+    protected boolean onJoinPrepare(Timer timer, int generation, String 
memberId) {
         log.info("Rebalance started");
         leaderState(null);
         final ExtendedAssignment localAssignmentSnapshot = assignmentSnapshot;
diff --git 
a/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala 
b/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala
index 56bc47c79e..23b56b8e91 100644
--- a/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala
@@ -342,15 +342,16 @@ abstract class AbstractConsumerTest extends 
BaseRequestTest {
 
   protected class ConsumerAssignmentPoller(consumer: Consumer[Array[Byte], 
Array[Byte]],
                                            topicsToSubscribe: List[String],
-                                           partitionsToAssign: 
Set[TopicPartition])
+                                           partitionsToAssign: 
Set[TopicPartition],
+                                           userRebalanceListener: 
ConsumerRebalanceListener)
     extends ShutdownableThread("daemon-consumer-assignment", false) {
 
     def this(consumer: Consumer[Array[Byte], Array[Byte]], topicsToSubscribe: 
List[String]) = {
-      this(consumer, topicsToSubscribe, Set.empty[TopicPartition])
+      this(consumer, topicsToSubscribe, Set.empty[TopicPartition], null)
     }
 
     def this(consumer: Consumer[Array[Byte], Array[Byte]], partitionsToAssign: 
Set[TopicPartition]) = {
-      this(consumer, List.empty[String], partitionsToAssign)
+      this(consumer, List.empty[String], partitionsToAssign, null)
     }
 
     @volatile var thrownException: Option[Throwable] = None
@@ -363,10 +364,14 @@ abstract class AbstractConsumerTest extends 
BaseRequestTest {
     val rebalanceListener: ConsumerRebalanceListener = new 
ConsumerRebalanceListener {
       override def onPartitionsAssigned(partitions: 
util.Collection[TopicPartition]) = {
         partitionAssignment ++= partitions.toArray(new 
Array[TopicPartition](0))
+        if (userRebalanceListener != null)
+          userRebalanceListener.onPartitionsAssigned(partitions)
       }
 
       override def onPartitionsRevoked(partitions: 
util.Collection[TopicPartition]) = {
         partitionAssignment --= partitions.toArray(new 
Array[TopicPartition](0))
+        if (userRebalanceListener != null)
+          userRebalanceListener.onPartitionsRevoked(partitions)
       }
     }
 
diff --git 
a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala 
b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
index 4ede241b0c..5dc7c2ada1 100644
--- a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
@@ -37,7 +37,11 @@ import kafka.server.QuotaType
 import kafka.server.KafkaServer
 import org.apache.kafka.clients.admin.NewPartitions
 import org.apache.kafka.clients.admin.NewTopic
+import org.junit.jupiter.params.ParameterizedTest
+import org.junit.jupiter.params.provider.ValueSource
 
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.locks.ReentrantLock
 import scala.collection.mutable
 
 /* We have some tests in this class instead of `BaseConsumerTest` in order to 
keep the build time under control. */
@@ -969,6 +973,89 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     }
   }
 
+  @ParameterizedTest
+  @ValueSource(strings = Array(
+    "org.apache.kafka.clients.consumer.CooperativeStickyAssignor",
+    "org.apache.kafka.clients.consumer.RangeAssignor"))
+  def testRebalanceAndRejoin(assignmentStrategy: String): Unit = {
+    // create 2 consumers
+    this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, 
"rebalance-and-rejoin-group")
+    
this.consumerConfig.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG,
 assignmentStrategy)
+    this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, 
"true")
+    val consumer1 = createConsumer()
+    val consumer2 = createConsumer()
+
+    // create a new topic, have 2 partitions
+    val topic = "topic1"
+    val producer = createProducer()
+    val expectedAssignment = createTopicAndSendRecords(producer, topic, 2, 100)
+
+    assertEquals(0, consumer1.assignment().size)
+    assertEquals(0, consumer2.assignment().size)
+
+    val lock = new ReentrantLock()
+    var generationId1 = -1
+    var memberId1 = ""
+    val customRebalanceListener = new ConsumerRebalanceListener {
+      override def onPartitionsRevoked(partitions: 
util.Collection[TopicPartition]): Unit = {
+      }
+      override def onPartitionsAssigned(partitions: 
util.Collection[TopicPartition]): Unit = {
+        if (!lock.tryLock(3000, TimeUnit.MILLISECONDS)) {
+          fail(s"Time out while awaiting for lock.")
+        }
+        try {
+          generationId1 = consumer1.groupMetadata().generationId()
+          memberId1 = consumer1.groupMetadata().memberId()
+        } finally {
+          lock.unlock()
+        }
+      }
+    }
+    val consumerPoller1 = new ConsumerAssignmentPoller(consumer1, List(topic), 
Set.empty, customRebalanceListener)
+    consumerPoller1.start()
+    TestUtils.waitUntilTrue(() => consumerPoller1.consumerAssignment() == 
expectedAssignment,
+      s"Timed out while awaiting expected assignment change to 
$expectedAssignment.")
+
+    // Since the consumer1 already completed the rebalance,
+    // the `onPartitionsAssigned` rebalance listener will be invoked to set 
the generationId and memberId
+    var stableGeneration = -1
+    var stableMemberId1 = ""
+    if (!lock.tryLock(3000, TimeUnit.MILLISECONDS)) {
+      fail(s"Time out while awaiting for lock.")
+    }
+    try {
+      stableGeneration = generationId1
+      stableMemberId1 = memberId1
+    } finally {
+      lock.unlock()
+    }
+
+    val consumerPoller2 = subscribeConsumerAndStartPolling(consumer2, 
List(topic))
+    TestUtils.waitUntilTrue(() => consumerPoller1.consumerAssignment().size == 
1,
+      s"Timed out while awaiting expected assignment size change to 1.")
+    TestUtils.waitUntilTrue(() => consumerPoller2.consumerAssignment().size == 
1,
+      s"Timed out while awaiting expected assignment size change to 1.")
+
+    if (!lock.tryLock(3000, TimeUnit.MILLISECONDS)) {
+      fail(s"Time out while awaiting for lock.")
+    }
+    try {
+      if 
(assignmentStrategy.equals(classOf[CooperativeStickyAssignor].getName)) {
+        // cooperative rebalance should rebalance twice before finally stable
+        assertEquals(stableGeneration + 2, generationId1)
+      } else {
+        // eager rebalance should rebalance once before finally stable
+        assertEquals(stableGeneration + 1, generationId1)
+      }
+      assertEquals(stableMemberId1, memberId1)
+    } finally {
+      lock.unlock()
+    }
+
+    consumerPoller1.shutdown()
+    consumerPoller2.shutdown()
+  }
+
   /**
    * This test re-uses BaseConsumerTest's consumers.
    * As a result, it is testing the default assignment strategy set by 
BaseConsumerTest

Reply via email to