This is an automated email from the ASF dual-hosted git repository. ferenc-csaky pushed a commit to branch v6.0 in repository https://gitbox.apache.org/repos/asf/flink-connector-aws.git
commit 3fe3aa7e7c8c4b173034b028bf487072962f9dd9 Author: chenylee-aws <[email protected]> AuthorDate: Tue May 26 12:03:42 2026 -0700 [FLINK-39660] Fix Netty event loop threads blocking and race condition in `FanOutKinesisShardSubscription` (cherry picked from commit ede02dae4f0f98efe8bc174872f94db8194442b1) --- .../kinesis/source/KinesisStreamsSource.java | 3 +- .../fanout/FanOutKinesisShardSplitReader.java | 17 +- .../fanout/FanOutKinesisShardSubscription.java | 457 ++++++----- .../fanout/FanOutKinesisShardSubscriptionTest.java | 893 ++++++++++++++++----- 4 files changed, 959 insertions(+), 411 deletions(-) diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSource.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSource.java index e07add1..69c8065 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSource.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSource.java @@ -72,7 +72,6 @@ import software.amazon.awssdk.services.kinesis.model.DescribeStreamConsumerRespo import software.amazon.awssdk.services.kinesis.model.LimitExceededException; import software.amazon.awssdk.services.kinesis.model.Record; import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; -import software.amazon.awssdk.utils.AttributeMap; import java.time.Duration; import java.util.Map; @@ -271,7 +270,7 @@ public class KinesisStreamsSource<T> SdkAsyncHttpClient asyncHttpClient = AWSGeneralUtil.createAsyncHttpClient( - AttributeMap.builder().build(), NettyNioAsyncHttpClient.builder()); + kinesisClientProperties, NettyNioAsyncHttpClient.builder()); KinesisAsyncClient kinesisAsyncClient = AWSClientUtil.createAwsAsyncClient( kinesisClientProperties, diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java index c0aefee..a8baaab 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java @@ -32,6 +32,7 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; import java.time.Duration; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ScheduledThreadPoolExecutor; import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT; @@ -44,6 +45,7 @@ public class FanOutKinesisShardSplitReader extends KinesisShardSplitReaderBase { private final AsyncStreamProxy asyncStreamProxy; private final String consumerArn; private final Duration subscriptionTimeout; + private final ScheduledThreadPoolExecutor timeoutScheduler; private final Map<String, FanOutKinesisShardSubscription> splitSubscriptions = new HashMap<>(); @@ -56,6 +58,15 @@ public class FanOutKinesisShardSplitReader extends KinesisShardSplitReaderBase { this.asyncStreamProxy = asyncStreamProxy; this.consumerArn = consumerArn; this.subscriptionTimeout = configuration.get(EFO_CONSUMER_SUBSCRIPTION_TIMEOUT); + this.timeoutScheduler = + new ScheduledThreadPoolExecutor( + 1, + r -> { + Thread t = new Thread(r, "subscription-timeout-scheduler"); + t.setDaemon(true); + return t; + }); + this.timeoutScheduler.setRemoveOnCancelPolicy(true); } @Override @@ -71,6 +82,7 @@ public class FanOutKinesisShardSplitReader extends KinesisShardSplitReaderBase { boolean shardCompleted = event.continuationSequenceNumber() == null; if (shardCompleted) { splitSubscriptions.remove(splitState.getShardId()); + subscription.close(); } return new RecordBatch(event.records(), event.millisBehindLatest(), shardCompleted); } @@ -85,7 +97,8 @@ public class FanOutKinesisShardSplitReader extends KinesisShardSplitReaderBase { consumerArn, split.getShardId(), split.getStartingPosition(), - subscriptionTimeout); + subscriptionTimeout, + timeoutScheduler); subscription.activateSubscription(); splitSubscriptions.put(split.splitId(), subscription); } @@ -93,6 +106,8 @@ public class FanOutKinesisShardSplitReader extends KinesisShardSplitReaderBase { @Override public void close() throws Exception { + splitSubscriptions.values().forEach(FanOutKinesisShardSubscription::close); + timeoutScheduler.shutdownNow(); asyncStreamProxy.close(); } } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java index 9674951..72df724 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java @@ -44,12 +44,11 @@ import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; /** @@ -70,132 +69,171 @@ public class FanOutKinesisShardSubscription { TimeoutException.class, IOException.class, LimitExceededException.class); - + private final ScheduledExecutorService timeoutScheduler; private final AsyncStreamProxy kinesis; private final String consumerArn; private final String shardId; - private final Duration subscriptionTimeout; - // Queue is meant for eager retrieval of records from the Kinesis stream. We will always have 2 - // record batches available on next read. - private final BlockingQueue<SubscribeToShardEvent> eventQueue = new LinkedBlockingQueue<>(2); + /** + * Number of events to keep in flight per subscriber. Pipelining the fetch overlaps the server's + * next-event work with the consumer's drain work. Must match the capacity of {@link + * #eventQueue}. + */ + private static final int PREFETCH = 2; + + private final BlockingQueue<SubscribeToShardEvent> eventQueue = + new LinkedBlockingQueue<>(PREFETCH); private final AtomicReference<Throwable> subscriptionException = new AtomicReference<>(); - // Store the current starting position for this subscription. Will be updated each time new - // batch of records is consumed - private StartingPosition startingPosition; + // All fields below are guarded by lockObject + private final Object lockObject = new Object(); + private ScheduledFuture<?> timeoutFuture; private FanOutShardSubscriber shardSubscriber; + private boolean closed = false; + private StartingPosition startingPosition; public FanOutKinesisShardSubscription( AsyncStreamProxy kinesis, String consumerArn, String shardId, StartingPosition startingPosition, - Duration subscriptionTimeout) { + Duration subscriptionTimeout, + ScheduledExecutorService timeoutScheduler) { this.kinesis = kinesis; this.consumerArn = consumerArn; this.shardId = shardId; this.startingPosition = startingPosition; this.subscriptionTimeout = subscriptionTimeout; + this.timeoutScheduler = timeoutScheduler; } /** Method to allow eager activation of the subscription. */ public void activateSubscription() { - LOG.info( - "Activating subscription to shard {} with starting position {} for consumer {}.", - shardId, - startingPosition, - consumerArn); - if (shardSubscriber != null - && shardSubscriber.getSubscriptionState() == SubscriptionState.SUBSCRIBED) { - LOG.warn("Skipping activation of subscription since it is already active."); - return; - } + synchronized (lockObject) { + if (closed) { + LOG.info("Subscription for shard {} is closed; skipping activation.", shardId); + return; + } + if (startingPosition == null) { + LOG.info( + "Shard {} has been completely consumed (shard end). Skipping re-subscription.", + shardId); + return; + } + if (shardSubscriber != null) { + LOG.warn( + "Shard {} Skipping activation of subscription since one is already active or in progress.", + shardId); + return; + } - // We have to use our own CountDownLatch to wait for subscription to be acquired because - // subscription event is tracked via the handler. - CountDownLatch waitForSubscriptionLatch = new CountDownLatch(1); - shardSubscriber = new FanOutShardSubscriber(waitForSubscriptionLatch); - SubscribeToShardResponseHandler responseHandler = - SubscribeToShardResponseHandler.builder() - .subscriber(() -> shardSubscriber) - .onError( - throwable -> { - // Errors that occur when obtaining a subscription are thrown - // here. - // After subscription is acquired, these errors can be ignored. - if (waitForSubscriptionLatch.getCount() > 0) { - terminateSubscription(throwable); - waitForSubscriptionLatch.countDown(); + LOG.info( + "Activating subscription to shard {} with starting position {} for consumer {}.", + shardId, + startingPosition, + consumerArn); + + FanOutShardSubscriber subscriber = new FanOutShardSubscriber(); + shardSubscriber = subscriber; + + SubscribeToShardResponseHandler responseHandler = + SubscribeToShardResponseHandler.builder() + .subscriber(() -> subscriber) + .onError( + throwable -> { + synchronized (lockObject) { + if (!disposeIfActive(subscriber)) { + return; + } + } + LOG.error( + "Error onError subscribing to shard {} with " + + "starting position {} for consumer {}.", + shardId, + startingPosition, + consumerArn, + throwable); + setSubscriptionException(throwable); + }) + .build(); + + cancelTimeoutFuture(); + timeoutFuture = + timeoutScheduler.schedule( + () -> { + String errorMessage = + "Timeout when subscribing to shard " + + shardId + + " with starting position " + + startingPosition + + " for consumer " + + consumerArn + + "."; + synchronized (lockObject) { + // The timeout future was cancelled between firing and + // acquiring the lock (e.g. onSubscribe succeeded, or another + // error path disposed the subscriber). Do nothing. + if (timeoutFuture == null) { + return; } - }) - .build(); - - // We don't need to keep track of the future here because we monitor subscription success - // using our own CountDownLatch - kinesis.subscribeToShard(consumerArn, shardId, startingPosition, responseHandler) - .exceptionally( - throwable -> { - // If consumer exists and is still activating, we want to countdown. - if (ExceptionUtils.findThrowable( - throwable, ResourceInUseException.class) - .isPresent()) { - waitForSubscriptionLatch.countDown(); + if (!disposeIfActive(subscriber)) { + return; + } + } + LOG.error(errorMessage); + setSubscriptionException(new TimeoutException(errorMessage)); + }, + subscriptionTimeout.toMillis(), + TimeUnit.MILLISECONDS); + + kinesis.subscribeToShard(consumerArn, shardId, startingPosition, responseHandler) + .exceptionally( + throwable -> { + synchronized (lockObject) { + if (!disposeIfActive(subscriber)) { + return null; + } + } + LOG.error( + "Error exceptionally subscribing to shard {} with starting position {} for " + + "consumer {}.", + shardId, + startingPosition, + consumerArn, + throwable); + setSubscriptionException(throwable); return null; - } - LOG.error( - "Error subscribing to shard {} with starting position {} for consumer {}.", - shardId, - startingPosition, - consumerArn, - throwable); - terminateSubscription(throwable); - return null; - }); - - // We have to handle timeout for subscriptions separately because Java 8 does not support a - // fluent orTimeout() methods on CompletableFuture. - CompletableFuture.runAsync( - () -> { - try { - if (waitForSubscriptionLatch.await( - subscriptionTimeout.toMillis(), TimeUnit.MILLISECONDS)) { - LOG.info( - "Successfully subscribed to shard {} with starting position {} for consumer {}.", - shardId, - startingPosition, - consumerArn); - // Request first batch of records. - shardSubscriber.requestRecords(); - - } else { - String errorMessage = - "Timeout when subscribing to shard " - + shardId - + " with starting position " - + startingPosition - + " for consumer " - + consumerArn - + "."; - LOG.error(errorMessage); - terminateSubscription(new TimeoutException(errorMessage)); - } - } catch (InterruptedException e) { - LOG.warn("Interrupted while waiting for subscription to complete.", e); - terminateSubscription(e); - Thread.currentThread().interrupt(); - } - }); + }); + } + } + + // Must be called while holding lockObject + private void cancelTimeoutFuture() { + if (timeoutFuture != null) { + timeoutFuture.cancel(false); + timeoutFuture = null; + } + } + + // Must be called while holding lockObject + private boolean disposeIfActive(FanOutShardSubscriber subscriber) { + if (shardSubscriber != subscriber) { + return false; + } + cancelTimeoutFuture(); + shardSubscriber.cancelSubscription(); + shardSubscriber = null; + return true; } - private void terminateSubscription(Throwable t) { + private void setSubscriptionException(Throwable t) { if (!subscriptionException.compareAndSet(null, t)) { LOG.warn( - "Another subscription exception has been queued, ignoring subsequent exceptions", + "Another subscription exception has been queued for shardId {}, ignoring subsequent exceptions", + shardId, t); } - shardSubscriber.cancel(); } /** @@ -209,10 +247,6 @@ public class FanOutKinesisShardSubscription { public SubscribeToShardEvent nextEvent() { Throwable throwable = subscriptionException.getAndSet(null); if (throwable != null) { - // If consumer is still activating, we want to wait. - if (ExceptionUtils.findThrowable(throwable, ResourceInUseException.class).isPresent()) { - return null; - } // We don't want to wrap ResourceNotFoundExceptions because it is handled via a // try-catch loop if (throwable instanceof ResourceNotFoundException) { @@ -226,46 +260,29 @@ public class FanOutKinesisShardSubscription { .findFirst(); if (recoverableException.isPresent()) { LOG.warn( - "Recoverable exception encountered while subscribing to shard. Ignoring.", + "Recoverable exception encountered for shard {} while subscribing to shard. Ignoring: {}", + shardId, recoverableException.get()); - shardSubscriber.cancel(); activateSubscription(); return null; } - LOG.error("Subscription encountered unrecoverable exception.", throwable); + LOG.error("Subscription encountered unrecoverable exception. {}", shardId, throwable); throw new KinesisStreamsSourceException( "Subscription encountered unrecoverable exception.", throwable); } - final SubscriptionState state = - Optional.ofNullable(shardSubscriber) - .map(FanOutShardSubscriber::getSubscriptionState) - .orElse(SubscriptionState.NOT_STARTED); - - switch (state) { - case NOT_STARTED: - LOG.debug( - "Subscription to shard {} for consumer {} is not yet active. Skipping.", - shardId, - consumerArn); - return null; - case COMPLETED: - if (shardSubscriber.isShardEndReached()) { - LOG.info( - "Subscription reached SHARD_END for shard {} for consumer {}.", - shardId, - consumerArn); - return null; - } - LOG.info( - "Subscription expired to shard {} for consumer {}. Restarting.", - shardId, - consumerArn); - activateSubscription(); - return null; - case SUBSCRIBED: - return eventQueue.poll(); - default: - throw new IllegalStateException("Unknown subscription state: " + state); + + return pollAndRequestNext(); + } + + private SubscribeToShardEvent pollAndRequestNext() { + synchronized (lockObject) { + SubscribeToShardEvent event = eventQueue.poll(); + // If shardSubscriber is null, the subscriber has either completed or been disposed. + // In either case, do not issue a follow-up request. + if (event != null && shardSubscriber != null) { + shardSubscriber.requestRecords(); + } + return event; } } @@ -274,61 +291,52 @@ public class FanOutKinesisShardSubscription { * Streams. */ private class FanOutShardSubscriber implements Subscriber<SubscribeToShardEventStream> { - private final CountDownLatch subscriptionLatch; - private Subscription subscription; - - private final AtomicReference<SubscriptionState> subscriptionState = - new AtomicReference<>(SubscriptionState.NOT_STARTED); - private final AtomicBoolean isShardEnd = new AtomicBoolean(false); - - private FanOutShardSubscriber(CountDownLatch subscriptionLatch) { - this.subscriptionLatch = subscriptionLatch; - } - /** - * Fetch the state that the subscriber is in. - * - * @return Subscription state for the subscriber. - */ - public SubscriptionState getSubscriptionState() { - return subscriptionState.get(); - } - - /** - * Boolean whether this subscriber has reached the end of a shard. - * - * @return True if ShardEnd. false otherwise. - */ - public boolean isShardEndReached() { - return isShardEnd.get(); - } + private Subscription subscription; public void requestRecords() { - subscription.request(1); - } - - public void cancel() { - if (this.subscriptionState.get() == SubscriptionState.COMPLETED) { - LOG.warn("Trying to cancel inactive subscription. Ignoring."); - return; + // subscription can be null if onSubscribe has not yet fired on a freshly activated + // subscriber. In that case the initial request(1) will be issued from onSubscribe + // itself, so it is safe to skip here. + if (subscription != null) { + subscription.request(1); } + } + public void cancelSubscription() { if (subscription != null) { subscription.cancel(); } - this.subscriptionState.set(SubscriptionState.COMPLETED); } @Override public void onSubscribe(Subscription subscription) { - LOG.info( - "Successfully subscribed to shard {} at {} using consumer {}.", - shardId, - startingPosition, - consumerArn); - this.subscription = subscription; - this.subscriptionState.set(SubscriptionState.SUBSCRIBED); - subscriptionLatch.countDown(); + synchronized (lockObject) { + if (shardSubscriber != this) { + // Timeout/error disposed this subscriber and a new one was created before SDK + // called onSubscribe + subscription.cancel(); + return; + } + cancelTimeoutFuture(); + this.subscription = subscription; + + int priming = PREFETCH - eventQueue.size(); + if (priming > 0) { + subscription.request(priming); + } else { + LOG.debug( + "Shard {} reactivated with {} buffered event(s). Skipping initial " + + "priming; request(1) will come from the consumer-drain path.", + shardId, + eventQueue.size()); + } + LOG.info( + "Successfully subscribed to shard {} at {} using consumer {}.", + shardId, + startingPosition, + consumerArn); + } } @Override @@ -337,31 +345,58 @@ public class FanOutKinesisShardSubscription { new SubscribeToShardResponseHandler.Visitor() { @Override public void visit(SubscribeToShardEvent event) { - try { + synchronized (lockObject) { + if (shardSubscriber != FanOutShardSubscriber.this) { + LOG.warn( + "Ignoring late event for shard {} from a disposed " + + "subscriber; it will be re-delivered after " + + "reactivation.", + shardId); + return; + } + LOG.debug( "Received event: {}, {}", event.getClass().getSimpleName(), event); - eventQueue.put(event); - if (event.continuationSequenceNumber() == null) { - isShardEnd.set(true); + // Non-blocking offer. Under the prefetch discipline maintained + // by onSubscribe (primes PREFETCH - queue.size() requests) and + // pollAndRequestNext (issues request(1) after each consumer + // drain), the invariant queue.size + outstanding == PREFETCH + // holds in steady state, so the queue is guaranteed to have + // room for each delivered event. If offer() ever returns false + // it indicates a protocol / state invariant violation (e.g. the + // server delivered an unrequested event) - fail loud rather + // than block the Netty event loop. The subscription will be + // reactivated from the previous startingPosition (which has + // not yet been advanced below) and the server will re-deliver + // this event. + if (!eventQueue.offer(event)) { + LOG.error( + "Event queue overflow for shard {}; server delivered " + + "an unrequested event. Failing subscription " + + "to recover.", + shardId); + + if (disposeIfActive(FanOutShardSubscriber.this)) { + setSubscriptionException( + new IOException( + "Event queue overflow for shard " + + shardId + + "; server delivered an " + + "unrequested event.")); + } return; } - // Update the starting position in case we have to recreate the - // subscription - startingPosition = - StartingPosition.continueFromSequenceNumber( - event.continuationSequenceNumber()); - - // Replace the record just consumed in the Queue - requestRecords(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new KinesisStreamsSourceException( - "Interrupted while adding Kinesis record to internal buffer.", - e); + if (event.continuationSequenceNumber() == null) { + startingPosition = null; + } else { + startingPosition = + StartingPosition.continueFromSequenceNumber( + event.continuationSequenceNumber()); + } } } }); @@ -369,24 +404,36 @@ public class FanOutKinesisShardSubscription { @Override public void onError(Throwable throwable) { - if (!subscriptionException.compareAndSet(null, throwable)) { - LOG.warn( - "Another subscription exception has been queued, ignoring subsequent exceptions", - throwable); + synchronized (lockObject) { + if (!disposeIfActive(this)) { + return; + } } + setSubscriptionException(throwable); } @Override public void onComplete() { - LOG.info("Subscription complete - {} ({})", shardId, consumerArn); - this.subscriptionState.set(SubscriptionState.COMPLETED); + synchronized (lockObject) { + if (shardSubscriber != this) { + LOG.warn( + "Ignoring late onComplete for shard {} from a disposed subscriber.", + shardId); + return; + } + LOG.info("Subscription complete - {} ({})", shardId, consumerArn); + shardSubscriber = null; + } + activateSubscription(); } } - /** States that the {@code FanOutShardSubscriber} may be in. */ - private enum SubscriptionState { - NOT_STARTED, - SUBSCRIBED, - COMPLETED + public void close() { + synchronized (lockObject) { + closed = true; + if (shardSubscriber != null) { + disposeIfActive(shardSubscriber); + } + } } } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionTest.java index b6d66b7..3496d63 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionTest.java @@ -21,267 +21,754 @@ package org.apache.flink.connector.kinesis.source.reader.fanout; import org.apache.flink.connector.kinesis.source.exception.KinesisStreamsSourceException; import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy; import org.apache.flink.connector.kinesis.source.split.StartingPosition; -import org.apache.flink.connector.kinesis.source.util.FakeKinesisFanOutBehaviorsFactory; import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import software.amazon.awssdk.services.kinesis.model.Record; import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; -import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponse; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; +import java.io.IOException; import java.time.Duration; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN; -import static org.apache.flink.connector.kinesis.source.util.TestUtil.generateShardId; +import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; -/** Tests for {@link FanOutKinesisShardSubscription}. */ +/** + * Tests for {@link FanOutKinesisShardSubscription}. + * + * <p>These tests focus on the concurrent lifecycle of a subscription: activation guards, error + * disposal, timeout handling, and the identity-based cleanup that prevents the connection-leak bug. + */ class FanOutKinesisShardSubscriptionTest { - private static final String TEST_SHARD_ID = generateShardId(1); - private static final Duration SUBSCRIPTION_TIMEOUT = Duration.ofSeconds(5); + private static final Duration DEFAULT_TIMEOUT = Duration.ofMillis(500); + private static final Duration LONG_TIMEOUT = Duration.ofSeconds(10); + + // ----- Happy path ----- @Test - void testNextEventReturnsNullBeforeActivation() { - AsyncStreamProxy proxy = FakeKinesisFanOutBehaviorsFactory.boundedShard().build(); - FanOutKinesisShardSubscription subscription = - new FanOutKinesisShardSubscription( - proxy, - CONSUMER_ARN, - TEST_SHARD_ID, - StartingPosition.fromStart(), - SUBSCRIPTION_TIMEOUT); + void nextEventReturnsNullBeforeActivation() { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy); + // Before activateSubscription is called, nextEvent returns null assertThat(subscription.nextEvent()).isNull(); + assertThat(proxy.subscribeCallCount()).isEqualTo(0); } @Test - void testResourceNotFoundExceptionThrown() { - AsyncStreamProxy proxy = - FakeKinesisFanOutBehaviorsFactory.resourceNotFoundWhenObtainingSubscription(); - FanOutKinesisShardSubscription subscription = - new FanOutKinesisShardSubscription( - proxy, - CONSUMER_ARN, - TEST_SHARD_ID, - StartingPosition.fromStart(), - SUBSCRIPTION_TIMEOUT); + void activatesAndDeliversEvents() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy); subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + + // Deliver onSubscribe and an event with records + s1.onSubscribeDelivered(); + SubscribeToShardEvent event = + SubscribeToShardEvent.builder() + .records(record("seq-1")) + .continuationSequenceNumber("cont-1") + .build(); + s1.deliverEvent(event); + + SubscribeToShardEvent received = pollEvent(subscription); + assertThat(received).isNotNull(); + assertThat(received.continuationSequenceNumber()).isEqualTo("cont-1"); + } - // Poll until exception surfaces - assertThatThrownBy( - () -> { - for (int i = 0; i < 200; i++) { - subscription.nextEvent(); - Thread.sleep(50); - } - }) - .isInstanceOf(ResourceNotFoundException.class); + @Test + void onCompleteTriggersReactivationForOngoingShard() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy); + + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + // deliver an event with a non-null continuation so the shard is not at end + s1.deliverEvent( + SubscribeToShardEvent.builder() + .records(record("seq-1")) + .continuationSequenceNumber("cont-1") + .build()); + pollEvent(subscription); + + // Natural end of subscription triggers EFO rotation + s1.onComplete(); + + // A second subscribeToShard call must occur + ScriptedSubscription s2 = proxy.awaitSubscription(); + assertThat(proxy.subscribeCallCount()).isEqualTo(2); + // The new activation should resume from where s1 left off + assertThat(s2.startingPosition) + .isEqualTo(StartingPosition.continueFromSequenceNumber("cont-1")); } @Test - void testUnrecoverableExceptionWrappedInSourceException() throws Exception { - AsyncStreamProxy proxy = - new AsyncStreamProxy() { - @Override - public CompletableFuture<Void> subscribeToShard( - String consumerArn, - String shardId, - StartingPosition startingPosition, - SubscribeToShardResponseHandler responseHandler) { - responseHandler.exceptionOccurred( - new IllegalStateException("unrecoverable")); - return CompletableFuture.completedFuture(null); - } - - @Override - public void close() {} - }; - FanOutKinesisShardSubscription subscription = - new FanOutKinesisShardSubscription( - proxy, - CONSUMER_ARN, - TEST_SHARD_ID, - StartingPosition.fromStart(), - SUBSCRIPTION_TIMEOUT); + void shardEndPreventsReactivation() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy); subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + // Deliver event with null continuation number to signal shard end + s1.deliverEvent( + SubscribeToShardEvent.builder() + .records(record("seq-1")) + .continuationSequenceNumber(null) + .build()); + pollEvent(subscription); + + // onComplete normally triggers reactivation, but not when startingPosition is null + s1.onComplete(); + + // There should be no further subscribe attempts. + // We give the scheduler a moment to make any mistake visible. + waitShort(); + assertThat(proxy.subscribeCallCount()).isEqualTo(1); + } - assertThatThrownBy( - () -> { - for (int i = 0; i < 200; i++) { - subscription.nextEvent(); - Thread.sleep(50); - } - }) + // ----- Concurrent activation prevention ----- + + @Test + void activateIsNoOpWhenSubscriptionAlreadyInFlight() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + // Important: use a long timeout so the first subscribe doesn't time out mid-test + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); + + subscription.activateSubscription(); + proxy.awaitSubscription(); + assertThat(proxy.subscribeCallCount()).isEqualTo(1); + + // Now call activateSubscription again. The first subscription has not received + // onSubscribe yet, so the broken main-branch guard (subscriptionActive) would + // allow a second subscribe to fire. The fixed guard (shardSubscriber != null) + // must block it. + subscription.activateSubscription(); + subscription.activateSubscription(); + subscription.activateSubscription(); + + waitShort(); + assertThat(proxy.subscribeCallCount()).isEqualTo(1); + } + + @Test + void activateIsNoOpAfterOnSubscribeSucceeds() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); + + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + + subscription.activateSubscription(); + subscription.activateSubscription(); + + waitShort(); + assertThat(proxy.subscribeCallCount()).isEqualTo(1); + } + + // ----- Error handling and retries ----- + + @Test + void recoverableErrorTriggersRetry() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy); + + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + + // Fail the subscription via the future path (common for acquire timeouts) + s1.completeExceptionally(new IOException("simulated connection reset")); + + // nextEvent must drain the exception, classify as recoverable, and retry + SubscribeToShardEvent e = subscription.nextEvent(); + assertThat(e).isNull(); + + ScriptedSubscription s2 = proxy.awaitSubscription(); + assertThat(s2).isNotSameAs(s1); + assertThat(proxy.subscribeCallCount()).isEqualTo(2); + } + + @Test + void unrecoverableErrorPropagatesFromNextEvent() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy); + + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + + // Use a non-recoverable runtime exception + s1.completeExceptionally(new RuntimeException("nope")); + + assertThatThrownBy(subscription::nextEvent) .isInstanceOf(KinesisStreamsSourceException.class) .hasMessageContaining("unrecoverable"); } @Test - void testSubscriptionTimeoutTerminatesSubscription() throws Exception { - AsyncStreamProxy proxy = - new AsyncStreamProxy() { - @Override - public CompletableFuture<Void> subscribeToShard( - String consumerArn, - String shardId, - StartingPosition startingPosition, - SubscribeToShardResponseHandler responseHandler) { - return new CompletableFuture<>(); - } - - @Override - public void close() {} - }; - FanOutKinesisShardSubscription subscription = - new FanOutKinesisShardSubscription( - proxy, - CONSUMER_ARN, - TEST_SHARD_ID, - StartingPosition.fromStart(), - Duration.ofMillis(200)); + void resourceNotFoundIsRethrownDirectly() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy); subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); - // Wait for timeout to trigger, then poll - should recover - Thread.sleep(500); - SubscribeToShardEvent event = subscription.nextEvent(); - assertThat(event).isNull(); + s1.completeExceptionally(ResourceNotFoundException.builder().message("gone").build()); + + assertThatThrownBy(subscription::nextEvent).isInstanceOf(ResourceNotFoundException.class); } + // ----- Dual error path dedup (disposeIfActive identity check) ----- + @Test - void testExpiredSubscriptionResubscribes() throws Exception { - AtomicInteger subscribeCount = new AtomicInteger(0); - AsyncStreamProxy proxy = - new AsyncStreamProxy() { - @Override - public CompletableFuture<Void> subscribeToShard( - String consumerArn, - String shardId, - StartingPosition startingPosition, - SubscribeToShardResponseHandler responseHandler) { - subscribeCount.incrementAndGet(); - return CompletableFuture.supplyAsync( - () -> { - responseHandler.responseReceived( - SubscribeToShardResponse.builder().build()); - responseHandler.onEventStream( - subscriber -> { - subscriber.onSubscribe( - new Subscription() { - @Override - public void request(long n) { - // Complete without sending any - // events (simulates 5-min expiry) - subscriber.onComplete(); - } - - @Override - public void cancel() {} - }); - }); - return null; - }); - } - - @Override - public void close() {} - }; + void dualErrorPathQueuesAtMostOneException() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy); + + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + + // Both error signals fire for the same subscribe (as AWS SDK does in practice) + s1.fireHandlerError(new IOException("handler error")); + s1.completeExceptionally(new IOException("future error")); + + // First nextEvent drains exactly one exception and retries + assertThat(subscription.nextEvent()).isNull(); + + // After retry, there must be exactly ONE new subscribe (not one per error signal) + ScriptedSubscription s2 = proxy.awaitSubscription(); + assertThat(s2).isNotSameAs(s1); + assertThat(proxy.subscribeCallCount()).isEqualTo(2); + + // Second nextEvent must not find a lingering exception from the losing error path + assertThat(subscription.nextEvent()).isNull(); + waitShort(); + // Still only 2 subscribes (no runaway retry) + assertThat(proxy.subscribeCallCount()).isEqualTo(2); + } + + // ----- Subscription timeout handling ----- + + @Test + void subscriptionTimeoutTriggersCleanupAndRetry() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = + newSubscription(proxy, Duration.ofMillis(100)); + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + // Deliberately do NOT call s1.onSubscribeDelivered() — the latch will time out + + // Wait for the timeout to fire and queue a TimeoutException + await().atMost(Duration.ofSeconds(2)) + .untilAsserted( + () -> { + SubscribeToShardEvent e = subscription.nextEvent(); + // After the timeout, nextEvent should drain the exception and fire a + // retry. + // Keep polling until the retry shows up. + assertThat(proxy.subscribeCallCount()).isGreaterThanOrEqualTo(2); + assertThat(e).isNull(); + }); + } + + // ----- Stale onSubscribe (race: timeout before onSubscribe) ----- + + @Test + void lateOnSubscribeOnStaleSubscriberIsCancelled() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); FanOutKinesisShardSubscription subscription = - new FanOutKinesisShardSubscription( - proxy, - CONSUMER_ARN, - TEST_SHARD_ID, - StartingPosition.fromStart(), - SUBSCRIPTION_TIMEOUT); + newSubscription(proxy, Duration.ofMillis(100)); subscription.activateSubscription(); - Thread.sleep(500); + ScriptedSubscription s1 = proxy.awaitSubscription(); + // Let the timeout fire (no onSubscribe) + await().atMost(Duration.ofSeconds(2)) + .until(() -> subscription.nextEvent() == null && proxy.subscribeCallCount() >= 2); + + // Now s1 is stale; a new subscription s2 has been created + ScriptedSubscription s2 = proxy.latestSubscription(); + assertThat(s2).isNotSameAs(s1); + + // Late onSubscribe delivery on the STALE subscriber must result in its + // Subscription being cancelled (to free the underlying HTTP/2 stream slot). + s1.onSubscribeDelivered(); + assertThat(s1.subscription.isCancelled()).isTrue(); + } - // nextEvent() should detect COMPLETED without shard-end and trigger resubscription - subscription.nextEvent(); - Thread.sleep(500); + // ----- Pull-based backpressure: at most one event in flight per subscriber ----- - assertThat(subscribeCount.get()).isEqualTo(2); + @Test + void initialRequestIssuedOnlyOnceAfterOnSubscribe() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); + + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + + // After onSubscribe on an empty queue, exactly PREFETCH requests must have been issued. + // The "max-one-in-flight" variant would prime with 1; depth-2 pipelining primes with 2. + assertThat(s1.subscription.getRequestedCount()).isEqualTo(2); + + // No drain yet → no further requests + waitShort(); + assertThat(s1.subscription.getRequestedCount()).isEqualTo(2); } @Test - void testShardEndDoesNotResubscribe() throws Exception { - AtomicInteger subscribeCount = new AtomicInteger(0); - AsyncStreamProxy proxy = - new AsyncStreamProxy() { - @Override - public CompletableFuture<Void> subscribeToShard( - String consumerArn, - String shardId, - StartingPosition startingPosition, - SubscribeToShardResponseHandler responseHandler) { - subscribeCount.incrementAndGet(); - return CompletableFuture.supplyAsync( - () -> { - responseHandler.responseReceived( - SubscribeToShardResponse.builder().build()); - responseHandler.onEventStream( - subscriber -> { - subscriber.onSubscribe( - new Subscription() { - private boolean sent = false; - - @Override - public void request(long n) { - if (!sent) { - sent = true; - // Send event with null - // continuation (shard end) - subscriber.onNext( - SubscribeToShardEvent - .builder() - .millisBehindLatest( - 0L) - .continuationSequenceNumber( - null) - .build()); - } else { - subscriber.onComplete(); - } - } - - @Override - public void cancel() {} - }); - }); - return null; - }); - } - - @Override - public void close() {} - }; + void requestOneAfterEachConsumerDrain() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); - FanOutKinesisShardSubscription subscription = - new FanOutKinesisShardSubscription( - proxy, - CONSUMER_ARN, - TEST_SHARD_ID, - StartingPosition.fromStart(), - SUBSCRIPTION_TIMEOUT); + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + + // Initial priming: PREFETCH (=2) outstanding requests. + assertThat(s1.subscription.getRequestedCount()).isEqualTo(2); + + // Server delivers one event; onNext must NOT issue another request (backpressure). + s1.deliverEvent( + SubscribeToShardEvent.builder() + .records(record("seq-1")) + .continuationSequenceNumber("cont-1") + .build()); + waitShort(); + assertThat(s1.subscription.getRequestedCount()).isEqualTo(2); + + // Consumer drains via nextEvent → exactly one more request(1) must fire + SubscribeToShardEvent drained = pollEvent(subscription); + assertThat(drained).isNotNull(); + assertThat(s1.subscription.getRequestedCount()).isEqualTo(3); + + // Second event/drain cycle + s1.deliverEvent( + SubscribeToShardEvent.builder() + .records(record("seq-2")) + .continuationSequenceNumber("cont-2") + .build()); + assertThat(pollEvent(subscription)).isNotNull(); + assertThat(s1.subscription.getRequestedCount()).isEqualTo(4); + } + + @Test + void pipelineDepthMatchesPrefetch() throws Exception { + // Verifies the invariant: queue.size + outstanding == PREFETCH (=2). Fills the queue to + // capacity, then drains one at a time, checking the queue never overflows and + // request counts increment as expected. + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); subscription.activateSubscription(); - Thread.sleep(500); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + assertThat(s1.subscription.getRequestedCount()).isEqualTo(2); + + // Deliver 2 events back-to-back, filling the queue to PREFETCH. + s1.deliverEvent( + SubscribeToShardEvent.builder() + .records(record("seq-1")) + .continuationSequenceNumber("cont-1") + .build()); + s1.deliverEvent( + SubscribeToShardEvent.builder() + .records(record("seq-2")) + .continuationSequenceNumber("cont-2") + .build()); + + // No further requests should have fired yet; outstanding = 0, queue = 2, sum = 2. + waitShort(); + assertThat(s1.subscription.getRequestedCount()).isEqualTo(2); + + // Drain event 1 → request(1). Outstanding = 1, queue = 1, sum = 2. + SubscribeToShardEvent e1 = pollEvent(subscription); + assertThat(e1.continuationSequenceNumber()).isEqualTo("cont-1"); + assertThat(s1.subscription.getRequestedCount()).isEqualTo(3); + + // Server delivers event 3; queue = 2 again. + s1.deliverEvent( + SubscribeToShardEvent.builder() + .records(record("seq-3")) + .continuationSequenceNumber("cont-3") + .build()); + waitShort(); + assertThat(s1.subscription.getRequestedCount()).isEqualTo(3); + + // Drain event 2 → request(1). Sum stays at 2. + SubscribeToShardEvent e2 = pollEvent(subscription); + assertThat(e2.continuationSequenceNumber()).isEqualTo("cont-2"); + assertThat(s1.subscription.getRequestedCount()).isEqualTo(4); + } - // Drain the shard-end event from the queue - subscription.nextEvent(); - Thread.sleep(500); + @Test + void nextEventOnEmptyQueueDoesNotRequestMore() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); - // Should not have resubscribed — shard has ended - assertThat(subscribeCount.get()).isEqualTo(1); + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + int priming = s1.subscription.getRequestedCount(); + + // nextEvent on an empty queue returns null and must NOT issue a spurious request(1) + assertThat(subscription.nextEvent()).isNull(); + assertThat(subscription.nextEvent()).isNull(); assertThat(subscription.nextEvent()).isNull(); + assertThat(s1.subscription.getRequestedCount()).isEqualTo(priming); + } + + // ----- Reactivation invariant: onSubscribe must not re-prime if queue has leftover events + // ----- + + @Test + void onSubscribeWithBufferedEventFromPreviousSubscriberDoesNotPrimeRequest() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); + + // s1 activates and delivers one event that gets buffered + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + assertThat(s1.subscription.getRequestedCount()).isEqualTo(2); + + s1.deliverEvent( + SubscribeToShardEvent.builder() + .records(record("seq-1")) + .continuationSequenceNumber("cont-1") + .build()); + + // Before draining, s1 errors (handler path). This disposes s1 but leaves event1 + // in the shared queue. + s1.fireHandlerError(new IOException("connection closed")); + + // nextEvent drains the exception and reactivates. We should now have s2 pending. + assertThat(subscription.nextEvent()).isNull(); + ScriptedSubscription s2 = proxy.awaitSubscription(); + assertThat(s2).isNotSameAs(s1); + + // s2's onSubscribe fires. The queue has 1 leftover event, so s2 primes only + // (PREFETCH - 1) = 1 to preserve the queue.size + outstanding == PREFETCH invariant. + s2.onSubscribeDelivered(); + assertThat(s2.subscription.getRequestedCount()) + .as("Reactivation with 1 buffered event primes only PREFETCH-1") + .isEqualTo(1); + + // Consumer drains the leftover event; pollAndRequestNext adds another request(1), + // bringing outstanding on s2 back to PREFETCH (=2). + SubscribeToShardEvent drained = pollEvent(subscription); + assertThat(drained.continuationSequenceNumber()).isEqualTo("cont-1"); + assertThat(s2.subscription.getRequestedCount()) + .as("After consumer drain, total requests equal PREFETCH") + .isEqualTo(2); + } + + @Test + void onSubscribeWithEmptyQueuePrimesRequestImmediately() throws Exception { + // Counterpart of the above: in the normal case (queue empty on onSubscribe), + // priming equals PREFETCH. + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); + + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + + assertThat(s1.subscription.getRequestedCount()) + .as("Normal activation with empty queue primes PREFETCH requests") + .isEqualTo(2); + } + + // ----- close() shutdown behavior ----- + + @Test + void closeCancelsActiveSubscription() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); + + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + assertThat(s1.subscription.isCancelled()).isFalse(); + + subscription.close(); + + // close must cancel the active reactive Subscription so the underlying HTTP/2 + // stream slot is released promptly. + assertThat(s1.subscription.isCancelled()).isTrue(); + } + + @Test + void activateAfterCloseIsNoOp() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); + + subscription.close(); + + // Any further activation attempts must be silently ignored. + subscription.activateSubscription(); + subscription.activateSubscription(); + + waitShort(); + assertThat(proxy.subscribeCallCount()).isEqualTo(0); + } + + @Test + void closeIsIdempotent() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); + + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + + // Two close() calls must not throw and must leave state consistent. + subscription.close(); + subscription.close(); + + assertThat(s1.subscription.isCancelled()).isTrue(); + } + + // ----- Stale onNext drop (identity check in onNext) ----- + + @Test + void onNextFromStaleSubscriberIsDropped() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); + + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + + // Disposal via error path retires s1 + s1.fireHandlerError(new IOException("dispose s1")); + + // Drain the error and trigger reactivation + assertThat(subscription.nextEvent()).isNull(); + ScriptedSubscription s2 = proxy.awaitSubscription(); + s2.onSubscribeDelivered(); + + // The AWS SDK may deliver a late onNext on the disposed subscriber before cancel + // propagates. Such events must be silently dropped (identity check in onNext). + s1.deliverEvent( + SubscribeToShardEvent.builder() + .records(record("stale")) + .continuationSequenceNumber("stale-cont") + .build()); + + // nextEvent must NOT return the stale event + waitShort(); + assertThat(subscription.nextEvent()).isNull(); + + // The active subscriber can still deliver events normally + s2.deliverEvent( + SubscribeToShardEvent.builder() + .records(record("live")) + .continuationSequenceNumber("live-cont") + .build()); + SubscribeToShardEvent got = pollEvent(subscription); + assertThat(got.continuationSequenceNumber()).isEqualTo("live-cont"); + } + + @Test + void onCompleteFromStaleSubscriberDoesNotDisposeActiveSubscriber() throws Exception { + ScriptedProxy proxy = new ScriptedProxy(); + FanOutKinesisShardSubscription subscription = newSubscription(proxy, LONG_TIMEOUT); + + subscription.activateSubscription(); + ScriptedSubscription s1 = proxy.awaitSubscription(); + s1.onSubscribeDelivered(); + + s1.fireHandlerError(new IOException("dispose s1")); + assertThat(subscription.nextEvent()).isNull(); + ScriptedSubscription s2 = proxy.awaitSubscription(); + s2.onSubscribeDelivered(); + + s1.onComplete(); + + waitShort(); + assertThat(proxy.subscribeCallCount()).isEqualTo(2); + + s2.deliverEvent( + SubscribeToShardEvent.builder() + .records(record("live")) + .continuationSequenceNumber("live-cont") + .build()); + SubscribeToShardEvent got = pollEvent(subscription); + assertThat(got.continuationSequenceNumber()).isEqualTo("live-cont"); + } + + private FanOutKinesisShardSubscription newSubscription(AsyncStreamProxy proxy) { + return newSubscription(proxy, DEFAULT_TIMEOUT); + } + + private FanOutKinesisShardSubscription newSubscription( + AsyncStreamProxy proxy, Duration subscriptionTimeout) { + ScheduledThreadPoolExecutor timeoutScheduler = new ScheduledThreadPoolExecutor(1); + timeoutScheduler.setRemoveOnCancelPolicy(true); + return new FanOutKinesisShardSubscription( + proxy, + CONSUMER_ARN, + SHARD_ID, + StartingPosition.fromStart(), + subscriptionTimeout, + timeoutScheduler); + } + + private static Record record(String sequenceNumber) { + return Record.builder().sequenceNumber(sequenceNumber).partitionKey("pk").build(); + } + + private static SubscribeToShardEvent pollEvent(FanOutKinesisShardSubscription subscription) { + return await().atMost(Duration.ofSeconds(2)).until(subscription::nextEvent, e -> e != null); + } + + private static void waitShort() throws InterruptedException { + // Give the async machinery a window to do anything incorrect. + // Not a timing assertion — just bounded patience for "nothing should happen". + Thread.sleep(100); + } + + // ----- Programmable fake AsyncStreamProxy ----- + + /** + * A fake {@link AsyncStreamProxy} that records subscribe calls and exposes each one via a + * {@link ScriptedSubscription} handle for the test to drive events and errors. + */ + private static final class ScriptedProxy implements AsyncStreamProxy { + private final ConcurrentLinkedQueue<ScriptedSubscription> calls = + new ConcurrentLinkedQueue<>(); + + @Override + public CompletableFuture<Void> subscribeToShard( + String consumerArn, + String shardId, + StartingPosition startingPosition, + SubscribeToShardResponseHandler responseHandler) { + ScriptedSubscription call = new ScriptedSubscription(startingPosition, responseHandler); + calls.add(call); + return call.future; + } + + @Override + public void close() {} + + int subscribeCallCount() { + return calls.size(); + } + + ScriptedSubscription awaitSubscription() { + return await().atMost(Duration.ofSeconds(2)) + .until( + () -> { + // Return the last unseen call + for (ScriptedSubscription c : calls) { + if (!c.observed) { + c.observed = true; + return c; + } + } + return null; + }, + c -> c != null); + } + + ScriptedSubscription latestSubscription() { + ScriptedSubscription last = null; + for (ScriptedSubscription c : calls) { + last = c; + } + return last; + } + } + + /** A handle for the test to drive a specific {@code subscribeToShard} call. */ + private static final class ScriptedSubscription { + final StartingPosition startingPosition; + final SubscribeToShardResponseHandler handler; + final CompletableFuture<Void> future = new CompletableFuture<>(); + final TestSubscription subscription = new TestSubscription(); + volatile boolean observed = false; + + private final AtomicReference<Subscriber<? super SubscribeToShardEventStream>> + subscriberRef = new AtomicReference<>(); + + ScriptedSubscription( + StartingPosition startingPosition, SubscribeToShardResponseHandler handler) { + this.startingPosition = startingPosition; + this.handler = handler; + // Start the event stream handshake: the handler.onEventStream callback tells us + // the Subscriber instance we should deliver signals to. + handler.onEventStream(subscriberRef::set); + } + + void onSubscribeDelivered() { + Subscriber<? super SubscribeToShardEventStream> s = awaitSubscriber(); + s.onSubscribe(subscription); + } + + void deliverEvent(SubscribeToShardEvent event) { + Subscriber<? super SubscribeToShardEventStream> s = awaitSubscriber(); + // SubscribeToShardEvent itself implements SubscribeToShardEventStream, so we can + // pass it directly. Its accept() dispatches to visitor.visit(this). + s.onNext(event); + } + + void onComplete() { + Subscriber<? super SubscribeToShardEventStream> s = awaitSubscriber(); + s.onComplete(); + } + + void fireHandlerError(Throwable t) { + handler.exceptionOccurred(t); + } + + void completeExceptionally(Throwable t) { + future.completeExceptionally(t); + } + + private Subscriber<? super SubscribeToShardEventStream> awaitSubscriber() { + return await().atMost(Duration.ofSeconds(2)).until(subscriberRef::get, s -> s != null); + } + } + + /** A {@link Subscription} that records whether cancel was called. */ + private static final class TestSubscription implements Subscription { + private final AtomicInteger requested = new AtomicInteger(0); + private final AtomicBoolean cancelled = new AtomicBoolean(false); + + @Override + public void request(long n) { + requested.addAndGet((int) n); + } + + @Override + public void cancel() { + cancelled.set(true); + } + + boolean isCancelled() { + return cancelled.get(); + } + + int getRequestedCount() { + return requested.get(); + } } }
