This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new d851b2e4b [#1571] fix(server): Memory may leak when 
`EventInvalidException` occurs (#1574)
d851b2e4b is described below

commit d851b2e4b903cbb63d7def8e63cb3795db38107c
Author: leslizhang <[email protected]>
AuthorDate: Mon Mar 18 15:26:37 2024 +0800

    [#1571] fix(server): Memory may leak when `EventInvalidException` occurs 
(#1574)
    
    ### What changes were proposed in this pull request?
    
    In the implementation of the methods `flushBuffer`, 
`handleEventAndUpdateMetrics`, and `removeBufferByShuffleId`, read-write locks 
have been added to manage concurrency. This ensures that a `ShuffleBuffer` 
successfully converted into a `flushEvent` won't be cleaned up again by 
`removeBufferByShuffleId`, and a `ShuffleBuffer` already cleaned up by 
`removeBufferByShuffleId` won't be transformed back into a `flushEvent`. This 
effectively resolves the concurrency issue.
    
    ### Why are the changes needed?
    
    Fix https://github.com/apache/incubator-uniffle/issues/1571 & 
https://github.com/apache/incubator-uniffle/issues/1560 & 
https://github.com/apache/incubator-uniffle/issues/1542
    
    The key logic of the PR is as follows:
    
    Before this PR:
    1. A `ShuffleBuffer` is turned into a `FlushEvent`, and **_its blocks and 
size are cleared_**
    →
    2. The `FlushEvent` is added to the flushing queue
    →
    3. The method `removeBufferByShuffleId` is executed, which causes the 
following things to happen:
    
    3.1. Running the following code snippet, but please note that in the code 
below, `buffer.getBlocks()` **_is empty and size is 0_**, because of the step 1 
above:
    ```
    for (ShuffleBuffer buffer : buffers) {
      buffer.getBlocks().forEach(spb -> spb.getData().release());
      ShuffleServerMetrics.gaugeTotalPartitionNum.dec();
      size += buffer.getSize();
    }
    ```
    
    3.2. `appId` is removed from the `bufferPool`
    →
    4. The `FlushEvent` is taken out from the queue and encounters an 
`EventInvalidException` because the `appId` was removed before
    →
    5. When handling the `EventInvalidException`, nothing is done and the 
`event.doCleanup()` method **_is not called, causing a memory leak_**.
    Of course, this is just one scenario of concurrency exceptions. In the 
previous code, without locking, in the `processFlushEvent` method, it is 
possible that the event may become invalid at any time when continuing 
executing in `processFlushEvent` method, which is why there is 
https://github.com/apache/incubator-uniffle/issues/1542. Also, there is 
https://github.com/apache/incubator-uniffle/issues/1560.
    
    ---
    
    After this PR:
    We will set a read lock for steps 1 and 2 above, a write lock for step 3, a 
read lock for step 4, and when encountering an `EventInvalidException` in step 
5, we will call the `event.doCleanup()` method to release the memory.
    
    In this way, we can ensure the following things when resources are being 
cleaned up:
    1. `ShuffleBuffers` that have not yet been converted to `FlushEvents` will 
not be converted in the future, but will be directly cleaned up.
    2. `FlushEvents` that have been converted from `ShuffleBuffers` will 
definitely encounter an `EventInvalidException`, and we will eventually handle 
this exception correctly, releasing memory.
    3. If there is already a `FlushEvent` being processed and it is about to be 
flushed to disk, the resource cleanup task will wait for all `FlushEvents` 
related to the `appId` to be completed before starting the cleanup task, 
ensuring that the cleanup and flushing tasks are completely independent and do 
not interfere with each other.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing UTs.
    ---------
    
    Co-authored-by: leslizhang <[email protected]>
---
 .../uniffle/server/DefaultFlushEventHandler.java   | 18 ++++-
 .../apache/uniffle/server/ShuffleFlushManager.java |  3 +-
 .../apache/uniffle/server/ShuffleTaskManager.java  | 80 +++++++++++++---------
 .../server/buffer/ShuffleBufferManager.java        | 47 ++++++++-----
 .../ShuffleFlushManagerOnKerberizedHadoopTest.java |  7 ++
 .../uniffle/server/ShuffleFlushManagerTest.java    | 39 +++++++++++
 .../server/buffer/ShuffleBufferManagerTest.java    | 47 +++++++++++--
 7 files changed, 183 insertions(+), 58 deletions(-)

diff --git 
a/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java 
b/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java
index c5b320091..2ff85bab2 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java
@@ -21,6 +21,7 @@ import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.Executor;
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Queues;
@@ -50,17 +51,20 @@ public class DefaultFlushEventHandler implements 
FlushEventHandler {
   private final StorageType storageType;
   protected final BlockingQueue<ShuffleDataFlushEvent> flushQueue = 
Queues.newLinkedBlockingQueue();
   private ConsumerWithException<ShuffleDataFlushEvent> eventConsumer;
+  private final ShuffleServer shuffleServer;
 
   private volatile boolean stopped = false;
 
   public DefaultFlushEventHandler(
       ShuffleServerConf conf,
       StorageManager storageManager,
+      ShuffleServer shuffleServer,
       ConsumerWithException<ShuffleDataFlushEvent> eventConsumer) {
     this.shuffleServerConf = conf;
     this.storageType =
         
StorageType.valueOf(shuffleServerConf.get(RssBaseConf.RSS_STORAGE_TYPE).name());
     this.storageManager = storageManager;
+    this.shuffleServer = shuffleServer;
     this.eventConsumer = eventConsumer;
     initFlushEventExecutor();
   }
@@ -83,8 +87,17 @@ public class DefaultFlushEventHandler implements 
FlushEventHandler {
    */
   private void handleEventAndUpdateMetrics(ShuffleDataFlushEvent event, 
Storage storage) {
     long start = System.currentTimeMillis();
+    String appId = event.getAppId();
+    ReentrantReadWriteLock.ReadLock readLock =
+        shuffleServer.getShuffleTaskManager().getAppReadLock(appId);
     try {
-      eventConsumer.accept(event);
+      readLock.lock();
+      try {
+        eventConsumer.accept(event);
+      } finally {
+        readLock.unlock();
+      }
+
       if (storage != null) {
         
ShuffleServerMetrics.incStorageSuccessCounter(storage.getStorageHost());
       }
@@ -124,8 +137,7 @@ public class DefaultFlushEventHandler implements 
FlushEventHandler {
       }
 
       if (e instanceof EventInvalidException) {
-        // Invalid events have already been released / cleaned up
-        // so no need to call event.doCleanup() here
+        event.doCleanup();
         return;
       }
 
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java
index 41cb26b00..15ea147d2 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java
@@ -81,7 +81,8 @@ public class ShuffleFlushManager {
     storageBasePaths = RssUtils.getConfiguredLocalDirs(shuffleServerConf);
     pendingEventTimeoutSec = 
shuffleServerConf.getLong(ShuffleServerConf.PENDING_EVENT_TIMEOUT_SEC);
     eventHandler =
-        new DefaultFlushEventHandler(shuffleServerConf, storageManager, 
this::processFlushEvent);
+        new DefaultFlushEventHandler(
+            shuffleServerConf, storageManager, shuffleServer, 
this::processFlushEvent);
   }
 
   public void addToFlushQueue(ShuffleDataFlushEvent event) {
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index f9ed125b0..b26167a2f 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -32,7 +32,7 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.locks.Lock;
-import java.util.concurrent.locks.ReentrantLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.cache.Cache;
@@ -112,7 +112,7 @@ public class ShuffleTaskManager {
   private Map<Long, PreAllocatedBufferInfo> requireBufferIds = 
JavaUtils.newConcurrentMap();
   private Thread clearResourceThread;
   private BlockingQueue<PurgeEvent> expiredAppIdQueue = 
Queues.newLinkedBlockingQueue();
-  private final Cache<String, Lock> appLocks;
+  private final Cache<String, ReentrantReadWriteLock> appLocks;
 
   public ShuffleTaskManager(
       ShuffleServerConf conf,
@@ -222,9 +222,18 @@ public class ShuffleTaskManager {
     topNShuffleDataSizeOfAppCalcTask.start();
   }
 
-  private Lock getAppLock(String appId) {
+  public ReentrantReadWriteLock.WriteLock getAppWriteLock(String appId) {
     try {
-      return appLocks.get(appId, ReentrantLock::new);
+      return appLocks.get(appId, ReentrantReadWriteLock::new).writeLock();
+    } catch (ExecutionException e) {
+      LOG.error("Failed to get App lock.", e);
+      throw new RssException(e);
+    }
+  }
+
+  public ReentrantReadWriteLock.ReadLock getAppReadLock(String appId) {
+    try {
+      return appLocks.get(appId, ReentrantReadWriteLock::new).readLock();
     } catch (ExecutionException e) {
       LOG.error("Failed to get App lock.", e);
       throw new RssException(e);
@@ -257,7 +266,7 @@ public class ShuffleTaskManager {
       String user,
       ShuffleDataDistributionType dataDistType,
       int maxConcurrencyPerPartitionToWrite) {
-    Lock lock = getAppLock(appId);
+    ReentrantReadWriteLock.WriteLock lock = getAppWriteLock(appId);
     try {
       lock.lock();
       refreshAppId(appId);
@@ -692,35 +701,42 @@ public class ShuffleTaskManager {
    * @param shuffleIds
    */
   public void removeResourcesByShuffleIds(String appId, List<Integer> 
shuffleIds) {
-    if (CollectionUtils.isEmpty(shuffleIds)) {
-      return;
-    }
+    Lock writeLock = getAppWriteLock(appId);
+    writeLock.lock();
+    try {
+      if (CollectionUtils.isEmpty(shuffleIds)) {
+        return;
+      }
 
-    LOG.info("Start remove resource for appId[{}], shuffleIds[{}]", appId, 
shuffleIds);
-    final long start = System.currentTimeMillis();
-    final ShuffleTaskInfo taskInfo = shuffleTaskInfos.get(appId);
-    if (taskInfo != null) {
-      for (Integer shuffleId : shuffleIds) {
-        taskInfo.getCachedBlockIds().remove(shuffleId);
-        taskInfo.getCommitCounts().remove(shuffleId);
-        taskInfo.getCommitLocks().remove(shuffleId);
+      LOG.info("Start remove resource for appId[{}], shuffleIds[{}]", appId, 
shuffleIds);
+      final long start = System.currentTimeMillis();
+      final ShuffleTaskInfo taskInfo = shuffleTaskInfos.get(appId);
+      if (taskInfo != null) {
+        for (Integer shuffleId : shuffleIds) {
+          taskInfo.getCachedBlockIds().remove(shuffleId);
+          taskInfo.getCommitCounts().remove(shuffleId);
+          taskInfo.getCommitLocks().remove(shuffleId);
+        }
       }
+      Optional.ofNullable(partitionsToBlockIds.get(appId))
+          .ifPresent(
+              x -> {
+                for (Integer shuffleId : shuffleIds) {
+                  x.remove(shuffleId);
+                }
+              });
+      shuffleBufferManager.removeBufferByShuffleId(appId, shuffleIds);
+      shuffleFlushManager.removeResourcesOfShuffleId(appId, shuffleIds);
+      storageManager.removeResources(
+          new ShufflePurgeEvent(appId, getUserByAppId(appId), shuffleIds));
+      LOG.info(
+          "Finish remove resource for appId[{}], shuffleIds[{}], cost[{}]",
+          appId,
+          shuffleIds,
+          System.currentTimeMillis() - start);
+    } finally {
+      writeLock.unlock();
     }
-    Optional.ofNullable(partitionsToBlockIds.get(appId))
-        .ifPresent(
-            x -> {
-              for (Integer shuffleId : shuffleIds) {
-                x.remove(shuffleId);
-              }
-            });
-    shuffleBufferManager.removeBufferByShuffleId(appId, shuffleIds);
-    shuffleFlushManager.removeResourcesOfShuffleId(appId, shuffleIds);
-    storageManager.removeResources(new ShufflePurgeEvent(appId, 
getUserByAppId(appId), shuffleIds));
-    LOG.info(
-        "Finish remove resource for appId[{}], shuffleIds[{}], cost[{}]",
-        appId,
-        shuffleIds,
-        System.currentTimeMillis() - start);
   }
 
   public void checkLeakShuffleData() {
@@ -736,7 +752,7 @@ public class ShuffleTaskManager {
 
   @VisibleForTesting
   public void removeResources(String appId, boolean checkAppExpired) {
-    Lock lock = getAppLock(appId);
+    Lock lock = getAppWriteLock(appId);
     try {
       lock.lock();
       LOG.info("Start remove resource for appId[" + appId + "]");
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
index b03aec552..ceca59211 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
@@ -18,12 +18,14 @@
 package org.apache.uniffle.server.buffer;
 
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
@@ -291,23 +293,36 @@ public class ShuffleBufferManager {
       int startPartition,
       int endPartition,
       boolean isHugePartition) {
-    ShuffleDataFlushEvent event =
-        buffer.toFlushEvent(
-            appId,
-            shuffleId,
-            startPartition,
-            endPartition,
-            () -> bufferPool.containsKey(appId),
-            shuffleFlushManager.getDataDistributionType(appId));
-    if (event != null) {
-      event.addCleanupCallback(() -> releaseMemory(event.getSize(), true, 
false));
-      updateShuffleSize(appId, shuffleId, -event.getSize());
-      inFlushSize.addAndGet(event.getSize());
-      if (isHugePartition) {
-        event.markOwnedByHugePartition();
+    ReentrantReadWriteLock.ReadLock readLock = 
shuffleTaskManager.getAppReadLock(appId);
+    readLock.lock();
+    if (!bufferPool.getOrDefault(appId, new 
HashMap<>()).containsKey(shuffleId)) {
+      LOG.info(
+          "Shuffle[{}] for app[{}] has already been removed, no need to flush 
the buffer",
+          shuffleId,
+          appId);
+      return;
+    }
+    try {
+      ShuffleDataFlushEvent event =
+          buffer.toFlushEvent(
+              appId,
+              shuffleId,
+              startPartition,
+              endPartition,
+              () -> bufferPool.getOrDefault(appId, new 
HashMap<>()).containsKey(shuffleId),
+              shuffleFlushManager.getDataDistributionType(appId));
+      if (event != null) {
+        event.addCleanupCallback(() -> releaseMemory(event.getSize(), true, 
false));
+        updateShuffleSize(appId, shuffleId, -event.getSize());
+        inFlushSize.addAndGet(event.getSize());
+        if (isHugePartition) {
+          event.markOwnedByHugePartition();
+        }
+        ShuffleServerMetrics.gaugeInFlushBufferSize.set(inFlushSize.get());
+        shuffleFlushManager.addToFlushQueue(event);
       }
-      ShuffleServerMetrics.gaugeInFlushBufferSize.set(inFlushSize.get());
-      shuffleFlushManager.addToFlushQueue(event);
+    } finally {
+      readLock.unlock();
     }
   }
 
diff --git 
a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java
 
b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java
index ab384142d..76d06cb2c 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java
@@ -22,6 +22,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.Path;
@@ -104,6 +105,12 @@ public class ShuffleFlushManagerOnKerberizedHadoopTest 
extends KerberizedHadoopB
 
     
when(mockShuffleServer.getShuffleTaskManager().getUserByAppId(appId1)).thenReturn("alex");
     
when(mockShuffleServer.getShuffleTaskManager().getUserByAppId(appId2)).thenReturn("alex");
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId1))
+        .thenReturn(rsLock.readLock());
+    ReentrantReadWriteLock rsLock2 = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId2))
+        .thenReturn(rsLock2.readLock());
 
     StorageManager storageManager =
         
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
diff --git 
a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java 
b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
index c17087073..274c2cbbb 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
@@ -26,6 +26,7 @@ import java.util.Random;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 import java.util.function.Supplier;
 import java.util.stream.IntStream;
 
@@ -139,6 +140,9 @@ public class ShuffleFlushManagerTest extends HadoopTestBase 
{
         ShuffleServerConf.SERVER_MAX_CONCURRENCY_OF_ONE_PARTITION, 
maxConcurrency);
 
     String appId = "concurrentWrite2HdfsWriteOneByOne_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     StorageManager storageManager =
         
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     storageManager.registerRemoteStorage(appId, remoteStorage);
@@ -171,6 +175,9 @@ public class ShuffleFlushManagerTest extends HadoopTestBase 
{
         ShuffleServerConf.SERVER_MAX_CONCURRENCY_OF_ONE_PARTITION, 
maxConcurrency);
 
     String appId = "concurrentWrite2HdfsWriteOfSinglePartition_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     StorageManager storageManager =
         
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     storageManager.registerRemoteStorage(appId, remoteStorage);
@@ -198,6 +205,9 @@ public class ShuffleFlushManagerTest extends HadoopTestBase 
{
   @Test
   public void writeTest() throws Exception {
     String appId = "writeTest_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     StorageManager storageManager =
         
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     storageManager.registerRemoteStorage(appId, remoteStorage);
@@ -263,6 +273,8 @@ public class ShuffleFlushManagerTest extends HadoopTestBase 
{
     // test case for process event whose related app was cleared already
     assertEquals(0, ShuffleServerMetrics.gaugeWriteHandler.get(), 0.5);
     ShuffleDataFlushEvent fakeEvent = createShuffleDataFlushEvent("fakeAppId", 
1, 1, 1, null);
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock("fakeAppId"))
+        .thenReturn(rsLock.readLock());
     manager.addToFlushQueue(fakeEvent);
     waitForQueueClear(manager);
     waitForMetrics(ShuffleServerMetrics.gaugeWriteHandler, 0, 0.5);
@@ -276,6 +288,9 @@ public class ShuffleFlushManagerTest extends HadoopTestBase 
{
         ShuffleServerConf.RSS_STORAGE_TYPE.key(), 
StorageType.MEMORY_LOCALFILE.name());
 
     String appId = "localMetricsTest_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     StorageManager storageManager =
         
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     ShuffleFlushManager manager =
@@ -306,6 +321,9 @@ public class ShuffleFlushManagerTest extends HadoopTestBase 
{
         ShuffleServerConf.RSS_STORAGE_TYPE.key(), 
StorageType.LOCALFILE.name());
 
     String appId = "localMetricsTest_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     StorageManager storageManager =
         
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     ShuffleFlushManager manager =
@@ -355,6 +373,9 @@ public class ShuffleFlushManagerTest extends HadoopTestBase 
{
     StorageManager storageManager =
         
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     String appId = "complexWriteTest_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     storageManager.registerRemoteStorage(appId, remoteStorage);
     List<ShufflePartitionedBlock> expectedBlocks = Lists.newArrayList();
     List<ShuffleDataFlushEvent> flushEvents1 = Lists.newArrayList();
@@ -399,6 +420,12 @@ public class ShuffleFlushManagerTest extends 
HadoopTestBase {
         
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     String appId1 = "complexWriteTest_appId1";
     String appId2 = "complexWriteTest_appId2";
+    ReentrantReadWriteLock rsLock1 = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId1))
+        .thenReturn(rsLock1.readLock());
+    ReentrantReadWriteLock rsLock2 = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId2))
+        .thenReturn(rsLock2.readLock());
     storageManager.registerRemoteStorage(appId1, remoteStorage);
     storageManager.registerRemoteStorage(appId2, remoteStorage);
     ShuffleFlushManager manager =
@@ -456,6 +483,12 @@ public class ShuffleFlushManagerTest extends 
HadoopTestBase {
   public void clearLocalTest(@TempDir File tempDir) throws Exception {
     final String appId1 = "clearLocalTest_appId1";
     final String appId2 = "clearLocalTest_appId12";
+    ReentrantReadWriteLock rsLock1 = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId1))
+        .thenReturn(rsLock1.readLock());
+    ReentrantReadWriteLock rsLock2 = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId2))
+        .thenReturn(rsLock2.readLock());
     ShuffleServerConf serverConf = new ShuffleServerConf();
     serverConf.set(
         ShuffleServerConf.RSS_STORAGE_BASE_PATH, 
Arrays.asList(tempDir.getAbsolutePath()));
@@ -690,6 +723,9 @@ public class ShuffleFlushManagerTest extends HadoopTestBase 
{
     StorageManager storageManager =
         
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     String appId = "fallbackWrittenWhenMultiStorageManagerEnableTest";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     storageManager.registerRemoteStorage(appId, new 
RemoteStorageInfo(remoteStorage.getPath()));
 
     ShuffleFlushManager flushManager =
@@ -740,6 +776,9 @@ public class ShuffleFlushManagerTest extends HadoopTestBase 
{
     StorageManager storageManager =
         
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     String appId = "fallbackWrittenWhenMultiStorageManagerEnableTest";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     storageManager.registerRemoteStorage(appId, new 
RemoteStorageInfo(remoteStorage.getPath()));
 
     ShuffleFlushManager flushManager =
diff --git 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
index 94e5d601b..08428cc74 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
@@ -24,6 +24,7 @@ import java.util.Map;
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import com.google.common.collect.RangeMap;
 import com.google.common.util.concurrent.Uninterruptibles;
@@ -85,6 +86,7 @@ public class ShuffleBufferManagerTest extends BufferTestBase {
     mockShuffleTaskManager = mock(ShuffleTaskManager.class);
     
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     shuffleBufferManager = new ShuffleBufferManager(conf, 
mockShuffleFlushManager, false);
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
   }
 
   @Test
@@ -115,6 +117,10 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
   @Test
   public void getShuffleDataWithExpectedTaskIdsTest() {
     String appId = "getShuffleDataWithExpectedTaskIdsTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    
when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     shuffleBufferManager.registerBuffer(appId, 1, 0, 1);
     ShufflePartitionedData spd1 = createData(0, 1, 68);
     ShufflePartitionedData spd2 = createData(0, 2, 68);
@@ -146,6 +152,10 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
   @Test
   public void getShuffleDataTest() {
     String appId = "getShuffleDataTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    
when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     shuffleBufferManager.registerBuffer(appId, 1, 0, 1);
     shuffleBufferManager.registerBuffer(appId, 2, 0, 1);
     shuffleBufferManager.registerBuffer(appId, 3, 0, 1);
@@ -209,6 +219,10 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
   public void shuffleIdToSizeTest() {
     String appId1 = "shuffleIdToSizeTest1";
     String appId2 = "shuffleIdToSizeTest2";
+    ReentrantReadWriteLock rwLock1 = new ReentrantReadWriteLock();
+    
when(mockShuffleTaskManager.getAppReadLock(appId1)).thenReturn(rwLock1.readLock());
+    ReentrantReadWriteLock rwLock2 = new ReentrantReadWriteLock();
+    
when(mockShuffleTaskManager.getAppReadLock(appId2)).thenReturn(rwLock2.readLock());
     shuffleBufferManager.registerBuffer(appId1, 1, 0, 0);
     shuffleBufferManager.registerBuffer(appId1, 2, 0, 0);
     shuffleBufferManager.registerBuffer(appId2, 1, 0, 0);
@@ -254,9 +268,12 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
   @Test
   public void cacheShuffleDataTest() {
     String appId = "cacheShuffleDataTest";
-    int shuffleId = 1;
-
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    
when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     int startPartitionNum = (int) 
ShuffleServerMetrics.gaugeTotalPartitionNum.get();
+    int shuffleId = 1;
     StatusCode sc =
         shuffleBufferManager.cacheShuffleData(appId, shuffleId, false, 
createData(0, 16));
     assertEquals(StatusCode.NO_REGISTER, sc);
@@ -322,8 +339,11 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
   @Test
   public void cacheShuffleDataWithPreAllocationTest() {
     String appId = "cacheShuffleDataWithPreAllocationTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    
when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     int shuffleId = 1;
-
     shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
     // pre allocate memory
     shuffleBufferManager.requireMemory(48, true);
@@ -393,8 +413,12 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
     
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mock(ShuffleTaskManager.class));
 
     String appId = "bufferSizeTest";
+    
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    
when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
     int shuffleId = 1;
-
+    
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 4, 5);
@@ -523,8 +547,11 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
     
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mock(ShuffleTaskManager.class));
 
     String appId = "bufferSizeTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    
when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     int shuffleId = 1;
-
     shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);
     shuffleBufferManager.cacheShuffleData(appId, shuffleId, false, 
createData(0, 64));
@@ -555,10 +582,14 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
     shuffleBufferManager = new ShuffleBufferManager(serverConf, 
shuffleFlushManager, false);
 
     String appId = "shuffleFlushTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    
when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
+
     int shuffleId = 0;
     int smallShuffleId = 1;
     int smallShuffleIdTwo = 2;
-
     shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);
     shuffleBufferManager.registerBuffer(appId, smallShuffleId, 0, 1);
@@ -676,6 +707,10 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
     
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mock(ShuffleTaskManager.class));
 
     String appId = "bufferSizeTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    
when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     int shuffleId = 1;
     shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);

Reply via email to