This is an automated email from the ASF dual-hosted git repository.

lhotari pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar.git


The following commit(s) were added to refs/heads/master by this push:
     new d207d9c09ab [fix][broker] Guard AsyncTokenBucket against long overflow 
(#25262)
d207d9c09ab is described below

commit d207d9c09abaf00d9a0f0e66c6cc149e70e70198
Author: Penghui Li <[email protected]>
AuthorDate: Thu Feb 26 03:11:36 2026 -0800

    [fix][broker] Guard AsyncTokenBucket against long overflow (#25262)
    
    Co-authored-by: Lari Hotari <[email protected]>
---
 .../apache/pulsar/broker/qos/AsyncTokenBucket.java | 50 +++++++++++++++++++--
 .../pulsar/broker/qos/AsyncTokenBucketTest.java    | 51 +++++++++++++++++++++-
 2 files changed, 96 insertions(+), 5 deletions(-)

diff --git 
a/pulsar-broker/src/main/java/org/apache/pulsar/broker/qos/AsyncTokenBucket.java
 
b/pulsar-broker/src/main/java/org/apache/pulsar/broker/qos/AsyncTokenBucket.java
index e4feb24453a..f7fc0031ccd 100644
--- 
a/pulsar-broker/src/main/java/org/apache/pulsar/broker/qos/AsyncTokenBucket.java
+++ 
b/pulsar-broker/src/main/java/org/apache/pulsar/broker/qos/AsyncTokenBucket.java
@@ -19,6 +19,7 @@
 
 package org.apache.pulsar.broker.qos;
 
+import java.math.BigInteger;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLongFieldUpdater;
 import java.util.concurrent.atomic.LongAdder;
@@ -206,9 +207,10 @@ public abstract class AsyncTokenBucket {
             long currentRatePeriodNanos = getRatePeriodNanos();
             // new tokens is the amount of tokens that are created in the 
duration since the last update
             // with the configured rate
-            newTokens = (durationNanos * currentRate) / currentRatePeriodNanos;
+            newTokens = safeMulDivFloor(durationNanos, currentRate, 
currentRatePeriodNanos);
             // carry forward the remainder nanos so that the rounding error is 
eliminated
-            long remainderNanos = durationNanos - ((newTokens * 
currentRatePeriodNanos) / currentRate);
+            long consumedNanos = safeMulDivFloor(newTokens, 
currentRatePeriodNanos, currentRate);
+            long remainderNanos = durationNanos >= consumedNanos ? 
durationNanos - consumedNanos : 0;
             if (remainderNanos > 0) {
                 REMAINDER_NANOS_UPDATER.addAndGet(this, remainderNanos);
             }
@@ -263,13 +265,53 @@ public abstract class AsyncTokenBucket {
      */
     public long calculateThrottlingDuration(long requiredTokens) {
         long currentTokens = consumeTokensAndMaybeUpdateTokensBalance(0);
+
         if (currentTokens >= requiredTokens) {
             return 0L;
         }
         // when currentTokens is negative, subtracting a negative value 
results in
         // adding the absolute value (-(-x) -> +x)
-        long needTokens = requiredTokens - currentTokens;
-        return (needTokens * getRatePeriodNanos()) / getRate();
+        long needTokens;
+        try {
+            needTokens = Math.subtractExact(requiredTokens, currentTokens);
+        } catch (ArithmeticException e) {
+            needTokens = Long.MAX_VALUE;
+        }
+        return safeMulDivFloor(needTokens, getRatePeriodNanos(), getRate());
+    }
+
+    private static long safeMulDivFloor(long multiplicand, long multiplier, 
long divisor) {
+        if (multiplicand < 0 || multiplier < 0) {
+            throw new IllegalArgumentException("multiplicand and multiplier 
must be >= 0");
+        }
+        if (divisor <= 0) {
+            throw new IllegalArgumentException("divisor must be > 0");
+        }
+        if (multiplicand == 0 || multiplier == 0) {
+            return 0;
+        }
+        // Fast path
+        // Check if multiplication fits in a 64-bit value
+        // Math.multiplyHigh is intrinsified by the JVM (single mulq/mul 
instruction),
+        // avoiding the cost of a division-based overflow check.
+        // It returns the upper 64 bits of the full 128-bit multiplication 
result.
+        // When the result is 0, the product fits in 64 bits.
+        if (Math.multiplyHigh(multiplicand, multiplier) == 0) {
+            long product = multiplicand * multiplier;
+            if (product >= 0) {
+                // product fits in signed 64-bit
+                return product / divisor;
+            }
+            // product is in [2^63, 2^64): fits unsigned but not signed
+            long result = Long.divideUnsigned(product, divisor);
+            // cap at Long.MAX_VALUE if result itself overflows signed long
+            return result >= 0 ? result : Long.MAX_VALUE;
+        }
+        // Fallback to BigInteger division
+        BigInteger result = BigInteger.valueOf(multiplicand)
+                .multiply(BigInteger.valueOf(multiplier))
+                .divide(BigInteger.valueOf(divisor));
+        return result.bitLength() < Long.SIZE ? result.longValue() : 
Long.MAX_VALUE;
     }
 
     /**
diff --git 
a/pulsar-broker/src/test/java/org/apache/pulsar/broker/qos/AsyncTokenBucketTest.java
 
b/pulsar-broker/src/test/java/org/apache/pulsar/broker/qos/AsyncTokenBucketTest.java
index 55ca9940541..f709cb65448 100644
--- 
a/pulsar-broker/src/test/java/org/apache/pulsar/broker/qos/AsyncTokenBucketTest.java
+++ 
b/pulsar-broker/src/test/java/org/apache/pulsar/broker/qos/AsyncTokenBucketTest.java
@@ -24,6 +24,7 @@ import static org.testng.Assert.assertEquals;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
 import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 
@@ -195,4 +196,52 @@ public class AsyncTokenBucketTest {
                 // iteration, the tokens should be equal to the initial tokens
                 .isEqualTo(initialTokens);
     }
-}
\ No newline at end of file
+
+    @DataProvider(name = "largeRates")
+    public Object[][] largeRates() {
+        return new Object[][]{
+                {500_000_000L},
+                {980_000_000L},
+                {1_000_000_000L},
+                {1_500_000_000L},
+                {2_000_000_000L},
+                {Long.MAX_VALUE / 100L},
+                {Long.MAX_VALUE / 10L},
+                {Long.MAX_VALUE / 9L},
+                {Long.MAX_VALUE}
+        };
+    }
+
+    @Test(dataProvider = "largeRates")
+    void shouldRefillTokensWithoutOverflowForLargeRateAnd10sPeriod(long rate) {
+        long ratePeriodNanos = TimeUnit.SECONDS.toNanos(10);
+        asyncTokenBucket =
+                AsyncTokenBucket.builder()
+                        .rate(rate)
+                        .ratePeriodNanos(ratePeriodNanos)
+                        .addTokensResolutionNanos(ratePeriodNanos)
+                        .initialTokens(0)
+                        .clock(clockSource)
+                        .build();
+
+        incrementSeconds(10);
+        incrementMillis(1);
+
+        assertEquals(asyncTokenBucket.getTokens(), rate);
+    }
+
+    @Test
+    void shouldCalculateThrottlingDurationWithoutOverflowForLargeNeedTokens() {
+        asyncTokenBucket =
+                AsyncTokenBucket.builder()
+                        .rate(1)
+                        .ratePeriodNanos(TimeUnit.SECONDS.toNanos(10))
+                        .initialTokens(0)
+                        .clock(clockSource)
+                        .build();
+        asyncTokenBucket.consumeTokens(1);
+
+        long throttlingDuration = 
asyncTokenBucket.calculateThrottlingDuration(1_000_000_000L);
+        assertEquals(throttlingDuration, Long.MAX_VALUE);
+    }
+}

Reply via email to