This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new b324cc33c [#2716] feat(spark): Introduce option of max segments 
decompression to control memory usage (#2735)
b324cc33c is described below

commit b324cc33c1457cc55eda981c577d2d5888177ed2
Author: Junfan Zhang <[email protected]>
AuthorDate: Mon Mar 2 14:52:14 2026 +0800

    [#2716] feat(spark): Introduce option of max segments decompression to 
control memory usage (#2735)
    
    ### What changes were proposed in this pull request?
    
    Introduce option of max segments decompression to control memory usage
    
    ### Why are the changes needed?
    
    To address issue #2716, this PR introduces an option to set the maximum 
number of concurrent decompression segments, allowing better control over 
overall memory usage. Setting this value to 1 restricts decompression to a 
single segment at a time.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes
    
    ### How was this patch tested?
    
    Unit test
---
 .../uniffle/client/impl/DecompressionWorker.java   | 103 +++++++++++++++------
 .../uniffle/client/impl/ShuffleReadClientImpl.java |   8 +-
 .../client/impl/DecompressionWorkerTest.java       |  38 +++++++-
 .../uniffle/common/config/RssClientConf.java       |   7 ++
 4 files changed, 124 insertions(+), 32 deletions(-)

diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java 
b/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java
index 120844718..66ff8d9cd 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java
@@ -19,12 +19,15 @@ package org.apache.uniffle.client.impl;
 
 import java.nio.ByteBuffer;
 import java.util.List;
+import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.concurrent.Semaphore;
 import java.util.concurrent.atomic.AtomicLong;
 
+import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -55,7 +58,10 @@ public class DecompressionWorker {
   private AtomicLong peekMemoryUsed = new AtomicLong(0);
   private AtomicLong nowMemoryUsed = new AtomicLong(0);
 
-  public DecompressionWorker(Codec codec, int threads, int 
fetchSecondsThreshold) {
+  private final Optional<Semaphore> segmentPermits;
+
+  public DecompressionWorker(
+      Codec codec, int threads, int fetchSecondsThreshold, int 
maxConcurrentDecompressionSegments) {
     if (codec == null) {
       throw new IllegalArgumentException("Codec cannot be null");
     }
@@ -67,6 +73,16 @@ public class DecompressionWorker {
         Executors.newFixedThreadPool(threads, 
ThreadUtils.getThreadFactory("decompressionWorker"));
     this.codec = codec;
     this.fetchSecondsThreshold = fetchSecondsThreshold;
+
+    if (maxConcurrentDecompressionSegments <= 0) {
+      this.segmentPermits = Optional.empty();
+    } else if (threads != 1) {
+      LOG.info(
+          "Disable backpressure control since threads is {} to avoid potential 
deadlock", threads);
+      this.segmentPermits = Optional.empty();
+    } else {
+      this.segmentPermits = Optional.of(new 
Semaphore(maxConcurrentDecompressionSegments));
+    }
   }
 
   public void add(int batchIndex, ShuffleDataResult shuffleDataResult) {
@@ -80,34 +96,49 @@ public class DecompressionWorker {
     for (BufferSegment bufferSegment : bufferSegments) {
       CompletableFuture<ByteBuffer> f =
           CompletableFuture.supplyAsync(
-              () -> {
-                int offset = bufferSegment.getOffset();
-                int length = bufferSegment.getLength();
-                ByteBuffer buffer = sharedByteBuffer.duplicate();
-                buffer.position(offset);
-                buffer.limit(offset + length);
-
-                int uncompressedLen = bufferSegment.getUncompressLength();
-
-                long startBufferAllocation = System.currentTimeMillis();
-                ByteBuffer dst =
-                    buffer.isDirect()
-                        ? ByteBuffer.allocateDirect(uncompressedLen)
-                        : ByteBuffer.allocate(uncompressedLen);
-                decompressionBufferAllocationMillis.addAndGet(
-                    System.currentTimeMillis() - startBufferAllocation);
-
-                long startDecompression = System.currentTimeMillis();
-                codec.decompress(buffer, uncompressedLen, dst, 0);
-                decompressionMillis.addAndGet(System.currentTimeMillis() - 
startDecompression);
-                decompressionBytes.addAndGet(length);
-
-                nowMemoryUsed.addAndGet(uncompressedLen);
-                resetPeekMemoryUsed();
-
-                return dst;
-              },
-              executorService);
+                  () -> {
+                    try {
+                      if (segmentPermits.isPresent()) {
+                        segmentPermits.get().acquire();
+                      }
+                    } catch (InterruptedException e) {
+                      Thread.currentThread().interrupt();
+                      LOG.warn("Interrupted while acquiring segment permit", 
e);
+                      return null;
+                    }
+
+                    int offset = bufferSegment.getOffset();
+                    int length = bufferSegment.getLength();
+                    ByteBuffer buffer = sharedByteBuffer.duplicate();
+                    buffer.position(offset);
+                    buffer.limit(offset + length);
+
+                    int uncompressedLen = bufferSegment.getUncompressLength();
+
+                    long startBufferAllocation = System.currentTimeMillis();
+                    ByteBuffer dst =
+                        buffer.isDirect()
+                            ? ByteBuffer.allocateDirect(uncompressedLen)
+                            : ByteBuffer.allocate(uncompressedLen);
+                    decompressionBufferAllocationMillis.addAndGet(
+                        System.currentTimeMillis() - startBufferAllocation);
+
+                    long startDecompression = System.currentTimeMillis();
+                    codec.decompress(buffer, uncompressedLen, dst, 0);
+                    decompressionMillis.addAndGet(System.currentTimeMillis() - 
startDecompression);
+                    decompressionBytes.addAndGet(length);
+
+                    nowMemoryUsed.addAndGet(uncompressedLen);
+                    resetPeekMemoryUsed();
+
+                    return dst;
+                  },
+                  executorService)
+              .exceptionally(
+                  ex -> {
+                    LOG.error("Errors on decompressing shuffle block", ex);
+                    return null;
+                  });
       ConcurrentHashMap<Integer, DecompressedShuffleBlock> blocks =
           tasks.computeIfAbsent(batchIndex, k -> new ConcurrentHashMap<>());
       blocks.put(
@@ -132,6 +163,7 @@ public class DecompressionWorker {
     // block
     if (block != null) {
       nowMemoryUsed.addAndGet(-block.getUncompressLength());
+      segmentPermits.ifPresent(x -> x.release());
     }
     return block;
   }
@@ -163,4 +195,17 @@ public class DecompressionWorker {
   public long decompressionMillis() {
     return decompressionMillis.get() + 
decompressionBufferAllocationMillis.get();
   }
+
+  @VisibleForTesting
+  protected long getPeekMemoryUsed() {
+    return peekMemoryUsed.get();
+  }
+
+  @VisibleForTesting
+  protected int getAvailablePermits() {
+    if (segmentPermits.isPresent()) {
+      return segmentPermits.get().availablePermits();
+    }
+    return -1;
+  }
 }
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index 364e98526..7cc386ad7 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -62,6 +62,7 @@ import 
org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest;
 import static 
org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_COUNT;
 import static 
org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_ENABLED;
 import static 
org.apache.uniffle.common.config.RssClientConf.RSS_READ_OVERLAPPING_DECOMPRESSION_FETCH_SECONDS_THRESHOLD;
+import static 
org.apache.uniffle.common.config.RssClientConf.RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS;
 
 public class ShuffleReadClientImpl implements ShuffleReadClient {
 
@@ -165,9 +166,14 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
     if (builder.isOverlappingDecompressionEnabled()) {
       int fetchThreshold =
           
builder.getRssConf().get(RSS_READ_OVERLAPPING_DECOMPRESSION_FETCH_SECONDS_THRESHOLD);
+      int maxSegments =
+          
builder.getRssConf().get(RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS);
       this.decompressionWorker =
           new DecompressionWorker(
-              builder.getCodec(), 
builder.getOverlappingDecompressionThreadNum(), fetchThreshold);
+              builder.getCodec(),
+              builder.getOverlappingDecompressionThreadNum(),
+              fetchThreshold,
+              maxSegments);
     }
     this.shuffleId = builder.getShuffleId();
     this.partitionId = builder.getPartitionId();
diff --git 
a/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java
 
b/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java
index 5d41ac481..92b33b218 100644
--- 
a/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java
+++ 
b/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java
@@ -21,7 +21,9 @@ import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Random;
+import java.util.concurrent.TimeUnit;
 
+import org.awaitility.Awaitility;
 import org.junit.jupiter.api.Test;
 
 import org.apache.uniffle.client.response.DecompressedShuffleBlock;
@@ -36,10 +38,42 @@ import static org.junit.jupiter.api.Assertions.assertNull;
 
 public class DecompressionWorkerTest {
 
+  @Test
+  public void testBackpressure() throws Exception {
+    RssConf rssConf = new RssConf();
+    rssConf.set(COMPRESSION_TYPE, Codec.Type.NOOP);
+    Codec codec = Codec.newInstance(rssConf).get();
+
+    int threads = 1;
+    int maxSegments = 10;
+    int fetchSecondsThreshold = 2;
+    DecompressionWorker worker =
+        new DecompressionWorker(codec, threads, fetchSecondsThreshold, 
maxSegments);
+
+    ShuffleDataResult shuffleDataResult = createShuffleDataResult(maxSegments 
+ 1, codec, 1024);
+    worker.add(0, shuffleDataResult);
+
+    // case1: check the peek memory used is correct when the decompression is 
in progress
+    Awaitility.await()
+        .timeout(200, TimeUnit.MILLISECONDS)
+        .until(() -> 1024 * maxSegments == worker.getPeekMemoryUsed());
+    assertEquals(0, worker.getAvailablePermits());
+
+    // case2: after the previous segments are consumed, the blocked segments 
can be gotten after the
+    // decompression is done
+    for (int i = 0; i < maxSegments; i++) {
+      worker.get(0, i);
+    }
+    Thread.sleep(10);
+    worker.get(0, maxSegments).getByteBuffer();
+    assertEquals(1024 * maxSegments, worker.getPeekMemoryUsed());
+    assertEquals(maxSegments, worker.getAvailablePermits());
+  }
+
   @Test
   public void testEmptyGet() throws Exception {
     DecompressionWorker worker =
-        new DecompressionWorker(Codec.newInstance(new RssConf()).get(), 1, 10);
+        new DecompressionWorker(Codec.newInstance(new RssConf()).get(), 1, 10, 
10000);
     assertNull(worker.get(1, 1));
   }
 
@@ -78,7 +112,7 @@ public class DecompressionWorkerTest {
     RssConf rssConf = new RssConf();
     rssConf.set(COMPRESSION_TYPE, Codec.Type.NOOP);
     Codec codec = Codec.newInstance(rssConf).get();
-    DecompressionWorker worker = new DecompressionWorker(codec, 1, 10);
+    DecompressionWorker worker = new DecompressionWorker(codec, 1, 10, 100000);
 
     // create some data
     ShuffleDataResult shuffleDataResult = createShuffleDataResult(10, codec, 
100);
diff --git 
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java 
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
index 0ce7be635..c1672fca8 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -389,4 +389,11 @@ public class RssClientConf {
               .defaultValue(-1)
               .withDescription(
                   "Fetch seconds threshold for overlapping decompress shuffle 
blocks.");
+
+  public static final ConfigOption<Integer>
+      RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS =
+          
ConfigOptions.key("rss.client.read.overlappingDecompressionMaxConcurrentSegments")
+              .intType()
+              .defaultValue(10)
+              .withDescription("Max concurrent segments number for overlapping 
decompression.");
 }

Reply via email to