This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new b324cc33c [#2716] feat(spark): Introduce option of max segments
decompression to control memory usage (#2735)
b324cc33c is described below
commit b324cc33c1457cc55eda981c577d2d5888177ed2
Author: Junfan Zhang <[email protected]>
AuthorDate: Mon Mar 2 14:52:14 2026 +0800
[#2716] feat(spark): Introduce option of max segments decompression to
control memory usage (#2735)
### What changes were proposed in this pull request?
Introduce option of max segments decompression to control memory usage
### Why are the changes needed?
To address issue #2716, this PR introduces an option to set the maximum
number of concurrent decompression segments, allowing better control over
overall memory usage. Setting this value to 1 restricts decompression to a
single segment at a time.
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
Unit test
---
.../uniffle/client/impl/DecompressionWorker.java | 103 +++++++++++++++------
.../uniffle/client/impl/ShuffleReadClientImpl.java | 8 +-
.../client/impl/DecompressionWorkerTest.java | 38 +++++++-
.../uniffle/common/config/RssClientConf.java | 7 ++
4 files changed, 124 insertions(+), 32 deletions(-)
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java
b/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java
index 120844718..66ff8d9cd 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java
@@ -19,12 +19,15 @@ package org.apache.uniffle.client.impl;
import java.nio.ByteBuffer;
import java.util.List;
+import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicLong;
+import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -55,7 +58,10 @@ public class DecompressionWorker {
private AtomicLong peekMemoryUsed = new AtomicLong(0);
private AtomicLong nowMemoryUsed = new AtomicLong(0);
- public DecompressionWorker(Codec codec, int threads, int
fetchSecondsThreshold) {
+ private final Optional<Semaphore> segmentPermits;
+
+ public DecompressionWorker(
+ Codec codec, int threads, int fetchSecondsThreshold, int
maxConcurrentDecompressionSegments) {
if (codec == null) {
throw new IllegalArgumentException("Codec cannot be null");
}
@@ -67,6 +73,16 @@ public class DecompressionWorker {
Executors.newFixedThreadPool(threads,
ThreadUtils.getThreadFactory("decompressionWorker"));
this.codec = codec;
this.fetchSecondsThreshold = fetchSecondsThreshold;
+
+ if (maxConcurrentDecompressionSegments <= 0) {
+ this.segmentPermits = Optional.empty();
+ } else if (threads != 1) {
+ LOG.info(
+ "Disable backpressure control since threads is {} to avoid potential
deadlock", threads);
+ this.segmentPermits = Optional.empty();
+ } else {
+ this.segmentPermits = Optional.of(new
Semaphore(maxConcurrentDecompressionSegments));
+ }
}
public void add(int batchIndex, ShuffleDataResult shuffleDataResult) {
@@ -80,34 +96,49 @@ public class DecompressionWorker {
for (BufferSegment bufferSegment : bufferSegments) {
CompletableFuture<ByteBuffer> f =
CompletableFuture.supplyAsync(
- () -> {
- int offset = bufferSegment.getOffset();
- int length = bufferSegment.getLength();
- ByteBuffer buffer = sharedByteBuffer.duplicate();
- buffer.position(offset);
- buffer.limit(offset + length);
-
- int uncompressedLen = bufferSegment.getUncompressLength();
-
- long startBufferAllocation = System.currentTimeMillis();
- ByteBuffer dst =
- buffer.isDirect()
- ? ByteBuffer.allocateDirect(uncompressedLen)
- : ByteBuffer.allocate(uncompressedLen);
- decompressionBufferAllocationMillis.addAndGet(
- System.currentTimeMillis() - startBufferAllocation);
-
- long startDecompression = System.currentTimeMillis();
- codec.decompress(buffer, uncompressedLen, dst, 0);
- decompressionMillis.addAndGet(System.currentTimeMillis() -
startDecompression);
- decompressionBytes.addAndGet(length);
-
- nowMemoryUsed.addAndGet(uncompressedLen);
- resetPeekMemoryUsed();
-
- return dst;
- },
- executorService);
+ () -> {
+ try {
+ if (segmentPermits.isPresent()) {
+ segmentPermits.get().acquire();
+ }
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ LOG.warn("Interrupted while acquiring segment permit",
e);
+ return null;
+ }
+
+ int offset = bufferSegment.getOffset();
+ int length = bufferSegment.getLength();
+ ByteBuffer buffer = sharedByteBuffer.duplicate();
+ buffer.position(offset);
+ buffer.limit(offset + length);
+
+ int uncompressedLen = bufferSegment.getUncompressLength();
+
+ long startBufferAllocation = System.currentTimeMillis();
+ ByteBuffer dst =
+ buffer.isDirect()
+ ? ByteBuffer.allocateDirect(uncompressedLen)
+ : ByteBuffer.allocate(uncompressedLen);
+ decompressionBufferAllocationMillis.addAndGet(
+ System.currentTimeMillis() - startBufferAllocation);
+
+ long startDecompression = System.currentTimeMillis();
+ codec.decompress(buffer, uncompressedLen, dst, 0);
+ decompressionMillis.addAndGet(System.currentTimeMillis() -
startDecompression);
+ decompressionBytes.addAndGet(length);
+
+ nowMemoryUsed.addAndGet(uncompressedLen);
+ resetPeekMemoryUsed();
+
+ return dst;
+ },
+ executorService)
+ .exceptionally(
+ ex -> {
+ LOG.error("Errors on decompressing shuffle block", ex);
+ return null;
+ });
ConcurrentHashMap<Integer, DecompressedShuffleBlock> blocks =
tasks.computeIfAbsent(batchIndex, k -> new ConcurrentHashMap<>());
blocks.put(
@@ -132,6 +163,7 @@ public class DecompressionWorker {
// block
if (block != null) {
nowMemoryUsed.addAndGet(-block.getUncompressLength());
+ segmentPermits.ifPresent(x -> x.release());
}
return block;
}
@@ -163,4 +195,17 @@ public class DecompressionWorker {
public long decompressionMillis() {
return decompressionMillis.get() +
decompressionBufferAllocationMillis.get();
}
+
+ @VisibleForTesting
+ protected long getPeekMemoryUsed() {
+ return peekMemoryUsed.get();
+ }
+
+ @VisibleForTesting
+ protected int getAvailablePermits() {
+ if (segmentPermits.isPresent()) {
+ return segmentPermits.get().availablePermits();
+ }
+ return -1;
+ }
}
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index 364e98526..7cc386ad7 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -62,6 +62,7 @@ import
org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest;
import static
org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_COUNT;
import static
org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_ENABLED;
import static
org.apache.uniffle.common.config.RssClientConf.RSS_READ_OVERLAPPING_DECOMPRESSION_FETCH_SECONDS_THRESHOLD;
+import static
org.apache.uniffle.common.config.RssClientConf.RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS;
public class ShuffleReadClientImpl implements ShuffleReadClient {
@@ -165,9 +166,14 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
if (builder.isOverlappingDecompressionEnabled()) {
int fetchThreshold =
builder.getRssConf().get(RSS_READ_OVERLAPPING_DECOMPRESSION_FETCH_SECONDS_THRESHOLD);
+ int maxSegments =
+
builder.getRssConf().get(RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS);
this.decompressionWorker =
new DecompressionWorker(
- builder.getCodec(),
builder.getOverlappingDecompressionThreadNum(), fetchThreshold);
+ builder.getCodec(),
+ builder.getOverlappingDecompressionThreadNum(),
+ fetchThreshold,
+ maxSegments);
}
this.shuffleId = builder.getShuffleId();
this.partitionId = builder.getPartitionId();
diff --git
a/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java
b/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java
index 5d41ac481..92b33b218 100644
---
a/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java
+++
b/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java
@@ -21,7 +21,9 @@ import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
+import java.util.concurrent.TimeUnit;
+import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import org.apache.uniffle.client.response.DecompressedShuffleBlock;
@@ -36,10 +38,42 @@ import static org.junit.jupiter.api.Assertions.assertNull;
public class DecompressionWorkerTest {
+ @Test
+ public void testBackpressure() throws Exception {
+ RssConf rssConf = new RssConf();
+ rssConf.set(COMPRESSION_TYPE, Codec.Type.NOOP);
+ Codec codec = Codec.newInstance(rssConf).get();
+
+ int threads = 1;
+ int maxSegments = 10;
+ int fetchSecondsThreshold = 2;
+ DecompressionWorker worker =
+ new DecompressionWorker(codec, threads, fetchSecondsThreshold,
maxSegments);
+
+ ShuffleDataResult shuffleDataResult = createShuffleDataResult(maxSegments
+ 1, codec, 1024);
+ worker.add(0, shuffleDataResult);
+
+ // case1: check the peek memory used is correct when the decompression is
in progress
+ Awaitility.await()
+ .timeout(200, TimeUnit.MILLISECONDS)
+ .until(() -> 1024 * maxSegments == worker.getPeekMemoryUsed());
+ assertEquals(0, worker.getAvailablePermits());
+
+ // case2: after the previous segments are consumed, the blocked segments
can be gotten after the
+ // decompression is done
+ for (int i = 0; i < maxSegments; i++) {
+ worker.get(0, i);
+ }
+ Thread.sleep(10);
+ worker.get(0, maxSegments).getByteBuffer();
+ assertEquals(1024 * maxSegments, worker.getPeekMemoryUsed());
+ assertEquals(maxSegments, worker.getAvailablePermits());
+ }
+
@Test
public void testEmptyGet() throws Exception {
DecompressionWorker worker =
- new DecompressionWorker(Codec.newInstance(new RssConf()).get(), 1, 10);
+ new DecompressionWorker(Codec.newInstance(new RssConf()).get(), 1, 10,
10000);
assertNull(worker.get(1, 1));
}
@@ -78,7 +112,7 @@ public class DecompressionWorkerTest {
RssConf rssConf = new RssConf();
rssConf.set(COMPRESSION_TYPE, Codec.Type.NOOP);
Codec codec = Codec.newInstance(rssConf).get();
- DecompressionWorker worker = new DecompressionWorker(codec, 1, 10);
+ DecompressionWorker worker = new DecompressionWorker(codec, 1, 10, 100000);
// create some data
ShuffleDataResult shuffleDataResult = createShuffleDataResult(10, codec,
100);
diff --git
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
index 0ce7be635..c1672fca8 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -389,4 +389,11 @@ public class RssClientConf {
.defaultValue(-1)
.withDescription(
"Fetch seconds threshold for overlapping decompress shuffle
blocks.");
+
+ public static final ConfigOption<Integer>
+ RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS =
+
ConfigOptions.key("rss.client.read.overlappingDecompressionMaxConcurrentSegments")
+ .intType()
+ .defaultValue(10)
+ .withDescription("Max concurrent segments number for overlapping
decompression.");
}