This is an automated email from the ASF dual-hosted git repository.
lhotari pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/pulsar.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new b2ca58adaf1 [fix][broker] Guard AsyncTokenBucket against long overflow
(#25262)
b2ca58adaf1 is described below
commit b2ca58adaf1e99a9466fe8494f9ff6a082c2545b
Author: Penghui Li <[email protected]>
AuthorDate: Thu Feb 26 03:11:36 2026 -0800
[fix][broker] Guard AsyncTokenBucket against long overflow (#25262)
Co-authored-by: Lari Hotari <[email protected]>
(cherry picked from commit d207d9c09abaf00d9a0f0e66c6cc149e70e70198)
---
.../apache/pulsar/broker/qos/AsyncTokenBucket.java | 50 +++++++++++++++++++--
.../pulsar/broker/qos/AsyncTokenBucketTest.java | 51 +++++++++++++++++++++-
2 files changed, 96 insertions(+), 5 deletions(-)
diff --git
a/pulsar-broker/src/main/java/org/apache/pulsar/broker/qos/AsyncTokenBucket.java
b/pulsar-broker/src/main/java/org/apache/pulsar/broker/qos/AsyncTokenBucket.java
index e4feb24453a..f7fc0031ccd 100644
---
a/pulsar-broker/src/main/java/org/apache/pulsar/broker/qos/AsyncTokenBucket.java
+++
b/pulsar-broker/src/main/java/org/apache/pulsar/broker/qos/AsyncTokenBucket.java
@@ -19,6 +19,7 @@
package org.apache.pulsar.broker.qos;
+import java.math.BigInteger;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.concurrent.atomic.LongAdder;
@@ -206,9 +207,10 @@ public abstract class AsyncTokenBucket {
long currentRatePeriodNanos = getRatePeriodNanos();
// new tokens is the amount of tokens that are created in the
duration since the last update
// with the configured rate
- newTokens = (durationNanos * currentRate) / currentRatePeriodNanos;
+ newTokens = safeMulDivFloor(durationNanos, currentRate,
currentRatePeriodNanos);
// carry forward the remainder nanos so that the rounding error is
eliminated
- long remainderNanos = durationNanos - ((newTokens *
currentRatePeriodNanos) / currentRate);
+ long consumedNanos = safeMulDivFloor(newTokens,
currentRatePeriodNanos, currentRate);
+ long remainderNanos = durationNanos >= consumedNanos ?
durationNanos - consumedNanos : 0;
if (remainderNanos > 0) {
REMAINDER_NANOS_UPDATER.addAndGet(this, remainderNanos);
}
@@ -263,13 +265,53 @@ public abstract class AsyncTokenBucket {
*/
public long calculateThrottlingDuration(long requiredTokens) {
long currentTokens = consumeTokensAndMaybeUpdateTokensBalance(0);
+
if (currentTokens >= requiredTokens) {
return 0L;
}
// when currentTokens is negative, subtracting a negative value
results in
// adding the absolute value (-(-x) -> +x)
- long needTokens = requiredTokens - currentTokens;
- return (needTokens * getRatePeriodNanos()) / getRate();
+ long needTokens;
+ try {
+ needTokens = Math.subtractExact(requiredTokens, currentTokens);
+ } catch (ArithmeticException e) {
+ needTokens = Long.MAX_VALUE;
+ }
+ return safeMulDivFloor(needTokens, getRatePeriodNanos(), getRate());
+ }
+
+ private static long safeMulDivFloor(long multiplicand, long multiplier,
long divisor) {
+ if (multiplicand < 0 || multiplier < 0) {
+ throw new IllegalArgumentException("multiplicand and multiplier
must be >= 0");
+ }
+ if (divisor <= 0) {
+ throw new IllegalArgumentException("divisor must be > 0");
+ }
+ if (multiplicand == 0 || multiplier == 0) {
+ return 0;
+ }
+ // Fast path
+ // Check if multiplication fits in a 64-bit value
+ // Math.multiplyHigh is intrinsified by the JVM (single mulq/mul
instruction),
+ // avoiding the cost of a division-based overflow check.
+ // It returns the upper 64 bits of the full 128-bit multiplication
result.
+ // When the result is 0, the product fits in 64 bits.
+ if (Math.multiplyHigh(multiplicand, multiplier) == 0) {
+ long product = multiplicand * multiplier;
+ if (product >= 0) {
+ // product fits in signed 64-bit
+ return product / divisor;
+ }
+ // product is in [2^63, 2^64): fits unsigned but not signed
+ long result = Long.divideUnsigned(product, divisor);
+ // cap at Long.MAX_VALUE if result itself overflows signed long
+ return result >= 0 ? result : Long.MAX_VALUE;
+ }
+ // Fallback to BigInteger division
+ BigInteger result = BigInteger.valueOf(multiplicand)
+ .multiply(BigInteger.valueOf(multiplier))
+ .divide(BigInteger.valueOf(divisor));
+ return result.bitLength() < Long.SIZE ? result.longValue() :
Long.MAX_VALUE;
}
/**
diff --git
a/pulsar-broker/src/test/java/org/apache/pulsar/broker/qos/AsyncTokenBucketTest.java
b/pulsar-broker/src/test/java/org/apache/pulsar/broker/qos/AsyncTokenBucketTest.java
index 55ca9940541..f709cb65448 100644
---
a/pulsar-broker/src/test/java/org/apache/pulsar/broker/qos/AsyncTokenBucketTest.java
+++
b/pulsar-broker/src/test/java/org/apache/pulsar/broker/qos/AsyncTokenBucketTest.java
@@ -24,6 +24,7 @@ import static org.testng.Assert.assertEquals;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
@@ -195,4 +196,52 @@ public class AsyncTokenBucketTest {
// iteration, the tokens should be equal to the initial tokens
.isEqualTo(initialTokens);
}
-}
\ No newline at end of file
+
+ @DataProvider(name = "largeRates")
+ public Object[][] largeRates() {
+ return new Object[][]{
+ {500_000_000L},
+ {980_000_000L},
+ {1_000_000_000L},
+ {1_500_000_000L},
+ {2_000_000_000L},
+ {Long.MAX_VALUE / 100L},
+ {Long.MAX_VALUE / 10L},
+ {Long.MAX_VALUE / 9L},
+ {Long.MAX_VALUE}
+ };
+ }
+
+ @Test(dataProvider = "largeRates")
+ void shouldRefillTokensWithoutOverflowForLargeRateAnd10sPeriod(long rate) {
+ long ratePeriodNanos = TimeUnit.SECONDS.toNanos(10);
+ asyncTokenBucket =
+ AsyncTokenBucket.builder()
+ .rate(rate)
+ .ratePeriodNanos(ratePeriodNanos)
+ .addTokensResolutionNanos(ratePeriodNanos)
+ .initialTokens(0)
+ .clock(clockSource)
+ .build();
+
+ incrementSeconds(10);
+ incrementMillis(1);
+
+ assertEquals(asyncTokenBucket.getTokens(), rate);
+ }
+
+ @Test
+ void shouldCalculateThrottlingDurationWithoutOverflowForLargeNeedTokens() {
+ asyncTokenBucket =
+ AsyncTokenBucket.builder()
+ .rate(1)
+ .ratePeriodNanos(TimeUnit.SECONDS.toNanos(10))
+ .initialTokens(0)
+ .clock(clockSource)
+ .build();
+ asyncTokenBucket.consumeTokens(1);
+
+ long throttlingDuration =
asyncTokenBucket.calculateThrottlingDuration(1_000_000_000L);
+ assertEquals(throttlingDuration, Long.MAX_VALUE);
+ }
+}