This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new a57ff95a4 [GLUTEN-6736] Phase 1: Use task-shared lock in 
ManagedReservationListener (#6741)
a57ff95a4 is described below

commit a57ff95a4110a5bc4d82e7aac7cd136952362d20
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu Aug 8 10:06:01 2024 +0800

    [GLUTEN-6736] Phase 1: Use task-shared lock in ManagedReservationListener 
(#6741)
---
 .../gluten/memory/SimpleMemoryUsageRecorder.java   |   8 +-
 .../gluten/memory/memtarget/MemoryTargets.java     |   2 +-
 .../gluten/memory/memtarget/OverAcquire.java       |  33 ++--
 .../gluten/memory/memtarget/TreeMemoryTargets.java |   6 +
 .../memtarget/spark/TreeMemoryConsumerTest.java    |  80 ++++++++++
 gluten-data/pom.xml                                |  46 ++++++
 .../listener/ManagedReservationListener.java       |  17 ++-
 .../memory/listener/ReservationListeners.java      |  44 +++---
 .../execution/MassiveMemoryAllocationSuite.scala   | 167 +++++++++++++++++++++
 9 files changed, 352 insertions(+), 51 deletions(-)

diff --git 
a/gluten-core/src/main/java/org/apache/gluten/memory/SimpleMemoryUsageRecorder.java
 
b/gluten-core/src/main/java/org/apache/gluten/memory/SimpleMemoryUsageRecorder.java
index 16b260469..fb8b0d1e2 100644
--- 
a/gluten-core/src/main/java/org/apache/gluten/memory/SimpleMemoryUsageRecorder.java
+++ 
b/gluten-core/src/main/java/org/apache/gluten/memory/SimpleMemoryUsageRecorder.java
@@ -30,13 +30,13 @@ public class SimpleMemoryUsageRecorder implements 
MemoryUsageRecorder {
   @Override
   public void inc(long bytes) {
     final long total = this.current.addAndGet(bytes);
-    long prev_peak;
+    long prevPeak;
     do {
-      prev_peak = this.peak.get();
-      if (total <= prev_peak) {
+      prevPeak = this.peak.get();
+      if (total <= prevPeak) {
         break;
       }
-    } while (!this.peak.compareAndSet(prev_peak, total));
+    } while (!this.peak.compareAndSet(prevPeak, total));
   }
 
   // peak used bytes
diff --git 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
index 75e3db2e7..bb1e7102b 100644
--- 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
+++ 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
@@ -50,7 +50,7 @@ public final class MemoryTargets {
     return memoryTarget;
   }
 
-  public static MemoryTarget newConsumer(
+  public static TreeMemoryTarget newConsumer(
       TaskMemoryManager tmm,
       String name,
       Spiller spiller,
diff --git 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/OverAcquire.java 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/OverAcquire.java
index ac82161ba..e7321b4b7 100644
--- 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/OverAcquire.java
+++ 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/OverAcquire.java
@@ -52,31 +52,28 @@ public class OverAcquire implements MemoryTarget {
 
   @Override
   public long borrow(long size) {
-    Preconditions.checkArgument(size != 0, "Size to borrow is zero");
+    if (size == 0) {
+      return 0;
+    }
+    Preconditions.checkState(overTarget.usedBytes() == 0);
     long granted = target.borrow(size);
     long majorSize = target.usedBytes();
-    long expectedOverAcquired = (long) (ratio * majorSize);
-    long overAcquired = overTarget.usedBytes();
-    long diff = expectedOverAcquired - overAcquired;
-    if (diff >= 0) { // otherwise, there might be a spill happened during the 
last borrow() call
-      overTarget.borrow(diff); // we don't have to check the returned value
-    }
+    long overSize = (long) (ratio * majorSize);
+    long overAcquired = overTarget.borrow(overSize);
+    Preconditions.checkState(overAcquired == overTarget.usedBytes());
+    long releasedOverSize = overTarget.repay(overAcquired);
+    Preconditions.checkState(releasedOverSize == overAcquired);
+    Preconditions.checkState(overTarget.usedBytes() == 0);
     return granted;
   }
 
   @Override
   public long repay(long size) {
-    Preconditions.checkArgument(size != 0, "Size to repay is zero");
-    long freed = target.repay(size);
-    // clean up the over-acquired target
-    long overAcquired = overTarget.usedBytes();
-    long freedOverAcquired = overTarget.repay(overAcquired);
-    Preconditions.checkArgument(
-        freedOverAcquired == overAcquired,
-        "Freed over-acquired size is not equal to requested size");
-    Preconditions.checkArgument(
-        overTarget.usedBytes() == 0, "Over-acquired target was not cleaned 
up");
-    return freed;
+    if (size == 0) {
+      return 0;
+    }
+    Preconditions.checkState(overTarget.usedBytes() == 0);
+    return target.repay(size);
   }
 
   @Override
diff --git 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
index 24d9fc0e2..98f79bfff 100644
--- 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
+++ 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
@@ -114,6 +114,9 @@ public class TreeMemoryTargets {
 
     @Override
     public long borrow(long size) {
+      if (size == 0) {
+        return 0;
+      }
       ensureFreeCapacity(size);
       return borrow0(Math.min(freeBytes(), size));
     }
@@ -154,6 +157,9 @@ public class TreeMemoryTargets {
 
     @Override
     public long repay(long size) {
+      if (size == 0) {
+        return 0;
+      }
       long toFree = Math.min(usedBytes(), size);
       long freed = parent.repay(toFree);
       selfRecorder.inc(-freed);
diff --git 
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
 
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
index 1632e5ef4..bbc43ba5d 100644
--- 
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
+++ 
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
@@ -17,6 +17,8 @@
 package org.apache.gluten.memory.memtarget.spark;
 
 import org.apache.gluten.GlutenConfig;
+import org.apache.gluten.memory.memtarget.MemoryTarget;
+import org.apache.gluten.memory.memtarget.Spiller;
 import org.apache.gluten.memory.memtarget.Spillers;
 import org.apache.gluten.memory.memtarget.TreeMemoryTarget;
 
@@ -28,6 +30,8 @@ import org.junit.Before;
 import org.junit.Test;
 
 import java.util.Collections;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
 
 import scala.Function0;
 
@@ -100,6 +104,82 @@ public class TreeMemoryConsumerTest {
         });
   }
 
+  @Test
+  public void testSpill() {
+    test(
+        () -> {
+          final Spillers.AppendableSpillerList spillers = 
Spillers.appendable();
+          final TreeMemoryTarget shared =
+              TreeMemoryConsumers.shared()
+                  .newConsumer(
+                      TaskContext.get().taskMemoryManager(),
+                      "FOO",
+                      spillers,
+                      Collections.emptyMap());
+          final AtomicInteger numSpills = new AtomicInteger(0);
+          final AtomicLong numSpilledBytes = new AtomicLong(0L);
+          spillers.append(
+              new Spiller() {
+                @Override
+                public long spill(MemoryTarget self, Phase phase, long size) {
+                  long repaid = shared.repay(size);
+                  numSpills.getAndIncrement();
+                  numSpilledBytes.getAndAdd(repaid);
+                  return repaid;
+                }
+              });
+          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(1, numSpills.get());
+          Assert.assertEquals(200, numSpilledBytes.get());
+          Assert.assertEquals(400, shared.usedBytes());
+
+          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(3, numSpills.get());
+          Assert.assertEquals(800, numSpilledBytes.get());
+          Assert.assertEquals(400, shared.usedBytes());
+        });
+  }
+
+  @Test
+  public void testOverSpill() {
+    test(
+        () -> {
+          final Spillers.AppendableSpillerList spillers = 
Spillers.appendable();
+          final TreeMemoryTarget shared =
+              TreeMemoryConsumers.shared()
+                  .newConsumer(
+                      TaskContext.get().taskMemoryManager(),
+                      "FOO",
+                      spillers,
+                      Collections.emptyMap());
+          final AtomicInteger numSpills = new AtomicInteger(0);
+          final AtomicLong numSpilledBytes = new AtomicLong(0L);
+          spillers.append(
+              new Spiller() {
+                @Override
+                public long spill(MemoryTarget self, Phase phase, long size) {
+                  long repaid = shared.repay(Long.MAX_VALUE);
+                  numSpills.getAndIncrement();
+                  numSpilledBytes.getAndAdd(repaid);
+                  return repaid;
+                }
+              });
+          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(1, numSpills.get());
+          Assert.assertEquals(300, numSpilledBytes.get());
+          Assert.assertEquals(300, shared.usedBytes());
+
+          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(3, numSpills.get());
+          Assert.assertEquals(900, numSpilledBytes.get());
+          Assert.assertEquals(300, shared.usedBytes());
+        });
+  }
+
   private void test(Runnable r) {
     TaskResources$.MODULE$.runUnsafe(
         new Function0<Object>() {
diff --git a/gluten-data/pom.xml b/gluten-data/pom.xml
index 500708d44..ffb56db43 100644
--- a/gluten-data/pom.xml
+++ b/gluten-data/pom.xml
@@ -195,6 +195,52 @@
         </exclusion>
       </exclusions>
     </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-core_${scala.binary.version}</artifactId>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-sql_${scala.binary.version}</artifactId>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.scalatest</groupId>
+      <artifactId>scalatest_${scala.binary.version}</artifactId>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.mockito</groupId>
+      <artifactId>mockito-core</artifactId>
+      <version>2.23.4</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>junit</groupId>
+      <artifactId>junit</artifactId>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.scalatestplus</groupId>
+      <artifactId>scalatestplus-mockito_${scala.binary.version}</artifactId>
+      <version>1.0.0-M2</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.scalatestplus</groupId>
+      <artifactId>scalatestplus-scalacheck_${scala.binary.version}</artifactId>
+      <version>3.1.0.0-RC2</version>
+      <scope>test</scope>
+    </dependency>
   </dependencies>
 
 </project>
diff --git 
a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java
 
b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java
index 4af8eb4e3..7c7fac8da 100644
--- 
a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java
+++ 
b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java
@@ -19,7 +19,6 @@ package org.apache.gluten.memory.listener;
 import org.apache.gluten.memory.SimpleMemoryUsageRecorder;
 import org.apache.gluten.memory.memtarget.MemoryTarget;
 
-import com.google.common.base.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -29,16 +28,23 @@ public class ManagedReservationListener implements 
ReservationListener {
   private static final Logger LOG = 
LoggerFactory.getLogger(ManagedReservationListener.class);
 
   private final MemoryTarget target;
-  private final SimpleMemoryUsageRecorder sharedUsage; // shared task metrics
+  // Metrics shared by task.
+  private final SimpleMemoryUsageRecorder sharedUsage;
+  // Lock shared by task. Using a common lock avoids ABBA deadlock
+  // when multiple listeners created under the same TMM.
+  // See: https://github.com/apache/incubator-gluten/issues/6622
+  private final Object sharedLock;
 
-  public ManagedReservationListener(MemoryTarget target, 
SimpleMemoryUsageRecorder sharedUsage) {
+  public ManagedReservationListener(
+      MemoryTarget target, SimpleMemoryUsageRecorder sharedUsage, Object 
sharedLock) {
     this.target = target;
     this.sharedUsage = sharedUsage;
+    this.sharedLock = sharedLock;
   }
 
   @Override
   public long reserve(long size) {
-    synchronized (this) {
+    synchronized (sharedLock) {
       try {
         long granted = target.borrow(size);
         sharedUsage.inc(granted);
@@ -52,11 +58,10 @@ public class ManagedReservationListener implements 
ReservationListener {
 
   @Override
   public long unreserve(long size) {
-    synchronized (this) {
+    synchronized (sharedLock) {
       try {
         long freed = target.repay(size);
         sharedUsage.inc(-freed);
-        Preconditions.checkState(freed == size);
         return freed;
       } catch (Exception e) {
         LOG.error("Error unreserving memory from target", e);
diff --git 
a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java
 
b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java
index 47b9937eb..db5ac8426 100644
--- 
a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java
+++ 
b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java
@@ -29,7 +29,8 @@ import java.util.Map;
 
 public final class ReservationListeners {
   public static final ReservationListener NOOP =
-      new ManagedReservationListener(new NoopMemoryTarget(), new 
SimpleMemoryUsageRecorder());
+      new ManagedReservationListener(
+          new NoopMemoryTarget(), new SimpleMemoryUsageRecorder(), new 
Object());
 
   public static ReservationListener create(
       String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder> 
mutableStats) {
@@ -46,32 +47,31 @@ public final class ReservationListeners {
     final double overAcquiredRatio = 
GlutenConfig.getConf().memoryOverAcquiredRatio();
     final long reservationBlockSize = 
GlutenConfig.getConf().memoryReservationBlockSize();
     final TaskMemoryManager tmm = 
TaskResources.getLocalTaskContext().taskMemoryManager();
+    final TreeMemoryTarget consumer =
+        MemoryTargets.newConsumer(
+            tmm, name, Spillers.withMinSpillSize(spiller, 
reservationBlockSize), mutableStats);
+    final MemoryTarget overConsumer =
+        MemoryTargets.newConsumer(
+            tmm,
+            consumer.name() + ".OverAcquire",
+            new Spiller() {
+              @Override
+              public long spill(MemoryTarget self, Phase phase, long size) {
+                if (!Spillers.PHASE_SET_ALL.contains(phase)) {
+                  return 0L;
+                }
+                return self.repay(size);
+              }
+            },
+            Collections.emptyMap());
     final MemoryTarget target =
         MemoryTargets.throwOnOom(
             MemoryTargets.overAcquire(
-                MemoryTargets.dynamicOffHeapSizingIfEnabled(
-                    MemoryTargets.newConsumer(
-                        tmm,
-                        name,
-                        Spillers.withMinSpillSize(spiller, 
reservationBlockSize),
-                        mutableStats)),
-                MemoryTargets.dynamicOffHeapSizingIfEnabled(
-                    MemoryTargets.newConsumer(
-                        tmm,
-                        "OverAcquire.DummyTarget",
-                        new Spiller() {
-                          @Override
-                          public long spill(MemoryTarget self, Spiller.Phase 
phase, long size) {
-                            if (!Spillers.PHASE_SET_ALL.contains(phase)) {
-                              return 0L;
-                            }
-                            return self.repay(size);
-                          }
-                        },
-                        Collections.emptyMap())),
+                MemoryTargets.dynamicOffHeapSizingIfEnabled(consumer),
+                MemoryTargets.dynamicOffHeapSizingIfEnabled(overConsumer),
                 overAcquiredRatio));
 
     // Listener.
-    return new ManagedReservationListener(target, 
TaskResources.getSharedUsage());
+    return new ManagedReservationListener(target, 
TaskResources.getSharedUsage(), tmm);
   }
 }
diff --git 
a/gluten-data/src/test/scala/org/apache/gluten/execution/MassiveMemoryAllocationSuite.scala
 
b/gluten-data/src/test/scala/org/apache/gluten/execution/MassiveMemoryAllocationSuite.scala
new file mode 100644
index 000000000..ebfa0e612
--- /dev/null
+++ 
b/gluten-data/src/test/scala/org/apache/gluten/execution/MassiveMemoryAllocationSuite.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.memory.MemoryUsageStatsBuilder
+import org.apache.gluten.memory.listener.{ReservationListener, 
ReservationListeners}
+import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.TaskResources
+
+import java.util.concurrent.{Callable, Executors, TimeUnit}
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+class MassiveMemoryAllocationSuite extends SparkFunSuite with 
SharedSparkSession {
+  test("concurrent allocation with spill - shared listener") {
+    val numThreads = 50
+    val offHeapSize = 500
+    val minExtraSpillSize = 2
+    val maxExtraSpillSize = 5
+    val numAllocations = 100
+    val minAllocationSize = 40
+    val maxAllocationSize = 100
+    val minAllocationDelayMs = 0
+    val maxAllocationDelayMs = 0
+    withSQLConf("spark.memory.offHeap.size" -> s"$offHeapSize") {
+      val total = new AtomicLong(0L)
+      TaskResources.runUnsafe {
+        val spiller = Spillers.appendable()
+        val listener = ReservationListeners.create(
+          s"listener",
+          spiller,
+          Map[String, MemoryUsageStatsBuilder]().asJava)
+        spiller.append(new Spiller() {
+          override def spill(self: MemoryTarget, phase: Spiller.Phase, size: 
Long): Long = {
+            val extraSpillSize = randomInt(minExtraSpillSize, 
maxExtraSpillSize)
+            val spillSize = size + extraSpillSize
+            val released = listener.unreserve(spillSize)
+            assert(released <= spillSize)
+            total.getAndAdd(-released)
+            spillSize
+          }
+        })
+        val pool = Executors.newFixedThreadPool(numThreads)
+        val tasks = (0 until numThreads).map {
+          _ =>
+            new Callable[Unit]() {
+              override def call(): Unit = {
+                (0 until numAllocations).foreach {
+                  _ =>
+                    val allocSize =
+                      randomInt(minAllocationSize, maxAllocationSize)
+                    val granted = listener.reserve(allocSize)
+                    assert(granted == allocSize)
+                    total.getAndAdd(granted)
+                    val sleepMs =
+                      randomInt(minAllocationDelayMs, maxAllocationDelayMs)
+                    Thread.sleep(sleepMs)
+                }
+              }
+            }
+        }.toList
+        val futures = pool.invokeAll(tasks.asJava)
+        pool.shutdown()
+        pool.awaitTermination(60, TimeUnit.SECONDS)
+        futures.forEach(_.get())
+        val totalBytes = total.get()
+        val released = listener.unreserve(totalBytes)
+        assert(released == totalBytes)
+        assert(listener.getUsedBytes == 0)
+      }
+    }
+  }
+
+  test("concurrent allocation with spill - dedicated listeners") {
+    val numThreads = 50
+    val offHeapSize = 500
+    val minExtraSpillSize = 2
+    val maxExtraSpillSize = 5
+    val numAllocations = 100
+    val minAllocationSize = 40
+    val maxAllocationSize = 100
+    val minAllocationDelayMs = 0
+    val maxAllocationDelayMs = 0
+    withSQLConf("spark.memory.offHeap.size" -> s"$offHeapSize") {
+      TaskResources.runUnsafe {
+        val total = new AtomicLong(0L)
+
+        def newListener(id: Int): ReservationListener = {
+          val spiller = Spillers.appendable()
+          val listener = ReservationListeners.create(
+            s"listener $id",
+            spiller,
+            Map[String, MemoryUsageStatsBuilder]().asJava)
+          spiller.append(new Spiller() {
+            override def spill(self: MemoryTarget, phase: Spiller.Phase, size: 
Long): Long = {
+              val extraSpillSize = randomInt(minExtraSpillSize, 
maxExtraSpillSize)
+              val spillSize = size + extraSpillSize
+              val released = listener.unreserve(spillSize)
+              assert(released <= spillSize)
+              total.getAndAdd(-released)
+              spillSize
+            }
+          })
+          listener
+        }
+
+        val listeners = (0 until numThreads).map(newListener).toList
+        val pool = Executors.newFixedThreadPool(numThreads)
+        val tasks = (0 until numThreads).map {
+          i =>
+            new Callable[Unit]() {
+              override def call(): Unit = {
+                val listener = listeners(i)
+                (0 until numAllocations).foreach {
+                  _ =>
+                    val allocSize =
+                      randomInt(minAllocationSize, maxAllocationSize)
+                    val granted = listener.reserve(allocSize)
+                    assert(granted == allocSize)
+                    total.getAndAdd(granted)
+                    val sleepMs =
+                      randomInt(minAllocationDelayMs, maxAllocationDelayMs)
+                    Thread.sleep(sleepMs)
+                }
+              }
+            }
+        }.toList
+        val futures = pool.invokeAll(tasks.asJava)
+        pool.shutdown()
+        pool.awaitTermination(60, TimeUnit.SECONDS)
+        futures.forEach(_.get())
+        val totalBytes = total.get()
+        val remaining = listeners.foldLeft(totalBytes) {
+          case (remainingBytes, listener) =>
+            assert(remainingBytes >= 0)
+            val unreserved = listener.unreserve(remainingBytes)
+            remainingBytes - unreserved
+        }
+        assert(remaining == 0)
+        assert(listeners.map(_.getUsedBytes).sum == 0)
+      }
+    }
+  }
+
+  private def randomInt(from: Int, to: Int): Int = {
+    from + Random.nextInt(to - from + 1)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to