This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new dc022d684 [#1594] feat(client):support generating larger block size
during shuffle map task by spill partial partitions data (#1670)
dc022d684 is described below
commit dc022d6840ddb84221c35937f174603f2a479ef0
Author: leslizhang <[email protected]>
AuthorDate: Tue May 14 14:52:12 2024 +0800
[#1594] feat(client):support generating larger block size during shuffle
map task by spill partial partitions data (#1670)
### What changes were proposed in this pull request?
when spilling shuffle data, we just spill part of the reduce partition
datas which hold the major space.
so, in each spilling process, the WriteBufferManager.clear() method should
implement one more logic: sort the to-be spilled buffers by their size and
select the top-N buffers to spill.
### Why are the changes needed?
related feature https://github.com/apache/incubator-uniffle/issues/1594
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
new UTs.
---------
Co-authored-by: leslizhang <[email protected]>
---
.../org/apache/spark/shuffle/RssSparkConfig.java | 6 +++
.../spark/shuffle/writer/BufferManagerOptions.java | 9 ++++
.../spark/shuffle/writer/WriteBufferManager.java | 44 ++++++++++++------
.../shuffle/writer/WriteBufferManagerTest.java | 52 +++++++++++++++++++++-
.../spark/shuffle/writer/RssShuffleWriter.java | 2 +-
.../spark/shuffle/writer/RssShuffleWriter.java | 2 +-
6 files changed, 98 insertions(+), 17 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 54a08a524..75b4b998b 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -71,6 +71,12 @@ public class RssSparkConfig {
.withDescription(
"The memory spill switch triggered by Spark TaskMemoryManager,
default value is false.");
+ public static final ConfigOption<Double> RSS_MEMORY_SPILL_RATIO =
+ ConfigOptions.key("rss.client.memory.spill.ratio")
+ .doubleType()
+ .defaultValue(1.0d)
+ .withDescription(
+ "The buffer size to spill when spill triggered by config
spark.rss.writer.buffer.spill.size");
public static final ConfigOption<Integer>
RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM =
ConfigOptions.key("rss.client.reassign.maxReassignServerNum")
.intType()
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java
index f09d55214..7a2bcd74e 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java
@@ -35,6 +35,7 @@ public class BufferManagerOptions {
private long preAllocatedBufferSize;
private long requireMemoryInterval;
private int requireMemoryRetryMax;
+ private double bufferSpillPercent;
public BufferManagerOptions(SparkConf sparkConf) {
bufferSize =
@@ -53,6 +54,10 @@ public class BufferManagerOptions {
sparkConf.getSizeAsBytes(
RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(),
RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.defaultValue().get());
+ bufferSpillPercent =
+ sparkConf.getDouble(
+ RssSparkConfig.RSS_MEMORY_SPILL_RATIO.key(),
+ RssSparkConfig.RSS_MEMORY_SPILL_RATIO.defaultValue());
preAllocatedBufferSize =
sparkConf.getSizeAsBytes(
RssSparkConfig.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.key(),
@@ -119,6 +124,10 @@ public class BufferManagerOptions {
return bufferSpillThreshold;
}
+ public double getBufferSpillPercent() {
+ return bufferSpillPercent;
+ }
+
public long getRequireMemoryInterval() {
return requireMemoryInterval;
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 671742817..c290f965a 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -19,10 +19,9 @@ package org.apache.spark.shuffle.writer;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.Iterator;
+import java.util.Comparator;
import java.util.List;
import java.util.Map;
-import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
@@ -97,6 +96,7 @@ public class WriteBufferManager extends MemoryConsumer {
private int memorySpillTimeoutSec;
private boolean isRowBased;
private BlockIdLayout blockIdLayout;
+ private double bufferSpillRatio;
private Function<Integer, List<ShuffleServerInfo>>
partitionAssignmentRetrieveFunc;
public WriteBufferManager(
@@ -162,6 +162,7 @@ public class WriteBufferManager extends MemoryConsumer {
this.sendSizeLimit =
rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION);
this.memorySpillTimeoutSec =
rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
this.memorySpillEnabled =
rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_ENABLED);
+ this.bufferSpillRatio = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_RATIO);
this.blockIdLayout = BlockIdLayout.from(rssConf);
this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
}
@@ -204,13 +205,12 @@ public class WriteBufferManager extends MemoryConsumer {
// check buffer size > spill threshold
if (usedBytes.get() - inSendListBytes.get() > spillSize) {
LOG.info(
- "ShuffleBufferManager spill for buffer size exceeding spill
threshold,"
- + "usedBytes[{}],inSendListBytes[{}],spillSize[{}]",
+ "ShuffleBufferManager spill for buffer size exceeding spill
threshold, "
+ + "usedBytes[{}], inSendListBytes[{}], spill size threshold[{}]",
usedBytes.get(),
inSendListBytes.get(),
spillSize);
- List<ShuffleBlockInfo> multiSendingBlocks = clear();
-
+ List<ShuffleBlockInfo> multiSendingBlocks = clear(bufferSpillRatio);
multiSendingBlocks.addAll(singleOrEmptySendingBlocks);
writeTime += System.currentTimeMillis() - start;
return multiSendingBlocks;
@@ -323,20 +323,34 @@ public class WriteBufferManager extends MemoryConsumer {
}
// transform all [partition, records] to [partition, ShuffleBlockInfo] and
clear cache
- public synchronized List<ShuffleBlockInfo> clear() {
+ public synchronized List<ShuffleBlockInfo> clear(double bufferSpillRatio) {
List<ShuffleBlockInfo> result = Lists.newArrayList();
long dataSize = 0;
long memoryUsed = 0;
- Iterator<Entry<Integer, WriterBuffer>> iterator =
buffers.entrySet().iterator();
- while (iterator.hasNext()) {
- Entry<Integer, WriterBuffer> entry = iterator.next();
- WriterBuffer wb = entry.getValue();
+ bufferSpillRatio = Math.max(0.1, Math.min(1.0, bufferSpillRatio));
+ List<Integer> partitionList = new ArrayList(buffers.keySet());
+ if (Double.compare(bufferSpillRatio, 1.0) < 0) {
+ partitionList.sort(
+ Comparator.comparingInt(o -> buffers.get(o) == null ? 0 :
buffers.get(o).getMemoryUsed())
+ .reversed());
+ }
+ long targetSpillSize = (long) ((usedBytes.get() - inSendListBytes.get()) *
bufferSpillRatio);
+ for (int partitionId : partitionList) {
+ WriterBuffer wb = buffers.get(partitionId);
+ if (wb == null) {
+ LOG.warn("get partition buffer failed,this should not happen!");
+ continue;
+ }
dataSize += wb.getDataLength();
memoryUsed += wb.getMemoryUsed();
- result.add(createShuffleBlock(entry.getKey(), wb));
+ result.add(createShuffleBlock(partitionId, wb));
recordCounter.addAndGet(wb.getRecordCount());
- iterator.remove();
copyTime += wb.getCopyTime();
+ buffers.remove(partitionId);
+ // got enough buffer to spill
+ if (memoryUsed >= targetSpillSize) {
+ break;
+ }
}
LOG.info(
"Flush total buffer for shuffleId["
@@ -349,6 +363,8 @@ public class WriteBufferManager extends MemoryConsumer {
+ memoryUsed
+ "], number of blocks["
+ result.size()
+ + "], flush ratio["
+ + bufferSpillRatio
+ "]");
return result;
}
@@ -491,7 +507,7 @@ public class WriteBufferManager extends MemoryConsumer {
return 0L;
}
- List<CompletableFuture<Long>> futures = spillFunc.apply(clear());
+ List<CompletableFuture<Long>> futures =
spillFunc.apply(clear(bufferSpillRatio));
CompletableFuture<Void> allOfFutures =
CompletableFuture.allOf(futures.toArray(new
CompletableFuture[futures.size()]));
try {
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 22143bc0e..4734f442c 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -185,7 +185,7 @@ public class WriteBufferManagerTest {
wbm.addRecord(0, testKey, testValue);
wbm.addRecord(1, testKey, testValue);
wbm.addRecord(2, testKey, testValue);
- result = wbm.clear();
+ result = wbm.clear(1.0);
assertEquals(3, result.size());
assertEquals(224, wbm.getAllocatedBytes());
assertEquals(96, wbm.getUsedBytes());
@@ -433,6 +433,56 @@ public class WriteBufferManagerTest {
Awaitility.await().timeout(5, TimeUnit.SECONDS).until(() ->
spyManager.getUsedBytes() == 0);
}
+ @Test
+ public void spillPartial() {
+ SparkConf conf = getConf();
+ conf.set("spark.rss.client.send.size.limit", "1000");
+ conf.set("spark.rss.client.memory.spill.ratio", "0.5");
+ conf.set("spark.rss.client.memory.spill.enabled", "true");
+ TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+ BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+
+ WriteBufferManager wbm =
+ new WriteBufferManager(
+ 0,
+ "taskId_spillPartialTest",
+ 0,
+ bufferOptions,
+ new KryoSerializer(conf),
+ Maps.newHashMap(),
+ mockTaskMemoryManager,
+ new ShuffleWriteMetrics(),
+ RssSparkConfig.toRssConf(conf),
+ null);
+
+ Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc =
+ blocks -> {
+ long sum = 0L;
+ List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
+ for (AddBlockEvent event : events) {
+ event.getProcessedCallbackChain().stream().forEach(x -> x.run());
+ sum += event.getShuffleDataInfoList().stream().mapToLong(x ->
x.getFreeMemory()).sum();
+ }
+ return Arrays.asList(CompletableFuture.completedFuture(sum));
+ };
+ wbm.setSpillFunc(spillFunc);
+
+ when(wbm.acquireMemory(512)).thenReturn(512L);
+
+ String testKey = "Key";
+ String testValue = "Value";
+ wbm.addRecord(0, testKey, testValue);
+ wbm.addRecord(1, testKey, testValue);
+ wbm.addRecord(1, testKey, testValue);
+ wbm.addRecord(1, testKey, testValue);
+ wbm.addRecord(1, testKey, testValue);
+
+ long releasedSize = wbm.spill(1000, wbm);
+ assertEquals(64, releasedSize);
+ assertEquals(96, wbm.getUsedBytes());
+ assertEquals(0, wbm.getBuffers().keySet().toArray()[0]);
+ }
+
public static class FakedTaskMemoryManager extends TaskMemoryManager {
private static final Logger LOGGER =
LoggerFactory.getLogger(FakedTaskMemoryManager.class);
private int invokedCnt = 0;
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index c38f159d3..2e116c72c 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -264,7 +264,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
final long start = System.currentTimeMillis();
- shuffleBlockInfos = bufferManager.clear();
+ shuffleBlockInfos = bufferManager.clear(1.0);
processShuffleBlockInfos(shuffleBlockInfos);
long s = System.currentTimeMillis();
checkSentRecordCount(recordCount);
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 06c877275..3cdcf9aa8 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -307,7 +307,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
final long start = System.currentTimeMillis();
- shuffleBlockInfos = bufferManager.clear();
+ shuffleBlockInfos = bufferManager.clear(1.0);
if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
processShuffleBlockInfos(shuffleBlockInfos);
}