This is an automated email from the ASF dual-hosted git repository.
mjsax pushed a commit to branch 4.1
in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/4.1 by this push:
new 6c1ec56436e KAFKA-19694: Trigger StreamsRebalanceListener in
Consumer.close (#20511)
6c1ec56436e is described below
commit 6c1ec56436e1e362b7e0a1c7ff9a9c1499cc3d43
Author: Lucas Brutschy <[email protected]>
AuthorDate: Tue Sep 16 16:32:47 2025 +0200
KAFKA-19694: Trigger StreamsRebalanceListener in Consumer.close (#20511)
In the consumer, we invoke the consumer rebalance onPartitionRevoked or
onPartitionLost callbacks, when the consumer closes. The point is that
the application may want to commit, or wipe the state if we are closing
unsuccessfully.
In the StreamsRebalanceListener, we did not implement this behavior,
which means when closing the consumer we may lose some progress, and in
the worst case also miss that we have to wipe our local state state
since we got fenced.
In this PR we implement StreamsRebalanceListenerInvoker, very similarly
to ConsumerRebalanceListenerInvoker and invoke it in Consumer.close.
Reviewers: Lianet Magrans <[email protected]>, Matthias J. Sax
<[email protected]>, TengYao Chi <[email protected]>,
Uladzislau Blok <[email protected]>
---------
Co-authored-by: Copilot <[email protected]>
---
.../consumer/internals/AsyncKafkaConsumer.java | 101 +++----
.../internals/StreamsRebalanceListenerInvoker.java | 117 ++++++++
.../consumer/internals/AsyncKafkaConsumerTest.java | 67 +++++
.../StreamsRebalanceListenerInvokerTest.java | 293 +++++++++++++++++++++
.../internals/DefaultStreamsRebalanceListener.java | 2 +
.../DefaultStreamsRebalanceListenerTest.java | 98 ++++---
6 files changed, 577 insertions(+), 101 deletions(-)
diff --git
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
index 6f1f8c8bc64..d457b7e48f0 100644
---
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
+++
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
@@ -186,25 +186,6 @@ public class AsyncKafkaConsumer<K, V> implements
ConsumerDelegate<K, V> {
*/
private class BackgroundEventProcessor implements
EventProcessor<BackgroundEvent> {
- private Optional<StreamsRebalanceListener> streamsRebalanceListener =
Optional.empty();
- private final Optional<StreamsRebalanceData> streamsRebalanceData;
-
- public BackgroundEventProcessor() {
- this.streamsRebalanceData = Optional.empty();
- }
-
- public BackgroundEventProcessor(final Optional<StreamsRebalanceData>
streamsRebalanceData) {
- this.streamsRebalanceData = streamsRebalanceData;
- }
-
- private void setStreamsRebalanceListener(final
StreamsRebalanceListener streamsRebalanceListener) {
- if (streamsRebalanceData.isEmpty()) {
- throw new IllegalStateException("Background event processor
was not created to be used with Streams " +
- "rebalance protocol events");
- }
- this.streamsRebalanceListener =
Optional.of(streamsRebalanceListener);
- }
-
@Override
public void process(final BackgroundEvent event) {
switch (event.type()) {
@@ -277,44 +258,26 @@ public class AsyncKafkaConsumer<K, V> implements
ConsumerDelegate<K, V> {
private StreamsOnTasksRevokedCallbackCompletedEvent
invokeOnTasksRevokedCallback(final Set<StreamsRebalanceData.TaskId>
activeTasksToRevoke,
final CompletableFuture<Void> future) {
- final Optional<Exception> exceptionFromCallback =
streamsRebalanceListener().onTasksRevoked(activeTasksToRevoke);
+ final Optional<Exception> exceptionFromCallback =
Optional.ofNullable(streamsRebalanceListenerInvoker().invokeTasksRevoked(activeTasksToRevoke));
final Optional<KafkaException> error = exceptionFromCallback.map(e
-> ConsumerUtils.maybeWrapAsKafkaException(e, "Task revocation callback throws
an error"));
return new StreamsOnTasksRevokedCallbackCompletedEvent(future,
error);
}
private StreamsOnTasksAssignedCallbackCompletedEvent
invokeOnTasksAssignedCallback(final StreamsRebalanceData.Assignment assignment,
final CompletableFuture<Void> future) {
- final Optional<KafkaException> error;
- final Optional<Exception> exceptionFromCallback =
streamsRebalanceListener().onTasksAssigned(assignment);
- if (exceptionFromCallback.isPresent()) {
- error =
Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(),
"Task assignment callback throws an error"));
- } else {
- error = Optional.empty();
- streamsRebalanceData().setReconciledAssignment(assignment);
- }
+ final Optional<Exception> exceptionFromCallback =
Optional.ofNullable(streamsRebalanceListenerInvoker().invokeTasksAssigned(assignment));
+ final Optional<KafkaException> error = exceptionFromCallback.map(e
-> ConsumerUtils.maybeWrapAsKafkaException(e, "Task assignment callback throws
an error"));
return new StreamsOnTasksAssignedCallbackCompletedEvent(future,
error);
}
private StreamsOnAllTasksLostCallbackCompletedEvent
invokeOnAllTasksLostCallback(final CompletableFuture<Void> future) {
- final Optional<KafkaException> error;
- final Optional<Exception> exceptionFromCallback =
streamsRebalanceListener().onAllTasksLost();
- if (exceptionFromCallback.isPresent()) {
- error =
Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(),
"All tasks lost callback throws an error"));
- } else {
- error = Optional.empty();
-
streamsRebalanceData().setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
- }
+ final Optional<Exception> exceptionFromCallback =
Optional.ofNullable(streamsRebalanceListenerInvoker().invokeAllTasksLost());
+ final Optional<KafkaException> error = exceptionFromCallback.map(e
-> ConsumerUtils.maybeWrapAsKafkaException(e, "All tasks lost callback throws
an error"));
return new StreamsOnAllTasksLostCallbackCompletedEvent(future,
error);
}
- private StreamsRebalanceData streamsRebalanceData() {
- return streamsRebalanceData.orElseThrow(
- () -> new IllegalStateException("Background event processor
was not created to be used with Streams " +
- "rebalance protocol events"));
- }
-
- private StreamsRebalanceListener streamsRebalanceListener() {
- return streamsRebalanceListener.orElseThrow(
+ private StreamsRebalanceListenerInvoker
streamsRebalanceListenerInvoker() {
+ return streamsRebalanceListenerInvoker.orElseThrow(
() -> new IllegalStateException("Background event processor
was not created to be used with Streams " +
"rebalance protocol events"));
}
@@ -365,6 +328,7 @@ public class AsyncKafkaConsumer<K, V> implements
ConsumerDelegate<K, V> {
private final WakeupTrigger wakeupTrigger = new WakeupTrigger();
private final OffsetCommitCallbackInvoker offsetCommitCallbackInvoker;
private final ConsumerRebalanceListenerInvoker rebalanceListenerInvoker;
+ private final Optional<StreamsRebalanceListenerInvoker>
streamsRebalanceListenerInvoker;
// Last triggered async commit future. Used to wait until all previous
async commits are completed.
// We only need to keep track of the last one, since they are guaranteed
to complete in order.
private CompletableFuture<Map<TopicPartition, OffsetAndMetadata>>
lastPendingAsyncCommit = null;
@@ -514,7 +478,9 @@ public class AsyncKafkaConsumer<K, V> implements
ConsumerDelegate<K, V> {
time,
new RebalanceCallbackMetricsManager(metrics)
);
- this.backgroundEventProcessor = new
BackgroundEventProcessor(streamsRebalanceData);
+ this.streamsRebalanceListenerInvoker = streamsRebalanceData.map(s
->
+ new StreamsRebalanceListenerInvoker(logContext, s));
+ this.backgroundEventProcessor = new BackgroundEventProcessor();
this.backgroundEventReaper =
backgroundEventReaperFactory.build(logContext);
// The FetchCollector is only used on the application thread.
@@ -574,6 +540,7 @@ public class AsyncKafkaConsumer<K, V> implements
ConsumerDelegate<K, V> {
this.time = time;
this.backgroundEventQueue = backgroundEventQueue;
this.rebalanceListenerInvoker = rebalanceListenerInvoker;
+ this.streamsRebalanceListenerInvoker = Optional.empty();
this.backgroundEventProcessor = new BackgroundEventProcessor();
this.backgroundEventReaper = backgroundEventReaper;
this.metrics = metrics;
@@ -694,6 +661,7 @@ public class AsyncKafkaConsumer<K, V> implements
ConsumerDelegate<K, V> {
networkClientDelegateSupplier,
requestManagersSupplier,
kafkaConsumerMetrics);
+ this.streamsRebalanceListenerInvoker = Optional.empty();
this.backgroundEventProcessor = new BackgroundEventProcessor();
this.backgroundEventReaper = new CompletableEventReaper(logContext);
}
@@ -1472,7 +1440,7 @@ public class AsyncKafkaConsumer<K, V> implements
ConsumerDelegate<K, V> {
() -> autoCommitOnClose(closeTimer), firstException);
swallow(log, Level.ERROR, "Failed to stop finding coordinator",
this::stopFindCoordinatorOnClose, firstException);
- swallow(log, Level.ERROR, "Failed to release group assignment",
+ swallow(log, Level.ERROR, "Failed to run rebalance callbacks",
this::runRebalanceCallbacksOnClose, firstException);
swallow(log, Level.ERROR, "Failed to leave group while closing
consumer",
() -> leaveGroupOnClose(closeTimer, membershipOperation),
firstException);
@@ -1526,21 +1494,34 @@ public class AsyncKafkaConsumer<K, V> implements
ConsumerDelegate<K, V> {
int memberEpoch = groupMetadata.get().get().generationId();
- Set<TopicPartition> assignedPartitions = groupAssignmentSnapshot.get();
+ Exception error = null;
- if (assignedPartitions.isEmpty())
- // Nothing to revoke.
- return;
+ if (streamsRebalanceListenerInvoker != null &&
streamsRebalanceListenerInvoker.isPresent()) {
+
+ if (memberEpoch > 0) {
+ error =
streamsRebalanceListenerInvoker.get().invokeAllTasksRevoked();
+ } else {
+ error =
streamsRebalanceListenerInvoker.get().invokeAllTasksLost();
+ }
- SortedSet<TopicPartition> droppedPartitions = new
TreeSet<>(TOPIC_PARTITION_COMPARATOR);
- droppedPartitions.addAll(assignedPartitions);
+ } else if (rebalanceListenerInvoker != null) {
- final Exception error;
+ Set<TopicPartition> assignedPartitions =
groupAssignmentSnapshot.get();
- if (memberEpoch > 0)
- error =
rebalanceListenerInvoker.invokePartitionsRevoked(droppedPartitions);
- else
- error =
rebalanceListenerInvoker.invokePartitionsLost(droppedPartitions);
+ if (assignedPartitions.isEmpty())
+ // Nothing to revoke.
+ return;
+
+ SortedSet<TopicPartition> droppedPartitions = new
TreeSet<>(TOPIC_PARTITION_COMPARATOR);
+ droppedPartitions.addAll(assignedPartitions);
+
+ if (memberEpoch > 0) {
+ error =
rebalanceListenerInvoker.invokePartitionsRevoked(droppedPartitions);
+ } else {
+ error =
rebalanceListenerInvoker.invokePartitionsLost(droppedPartitions);
+ }
+
+ }
if (error != null)
throw ConsumerUtils.maybeWrapAsKafkaException(error);
@@ -1957,8 +1938,12 @@ public class AsyncKafkaConsumer<K, V> implements
ConsumerDelegate<K, V> {
}
public void subscribe(Collection<String> topics, StreamsRebalanceListener
streamsRebalanceListener) {
+
+ streamsRebalanceListenerInvoker
+ .orElseThrow(() -> new IllegalStateException("Consumer was not
created to be used with Streams rebalance protocol events"))
+ .setRebalanceListener(streamsRebalanceListener);
+
subscribeInternal(topics, Optional.empty());
-
backgroundEventProcessor.setStreamsRebalanceListener(streamsRebalanceListener);
}
@Override
diff --git
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java
new file mode 100644
index 00000000000..f4c5aa4addc
--- /dev/null
+++
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.WakeupException;
+import org.apache.kafka.common.utils.LogContext;
+
+import org.slf4j.Logger;
+
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * This class encapsulates the invocation of the callback methods defined in
the {@link StreamsRebalanceListener}
+ * interface. When streams group task assignment changes, these methods are
invoked. This class wraps those
+ * callback calls with some logging and error handling.
+ */
+public class StreamsRebalanceListenerInvoker {
+
+ private final Logger log;
+
+ private final StreamsRebalanceData streamsRebalanceData;
+ private Optional<StreamsRebalanceListener> listener;
+
+ StreamsRebalanceListenerInvoker(LogContext logContext,
StreamsRebalanceData streamsRebalanceData) {
+ this.log = logContext.logger(getClass());
+ this.listener = Optional.empty();
+ this.streamsRebalanceData = streamsRebalanceData;
+ }
+
+ public void setRebalanceListener(StreamsRebalanceListener
streamsRebalanceListener) {
+ Objects.requireNonNull(streamsRebalanceListener,
"StreamsRebalanceListener cannot be null");
+ this.listener = Optional.of(streamsRebalanceListener);
+ }
+
+ public Exception invokeAllTasksRevoked() {
+ if (listener.isEmpty()) {
+ throw new IllegalStateException("StreamsRebalanceListener is not
defined");
+ }
+ return
invokeTasksRevoked(streamsRebalanceData.reconciledAssignment().activeTasks());
+ }
+
+ public Exception invokeTasksAssigned(final StreamsRebalanceData.Assignment
assignment) {
+ if (listener.isEmpty()) {
+ throw new IllegalStateException("StreamsRebalanceListener is not
defined");
+ }
+ log.info("Invoking tasks assigned callback for new assignment: {}",
assignment);
+ try {
+ listener.get().onTasksAssigned(assignment);
+ } catch (WakeupException | InterruptException e) {
+ throw e;
+ } catch (Exception e) {
+ log.error(
+ "Streams rebalance listener failed on invocation of
onTasksAssigned for tasks {}",
+ assignment,
+ e
+ );
+ return e;
+ }
+ return null;
+ }
+
+ public Exception invokeTasksRevoked(final Set<StreamsRebalanceData.TaskId>
tasks) {
+ if (listener.isEmpty()) {
+ throw new IllegalStateException("StreamsRebalanceListener is not
defined");
+ }
+ log.info("Invoking task revoked callback for revoked active tasks {}",
tasks);
+ try {
+ listener.get().onTasksRevoked(tasks);
+ } catch (WakeupException | InterruptException e) {
+ throw e;
+ } catch (Exception e) {
+ log.error(
+ "Streams rebalance listener failed on invocation of
onTasksRevoked for tasks {}",
+ tasks,
+ e
+ );
+ return e;
+ }
+ return null;
+ }
+
+ public Exception invokeAllTasksLost() {
+ if (listener.isEmpty()) {
+ throw new IllegalStateException("StreamsRebalanceListener is not
defined");
+ }
+ log.info("Invoking tasks lost callback for all tasks");
+ try {
+ listener.get().onAllTasksLost();
+ } catch (WakeupException | InterruptException e) {
+ throw e;
+ } catch (Exception e) {
+ log.error(
+ "Streams rebalance listener failed on invocation of
onTasksLost.",
+ e
+ );
+ return e;
+ }
+ return null;
+ }
+}
diff --git
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
index 513ad3fe294..c74c5f90dab 100644
---
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
+++
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
@@ -2211,6 +2211,73 @@ public class AsyncKafkaConsumerTest {
}).when(applicationEventHandler).add(ArgumentMatchers.isA(CommitEvent.class));
}
+ @Test
+ public void
testCloseInvokesStreamsRebalanceListenerOnTasksRevokedWhenMemberEpochPositive()
{
+ final String groupId = "streamsGroup";
+ final StreamsRebalanceData streamsRebalanceData = new
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of());
+
+ try (final MockedStatic<RequestManagers> requestManagers =
mockStatic(RequestManagers.class)) {
+ consumer =
newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId),
streamsRebalanceData);
+ StreamsRebalanceListener mockStreamsListener =
mock(StreamsRebalanceListener.class);
+
when(mockStreamsListener.onTasksRevoked(any())).thenReturn(Optional.empty());
+ consumer.subscribe(singletonList("topic"), mockStreamsListener);
+ final MemberStateListener groupMetadataUpdateListener =
captureGroupMetadataUpdateListener(requestManagers);
+ final int memberEpoch = 42;
+ final String memberId = "memberId";
+
groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch),
memberId);
+
+ consumer.close(CloseOptions.timeout(Duration.ZERO));
+
+ verify(mockStreamsListener).onTasksRevoked(any());
+ }
+ }
+
+ @Test
+ public void
testCloseInvokesStreamsRebalanceListenerOnAllTasksLostWhenMemberEpochZeroOrNegative()
{
+ final String groupId = "streamsGroup";
+ final StreamsRebalanceData streamsRebalanceData = new
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of());
+
+ try (final MockedStatic<RequestManagers> requestManagers =
mockStatic(RequestManagers.class)) {
+ consumer =
newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId),
streamsRebalanceData);
+ StreamsRebalanceListener mockStreamsListener =
mock(StreamsRebalanceListener.class);
+
when(mockStreamsListener.onAllTasksLost()).thenReturn(Optional.empty());
+ consumer.subscribe(singletonList("topic"), mockStreamsListener);
+ final MemberStateListener groupMetadataUpdateListener =
captureGroupMetadataUpdateListener(requestManagers);
+ final int memberEpoch = 0;
+ final String memberId = "memberId";
+
groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch),
memberId);
+
+ consumer.close(CloseOptions.timeout(Duration.ZERO));
+
+ verify(mockStreamsListener).onAllTasksLost();
+ }
+ }
+
+ @Test
+ public void testCloseWrapsStreamsRebalanceListenerException() {
+ final String groupId = "streamsGroup";
+ final StreamsRebalanceData streamsRebalanceData = new
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of());
+
+ try (final MockedStatic<RequestManagers> requestManagers =
mockStatic(RequestManagers.class)) {
+ consumer =
newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId),
streamsRebalanceData);
+ StreamsRebalanceListener mockStreamsListener =
mock(StreamsRebalanceListener.class);
+ RuntimeException testException = new RuntimeException("Test
streams listener exception");
+
doThrow(testException).when(mockStreamsListener).onTasksRevoked(any());
+ consumer.subscribe(singletonList("topic"), mockStreamsListener);
+ final MemberStateListener groupMetadataUpdateListener =
captureGroupMetadataUpdateListener(requestManagers);
+ final int memberEpoch = 1;
+ final String memberId = "memberId";
+
groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch),
memberId);
+
+ KafkaException thrownException =
assertThrows(KafkaException.class,
+ () -> consumer.close(CloseOptions.timeout(Duration.ZERO)));
+
+ assertInstanceOf(RuntimeException.class,
thrownException.getCause());
+ assertTrue(thrownException.getCause().getMessage().contains("Test
streams listener exception"));
+ verify(mockStreamsListener).onTasksRevoked(any());
+ }
+ }
+
private void markReconcileAndAutoCommitCompleteForPollEvent() {
doAnswer(invocation -> {
PollEvent event = invocation.getArgument(0);
diff --git
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java
new file mode 100644
index 00000000000..2f3e5ab0523
--- /dev/null
+++
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.WakeupException;
+import org.apache.kafka.common.utils.LogContext;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.mockito.junit.jupiter.MockitoSettings;
+import org.mockito.quality.Strictness;
+
+import java.util.Optional;
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+@MockitoSettings(strictness = Strictness.STRICT_STUBS)
+public class StreamsRebalanceListenerInvokerTest {
+
+ @Mock
+ private StreamsRebalanceListener mockListener;
+
+ @Mock
+ private StreamsRebalanceData streamsRebalanceData;
+
+ private StreamsRebalanceListenerInvoker invoker;
+ private final LogContext logContext = new LogContext();
+
+ @BeforeEach
+ public void setup() {
+ invoker = new StreamsRebalanceListenerInvoker(logContext,
streamsRebalanceData);
+ }
+
+ @Test
+ public void testSetRebalanceListenerWithNull() {
+ NullPointerException exception =
assertThrows(NullPointerException.class,
+ () -> invoker.setRebalanceListener(null));
+ assertEquals("StreamsRebalanceListener cannot be null",
exception.getMessage());
+ }
+
+ @Test
+ public void testSetRebalanceListenerOverwritesExisting() {
+ StreamsRebalanceListener firstListener =
org.mockito.Mockito.mock(StreamsRebalanceListener.class);
+ StreamsRebalanceListener secondListener =
org.mockito.Mockito.mock(StreamsRebalanceListener.class);
+
+ StreamsRebalanceData.Assignment mockAssignment =
createMockAssignment();
+
when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment);
+
when(secondListener.onTasksRevoked(any())).thenReturn(Optional.empty());
+
+ // Set first listener
+ invoker.setRebalanceListener(firstListener);
+
+ // Overwrite with second listener
+ invoker.setRebalanceListener(secondListener);
+
+ // Should use second listener
+ invoker.invokeAllTasksRevoked();
+ verify(firstListener, never()).onTasksRevoked(any());
+
verify(secondListener).onTasksRevoked(eq(mockAssignment.activeTasks()));
+ }
+
+ @Test
+ public void testInvokeMethodsWithNoListener() {
+ IllegalStateException exception1 =
assertThrows(IllegalStateException.class,
+ () -> invoker.invokeAllTasksRevoked());
+ assertEquals("StreamsRebalanceListener is not defined",
exception1.getMessage());
+
+ IllegalStateException exception2 =
assertThrows(IllegalStateException.class,
+ () -> invoker.invokeTasksAssigned(createMockAssignment()));
+ assertEquals("StreamsRebalanceListener is not defined",
exception2.getMessage());
+
+ IllegalStateException exception3 =
assertThrows(IllegalStateException.class,
+ () -> invoker.invokeTasksRevoked(createMockTasks()));
+ assertEquals("StreamsRebalanceListener is not defined",
exception3.getMessage());
+
+ IllegalStateException exception4 =
assertThrows(IllegalStateException.class,
+ () -> invoker.invokeAllTasksLost());
+ assertEquals("StreamsRebalanceListener is not defined",
exception4.getMessage());
+ }
+
+ @Test
+ public void testInvokeAllTasksRevokedWithListener() {
+ invoker.setRebalanceListener(mockListener);
+
+ StreamsRebalanceData.Assignment mockAssignment =
createMockAssignment();
+
when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment);
+ when(mockListener.onTasksRevoked(any())).thenReturn(Optional.empty());
+
+ Exception result = invoker.invokeAllTasksRevoked();
+
+ assertNull(result);
+ verify(mockListener).onTasksRevoked(eq(mockAssignment.activeTasks()));
+ }
+
+ @Test
+ public void testInvokeTasksAssignedWithListener() {
+ invoker.setRebalanceListener(mockListener);
+ StreamsRebalanceData.Assignment assignment = createMockAssignment();
+
when(mockListener.onTasksAssigned(assignment)).thenReturn(Optional.empty());
+
+ Exception result = invoker.invokeTasksAssigned(assignment);
+
+ assertNull(result);
+ verify(mockListener).onTasksAssigned(eq(assignment));
+ }
+
+ @Test
+ public void testInvokeTasksAssignedWithWakeupException() {
+ invoker.setRebalanceListener(mockListener);
+ StreamsRebalanceData.Assignment assignment = createMockAssignment();
+ WakeupException wakeupException = new WakeupException();
+
doThrow(wakeupException).when(mockListener).onTasksAssigned(assignment);
+
+ WakeupException thrownException = assertThrows(WakeupException.class,
+ () -> invoker.invokeTasksAssigned(assignment));
+
+ assertEquals(wakeupException, thrownException);
+ verify(mockListener).onTasksAssigned(eq(assignment));
+ }
+
+ @Test
+ public void testInvokeTasksAssignedWithInterruptException() {
+ invoker.setRebalanceListener(mockListener);
+ StreamsRebalanceData.Assignment assignment = createMockAssignment();
+ InterruptException interruptException = new InterruptException("Test
interrupt");
+
doThrow(interruptException).when(mockListener).onTasksAssigned(assignment);
+
+ InterruptException thrownException =
assertThrows(InterruptException.class,
+ () -> invoker.invokeTasksAssigned(assignment));
+
+ assertEquals(interruptException, thrownException);
+ verify(mockListener).onTasksAssigned(eq(assignment));
+ }
+
+ @Test
+ public void testInvokeTasksAssignedWithOtherException() {
+ invoker.setRebalanceListener(mockListener);
+ StreamsRebalanceData.Assignment assignment = createMockAssignment();
+ RuntimeException runtimeException = new RuntimeException("Test
exception");
+
doThrow(runtimeException).when(mockListener).onTasksAssigned(assignment);
+
+ Exception result = invoker.invokeTasksAssigned(assignment);
+
+ assertEquals(runtimeException, result);
+ verify(mockListener).onTasksAssigned(eq(assignment));
+ }
+
+ @Test
+ public void testInvokeTasksRevokedWithListener() {
+ invoker.setRebalanceListener(mockListener);
+ Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
+ when(mockListener.onTasksRevoked(tasks)).thenReturn(Optional.empty());
+
+ Exception result = invoker.invokeTasksRevoked(tasks);
+
+ assertNull(result);
+ verify(mockListener).onTasksRevoked(eq(tasks));
+ }
+
+ @Test
+ public void testInvokeTasksRevokedWithWakeupException() {
+ invoker.setRebalanceListener(mockListener);
+ Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
+ WakeupException wakeupException = new WakeupException();
+ doThrow(wakeupException).when(mockListener).onTasksRevoked(tasks);
+
+ WakeupException thrownException = assertThrows(WakeupException.class,
+ () -> invoker.invokeTasksRevoked(tasks));
+
+ assertEquals(wakeupException, thrownException);
+ verify(mockListener).onTasksRevoked(eq(tasks));
+ }
+
+ @Test
+ public void testInvokeTasksRevokedWithInterruptException() {
+ invoker.setRebalanceListener(mockListener);
+ Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
+ InterruptException interruptException = new InterruptException("Test
interrupt");
+ doThrow(interruptException).when(mockListener).onTasksRevoked(tasks);
+
+ InterruptException thrownException =
assertThrows(InterruptException.class,
+ () -> invoker.invokeTasksRevoked(tasks));
+
+ assertEquals(interruptException, thrownException);
+ verify(mockListener).onTasksRevoked(eq(tasks));
+ }
+
+ @Test
+ public void testInvokeTasksRevokedWithOtherException() {
+ invoker.setRebalanceListener(mockListener);
+ Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
+ RuntimeException runtimeException = new RuntimeException("Test
exception");
+ doThrow(runtimeException).when(mockListener).onTasksRevoked(tasks);
+
+ Exception result = invoker.invokeTasksRevoked(tasks);
+
+ assertEquals(runtimeException, result);
+ verify(mockListener).onTasksRevoked(eq(tasks));
+ }
+
+ @Test
+ public void testInvokeAllTasksLostWithListener() {
+ invoker.setRebalanceListener(mockListener);
+ when(mockListener.onAllTasksLost()).thenReturn(Optional.empty());
+
+ Exception result = invoker.invokeAllTasksLost();
+
+ assertNull(result);
+ verify(mockListener).onAllTasksLost();
+ }
+
+ @Test
+ public void testInvokeAllTasksLostWithWakeupException() {
+ invoker.setRebalanceListener(mockListener);
+ WakeupException wakeupException = new WakeupException();
+ doThrow(wakeupException).when(mockListener).onAllTasksLost();
+
+ WakeupException thrownException = assertThrows(WakeupException.class,
+ () -> invoker.invokeAllTasksLost());
+
+ assertEquals(wakeupException, thrownException);
+ verify(mockListener).onAllTasksLost();
+ }
+
+ @Test
+ public void testInvokeAllTasksLostWithInterruptException() {
+ invoker.setRebalanceListener(mockListener);
+ InterruptException interruptException = new InterruptException("Test
interrupt");
+ doThrow(interruptException).when(mockListener).onAllTasksLost();
+
+ InterruptException thrownException =
assertThrows(InterruptException.class,
+ () -> invoker.invokeAllTasksLost());
+
+ assertEquals(interruptException, thrownException);
+ verify(mockListener).onAllTasksLost();
+ }
+
+ @Test
+ public void testInvokeAllTasksLostWithOtherException() {
+ invoker.setRebalanceListener(mockListener);
+ RuntimeException runtimeException = new RuntimeException("Test
exception");
+ doThrow(runtimeException).when(mockListener).onAllTasksLost();
+
+ Exception result = invoker.invokeAllTasksLost();
+
+ assertEquals(runtimeException, result);
+ verify(mockListener).onAllTasksLost();
+ }
+
+ private StreamsRebalanceData.Assignment createMockAssignment() {
+ Set<StreamsRebalanceData.TaskId> activeTasks = createMockTasks();
+ Set<StreamsRebalanceData.TaskId> standbyTasks = Set.of();
+ Set<StreamsRebalanceData.TaskId> warmupTasks = Set.of();
+
+ return new StreamsRebalanceData.Assignment(activeTasks, standbyTasks,
warmupTasks);
+ }
+
+ private Set<StreamsRebalanceData.TaskId> createMockTasks() {
+ return Set.of(
+ new StreamsRebalanceData.TaskId("subtopology1", 0),
+ new StreamsRebalanceData.TaskId("subtopology1", 1)
+ );
+ }
+
+}
diff --git
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java
index dcc4821f2a8..a95fcef5a6c 100644
---
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java
+++
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java
@@ -89,6 +89,7 @@ public class DefaultStreamsRebalanceListener implements
StreamsRebalanceListener
taskManager.handleAssignment(activeTasksWithPartitions,
standbyTasksWithPartitions);
streamThread.setState(StreamThread.State.PARTITIONS_ASSIGNED);
taskManager.handleRebalanceComplete();
+ streamsRebalanceData.setReconciledAssignment(assignment);
} catch (final Exception exception) {
return Optional.of(exception);
}
@@ -99,6 +100,7 @@ public class DefaultStreamsRebalanceListener implements
StreamsRebalanceListener
public Optional<Exception> onAllTasksLost() {
try {
taskManager.handleLostAll();
+
streamsRebalanceData.setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
} catch (final Exception exception) {
return Optional.of(exception);
}
diff --git
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java
index 66cb8e5185b..1297df7b1ee 100644
---
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java
+++
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java
@@ -118,49 +118,46 @@ public class DefaultStreamsRebalanceListenerTest {
@Test
void testOnTasksAssigned() {
- createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(
- UUID.randomUUID(),
- Optional.empty(),
- Map.of(
- "1",
- new StreamsRebalanceData.Subtopology(
- Set.of("source1"),
- Set.of(),
- Map.of("repartition1", new
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1),
Map.of())),
- Map.of(),
- Set.of()
- ),
- "2",
- new StreamsRebalanceData.Subtopology(
- Set.of("source2"),
- Set.of(),
- Map.of("repartition2", new
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1),
Map.of())),
- Map.of(),
- Set.of()
- ),
- "3",
- new StreamsRebalanceData.Subtopology(
- Set.of("source3"),
- Set.of(),
- Map.of("repartition3", new
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1),
Map.of())),
- Map.of(),
- Set.of()
- )
+ final StreamsRebalanceData streamsRebalanceData =
mock(StreamsRebalanceData.class);
+ when(streamsRebalanceData.subtopologies()).thenReturn(Map.of(
+ "1",
+ new StreamsRebalanceData.Subtopology(
+ Set.of("source1"),
+ Set.of(),
+ Map.of("repartition1", new
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1),
Map.of())),
+ Map.of(),
+ Set.of()
),
- Map.of()
+ "2",
+ new StreamsRebalanceData.Subtopology(
+ Set.of("source2"),
+ Set.of(),
+ Map.of("repartition2", new
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1),
Map.of())),
+ Map.of(),
+ Set.of()
+ ),
+ "3",
+ new StreamsRebalanceData.Subtopology(
+ Set.of("source3"),
+ Set.of(),
+ Map.of("repartition3", new
StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1),
Map.of())),
+ Map.of(),
+ Set.of()
+ )
));
+ createRebalanceListenerWithRebalanceData(streamsRebalanceData);
- final Optional<Exception> result =
defaultStreamsRebalanceListener.onTasksAssigned(
- new StreamsRebalanceData.Assignment(
- Set.of(new StreamsRebalanceData.TaskId("1", 0)),
- Set.of(new StreamsRebalanceData.TaskId("2", 0)),
- Set.of(new StreamsRebalanceData.TaskId("3", 0))
- )
+ final StreamsRebalanceData.Assignment assignment = new
StreamsRebalanceData.Assignment(
+ Set.of(new StreamsRebalanceData.TaskId("1", 0)),
+ Set.of(new StreamsRebalanceData.TaskId("2", 0)),
+ Set.of(new StreamsRebalanceData.TaskId("3", 0))
);
+ final Optional<Exception> result =
defaultStreamsRebalanceListener.onTasksAssigned(assignment);
+
assertTrue(result.isEmpty());
- final InOrder inOrder = inOrder(taskManager, streamThread);
+ final InOrder inOrder = inOrder(taskManager, streamThread,
streamsRebalanceData);
inOrder.verify(taskManager).handleAssignment(
Map.of(new TaskId(1, 0), Set.of(new TopicPartition("source1", 0),
new TopicPartition("repartition1", 0))),
Map.of(
@@ -170,6 +167,7 @@ public class DefaultStreamsRebalanceListenerTest {
);
inOrder.verify(streamThread).setState(StreamThread.State.PARTITIONS_ASSIGNED);
inOrder.verify(taskManager).handleRebalanceComplete();
+
inOrder.verify(streamsRebalanceData).setReconciledAssignment(assignment);
}
@Test
@@ -177,21 +175,32 @@ public class DefaultStreamsRebalanceListenerTest {
final Exception exception = new RuntimeException("sample exception");
doThrow(exception).when(taskManager).handleAssignment(any(), any());
- createRebalanceListenerWithRebalanceData(new
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
- final Optional<Exception> result =
defaultStreamsRebalanceListener.onTasksAssigned(new
StreamsRebalanceData.Assignment(Set.of(), Set.of(), Set.of()));
- assertTrue(defaultStreamsRebalanceListener.onAllTasksLost().isEmpty());
+ final StreamsRebalanceData streamsRebalanceData =
mock(StreamsRebalanceData.class);
+ when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
+ createRebalanceListenerWithRebalanceData(streamsRebalanceData);
+
+ final Optional<Exception> result =
defaultStreamsRebalanceListener.onTasksAssigned(
+ new StreamsRebalanceData.Assignment(Set.of(), Set.of(), Set.of())
+ );
assertTrue(result.isPresent());
assertEquals(exception, result.get());
- verify(taskManager).handleLostAll();
+ verify(taskManager).handleAssignment(any(), any());
verify(streamThread,
never()).setState(StreamThread.State.PARTITIONS_ASSIGNED);
verify(taskManager, never()).handleRebalanceComplete();
+ verify(streamsRebalanceData, never()).setReconciledAssignment(any());
}
@Test
void testOnAllTasksLost() {
- createRebalanceListenerWithRebalanceData(new
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
+ final StreamsRebalanceData streamsRebalanceData =
mock(StreamsRebalanceData.class);
+ when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
+ createRebalanceListenerWithRebalanceData(streamsRebalanceData);
+
assertTrue(defaultStreamsRebalanceListener.onAllTasksLost().isEmpty());
- verify(taskManager).handleLostAll();
+
+ final InOrder inOrder = inOrder(taskManager, streamsRebalanceData);
+ inOrder.verify(taskManager).handleLostAll();
+
inOrder.verify(streamsRebalanceData).setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
}
@Test
@@ -199,10 +208,13 @@ public class DefaultStreamsRebalanceListenerTest {
final Exception exception = new RuntimeException("sample exception");
doThrow(exception).when(taskManager).handleLostAll();
- createRebalanceListenerWithRebalanceData(new
StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
+ final StreamsRebalanceData streamsRebalanceData =
mock(StreamsRebalanceData.class);
+ when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
+ createRebalanceListenerWithRebalanceData(streamsRebalanceData);
final Optional<Exception> result =
defaultStreamsRebalanceListener.onAllTasksLost();
assertTrue(result.isPresent());
assertEquals(exception, result.get());
verify(taskManager).handleLostAll();
+ verify(streamsRebalanceData, never()).setReconciledAssignment(any());
}
}