This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch 4.1
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/4.1 by this push:
     new 6c1ec56436e KAFKA-19694: Trigger StreamsRebalanceListener in 
Consumer.close (#20511)
6c1ec56436e is described below

commit 6c1ec56436e1e362b7e0a1c7ff9a9c1499cc3d43
Author: Lucas Brutschy <[email protected]>
AuthorDate: Tue Sep 16 16:32:47 2025 +0200

    KAFKA-19694: Trigger StreamsRebalanceListener in Consumer.close (#20511)
    
    In the consumer, we invoke the consumer rebalance onPartitionRevoked or
    onPartitionLost callbacks, when the consumer closes. The point is that
    the application may want to commit, or wipe the state if we are closing
    unsuccessfully.
    
    In the StreamsRebalanceListener, we did not implement this behavior,
    which means when closing the consumer we may lose some progress, and in
    the worst case also miss that we have to wipe our local state state
    since we got fenced.
    
    In this PR we implement StreamsRebalanceListenerInvoker, very similarly
    to ConsumerRebalanceListenerInvoker and invoke it in Consumer.close.
    
    Reviewers: Lianet Magrans <[email protected]>, Matthias J. Sax
     <[email protected]>, TengYao Chi <[email protected]>,
     Uladzislau Blok <[email protected]>
    
    ---------
    
    Co-authored-by: Copilot <[email protected]>
---
 .../consumer/internals/AsyncKafkaConsumer.java     | 101 +++----
 .../internals/StreamsRebalanceListenerInvoker.java | 117 ++++++++
 .../consumer/internals/AsyncKafkaConsumerTest.java |  67 +++++
 .../StreamsRebalanceListenerInvokerTest.java       | 293 +++++++++++++++++++++
 .../internals/DefaultStreamsRebalanceListener.java |   2 +
 .../DefaultStreamsRebalanceListenerTest.java       |  98 ++++---
 6 files changed, 577 insertions(+), 101 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
index 6f1f8c8bc64..d457b7e48f0 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
@@ -186,25 +186,6 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
      */
     private class BackgroundEventProcessor implements 
EventProcessor<BackgroundEvent> {
 
-        private Optional<StreamsRebalanceListener> streamsRebalanceListener = 
Optional.empty();
-        private final Optional<StreamsRebalanceData> streamsRebalanceData;
-
-        public BackgroundEventProcessor() {
-            this.streamsRebalanceData = Optional.empty();
-        }
-
-        public BackgroundEventProcessor(final Optional<StreamsRebalanceData> 
streamsRebalanceData) {
-            this.streamsRebalanceData = streamsRebalanceData;
-        }
-
-        private void setStreamsRebalanceListener(final 
StreamsRebalanceListener streamsRebalanceListener) {
-            if (streamsRebalanceData.isEmpty()) {
-                throw new IllegalStateException("Background event processor 
was not created to be used with Streams " +
-                    "rebalance protocol events");
-            }
-            this.streamsRebalanceListener = 
Optional.of(streamsRebalanceListener);
-        }
-
         @Override
         public void process(final BackgroundEvent event) {
             switch (event.type()) {
@@ -277,44 +258,26 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
 
         private StreamsOnTasksRevokedCallbackCompletedEvent 
invokeOnTasksRevokedCallback(final Set<StreamsRebalanceData.TaskId> 
activeTasksToRevoke,
                                                                                
          final CompletableFuture<Void> future) {
-            final Optional<Exception> exceptionFromCallback = 
streamsRebalanceListener().onTasksRevoked(activeTasksToRevoke);
+            final Optional<Exception> exceptionFromCallback = 
Optional.ofNullable(streamsRebalanceListenerInvoker().invokeTasksRevoked(activeTasksToRevoke));
             final Optional<KafkaException> error = exceptionFromCallback.map(e 
-> ConsumerUtils.maybeWrapAsKafkaException(e, "Task revocation callback throws 
an error"));
             return new StreamsOnTasksRevokedCallbackCompletedEvent(future, 
error);
         }
 
         private StreamsOnTasksAssignedCallbackCompletedEvent 
invokeOnTasksAssignedCallback(final StreamsRebalanceData.Assignment assignment,
                                                                                
            final CompletableFuture<Void> future) {
-            final Optional<KafkaException> error;
-            final Optional<Exception> exceptionFromCallback = 
streamsRebalanceListener().onTasksAssigned(assignment);
-            if (exceptionFromCallback.isPresent()) {
-                error = 
Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(),
 "Task assignment callback throws an error"));
-            } else {
-                error = Optional.empty();
-                streamsRebalanceData().setReconciledAssignment(assignment);
-            }
+            final Optional<Exception> exceptionFromCallback = 
Optional.ofNullable(streamsRebalanceListenerInvoker().invokeTasksAssigned(assignment));
+            final Optional<KafkaException> error = exceptionFromCallback.map(e 
-> ConsumerUtils.maybeWrapAsKafkaException(e, "Task assignment callback throws 
an error"));
             return new StreamsOnTasksAssignedCallbackCompletedEvent(future, 
error);
         }
 
         private StreamsOnAllTasksLostCallbackCompletedEvent 
invokeOnAllTasksLostCallback(final CompletableFuture<Void> future) {
-            final Optional<KafkaException> error;
-            final Optional<Exception> exceptionFromCallback = 
streamsRebalanceListener().onAllTasksLost();
-            if (exceptionFromCallback.isPresent()) {
-                error = 
Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(),
 "All tasks lost callback throws an error"));
-            } else {
-                error = Optional.empty();
-                
streamsRebalanceData().setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
-            }
+            final Optional<Exception> exceptionFromCallback = 
Optional.ofNullable(streamsRebalanceListenerInvoker().invokeAllTasksLost());
+            final Optional<KafkaException> error = exceptionFromCallback.map(e 
-> ConsumerUtils.maybeWrapAsKafkaException(e, "All tasks lost callback throws 
an error"));
             return new StreamsOnAllTasksLostCallbackCompletedEvent(future, 
error);
         }
 
-        private StreamsRebalanceData streamsRebalanceData() {
-            return streamsRebalanceData.orElseThrow(
-                () -> new IllegalStateException("Background event processor 
was not created to be used with Streams " +
-                    "rebalance protocol events"));
-        }
-
-        private StreamsRebalanceListener streamsRebalanceListener() {
-            return streamsRebalanceListener.orElseThrow(
+        private StreamsRebalanceListenerInvoker 
streamsRebalanceListenerInvoker() {
+            return streamsRebalanceListenerInvoker.orElseThrow(
                 () -> new IllegalStateException("Background event processor 
was not created to be used with Streams " +
                     "rebalance protocol events"));
         }
@@ -365,6 +328,7 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
     private final WakeupTrigger wakeupTrigger = new WakeupTrigger();
     private final OffsetCommitCallbackInvoker offsetCommitCallbackInvoker;
     private final ConsumerRebalanceListenerInvoker rebalanceListenerInvoker;
+    private final Optional<StreamsRebalanceListenerInvoker> 
streamsRebalanceListenerInvoker;
     // Last triggered async commit future. Used to wait until all previous 
async commits are completed.
     // We only need to keep track of the last one, since they are guaranteed 
to complete in order.
     private CompletableFuture<Map<TopicPartition, OffsetAndMetadata>> 
lastPendingAsyncCommit = null;
@@ -514,7 +478,9 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
                     time,
                     new RebalanceCallbackMetricsManager(metrics)
             );
-            this.backgroundEventProcessor = new 
BackgroundEventProcessor(streamsRebalanceData);
+            this.streamsRebalanceListenerInvoker = streamsRebalanceData.map(s 
->
+                new StreamsRebalanceListenerInvoker(logContext, s));
+            this.backgroundEventProcessor = new BackgroundEventProcessor();
             this.backgroundEventReaper = 
backgroundEventReaperFactory.build(logContext);
 
             // The FetchCollector is only used on the application thread.
@@ -574,6 +540,7 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
         this.time = time;
         this.backgroundEventQueue = backgroundEventQueue;
         this.rebalanceListenerInvoker = rebalanceListenerInvoker;
+        this.streamsRebalanceListenerInvoker = Optional.empty();
         this.backgroundEventProcessor = new BackgroundEventProcessor();
         this.backgroundEventReaper = backgroundEventReaper;
         this.metrics = metrics;
@@ -694,6 +661,7 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
                 networkClientDelegateSupplier,
                 requestManagersSupplier,
                 kafkaConsumerMetrics);
+        this.streamsRebalanceListenerInvoker = Optional.empty();
         this.backgroundEventProcessor = new BackgroundEventProcessor();
         this.backgroundEventReaper = new CompletableEventReaper(logContext);
     }
@@ -1472,7 +1440,7 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
             () -> autoCommitOnClose(closeTimer), firstException);
         swallow(log, Level.ERROR, "Failed to stop finding coordinator",
             this::stopFindCoordinatorOnClose, firstException);
-        swallow(log, Level.ERROR, "Failed to release group assignment",
+        swallow(log, Level.ERROR, "Failed to run rebalance callbacks",
             this::runRebalanceCallbacksOnClose, firstException);
         swallow(log, Level.ERROR, "Failed to leave group while closing 
consumer",
             () -> leaveGroupOnClose(closeTimer, membershipOperation), 
firstException);
@@ -1526,21 +1494,34 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
 
         int memberEpoch = groupMetadata.get().get().generationId();
 
-        Set<TopicPartition> assignedPartitions = groupAssignmentSnapshot.get();
+        Exception error = null;
 
-        if (assignedPartitions.isEmpty())
-            // Nothing to revoke.
-            return;
+        if (streamsRebalanceListenerInvoker != null && 
streamsRebalanceListenerInvoker.isPresent()) {
+
+            if (memberEpoch > 0) {
+                error = 
streamsRebalanceListenerInvoker.get().invokeAllTasksRevoked();
+            } else {
+                error = 
streamsRebalanceListenerInvoker.get().invokeAllTasksLost();
+            }
 
-        SortedSet<TopicPartition> droppedPartitions = new 
TreeSet<>(TOPIC_PARTITION_COMPARATOR);
-        droppedPartitions.addAll(assignedPartitions);
+        } else if (rebalanceListenerInvoker != null) {
 
-        final Exception error;
+            Set<TopicPartition> assignedPartitions = 
groupAssignmentSnapshot.get();
 
-        if (memberEpoch > 0)
-            error = 
rebalanceListenerInvoker.invokePartitionsRevoked(droppedPartitions);
-        else
-            error = 
rebalanceListenerInvoker.invokePartitionsLost(droppedPartitions);
+            if (assignedPartitions.isEmpty())
+                // Nothing to revoke.
+                return;
+
+            SortedSet<TopicPartition> droppedPartitions = new 
TreeSet<>(TOPIC_PARTITION_COMPARATOR);
+            droppedPartitions.addAll(assignedPartitions);
+
+            if (memberEpoch > 0) {
+                error = 
rebalanceListenerInvoker.invokePartitionsRevoked(droppedPartitions);
+            } else {
+                error = 
rebalanceListenerInvoker.invokePartitionsLost(droppedPartitions);
+            }
+
+        }
 
         if (error != null)
             throw ConsumerUtils.maybeWrapAsKafkaException(error);
@@ -1957,8 +1938,12 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
     }
 
     public void subscribe(Collection<String> topics, StreamsRebalanceListener 
streamsRebalanceListener) {
+
+        streamsRebalanceListenerInvoker
+            .orElseThrow(() -> new IllegalStateException("Consumer was not 
created to be used with Streams rebalance protocol events"))
+            .setRebalanceListener(streamsRebalanceListener);
+
         subscribeInternal(topics, Optional.empty());
-        
backgroundEventProcessor.setStreamsRebalanceListener(streamsRebalanceListener);
     }
 
     @Override
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java
new file mode 100644
index 00000000000..f4c5aa4addc
--- /dev/null
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.WakeupException;
+import org.apache.kafka.common.utils.LogContext;
+
+import org.slf4j.Logger;
+
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * This class encapsulates the invocation of the callback methods defined in 
the {@link StreamsRebalanceListener}
+ * interface. When streams group task assignment changes, these methods are 
invoked. This class wraps those
+ * callback calls with some logging and error handling.
+ */
+public class StreamsRebalanceListenerInvoker {
+
+    private final Logger log;
+
+    private final StreamsRebalanceData streamsRebalanceData;
+    private Optional<StreamsRebalanceListener> listener;
+
+    StreamsRebalanceListenerInvoker(LogContext logContext, 
StreamsRebalanceData streamsRebalanceData) {
+        this.log = logContext.logger(getClass());
+        this.listener = Optional.empty();
+        this.streamsRebalanceData = streamsRebalanceData;
+    }
+
+    public void setRebalanceListener(StreamsRebalanceListener 
streamsRebalanceListener) {
+        Objects.requireNonNull(streamsRebalanceListener, 
"StreamsRebalanceListener cannot be null");
+        this.listener = Optional.of(streamsRebalanceListener);
+    }
+
+    public Exception invokeAllTasksRevoked() {
+        if (listener.isEmpty()) {
+            throw new IllegalStateException("StreamsRebalanceListener is not 
defined");
+        }
+        return 
invokeTasksRevoked(streamsRebalanceData.reconciledAssignment().activeTasks());
+    }
+
+    public Exception invokeTasksAssigned(final StreamsRebalanceData.Assignment 
assignment) {
+        if (listener.isEmpty()) {
+            throw new IllegalStateException("StreamsRebalanceListener is not 
defined");
+        }
+        log.info("Invoking tasks assigned callback for new assignment: {}", 
assignment);
+        try {
+            listener.get().onTasksAssigned(assignment);
+        } catch (WakeupException | InterruptException e) {
+            throw e;
+        } catch (Exception e) {
+            log.error(
+                "Streams rebalance listener failed on invocation of 
onTasksAssigned for tasks {}",
+                assignment,
+                e
+            );
+            return e;
+        }
+        return null;
+    }
+
+    public Exception invokeTasksRevoked(final Set<StreamsRebalanceData.TaskId> 
tasks) {
+        if (listener.isEmpty()) {
+            throw new IllegalStateException("StreamsRebalanceListener is not 
defined");
+        }
+        log.info("Invoking task revoked callback for revoked active tasks {}", 
tasks);
+        try {
+            listener.get().onTasksRevoked(tasks);
+        } catch (WakeupException | InterruptException e) {
+            throw e;
+        } catch (Exception e) {
+            log.error(
+                "Streams rebalance listener failed on invocation of 
onTasksRevoked for tasks {}",
+                tasks,
+                e
+            );
+            return e;
+        }
+        return null;
+    }
+
+    public Exception invokeAllTasksLost() {
+        if (listener.isEmpty()) {
+            throw new IllegalStateException("StreamsRebalanceListener is not 
defined");
+        }
+        log.info("Invoking tasks lost callback for all tasks");
+        try {
+            listener.get().onAllTasksLost();
+        } catch (WakeupException | InterruptException e) {
+            throw e;
+        } catch (Exception e) {
+            log.error(
+                "Streams rebalance listener failed on invocation of 
onTasksLost.",
+                e
+            );
+            return e;
+        }
+        return null;
+    }
+}
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
index 513ad3fe294..c74c5f90dab 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
@@ -2211,6 +2211,73 @@ public class AsyncKafkaConsumerTest {
         
}).when(applicationEventHandler).add(ArgumentMatchers.isA(CommitEvent.class));
     }
 
+    @Test
+    public void 
testCloseInvokesStreamsRebalanceListenerOnTasksRevokedWhenMemberEpochPositive() 
{
+        final String groupId = "streamsGroup";
+        final StreamsRebalanceData streamsRebalanceData = new 
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of());
+        
+        try (final MockedStatic<RequestManagers> requestManagers = 
mockStatic(RequestManagers.class)) {
+            consumer = 
newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), 
streamsRebalanceData);
+            StreamsRebalanceListener mockStreamsListener = 
mock(StreamsRebalanceListener.class);
+            
when(mockStreamsListener.onTasksRevoked(any())).thenReturn(Optional.empty());
+            consumer.subscribe(singletonList("topic"), mockStreamsListener);
+            final MemberStateListener groupMetadataUpdateListener = 
captureGroupMetadataUpdateListener(requestManagers);
+            final int memberEpoch = 42;
+            final String memberId = "memberId";
+            
groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch), 
memberId);
+            
+            consumer.close(CloseOptions.timeout(Duration.ZERO));
+            
+            verify(mockStreamsListener).onTasksRevoked(any());
+        }
+    }
+    
+    @Test
+    public void 
testCloseInvokesStreamsRebalanceListenerOnAllTasksLostWhenMemberEpochZeroOrNegative()
 {
+        final String groupId = "streamsGroup";
+        final StreamsRebalanceData streamsRebalanceData = new 
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of());
+        
+        try (final MockedStatic<RequestManagers> requestManagers = 
mockStatic(RequestManagers.class)) {
+            consumer = 
newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), 
streamsRebalanceData);
+            StreamsRebalanceListener mockStreamsListener = 
mock(StreamsRebalanceListener.class);
+            
when(mockStreamsListener.onAllTasksLost()).thenReturn(Optional.empty());
+            consumer.subscribe(singletonList("topic"), mockStreamsListener);
+            final MemberStateListener groupMetadataUpdateListener = 
captureGroupMetadataUpdateListener(requestManagers);
+            final int memberEpoch = 0;
+            final String memberId = "memberId";
+            
groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch), 
memberId);
+            
+            consumer.close(CloseOptions.timeout(Duration.ZERO));
+            
+            verify(mockStreamsListener).onAllTasksLost();
+        }
+    }
+    
+    @Test
+    public void testCloseWrapsStreamsRebalanceListenerException() {
+        final String groupId = "streamsGroup";
+        final StreamsRebalanceData streamsRebalanceData = new 
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of());
+        
+        try (final MockedStatic<RequestManagers> requestManagers = 
mockStatic(RequestManagers.class)) {
+            consumer = 
newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), 
streamsRebalanceData);
+            StreamsRebalanceListener mockStreamsListener = 
mock(StreamsRebalanceListener.class);
+            RuntimeException testException = new RuntimeException("Test 
streams listener exception");
+            
doThrow(testException).when(mockStreamsListener).onTasksRevoked(any());
+            consumer.subscribe(singletonList("topic"), mockStreamsListener);
+            final MemberStateListener groupMetadataUpdateListener = 
captureGroupMetadataUpdateListener(requestManagers);
+            final int memberEpoch = 1;
+            final String memberId = "memberId";
+            
groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch), 
memberId);
+            
+            KafkaException thrownException = 
assertThrows(KafkaException.class, 
+                () -> consumer.close(CloseOptions.timeout(Duration.ZERO)));
+
+            assertInstanceOf(RuntimeException.class, 
thrownException.getCause());
+            assertTrue(thrownException.getCause().getMessage().contains("Test 
streams listener exception"));
+            verify(mockStreamsListener).onTasksRevoked(any());
+        }
+    }
+
     private void markReconcileAndAutoCommitCompleteForPollEvent() {
         doAnswer(invocation -> {
             PollEvent event = invocation.getArgument(0);
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java
new file mode 100644
index 00000000000..2f3e5ab0523
--- /dev/null
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.WakeupException;
+import org.apache.kafka.common.utils.LogContext;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.mockito.junit.jupiter.MockitoSettings;
+import org.mockito.quality.Strictness;
+
+import java.util.Optional;
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+@MockitoSettings(strictness = Strictness.STRICT_STUBS)
+public class StreamsRebalanceListenerInvokerTest {
+
+    @Mock
+    private StreamsRebalanceListener mockListener;
+
+    @Mock
+    private StreamsRebalanceData streamsRebalanceData;
+
+    private StreamsRebalanceListenerInvoker invoker;
+    private final LogContext logContext = new LogContext();
+
+    @BeforeEach
+    public void setup() {
+        invoker = new StreamsRebalanceListenerInvoker(logContext, 
streamsRebalanceData);
+    }
+
+    @Test
+    public void testSetRebalanceListenerWithNull() {
+        NullPointerException exception = 
assertThrows(NullPointerException.class, 
+            () -> invoker.setRebalanceListener(null));
+        assertEquals("StreamsRebalanceListener cannot be null", 
exception.getMessage());
+    }
+
+    @Test
+    public void testSetRebalanceListenerOverwritesExisting() {
+        StreamsRebalanceListener firstListener = 
org.mockito.Mockito.mock(StreamsRebalanceListener.class);
+        StreamsRebalanceListener secondListener = 
org.mockito.Mockito.mock(StreamsRebalanceListener.class);
+
+        StreamsRebalanceData.Assignment mockAssignment = 
createMockAssignment();
+        
when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment);
+        
when(secondListener.onTasksRevoked(any())).thenReturn(Optional.empty());
+
+        // Set first listener
+        invoker.setRebalanceListener(firstListener);
+
+        // Overwrite with second listener
+        invoker.setRebalanceListener(secondListener);
+
+        // Should use second listener
+        invoker.invokeAllTasksRevoked();
+        verify(firstListener, never()).onTasksRevoked(any());
+        
verify(secondListener).onTasksRevoked(eq(mockAssignment.activeTasks()));
+    }
+
+    @Test
+    public void testInvokeMethodsWithNoListener() {
+        IllegalStateException exception1 = 
assertThrows(IllegalStateException.class, 
+            () -> invoker.invokeAllTasksRevoked());
+        assertEquals("StreamsRebalanceListener is not defined", 
exception1.getMessage());
+
+        IllegalStateException exception2 = 
assertThrows(IllegalStateException.class, 
+            () -> invoker.invokeTasksAssigned(createMockAssignment()));
+        assertEquals("StreamsRebalanceListener is not defined", 
exception2.getMessage());
+
+        IllegalStateException exception3 = 
assertThrows(IllegalStateException.class, 
+            () -> invoker.invokeTasksRevoked(createMockTasks()));
+        assertEquals("StreamsRebalanceListener is not defined", 
exception3.getMessage());
+
+        IllegalStateException exception4 = 
assertThrows(IllegalStateException.class, 
+            () -> invoker.invokeAllTasksLost());
+        assertEquals("StreamsRebalanceListener is not defined", 
exception4.getMessage());
+    }
+
+    @Test
+    public void testInvokeAllTasksRevokedWithListener() {
+        invoker.setRebalanceListener(mockListener);
+        
+        StreamsRebalanceData.Assignment mockAssignment = 
createMockAssignment();
+        
when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment);
+        when(mockListener.onTasksRevoked(any())).thenReturn(Optional.empty());
+        
+        Exception result = invoker.invokeAllTasksRevoked();
+        
+        assertNull(result);
+        verify(mockListener).onTasksRevoked(eq(mockAssignment.activeTasks()));
+    }
+
+    @Test
+    public void testInvokeTasksAssignedWithListener() {
+        invoker.setRebalanceListener(mockListener);
+        StreamsRebalanceData.Assignment assignment = createMockAssignment();
+        
when(mockListener.onTasksAssigned(assignment)).thenReturn(Optional.empty());
+        
+        Exception result = invoker.invokeTasksAssigned(assignment);
+        
+        assertNull(result);
+        verify(mockListener).onTasksAssigned(eq(assignment));
+    }
+
+    @Test
+    public void testInvokeTasksAssignedWithWakeupException() {
+        invoker.setRebalanceListener(mockListener);
+        StreamsRebalanceData.Assignment assignment = createMockAssignment();
+        WakeupException wakeupException = new WakeupException();
+        
doThrow(wakeupException).when(mockListener).onTasksAssigned(assignment);
+        
+        WakeupException thrownException = assertThrows(WakeupException.class, 
+            () -> invoker.invokeTasksAssigned(assignment));
+        
+        assertEquals(wakeupException, thrownException);
+        verify(mockListener).onTasksAssigned(eq(assignment));
+    }
+
+    @Test
+    public void testInvokeTasksAssignedWithInterruptException() {
+        invoker.setRebalanceListener(mockListener);
+        StreamsRebalanceData.Assignment assignment = createMockAssignment();
+        InterruptException interruptException = new InterruptException("Test 
interrupt");
+        
doThrow(interruptException).when(mockListener).onTasksAssigned(assignment);
+        
+        InterruptException thrownException = 
assertThrows(InterruptException.class, 
+            () -> invoker.invokeTasksAssigned(assignment));
+        
+        assertEquals(interruptException, thrownException);
+        verify(mockListener).onTasksAssigned(eq(assignment));
+    }
+
+    @Test
+    public void testInvokeTasksAssignedWithOtherException() {
+        invoker.setRebalanceListener(mockListener);
+        StreamsRebalanceData.Assignment assignment = createMockAssignment();
+        RuntimeException runtimeException = new RuntimeException("Test 
exception");
+        
doThrow(runtimeException).when(mockListener).onTasksAssigned(assignment);
+        
+        Exception result = invoker.invokeTasksAssigned(assignment);
+        
+        assertEquals(runtimeException, result);
+        verify(mockListener).onTasksAssigned(eq(assignment));
+    }
+
+    @Test
+    public void testInvokeTasksRevokedWithListener() {
+        invoker.setRebalanceListener(mockListener);
+        Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
+        when(mockListener.onTasksRevoked(tasks)).thenReturn(Optional.empty());
+        
+        Exception result = invoker.invokeTasksRevoked(tasks);
+        
+        assertNull(result);
+        verify(mockListener).onTasksRevoked(eq(tasks));
+    }
+
+    @Test
+    public void testInvokeTasksRevokedWithWakeupException() {
+        invoker.setRebalanceListener(mockListener);
+        Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
+        WakeupException wakeupException = new WakeupException();
+        doThrow(wakeupException).when(mockListener).onTasksRevoked(tasks);
+        
+        WakeupException thrownException = assertThrows(WakeupException.class, 
+            () -> invoker.invokeTasksRevoked(tasks));
+        
+        assertEquals(wakeupException, thrownException);
+        verify(mockListener).onTasksRevoked(eq(tasks));
+    }
+
+    @Test
+    public void testInvokeTasksRevokedWithInterruptException() {
+        invoker.setRebalanceListener(mockListener);
+        Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
+        InterruptException interruptException = new InterruptException("Test 
interrupt");
+        doThrow(interruptException).when(mockListener).onTasksRevoked(tasks);
+        
+        InterruptException thrownException = 
assertThrows(InterruptException.class, 
+            () -> invoker.invokeTasksRevoked(tasks));
+        
+        assertEquals(interruptException, thrownException);
+        verify(mockListener).onTasksRevoked(eq(tasks));
+    }
+
+    @Test
+    public void testInvokeTasksRevokedWithOtherException() {
+        invoker.setRebalanceListener(mockListener);
+        Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
+        RuntimeException runtimeException = new RuntimeException("Test 
exception");
+        doThrow(runtimeException).when(mockListener).onTasksRevoked(tasks);
+        
+        Exception result = invoker.invokeTasksRevoked(tasks);
+        
+        assertEquals(runtimeException, result);
+        verify(mockListener).onTasksRevoked(eq(tasks));
+    }
+
+    @Test
+    public void testInvokeAllTasksLostWithListener() {
+        invoker.setRebalanceListener(mockListener);
+        when(mockListener.onAllTasksLost()).thenReturn(Optional.empty());
+        
+        Exception result = invoker.invokeAllTasksLost();
+        
+        assertNull(result);
+        verify(mockListener).onAllTasksLost();
+    }
+
+    @Test
+    public void testInvokeAllTasksLostWithWakeupException() {
+        invoker.setRebalanceListener(mockListener);
+        WakeupException wakeupException = new WakeupException();
+        doThrow(wakeupException).when(mockListener).onAllTasksLost();
+        
+        WakeupException thrownException = assertThrows(WakeupException.class, 
+            () -> invoker.invokeAllTasksLost());
+        
+        assertEquals(wakeupException, thrownException);
+        verify(mockListener).onAllTasksLost();
+    }
+
+    @Test
+    public void testInvokeAllTasksLostWithInterruptException() {
+        invoker.setRebalanceListener(mockListener);
+        InterruptException interruptException = new InterruptException("Test 
interrupt");
+        doThrow(interruptException).when(mockListener).onAllTasksLost();
+        
+        InterruptException thrownException = 
assertThrows(InterruptException.class, 
+            () -> invoker.invokeAllTasksLost());
+        
+        assertEquals(interruptException, thrownException);
+        verify(mockListener).onAllTasksLost();
+    }
+
+    @Test
+    public void testInvokeAllTasksLostWithOtherException() {
+        invoker.setRebalanceListener(mockListener);
+        RuntimeException runtimeException = new RuntimeException("Test 
exception");
+        doThrow(runtimeException).when(mockListener).onAllTasksLost();
+        
+        Exception result = invoker.invokeAllTasksLost();
+        
+        assertEquals(runtimeException, result);
+        verify(mockListener).onAllTasksLost();
+    }
+
+    private StreamsRebalanceData.Assignment createMockAssignment() {
+        Set<StreamsRebalanceData.TaskId> activeTasks = createMockTasks();
+        Set<StreamsRebalanceData.TaskId> standbyTasks = Set.of();
+        Set<StreamsRebalanceData.TaskId> warmupTasks = Set.of();
+        
+        return new StreamsRebalanceData.Assignment(activeTasks, standbyTasks, 
warmupTasks);
+    }
+
+    private Set<StreamsRebalanceData.TaskId> createMockTasks() {
+        return Set.of(
+            new StreamsRebalanceData.TaskId("subtopology1", 0),
+            new StreamsRebalanceData.TaskId("subtopology1", 1)
+        );
+    }
+
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java
index dcc4821f2a8..a95fcef5a6c 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java
@@ -89,6 +89,7 @@ public class DefaultStreamsRebalanceListener implements 
StreamsRebalanceListener
             taskManager.handleAssignment(activeTasksWithPartitions, 
standbyTasksWithPartitions);
             streamThread.setState(StreamThread.State.PARTITIONS_ASSIGNED);
             taskManager.handleRebalanceComplete();
+            streamsRebalanceData.setReconciledAssignment(assignment);
         } catch (final Exception exception) {
             return Optional.of(exception);
         }
@@ -99,6 +100,7 @@ public class DefaultStreamsRebalanceListener implements 
StreamsRebalanceListener
     public Optional<Exception> onAllTasksLost() {
         try {
             taskManager.handleLostAll();
+            
streamsRebalanceData.setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
         } catch (final Exception exception) {
             return Optional.of(exception);
         }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java
index 66cb8e5185b..1297df7b1ee 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java
@@ -118,49 +118,46 @@ public class DefaultStreamsRebalanceListenerTest {
 
     @Test
     void testOnTasksAssigned() {
-        createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(
-            UUID.randomUUID(),
-            Optional.empty(),
-            Map.of(
-                "1",
-                new StreamsRebalanceData.Subtopology(
-                    Set.of("source1"),
-                    Set.of(),
-                    Map.of("repartition1", new 
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), 
Map.of())),
-                    Map.of(),
-                    Set.of()
-                ),
-                "2",
-                new StreamsRebalanceData.Subtopology(
-                    Set.of("source2"),
-                    Set.of(),
-                    Map.of("repartition2", new 
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), 
Map.of())),
-                    Map.of(),
-                    Set.of()
-                ),
-                "3",
-                new StreamsRebalanceData.Subtopology(
-                    Set.of("source3"),
-                    Set.of(),
-                    Map.of("repartition3", new 
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), 
Map.of())),
-                    Map.of(),
-                    Set.of()
-                )
+        final StreamsRebalanceData streamsRebalanceData = 
mock(StreamsRebalanceData.class);
+        when(streamsRebalanceData.subtopologies()).thenReturn(Map.of(
+            "1",
+            new StreamsRebalanceData.Subtopology(
+                Set.of("source1"),
+                Set.of(),
+                Map.of("repartition1", new 
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), 
Map.of())),
+                Map.of(),
+                Set.of()
             ),
-            Map.of()
+            "2",
+            new StreamsRebalanceData.Subtopology(
+                Set.of("source2"),
+                Set.of(),
+                Map.of("repartition2", new 
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), 
Map.of())),
+                Map.of(),
+                Set.of()
+            ),
+            "3",
+            new StreamsRebalanceData.Subtopology(
+                Set.of("source3"),
+                Set.of(),
+                Map.of("repartition3", new 
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), 
Map.of())),
+                Map.of(),
+                Set.of()
+            )
         ));
+        createRebalanceListenerWithRebalanceData(streamsRebalanceData);
 
-        final Optional<Exception> result = 
defaultStreamsRebalanceListener.onTasksAssigned(
-            new StreamsRebalanceData.Assignment(
-                Set.of(new StreamsRebalanceData.TaskId("1", 0)),
-                Set.of(new StreamsRebalanceData.TaskId("2", 0)),
-                Set.of(new StreamsRebalanceData.TaskId("3", 0))
-            )
+        final StreamsRebalanceData.Assignment assignment = new 
StreamsRebalanceData.Assignment(
+            Set.of(new StreamsRebalanceData.TaskId("1", 0)),
+            Set.of(new StreamsRebalanceData.TaskId("2", 0)),
+            Set.of(new StreamsRebalanceData.TaskId("3", 0))
         );
 
+        final Optional<Exception> result = 
defaultStreamsRebalanceListener.onTasksAssigned(assignment);
+
         assertTrue(result.isEmpty());
 
-        final InOrder inOrder = inOrder(taskManager, streamThread);
+        final InOrder inOrder = inOrder(taskManager, streamThread, 
streamsRebalanceData);
         inOrder.verify(taskManager).handleAssignment(
             Map.of(new TaskId(1, 0), Set.of(new TopicPartition("source1", 0), 
new TopicPartition("repartition1", 0))),
             Map.of(
@@ -170,6 +167,7 @@ public class DefaultStreamsRebalanceListenerTest {
         );
         
inOrder.verify(streamThread).setState(StreamThread.State.PARTITIONS_ASSIGNED);
         inOrder.verify(taskManager).handleRebalanceComplete();
+        
inOrder.verify(streamsRebalanceData).setReconciledAssignment(assignment);
     }
 
     @Test
@@ -177,21 +175,32 @@ public class DefaultStreamsRebalanceListenerTest {
         final Exception exception = new RuntimeException("sample exception");
         doThrow(exception).when(taskManager).handleAssignment(any(), any());
 
-        createRebalanceListenerWithRebalanceData(new 
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
-        final Optional<Exception> result = 
defaultStreamsRebalanceListener.onTasksAssigned(new 
StreamsRebalanceData.Assignment(Set.of(), Set.of(), Set.of()));
-        assertTrue(defaultStreamsRebalanceListener.onAllTasksLost().isEmpty());
+        final StreamsRebalanceData streamsRebalanceData = 
mock(StreamsRebalanceData.class);
+        when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
+        createRebalanceListenerWithRebalanceData(streamsRebalanceData);
+
+        final Optional<Exception> result = 
defaultStreamsRebalanceListener.onTasksAssigned(
+            new StreamsRebalanceData.Assignment(Set.of(), Set.of(), Set.of())
+        );
         assertTrue(result.isPresent());
         assertEquals(exception, result.get());
-        verify(taskManager).handleLostAll();
+        verify(taskManager).handleAssignment(any(), any());
         verify(streamThread, 
never()).setState(StreamThread.State.PARTITIONS_ASSIGNED);
         verify(taskManager, never()).handleRebalanceComplete();
+        verify(streamsRebalanceData, never()).setReconciledAssignment(any());
     }
 
     @Test
     void testOnAllTasksLost() {
-        createRebalanceListenerWithRebalanceData(new 
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
+        final StreamsRebalanceData streamsRebalanceData = 
mock(StreamsRebalanceData.class);
+        when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
+        createRebalanceListenerWithRebalanceData(streamsRebalanceData);
+        
         assertTrue(defaultStreamsRebalanceListener.onAllTasksLost().isEmpty());
-        verify(taskManager).handleLostAll();
+        
+        final InOrder inOrder = inOrder(taskManager, streamsRebalanceData);
+        inOrder.verify(taskManager).handleLostAll();
+        
inOrder.verify(streamsRebalanceData).setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
     }
 
     @Test
@@ -199,10 +208,13 @@ public class DefaultStreamsRebalanceListenerTest {
         final Exception exception = new RuntimeException("sample exception");
         doThrow(exception).when(taskManager).handleLostAll();
 
-        createRebalanceListenerWithRebalanceData(new 
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
+        final StreamsRebalanceData streamsRebalanceData = 
mock(StreamsRebalanceData.class);
+        when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
+        createRebalanceListenerWithRebalanceData(streamsRebalanceData);
         final Optional<Exception> result = 
defaultStreamsRebalanceListener.onAllTasksLost();
         assertTrue(result.isPresent());
         assertEquals(exception, result.get());
         verify(taskManager).handleLostAll();
+        verify(streamsRebalanceData, never()).setReconciledAssignment(any());
     }
 }


Reply via email to