This is an automated email from the ASF dual-hosted git repository.

KKcorps pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new fb32b305370 Improve Kinesis consumer throttling handling (#18531)
fb32b305370 is described below

commit fb32b305370d4320ac62cd53c7d166c85f84ea07
Author: Kartik Khare <[email protected]>
AuthorDate: Fri May 22 08:57:35 2026 +0530

    Improve Kinesis consumer throttling handling (#18531)
    
    Kinesis GetRecords and GetShardIterator throttling happens before Pinot's 
post-fetch consumption limiters can help, so the consumer now coordinates 
per-JVM request rate, supports fractional RPS configuration, and backs off 
boundedly on AWS throttle responses before returning a no-progress batch.
    
    Co-authored-by: Kartik Khare <[email protected]>
---
 .../pinot/plugin/stream/kinesis/KinesisConfig.java |  36 ++-
 .../plugin/stream/kinesis/KinesisConsumer.java     | 266 ++++++++++++++++++---
 .../plugin/stream/kinesis/KinesisConsumerTest.java | 227 +++++++++++++++++-
 3 files changed, 482 insertions(+), 47 deletions(-)

diff --git 
a/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConfig.java
 
b/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConfig.java
index 6f84407006b..7a47af7292a 100644
--- 
a/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConfig.java
+++ 
b/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConfig.java
@@ -76,7 +76,7 @@ public class KinesisConfig {
   // We are setting it to 1 to avoid hitting the limit  in a replicated setup,
   // where multiple replicas are fetching from the same shard.
   // see - 
https://docs.aws.amazon.com/kinesis/latest/APIReference/API_GetRecords.html
-  public static final String DEFAULT_RPS_LIMIT = "1";
+  public static final String DEFAULT_RPS_LIMIT = "1.0";
 
   private final String _streamTopicName;
   private final String _awsRegion;
@@ -94,7 +94,7 @@ public class KinesisConfig {
   private String _externalId;
   private int _sessionDurationSeconds;
   private boolean _asyncSessionUpdateEnabled;
-  private int _rpsLimit;
+  private double _rpsLimit;
 
   public KinesisConfig(StreamConfig streamConfig) {
     Map<String, String> props = streamConfig.getStreamConfigsMap();
@@ -103,13 +103,7 @@ public class KinesisConfig {
     Preconditions.checkNotNull(_awsRegion, "Must provide 'region' in stream 
config for table: %s",
         streamConfig.getTableNameWithType());
     _numMaxRecordsToFetch = 
Integer.parseInt(props.getOrDefault(MAX_RECORDS_TO_FETCH, DEFAULT_MAX_RECORDS));
-    _rpsLimit = Integer.parseInt(props.getOrDefault(RPS_LIMIT, 
DEFAULT_RPS_LIMIT));
-
-    if (_rpsLimit <= 0) {
-      LOGGER.warn("Invalid 'requests_per_second_limit' value: {}."
-          + " Please provide value greater than 0. Using default: {}", 
_rpsLimit, DEFAULT_RPS_LIMIT);
-      _rpsLimit = Integer.parseInt(DEFAULT_RPS_LIMIT);
-    }
+    _rpsLimit = parseRpsLimit(props.getOrDefault(RPS_LIMIT, 
DEFAULT_RPS_LIMIT));
 
     _shardIteratorType =
         ShardIteratorType.fromValue(props.getOrDefault(SHARD_ITERATOR_TYPE, 
DEFAULT_SHARD_ITERATOR_TYPE));
@@ -149,7 +143,15 @@ public class KinesisConfig {
     return _numMaxRecordsToFetch;
   }
 
+  /**
+   * @deprecated Use {@link #getRpsLimitPerSecond()} to preserve fractional 
limits.
+   */
+  @Deprecated
   public int getRpsLimit() {
+    return (int) Math.ceil(_rpsLimit);
+  }
+
+  public double getRpsLimitPerSecond() {
     return _rpsLimit;
   }
 
@@ -196,4 +198,20 @@ public class KinesisConfig {
   public boolean isPopulateMetadata() {
     return _populateMetadata;
   }
+
+  private double parseRpsLimit(String rpsLimit) {
+    try {
+      double parsedRpsLimit = Double.parseDouble(rpsLimit);
+      if (parsedRpsLimit > 0) {
+        return parsedRpsLimit;
+      }
+    } catch (NumberFormatException e) {
+      LOGGER.warn("Invalid '{}' value: {}. Please provide a numeric value 
greater than 0. Using default: {}",
+          RPS_LIMIT, rpsLimit, DEFAULT_RPS_LIMIT);
+      return Double.parseDouble(DEFAULT_RPS_LIMIT);
+    }
+    LOGGER.warn("Invalid '{}' value: {}. Please provide value greater than 0. 
Using default: {}", RPS_LIMIT, rpsLimit,
+        DEFAULT_RPS_LIMIT);
+    return Double.parseDouble(DEFAULT_RPS_LIMIT);
+  }
 }
diff --git 
a/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumer.java
 
b/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumer.java
index d63560bb4db..5c017acc076 100644
--- 
a/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumer.java
+++ 
b/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumer.java
@@ -19,11 +19,17 @@
 package org.apache.pinot.plugin.stream.kinesis;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import com.google.common.util.concurrent.RateLimiter;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
 import org.apache.pinot.spi.stream.BytesStreamMessage;
 import org.apache.pinot.spi.stream.PartitionGroupConsumer;
 import org.apache.pinot.spi.stream.StreamMessageMetadata;
@@ -44,21 +50,31 @@ import 
software.amazon.awssdk.services.kinesis.model.ShardIteratorType;
  */
 public class KinesisConsumer extends KinesisConnectionHandler implements 
PartitionGroupConsumer {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(KinesisConsumer.class);
+  private static final int INITIAL_RATE_LIMIT_BACKOFF_MS = 1000;
+  private static final int MAX_RATE_LIMIT_BACKOFF_MS = 5000;
+  private static final int RATE_LIMIT_BACKOFF_JITTER_BOUND_MS = 250;
+  private static final RequestRateLimiter SHARED_REQUEST_RATE_LIMITER = new 
SharedKinesisRequestRateLimiter();
 
   private String _nextStartSequenceNumber = null;
   private String _nextShardIterator = null;
-  private int _currentSecond = 0;
-  private int _numRequestsInCurrentSecond = 0;
+  private final RequestRateLimiter _requestRateLimiter;
 
   public KinesisConsumer(KinesisConfig config) {
     super(config);
+    _requestRateLimiter = SHARED_REQUEST_RATE_LIMITER;
     LOGGER.info("Created Kinesis consumer with topic: {}, RPS limit: {}, max 
records per fetch: {}",
-        config.getStreamTopicName(), config.getRpsLimit(), 
config.getNumMaxRecordsToFetch());
+        config.getStreamTopicName(), config.getRpsLimitPerSecond(), 
config.getNumMaxRecordsToFetch());
   }
 
   @VisibleForTesting
   public KinesisConsumer(KinesisConfig config, KinesisClient kinesisClient) {
+    this(config, kinesisClient, SHARED_REQUEST_RATE_LIMITER);
+  }
+
+  @VisibleForTesting
+  KinesisConsumer(KinesisConfig config, KinesisClient kinesisClient, 
RequestRateLimiter requestRateLimiter) {
     super(config, kinesisClient);
+    _requestRateLimiter = requestRateLimiter;
   }
 
   /**
@@ -73,16 +89,39 @@ public class KinesisConsumer extends 
KinesisConnectionHandler implements Partiti
    */
   @Override
   public synchronized KinesisMessageBatch 
fetchMessages(StreamPartitionMsgOffset startMsgOffset, int timeoutMs) {
-    try {
-      return getKinesisMessageBatch((KinesisPartitionGroupOffset) 
startMsgOffset);
-    } catch (ProvisionedThroughputExceededException pte) {
-      LOGGER.error("Rate limit exceeded while fetching messages from Kinesis 
stream: {} with threshold: {}",
-          pte.getMessage(), _config.getRpsLimit());
-      return new KinesisMessageBatch(List.of(), (KinesisPartitionGroupOffset) 
startMsgOffset, false, 0);
+    KinesisPartitionGroupOffset startOffset = (KinesisPartitionGroupOffset) 
startMsgOffset;
+    long deadlineMs = currentTimeMillis() + Math.max(timeoutMs, 0);
+    int attempts = 0;
+    KinesisRateLimitException lastRateLimitException = null;
+    while (true) {
+      if (lastRateLimitException != null && currentTimeMillis() >= deadlineMs) 
{
+        logRateLimitTimeout(startOffset, attempts, lastRateLimitException);
+        return new KinesisMessageBatch(List.of(), startOffset, false, 0);
+      }
+      try {
+        return getKinesisMessageBatch(startOffset, deadlineMs);
+      } catch (KinesisRateLimitException e) {
+        lastRateLimitException = e;
+        attempts++;
+        long remainingMs = deadlineMs - currentTimeMillis();
+        if (remainingMs <= 0) {
+          logRateLimitTimeout(startOffset, attempts, e);
+          return new KinesisMessageBatch(List.of(), startOffset, false, 0);
+        }
+        long backoffMs = Math.min(computeRateLimitBackoffMs(attempts), 
remainingMs);
+        LOGGER.warn("Rate limit exceeded while fetching messages from Kinesis 
stream: {}, shard: {}, operation: {}, "
+                + "threshold: {}, attempt: {}, backing off for {} ms. Error: 
{}", _config.getStreamTopicName(),
+            startOffset.getShardId(), e.getRequestType(), 
_config.getRpsLimitPerSecond(), attempts, backoffMs,
+            e.getCause().getMessage());
+        sleep(backoffMs);
+      } catch (KinesisRequestTimeoutException e) {
+        logRequestLimiterTimeout(startOffset, e);
+        return new KinesisMessageBatch(List.of(), startOffset, false, 0);
+      }
     }
   }
 
-  private KinesisMessageBatch 
getKinesisMessageBatch(KinesisPartitionGroupOffset startMsgOffset) {
+  private KinesisMessageBatch 
getKinesisMessageBatch(KinesisPartitionGroupOffset startMsgOffset, long 
deadlineMs) {
     KinesisPartitionGroupOffset startOffset = startMsgOffset;
     String shardId = startOffset.getShardId();
     String startSequenceNumber = startOffset.getSequenceNumber();
@@ -97,17 +136,20 @@ public class KinesisConsumer extends 
KinesisConnectionHandler implements Partiti
           
GetShardIteratorRequest.builder().streamName(_config.getStreamTopicName()).shardId(shardId)
               
.startingSequenceNumber(startSequenceNumber).shardIteratorType(ShardIteratorType.AFTER_SEQUENCE_NUMBER)
               .build();
-      shardIterator = 
_kinesisClient.getShardIterator(getShardIteratorRequest).shardIterator();
+      shardIterator = executeKinesisRequest(shardId, 
RequestType.GET_SHARD_ITERATOR, deadlineMs,
+          () -> 
_kinesisClient.getShardIterator(getShardIteratorRequest)).shardIterator();
     }
     if (shardIterator == null) {
       return new KinesisMessageBatch(List.of(), startOffset, true, 0);
     }
+    _nextStartSequenceNumber = startSequenceNumber;
+    _nextShardIterator = shardIterator;
 
     // Read records
-    rateLimitRequests();
     GetRecordsRequest getRecordRequest =
         
GetRecordsRequest.builder().shardIterator(shardIterator).limit(_config.getNumMaxRecordsToFetch()).build();
-    GetRecordsResponse getRecordsResponse = 
_kinesisClient.getRecords(getRecordRequest);
+    GetRecordsResponse getRecordsResponse = executeKinesisRequest(shardId, 
RequestType.GET_RECORDS, deadlineMs,
+        () -> _kinesisClient.getRecords(getRecordRequest));
 
     List<Record> records = getRecordsResponse.records();
     List<BytesStreamMessage> messages;
@@ -120,7 +162,8 @@ public class KinesisConsumer extends 
KinesisConnectionHandler implements Partiti
         batchSizeInBytes += bytesStreamMessage.getLength();
         messages.add(bytesStreamMessage);
       }
-      offsetOfNextBatch = (KinesisPartitionGroupOffset) 
messages.get(messages.size() - 1).getMetadata().getNextOffset();
+      offsetOfNextBatch =
+          (KinesisPartitionGroupOffset) messages.get(messages.size() - 
1).getMetadata().getNextOffset();
     } else {
       // TODO: Revisit whether Kinesis can return empty batch when there are 
available records. The consumer cna handle
       //       empty message batch, but it will treat it as fully caught up.
@@ -133,31 +176,65 @@ public class KinesisConsumer extends 
KinesisConnectionHandler implements Partiti
     return new KinesisMessageBatch(messages, offsetOfNextBatch, 
_nextShardIterator == null, batchSizeInBytes);
   }
 
-  /**
-   * Kinesis enforces a limit of 5 getRecords request per second on each shard 
from AWS end, beyond which we start
-   * getting {@link ProvisionedThroughputExceededException}. Rate limit the 
requests to avoid this.
-   */
-  private void rateLimitRequests() {
-    long currentTimeMs = System.currentTimeMillis();
-    int currentTimeSeconds = (int) 
TimeUnit.MILLISECONDS.toSeconds(currentTimeMs);
-    if (currentTimeSeconds == _currentSecond) {
-      if (_numRequestsInCurrentSecond == _config.getRpsLimit()) {
-        try {
-          Thread.sleep(1000 - (currentTimeMs % 1000));
-        } catch (InterruptedException e) {
-          throw new RuntimeException(e);
-        }
-        _currentSecond = (int) 
TimeUnit.MILLISECONDS.toSeconds(System.currentTimeMillis());
-        _numRequestsInCurrentSecond = 1;
-      } else {
-        _numRequestsInCurrentSecond++;
-      }
-    } else {
-      _currentSecond = currentTimeSeconds;
-      _numRequestsInCurrentSecond = 1;
+  private <T> T executeKinesisRequest(String shardId, RequestType requestType, 
long deadlineMs,
+      Supplier<T> requestSupplier) {
+    long remainingMs = deadlineMs - currentTimeMillis();
+    if (remainingMs <= 0
+        || !_requestRateLimiter.tryAcquire(_config.getStreamTopicName(), 
shardId, requestType,
+            _config.getRpsLimitPerSecond(), remainingMs)) {
+      throw new KinesisRequestTimeoutException(requestType);
+    }
+    try {
+      return requestSupplier.get();
+    } catch (ProvisionedThroughputExceededException pte) {
+      throw new KinesisRateLimitException(requestType, pte);
+    }
+  }
+
+  private long computeRateLimitBackoffMs(int attempts) {
+    long baseBackoffMs = INITIAL_RATE_LIMIT_BACKOFF_MS * (1L << 
Math.min(attempts - 1, 20));
+    long cappedBaseBackoffMs = Math.min(baseBackoffMs, 
MAX_RATE_LIMIT_BACKOFF_MS);
+    long jitterMs = 
getRateLimitBackoffJitterMs(RATE_LIMIT_BACKOFF_JITTER_BOUND_MS);
+    return Math.min(cappedBaseBackoffMs + jitterMs, MAX_RATE_LIMIT_BACKOFF_MS);
+  }
+
+  @VisibleForTesting
+  long currentTimeMillis() {
+    return System.currentTimeMillis();
+  }
+
+  @VisibleForTesting
+  long getRateLimitBackoffJitterMs(long maxJitterMs) {
+    return ThreadLocalRandom.current().nextLong(maxJitterMs + 1);
+  }
+
+  @VisibleForTesting
+  void sleep(long backoffMs) {
+    try {
+      Thread.sleep(backoffMs);
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      throw new RuntimeException("Interrupted while backing off after Kinesis 
rate limit exceeded", e);
     }
   }
 
+  private void logRateLimitTimeout(KinesisPartitionGroupOffset startOffset, 
int attempts,
+      KinesisRateLimitException rateLimitException) {
+    LOGGER.warn("Rate limit exceeded while fetching messages from Kinesis 
stream: {}, shard: {}, operation: {}, "
+            + "threshold: {}, attempts: {}. Fetch timeout exhausted; returning 
empty batch at original offset. "
+            + "Error: {}",
+        _config.getStreamTopicName(), startOffset.getShardId(), 
rateLimitException.getRequestType(),
+        _config.getRpsLimitPerSecond(), attempts, 
rateLimitException.getCause().getMessage());
+  }
+
+  private void logRequestLimiterTimeout(KinesisPartitionGroupOffset 
startOffset,
+      KinesisRequestTimeoutException timeoutException) {
+    LOGGER.warn("Timed out waiting for Kinesis request limiter while fetching 
messages from stream: {}, shard: {}, "
+            + "operation: {}, threshold: {}. Fetch timeout exhausted; 
returning empty batch at original offset.",
+        _config.getStreamTopicName(), startOffset.getShardId(), 
timeoutException.getRequestType(),
+        _config.getRpsLimitPerSecond());
+  }
+
   private BytesStreamMessage extractStreamMessage(Record record, String 
shardId) {
     byte[] key = record.partitionKey().getBytes(StandardCharsets.UTF_8);
     byte[] value = record.data().asByteArray();
@@ -180,4 +257,121 @@ public class KinesisConsumer extends 
KinesisConnectionHandler implements Partiti
   public void close() {
     super.close();
   }
+
+  enum RequestType {
+    GET_RECORDS,
+    GET_SHARD_ITERATOR
+  }
+
+  @VisibleForTesting
+  interface RequestRateLimiter {
+    boolean tryAcquire(String streamName, String shardId, RequestType 
requestType, double rpsLimit, long timeoutMs);
+  }
+
+  private static class KinesisRateLimitException extends RuntimeException {
+    private final RequestType _requestType;
+
+    KinesisRateLimitException(RequestType requestType, 
ProvisionedThroughputExceededException cause) {
+      super(cause);
+      _requestType = requestType;
+    }
+
+    RequestType getRequestType() {
+      return _requestType;
+    }
+
+    @Override
+    public synchronized ProvisionedThroughputExceededException getCause() {
+      return (ProvisionedThroughputExceededException) super.getCause();
+    }
+  }
+
+  private static class KinesisRequestTimeoutException extends RuntimeException 
{
+    private final RequestType _requestType;
+
+    KinesisRequestTimeoutException(RequestType requestType) {
+      _requestType = requestType;
+    }
+
+    RequestType getRequestType() {
+      return _requestType;
+    }
+  }
+
+  /**
+   * Shared per-JVM request limiter for Kinesis read operations.
+   * <p>
+   * This class is thread-safe. Limiters are keyed by stream, shard, and 
operation so multiple consumers on the same
+   * server share a single smooth request budget for the same AWS shard 
operation.
+   */
+  @VisibleForTesting
+  static class SharedKinesisRequestRateLimiter implements RequestRateLimiter {
+    private static final int RATE_LIMITER_EXPIRATION_HOURS = 1;
+    private final Cache<RequestRateLimiterKey, RateLimiter> _rateLimiters;
+
+    SharedKinesisRequestRateLimiter() {
+      
this(CacheBuilder.newBuilder().expireAfterAccess(RATE_LIMITER_EXPIRATION_HOURS, 
TimeUnit.HOURS).build());
+    }
+
+    @VisibleForTesting
+    SharedKinesisRequestRateLimiter(Cache<RequestRateLimiterKey, RateLimiter> 
rateLimiters) {
+      _rateLimiters = rateLimiters;
+    }
+
+    @Override
+    public boolean tryAcquire(String streamName, String shardId, RequestType 
requestType, double rpsLimit,
+        long timeoutMs) {
+      if (timeoutMs <= 0) {
+        return false;
+      }
+      RequestRateLimiterKey key = new RequestRateLimiterKey(streamName, 
shardId, requestType);
+      RateLimiter rateLimiter = _rateLimiters.asMap().computeIfAbsent(key, 
ignored -> RateLimiter.create(rpsLimit));
+      if (rpsLimit < rateLimiter.getRate()) {
+        rateLimiter.setRate(rpsLimit);
+      }
+      return rateLimiter.tryAcquire(timeoutMs, TimeUnit.MILLISECONDS);
+    }
+
+    @VisibleForTesting
+    double getRateForTesting(String streamName, String shardId, RequestType 
requestType) {
+      RateLimiter rateLimiter =
+          _rateLimiters.getIfPresent(new RequestRateLimiterKey(streamName, 
shardId, requestType));
+      return rateLimiter == null ? 0.0 : rateLimiter.getRate();
+    }
+
+    @VisibleForTesting
+    long getLimiterCountForTesting() {
+      return _rateLimiters.size();
+    }
+  }
+
+  private static class RequestRateLimiterKey {
+    private final String _streamName;
+    private final String _shardId;
+    private final RequestType _requestType;
+
+    RequestRateLimiterKey(String streamName, String shardId, RequestType 
requestType) {
+      _streamName = streamName;
+      _shardId = shardId;
+      _requestType = requestType;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (!(o instanceof RequestRateLimiterKey)) {
+        return false;
+      }
+      RequestRateLimiterKey that = (RequestRateLimiterKey) o;
+      return _streamName.equals(that._streamName) && 
_shardId.equals(that._shardId)
+          && _requestType == that._requestType;
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(_streamName, _shardId, _requestType);
+    }
+  }
 }
diff --git 
a/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/test/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumerTest.java
 
b/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/test/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumerTest.java
index 31b56a0c780..418c1c8520f 100644
--- 
a/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/test/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumerTest.java
+++ 
b/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/test/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumerTest.java
@@ -33,6 +33,7 @@ import 
software.amazon.awssdk.services.kinesis.model.GetRecordsRequest;
 import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse;
 import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest;
 import software.amazon.awssdk.services.kinesis.model.GetShardIteratorResponse;
+import 
software.amazon.awssdk.services.kinesis.model.ProvisionedThroughputExceededException;
 import software.amazon.awssdk.services.kinesis.model.Record;
 import software.amazon.awssdk.services.kinesis.model.ShardIteratorType;
 
@@ -57,11 +58,18 @@ public class KinesisConsumerTest {
   private static final String PARTITION_KEY_PREFIX = "PARTITION_KEY-";
   private static final String PLACEHOLDER = "DUMMY";
   private static final int MAX_RECORDS_TO_FETCH = 20;
+  private static final double DOUBLE_COMPARISON_DELTA = 0.000001;
+  private static final KinesisConsumer.RequestRateLimiter 
NO_OP_REQUEST_RATE_LIMITER =
+      (streamName, shardId, requestType, rpsLimit, timeoutMs) -> true;
 
   private KinesisConfig _kinesisConfig;
   private List<Record> _records;
 
   private KinesisConfig getKinesisConfig() {
+    return getKinesisConfig(Map.of());
+  }
+
+  private KinesisConfig getKinesisConfig(Map<String, String> overrides) {
     Map<String, String> props = new HashMap<>();
     props.put(StreamConfigProperties.STREAM_TYPE, STREAM_TYPE);
     props.put(StreamConfigProperties.constructStreamProperty(STREAM_TYPE, 
StreamConfigProperties.STREAM_TOPIC_NAME),
@@ -73,6 +81,7 @@ public class KinesisConsumerTest {
     props.put(KinesisConfig.REGION, AWS_REGION);
     props.put(KinesisConfig.MAX_RECORDS_TO_FETCH, 
String.valueOf(MAX_RECORDS_TO_FETCH));
     props.put(KinesisConfig.SHARD_ITERATOR_TYPE, 
ShardIteratorType.AT_SEQUENCE_NUMBER.toString());
+    props.putAll(overrides);
     return new KinesisConfig(new StreamConfig(TABLE_NAME_WITH_TYPE, props));
   }
 
@@ -96,7 +105,7 @@ public class KinesisConsumerTest {
     when(kinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(
         
GetRecordsResponse.builder().nextShardIterator(PLACEHOLDER).records(_records).build());
 
-    KinesisConsumer kinesisConsumer = new KinesisConsumer(_kinesisConfig, 
kinesisClient);
+    KinesisConsumer kinesisConsumer = newTestConsumer(kinesisClient);
 
     // Fetch first batch
     KinesisPartitionGroupOffset startOffset = new 
KinesisPartitionGroupOffset("0", "1");
@@ -130,7 +139,7 @@ public class KinesisConsumerTest {
     when(kinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(
         
GetRecordsResponse.builder().nextShardIterator(null).records(_records).build());
 
-    KinesisConsumer kinesisConsumer = new KinesisConsumer(_kinesisConfig, 
kinesisClient);
+    KinesisConsumer kinesisConsumer = newTestConsumer(kinesisClient);
 
     // Fetch first batch
     KinesisPartitionGroupOffset startOffset = new 
KinesisPartitionGroupOffset("0", "1");
@@ -154,4 +163,218 @@ public class KinesisConsumerTest {
   public String baToString(byte[] bytes) {
     return SdkBytes.fromByteArray(bytes).asUtf8String();
   }
+
+  @Test
+  public void testRpsLimitConfig() {
+    KinesisConfig defaultConfig = getKinesisConfig();
+    assertEquals(defaultConfig.getRpsLimitPerSecond(), 1.0, 
DOUBLE_COMPARISON_DELTA);
+    assertEquals(defaultConfig.getRpsLimit(), 1);
+
+    KinesisConfig integerConfig = 
getKinesisConfig(Map.of(KinesisConfig.RPS_LIMIT, "2"));
+    assertEquals(integerConfig.getRpsLimitPerSecond(), 2.0, 
DOUBLE_COMPARISON_DELTA);
+    assertEquals(integerConfig.getRpsLimit(), 2);
+
+    KinesisConfig decimalConfig = 
getKinesisConfig(Map.of(KinesisConfig.RPS_LIMIT, "0.25"));
+    assertEquals(decimalConfig.getRpsLimitPerSecond(), 0.25, 
DOUBLE_COMPARISON_DELTA);
+    assertEquals(decimalConfig.getRpsLimit(), 1);
+
+    assertEquals(getKinesisConfig(Map.of(KinesisConfig.RPS_LIMIT, 
"0")).getRpsLimitPerSecond(), 1.0,
+        DOUBLE_COMPARISON_DELTA);
+    assertEquals(getKinesisConfig(Map.of(KinesisConfig.RPS_LIMIT, 
"-1")).getRpsLimitPerSecond(), 1.0,
+        DOUBLE_COMPARISON_DELTA);
+  }
+
+  @Test
+  public void testFetchRetriesGetRecordsRateLimitExceeded() {
+    KinesisClient kinesisClient = mock(KinesisClient.class);
+    
when(kinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(
+        GetShardIteratorResponse.builder().shardIterator(PLACEHOLDER).build());
+    when(kinesisClient.getRecords(any(GetRecordsRequest.class)))
+        .thenThrow(rateLimitException())
+        .thenThrow(rateLimitException())
+        
.thenReturn(GetRecordsResponse.builder().nextShardIterator(PLACEHOLDER).records(_records).build());
+
+    TestKinesisConsumer kinesisConsumer =
+        new TestKinesisConsumer(_kinesisConfig, kinesisClient, 
NO_OP_REQUEST_RATE_LIMITER);
+
+    KinesisPartitionGroupOffset startOffset = new 
KinesisPartitionGroupOffset("0", "1");
+    KinesisMessageBatch kinesisMessageBatch = 
kinesisConsumer.fetchMessages(startOffset, 10_000);
+
+    assertEquals(kinesisMessageBatch.getMessageCount(), NUM_RECORDS);
+    assertFalse(kinesisMessageBatch.isEndOfPartitionGroup());
+    assertEquals(kinesisConsumer.getSleepMsList(), List.of(1000L, 2000L));
+    verify(kinesisClient, 
times(1)).getShardIterator(any(GetShardIteratorRequest.class));
+    verify(kinesisClient, times(3)).getRecords(any(GetRecordsRequest.class));
+  }
+
+  @Test
+  public void 
testFetchReturnsEmptyBatchWhenGetRecordsRateLimitExceedsTimeout() {
+    KinesisClient kinesisClient = mock(KinesisClient.class);
+    
when(kinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(
+        GetShardIteratorResponse.builder().shardIterator(PLACEHOLDER).build());
+    when(kinesisClient.getRecords(any(GetRecordsRequest.class)))
+        .thenThrow(rateLimitException())
+        .thenThrow(rateLimitException());
+
+    TestKinesisConsumer kinesisConsumer =
+        new TestKinesisConsumer(_kinesisConfig, kinesisClient, 
NO_OP_REQUEST_RATE_LIMITER);
+
+    KinesisPartitionGroupOffset startOffset = new 
KinesisPartitionGroupOffset("0", "1");
+    KinesisMessageBatch kinesisMessageBatch = 
kinesisConsumer.fetchMessages(startOffset, 2500);
+
+    assertEquals(kinesisMessageBatch.getMessageCount(), 0);
+    assertEquals(kinesisMessageBatch.getOffsetOfNextBatch(), startOffset);
+    assertFalse(kinesisMessageBatch.isEndOfPartitionGroup());
+    assertEquals(kinesisConsumer.getSleepMsList(), List.of(1000L, 1500L));
+    verify(kinesisClient, 
times(1)).getShardIterator(any(GetShardIteratorRequest.class));
+    verify(kinesisClient, times(2)).getRecords(any(GetRecordsRequest.class));
+  }
+
+  @Test
+  public void testFetchRetriesGetShardIteratorRateLimitExceeded() {
+    KinesisClient kinesisClient = mock(KinesisClient.class);
+    when(kinesisClient.getShardIterator(any(GetShardIteratorRequest.class)))
+        .thenThrow(rateLimitException())
+        
.thenReturn(GetShardIteratorResponse.builder().shardIterator(PLACEHOLDER).build());
+    when(kinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(
+        
GetRecordsResponse.builder().nextShardIterator(PLACEHOLDER).records(_records).build());
+
+    CapturingRequestRateLimiter requestRateLimiter = new 
CapturingRequestRateLimiter();
+    TestKinesisConsumer kinesisConsumer = new 
TestKinesisConsumer(_kinesisConfig, kinesisClient, requestRateLimiter);
+
+    KinesisPartitionGroupOffset startOffset = new 
KinesisPartitionGroupOffset("0", "1");
+    KinesisMessageBatch kinesisMessageBatch = 
kinesisConsumer.fetchMessages(startOffset, 10_000);
+
+    assertEquals(kinesisMessageBatch.getMessageCount(), NUM_RECORDS);
+    assertEquals(kinesisConsumer.getSleepMsList(), List.of(1000L));
+    assertEquals(requestRateLimiter.getRequestTypes(),
+        List.of(KinesisConsumer.RequestType.GET_SHARD_ITERATOR, 
KinesisConsumer.RequestType.GET_SHARD_ITERATOR,
+            KinesisConsumer.RequestType.GET_RECORDS));
+    verify(kinesisClient, 
times(2)).getShardIterator(any(GetShardIteratorRequest.class));
+    verify(kinesisClient, times(1)).getRecords(any(GetRecordsRequest.class));
+  }
+
+  @Test
+  public void 
testFetchReturnsEmptyBatchWhenRequestLimiterExceedsRemainingTimeout() {
+    KinesisClient kinesisClient = mock(KinesisClient.class);
+    
when(kinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(
+        GetShardIteratorResponse.builder().shardIterator(PLACEHOLDER).build());
+
+    AdvancingRequestRateLimiter requestRateLimiter = new 
AdvancingRequestRateLimiter();
+    TestKinesisConsumer kinesisConsumer = new 
TestKinesisConsumer(_kinesisConfig, kinesisClient, requestRateLimiter);
+    requestRateLimiter.setKinesisConsumer(kinesisConsumer);
+
+    KinesisPartitionGroupOffset startOffset = new 
KinesisPartitionGroupOffset("0", "1");
+    KinesisMessageBatch kinesisMessageBatch = 
kinesisConsumer.fetchMessages(startOffset, 5000);
+
+    assertEquals(kinesisMessageBatch.getMessageCount(), 0);
+    assertEquals(kinesisMessageBatch.getOffsetOfNextBatch(), startOffset);
+    assertFalse(kinesisMessageBatch.isEndOfPartitionGroup());
+    assertEquals(requestRateLimiter.getRequestTypes(),
+        List.of(KinesisConsumer.RequestType.GET_SHARD_ITERATOR, 
KinesisConsumer.RequestType.GET_RECORDS));
+    assertEquals(requestRateLimiter.getTimeoutMsList(), List.of(5000L, 1000L));
+    verify(kinesisClient, 
times(1)).getShardIterator(any(GetShardIteratorRequest.class));
+    verify(kinesisClient, times(0)).getRecords(any(GetRecordsRequest.class));
+  }
+
+  @Test
+  public void testSharedLimiterUsesSameLimiterForSameStreamShardOperation() {
+    KinesisConsumer.SharedKinesisRequestRateLimiter requestRateLimiter =
+        new KinesisConsumer.SharedKinesisRequestRateLimiter();
+
+    assertTrue(requestRateLimiter.tryAcquire(STREAM_NAME, "0", 
KinesisConsumer.RequestType.GET_RECORDS, 1_000_000.0,
+        1000));
+    assertTrue(requestRateLimiter.tryAcquire(STREAM_NAME, "0", 
KinesisConsumer.RequestType.GET_RECORDS, 500_000.0,
+        1000));
+
+    assertEquals(requestRateLimiter.getLimiterCountForTesting(), 1);
+    assertEquals(requestRateLimiter.getRateForTesting(STREAM_NAME, "0", 
KinesisConsumer.RequestType.GET_RECORDS),
+        500_000.0, DOUBLE_COMPARISON_DELTA);
+
+    assertTrue(requestRateLimiter.tryAcquire(STREAM_NAME, "0", 
KinesisConsumer.RequestType.GET_SHARD_ITERATOR,
+        250_000.0, 1000));
+    assertEquals(requestRateLimiter.getLimiterCountForTesting(), 2);
+  }
+
+  private KinesisConsumer newTestConsumer(KinesisClient kinesisClient) {
+    return new TestKinesisConsumer(_kinesisConfig, kinesisClient, 
NO_OP_REQUEST_RATE_LIMITER);
+  }
+
+  private ProvisionedThroughputExceededException rateLimitException() {
+    return 
ProvisionedThroughputExceededException.builder().message("throttled").build();
+  }
+
+  private static class TestKinesisConsumer extends KinesisConsumer {
+    private final List<Long> _sleepMsList = new ArrayList<>();
+    private long _currentTimeMs;
+
+    TestKinesisConsumer(KinesisConfig config, KinesisClient kinesisClient,
+        KinesisConsumer.RequestRateLimiter requestRateLimiter) {
+      super(config, kinesisClient, requestRateLimiter);
+    }
+
+    @Override
+    void sleep(long sleepMs) {
+      _sleepMsList.add(sleepMs);
+      _currentTimeMs += sleepMs;
+    }
+
+    @Override
+    long currentTimeMillis() {
+      return _currentTimeMs;
+    }
+
+    @Override
+    long getRateLimitBackoffJitterMs(long maxJitterMs) {
+      return 0L;
+    }
+
+    List<Long> getSleepMsList() {
+      return _sleepMsList;
+    }
+
+    void advanceTimeMs(long timeMs) {
+      _currentTimeMs += timeMs;
+    }
+  }
+
+  private static class CapturingRequestRateLimiter implements 
KinesisConsumer.RequestRateLimiter {
+    private final List<KinesisConsumer.RequestType> _requestTypes = new 
ArrayList<>();
+    private final List<Long> _timeoutMsList = new ArrayList<>();
+
+    @Override
+    public boolean tryAcquire(String streamName, String shardId, 
KinesisConsumer.RequestType requestType,
+        double rpsLimit, long timeoutMs) {
+      _requestTypes.add(requestType);
+      _timeoutMsList.add(timeoutMs);
+      return true;
+    }
+
+    List<KinesisConsumer.RequestType> getRequestTypes() {
+      return _requestTypes;
+    }
+
+    List<Long> getTimeoutMsList() {
+      return _timeoutMsList;
+    }
+  }
+
+  private static class AdvancingRequestRateLimiter extends 
CapturingRequestRateLimiter {
+    private TestKinesisConsumer _kinesisConsumer;
+
+    void setKinesisConsumer(TestKinesisConsumer kinesisConsumer) {
+      _kinesisConsumer = kinesisConsumer;
+    }
+
+    @Override
+    public boolean tryAcquire(String streamName, String shardId, 
KinesisConsumer.RequestType requestType,
+        double rpsLimit, long timeoutMs) {
+      super.tryAcquire(streamName, shardId, requestType, rpsLimit, timeoutMs);
+      if (requestType == KinesisConsumer.RequestType.GET_SHARD_ITERATOR) {
+        _kinesisConsumer.advanceTimeMs(4000);
+        return true;
+      }
+      return false;
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to