This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 649118487 [#1748] feat(remote merge): Introduce MergeManager to merge
records on the server side. (#1946)
649118487 is described below
commit 649118487a8e25e8008c1b90d1d750f72f730c1a
Author: zhengchenyu <[email protected]>
AuthorDate: Mon Aug 12 10:23:56 2024 +0800
[#1748] feat(remote merge): Introduce MergeManager to merge records on the
server side. (#1946)
### What changes were proposed in this pull request?
Used to merge records on the server side. By this, client can get sorted
record.
### Why are the changes needed?
Fix: #1748
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
unit test and test in cluster.
---
.../apache/uniffle/common/ShuffleIndexResult.java | 8 +-
.../org/apache/uniffle/common/config/RssConf.java | 12 +
.../response/RssGetShuffleIndexResponse.java | 2 +-
.../org/apache/uniffle/server/ShuffleServer.java | 21 +-
.../apache/uniffle/server/ShuffleServerConf.java | 56 +++
.../uniffle/server/ShuffleServerMetrics.java | 3 +
.../apache/uniffle/server/ShuffleTaskManager.java | 20 +-
.../uniffle/server/merge/BlockFlushFileReader.java | 387 +++++++++++++++++
.../server/merge/DefaultMergeEventHandler.java | 111 +++++
.../apache/uniffle/server/merge/MergeEvent.java | 88 ++++
.../uniffle/server/merge/MergeEventHandler.java | 18 +-
.../apache/uniffle/server/merge/MergeStatus.java | 26 +-
.../apache/uniffle/server/merge/MergedResult.java | 109 +++++
.../org/apache/uniffle/server/merge/Partition.java | 478 +++++++++++++++++++++
.../org/apache/uniffle/server/merge/Shuffle.java | 100 +++++
.../uniffle/server/merge/ShuffleMergeManager.java | 293 +++++++++++++
.../server/merge/BlockFlushFileReaderTest.java | 257 +++++++++++
.../uniffle/server/merge/MergedResultTest.java | 176 ++++++++
.../server/merge/ShuffleMergeManagerTest.java | 223 ++++++++++
server/src/test/resources/log4j2.xml | 29 ++
.../handler/impl/LocalFileServerReadHandler.java | 11 +-
21 files changed, 2398 insertions(+), 30 deletions(-)
diff --git
a/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
b/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
index 71bb3df39..c90f8997e 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
@@ -28,6 +28,7 @@ import org.apache.uniffle.common.util.ByteBufUtils;
public class ShuffleIndexResult {
private final ManagedBuffer buffer;
private long dataFileLen;
+ private String dataFileName;
public ShuffleIndexResult() {
this(ByteBuffer.wrap(new byte[0]), -1);
@@ -43,9 +44,10 @@ public class ShuffleIndexResult {
this.dataFileLen = dataFileLen;
}
- public ShuffleIndexResult(ManagedBuffer buffer, long dataFileLen) {
+ public ShuffleIndexResult(ManagedBuffer buffer, long dataFileLen, String
dataFileName) {
this.buffer = buffer;
this.dataFileLen = dataFileLen;
+ this.dataFileName = dataFileName;
}
public byte[] getData() {
@@ -79,4 +81,8 @@ public class ShuffleIndexResult {
public ManagedBuffer getManagedBuffer() {
return buffer;
}
+
+ public String getDataFileName() {
+ return dataFileName;
+ }
}
diff --git a/common/src/main/java/org/apache/uniffle/common/config/RssConf.java
b/common/src/main/java/org/apache/uniffle/common/config/RssConf.java
index b77a50b23..74d1c2bdb 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssConf.java
@@ -18,6 +18,7 @@
package org.apache.uniffle.common.config;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
@@ -682,4 +683,15 @@ public class RssConf implements Cloneable {
public void remove(String key) {
this.settings.remove(key);
}
+
+ public Map<String, Object> getPropsWithPrefix(String confPrefix) {
+ Map<String, Object> configMap = new HashMap<>();
+ for (Map.Entry<String, Object> entry : settings.entrySet()) {
+ if (entry.getKey().startsWith(confPrefix)) {
+ String keyName = entry.getKey().substring(confPrefix.length());
+ configMap.put(keyName, entry.getValue());
+ }
+ }
+ return configMap;
+ }
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
index 37a31652e..4d3667ab1 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
@@ -26,7 +26,7 @@ public class RssGetShuffleIndexResponse extends
ClientResponse {
public RssGetShuffleIndexResponse(StatusCode statusCode, ManagedBuffer data,
long dataFileLen) {
super(statusCode);
- this.shuffleIndexResult = new ShuffleIndexResult(data, dataFileLen);
+ this.shuffleIndexResult = new ShuffleIndexResult(data, dataFileLen, null);
}
public ShuffleIndexResult getShuffleIndexResult() {
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
index ee790bad0..60be04b65 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
@@ -57,6 +57,7 @@ import org.apache.uniffle.common.util.ThreadUtils;
import org.apache.uniffle.common.web.CoalescedCollectorRegistry;
import org.apache.uniffle.common.web.JettyServer;
import org.apache.uniffle.server.buffer.ShuffleBufferManager;
+import org.apache.uniffle.server.merge.ShuffleMergeManager;
import org.apache.uniffle.server.netty.StreamServer;
import org.apache.uniffle.server.storage.StorageManager;
import org.apache.uniffle.server.storage.StorageManagerFactory;
@@ -90,6 +91,8 @@ public class ShuffleServer {
private ShuffleFlushManager shuffleFlushManager;
private ShuffleBufferManager shuffleBufferManager;
private StorageManager storageManager;
+ private boolean remoteMergeEnable;
+ private ShuffleMergeManager shuffleMergeManager;
private HealthCheck healthCheck;
private Set<String> tags = Sets.newHashSet();
private GRPCMetrics grpcMetrics;
@@ -305,9 +308,17 @@ public class ShuffleServer {
shuffleFlushManager = new ShuffleFlushManager(shuffleServerConf, this,
storageManager);
shuffleBufferManager =
new ShuffleBufferManager(shuffleServerConf, shuffleFlushManager,
nettyServerEnabled);
+ remoteMergeEnable =
shuffleServerConf.get(ShuffleServerConf.SERVER_MERGE_ENABLE);
+ if (remoteMergeEnable) {
+ shuffleMergeManager = new ShuffleMergeManager(shuffleServerConf, this);
+ }
shuffleTaskManager =
new ShuffleTaskManager(
- shuffleServerConf, shuffleFlushManager, shuffleBufferManager,
storageManager);
+ shuffleServerConf,
+ shuffleFlushManager,
+ shuffleBufferManager,
+ storageManager,
+ shuffleMergeManager);
shuffleTaskManager.start();
setServer();
@@ -569,4 +580,12 @@ public class ShuffleServer {
shuffleServer.getJettyPort(),
shuffleServer.getStartTimeMs());
}
+
+ public ShuffleMergeManager getShuffleMergeManager() {
+ return shuffleMergeManager;
+ }
+
+ public boolean isRemoteMergeEnable() {
+ return remoteMergeEnable;
+ }
}
diff --git
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
index 5b7aad8ed..eef889690 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
@@ -659,6 +659,62 @@ public class ShuffleServerConf extends RssBaseConf {
.defaultValue(10 * 60L)
.withDescription("The storage remove resource operation timeout.");
+ public static final ConfigOption<Boolean> SERVER_MERGE_ENABLE =
+ ConfigOptions.key("rss.server.merge.enable")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription("Whether to enable remote merge");
+
+ public static final ConfigOption<Integer> SERVER_MERGE_THREAD_POOL_SIZE =
+ ConfigOptions.key("rss.server.merge.threadPoolSize")
+ .intType()
+ .defaultValue(10)
+ .withDescription("thread pool for merge");
+
+ public static final ConfigOption<Integer>
SERVER_MERGE_THREAD_POOL_QUEUE_SIZE =
+ ConfigOptions.key("rss.server.merge.threadPoolQueueSize")
+ .intType()
+ .defaultValue(Integer.MAX_VALUE)
+ .withDescription("size of waiting queue for merge thread pool");
+
+ public static final ConfigOption<Integer> SERVER_MERGE_THREAD_ALIVE_TIME =
+ ConfigOptions.key("rss.server.merge.threadAliveTime")
+ .intType()
+ .defaultValue(120)
+ .withDescription("thread idle time in merge thread pool (s)");
+
+ public static final ConfigOption<String>
SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE =
+ ConfigOptions.key("rss.server.merge.defaultMergedBlockSize")
+ .stringType()
+ .defaultValue("14m")
+ .withDescription("The default merged block size.");
+
+ public static final ConfigOption<Long>
SERVER_MERGE_CACHE_MERGED_BLOCK_INIT_SLEEP_MS =
+ ConfigOptions.key("rss.server.merge.cacheMergedBlockInitSleepMs")
+ .longType()
+ .defaultValue(100L)
+ .withDescription(
+ "When caching merged block, the minimum waiting event after
failure to require memory");
+
+ public static final ConfigOption<Long>
SERVER_MERGE_CACHE_MERGED_BLOCK_MAX_SLEEP_MS =
+ ConfigOptions.key("rss.server.merge.cacheMergedBlockMaxSleepMs")
+ .longType()
+ .defaultValue(2000L)
+ .withDescription(
+ "When caching merged block, the maximum waiting event after
failure to require memory");
+
+ public static final ConfigOption<Integer>
SERVER_MERGE_BLOCK_RING_BUFFER_SIZE =
+ ConfigOptions.key("rss.server.merge.blockRingBufferSize")
+ .intType()
+ .defaultValue(2)
+ .withDescription("The ring buffer size for read block when merge");
+
+ public static final ConfigOption<String> SERVER_MERGE_CLASS_LOADER_JARS_PATH
=
+ ConfigOptions.key("rss.server.merge.classLoaderJarsPath")
+ .stringType()
+ .defaultValue(null)
+ .withDescription("The jars path for class loader when merge");
+
public ShuffleServerConf() {}
public ShuffleServerConf(String fileName) {
diff --git
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerMetrics.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerMetrics.java
index bac820520..3e886407c 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerMetrics.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerMetrics.java
@@ -58,6 +58,7 @@ public class ShuffleServerMetrics {
private static final String EVENT_SIZE_THRESHOLD_LEVEL3 =
"event_size_threshold_level3";
private static final String EVENT_SIZE_THRESHOLD_LEVEL4 =
"event_size_threshold_level4";
private static final String EVENT_QUEUE_SIZE = "event_queue_size";
+ private static final String MERGE_EVENT_QUEUE_SIZE =
"merge_event_queue_size";
private static final String HADOOP_FLUSH_THREAD_POOL_QUEUE_SIZE =
"hadoop_flush_thread_pool_queue_size";
private static final String LOCALFILE_FLUSH_THREAD_POOL_QUEUE_SIZE =
@@ -222,6 +223,7 @@ public class ShuffleServerMetrics {
public static Gauge.Child gaugeUsedDirectMemorySizeByGrpcNetty;
public static Gauge.Child gaugeWriteHandler;
public static Gauge.Child gaugeEventQueueSize;
+ public static Gauge.Child gaugeMergeEventQueueSize;
public static Gauge.Child gaugeHadoopFlushThreadPoolQueueSize;
public static Gauge.Child gaugeLocalfileFlushThreadPoolQueueSize;
public static Gauge.Child gaugeFallbackFlushThreadPoolQueueSize;
@@ -454,6 +456,7 @@ public class ShuffleServerMetrics {
metricsManager.addLabeledGauge(USED_DIRECT_MEMORY_SIZE_BY_GRPC_NETTY);
gaugeWriteHandler = metricsManager.addLabeledGauge(TOTAL_WRITE_HANDLER);
gaugeEventQueueSize = metricsManager.addLabeledGauge(EVENT_QUEUE_SIZE);
+ gaugeMergeEventQueueSize =
metricsManager.addLabeledGauge(MERGE_EVENT_QUEUE_SIZE);
gaugeHadoopFlushThreadPoolQueueSize =
metricsManager.addLabeledGauge(HADOOP_FLUSH_THREAD_POOL_QUEUE_SIZE);
gaugeLocalfileFlushThreadPoolQueueSize =
diff --git
a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index 8dc1653ed..226682e63 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -79,6 +79,7 @@ import org.apache.uniffle.server.event.AppPurgeEvent;
import org.apache.uniffle.server.event.AppUnregisterPurgeEvent;
import org.apache.uniffle.server.event.PurgeEvent;
import org.apache.uniffle.server.event.ShufflePurgeEvent;
+import org.apache.uniffle.server.merge.ShuffleMergeManager;
import org.apache.uniffle.server.storage.StorageManager;
import org.apache.uniffle.storage.common.Storage;
import org.apache.uniffle.storage.common.StorageReadMetrics;
@@ -119,17 +120,28 @@ public class ShuffleTaskManager {
private BlockingQueue<PurgeEvent> expiredAppIdQueue =
Queues.newLinkedBlockingQueue();
private final Cache<String, ReentrantReadWriteLock> appLocks;
private final long storageRemoveOperationTimeoutSec;
+ private ShuffleMergeManager shuffleMergeManager;
public ShuffleTaskManager(
ShuffleServerConf conf,
ShuffleFlushManager shuffleFlushManager,
ShuffleBufferManager shuffleBufferManager,
StorageManager storageManager) {
+ this(conf, shuffleFlushManager, shuffleBufferManager, storageManager,
null);
+ }
+
+ public ShuffleTaskManager(
+ ShuffleServerConf conf,
+ ShuffleFlushManager shuffleFlushManager,
+ ShuffleBufferManager shuffleBufferManager,
+ StorageManager storageManager,
+ ShuffleMergeManager shuffleMergeManager) {
this.conf = conf;
this.shuffleFlushManager = shuffleFlushManager;
this.partitionsToBlockIds = JavaUtils.newConcurrentMap();
this.shuffleBufferManager = shuffleBufferManager;
this.storageManager = storageManager;
+ this.shuffleMergeManager = shuffleMergeManager;
this.appExpiredWithoutHB =
conf.getLong(ShuffleServerConf.SERVER_APP_EXPIRED_WITHOUT_HEARTBEAT);
this.commitCheckIntervalMax =
conf.getLong(ShuffleServerConf.SERVER_COMMIT_CHECK_INTERVAL_MAX);
this.preAllocationExpired =
conf.getLong(ShuffleServerConf.SERVER_PRE_ALLOCATION_EXPIRED);
@@ -804,7 +816,9 @@ public class ShuffleTaskManager {
},
storageRemoveOperationTimeoutSec,
operationMsg);
-
+ if (shuffleMergeManager != null) {
+ shuffleMergeManager.removeBuffer(appId, shuffleIds);
+ }
LOG.info(
"Finish remove resource for appId[{}], shuffleIds[{}], cost[{}]",
appId,
@@ -862,7 +876,9 @@ public class ShuffleTaskManager {
},
storageRemoveOperationTimeoutSec,
operationMsg);
-
+ if (shuffleMergeManager != null) {
+ shuffleMergeManager.removeBuffer(appId);
+ }
if (shuffleTaskInfo.hasHugePartition()) {
ShuffleServerMetrics.gaugeAppWithHugePartitionNum.dec();
ShuffleServerMetrics.gaugeHugePartitionNum.dec();
diff --git
a/server/src/main/java/org/apache/uniffle/server/merge/BlockFlushFileReader.java
b/server/src/main/java/org/apache/uniffle/server/merge/BlockFlushFileReader.java
new file mode 100644
index 000000000..76fa2cc1d
--- /dev/null
+++
b/server/src/main/java/org/apache/uniffle/server/merge/BlockFlushFileReader.java
@@ -0,0 +1,387 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.merge;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantLock;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.buffer.FileSegmentManagedBuffer;
+import org.apache.uniffle.common.serializer.PartialInputStream;
+import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.storage.common.FileBasedShuffleSegment;
+
+/**
+ * Remote Merge merges the original blocks into a new block set starting with
id 1. For now, all
+ * blocks under the partition are written into a file. Remote Merge needs to
read the contents of
+ * each block separately. If we use a file handle to manage each block, a
large number of open files
+ * will be wasted. Therefore, BlockFlushFileReader was introduced.
+ *
+ * <p>BlockFlushFileReader uses one file handle to manage all blocks under
this partition. The
+ * FlushFileReader thread is used to read this file corresponding to the
partition. FlushFileReader
+ * reads the partial buffer of each block in sequence each time. In this way,
FlushFileReader always
+ * reads data in the order of increasing offset, which reduces random reads
compared to opening a
+ * file per block. BlockInputStream reads the buffer corresponding to the
block. We use RingBuffer
+ * to balance the buffer generated by FlushFileReader and the buffer consumed
by BlockInputStream.
+ */
+public class BlockFlushFileReader {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(BlockFlushFileReader.class);
+ private static final int BUFFER_SIZE = 4096;
+
+ private String dataFile;
+ private FileInputStream dataInput;
+ private FileChannel dataFileChannel;
+ boolean stop = false;
+
+ // blockid -> BlockInputStream
+ private final Map<Long, BlockInputStream> inputStreamMap =
JavaUtils.newConcurrentMap();
+ private final LinkedHashMap<Long, FileBasedShuffleSegment> indexSegments =
new LinkedHashMap<>();
+
+ private FlushFileReader flushFileReader;
+ private volatile Throwable readThrowable = null;
+ // Even though there are many BlockInputStream, these BlockInputStream must
+ // be executed in the same thread, we called the Merge Thread. When the
buffer
+ // of BlockInputStream have been read out, we can notify flushFileReader by
+ // unlock. Then flushFileReader will load the buffer, and Merge will read the
+ // buffer of BlockInputStream until flushFileReader load done and unlock.
+ private final ReentrantLock lock = new ReentrantLock(true);
+
+ private final int ringBufferSize;
+ private final int mask;
+
+ public BlockFlushFileReader(String dataFile, String indexFile, int
ringBufferSize)
+ throws IOException {
+ // Make sure flush file will not be updated
+ this.ringBufferSize = ringBufferSize;
+ this.mask = ringBufferSize - 1;
+ loadShuffleIndex(indexFile);
+ this.dataFile = dataFile;
+ this.dataInput = new FileInputStream(dataFile);
+ this.dataFileChannel = dataInput.getChannel();
+ // Avoid flushFileReader noop loop
+ this.lock.lock();
+ this.flushFileReader = new FlushFileReader();
+ this.flushFileReader.start();
+ }
+
+ public void loadShuffleIndex(String indexFileName) {
+ File indexFile = new File(indexFileName);
+ long indexFileSize = indexFile.length();
+ int indexNum = (int) (indexFileSize /
FileBasedShuffleSegment.SEGMENT_SIZE);
+ int len = indexNum * FileBasedShuffleSegment.SEGMENT_SIZE;
+ ByteBuffer indexData = new FileSegmentManagedBuffer(indexFile, 0,
len).nioByteBuffer();
+ while (indexData.hasRemaining()) {
+ long offset = indexData.getLong();
+ int length = indexData.getInt();
+ int uncompressLength = indexData.getInt();
+ long crc = indexData.getLong();
+ long blockId = indexData.getLong();
+ long taskAttemptId = indexData.getLong();
+ FileBasedShuffleSegment fileBasedShuffleSegment =
+ new FileBasedShuffleSegment(
+ blockId, offset, length, uncompressLength, crc, taskAttemptId);
+ indexSegments.put(fileBasedShuffleSegment.getBlockId(),
fileBasedShuffleSegment);
+ }
+ }
+
+ public void close() throws IOException, InterruptedException {
+ if (!this.stop) {
+ stop = true;
+ flushFileReader.interrupt();
+ flushFileReader = null;
+ }
+ if (dataInput != null) {
+ this.dataInput.close();
+ this.dataInput = null;
+ this.dataFile = null;
+ }
+ }
+
+ public BlockInputStream registerBlockInputStream(long blockId) {
+ if (!indexSegments.containsKey(blockId)) {
+ return null;
+ }
+ if (!inputStreamMap.containsKey(blockId)) {
+ inputStreamMap.put(
+ blockId, new BlockInputStream(blockId,
this.indexSegments.get(blockId).getLength()));
+ }
+ return inputStreamMap.get(blockId);
+ }
+
+ class FlushFileReader extends Thread {
+ @Override
+ public void run() {
+ while (!stop) {
+ int available = 0;
+ int process = 0;
+ try {
+ lock.lockInterruptibly();
+ try {
+ Iterator<Map.Entry<Long, FileBasedShuffleSegment>> iterator =
+ indexSegments.entrySet().iterator();
+ while (iterator.hasNext()) {
+ FileBasedShuffleSegment segment = iterator.next().getValue();
+ long blockId = segment.getBlockId();
+ BlockInputStream inputStream = inputStreamMap.get(blockId);
+ if (inputStream == null || inputStream.eof) {
+ continue;
+ }
+ available++;
+ if (inputStream.isBufferFull()) {
+ continue;
+ }
+ process++;
+ long off = segment.getOffset() +
inputStream.getOffsetInThisBlock();
+ if (dataFileChannel.position() != off) {
+ dataFileChannel.position(off);
+ }
+ inputStream.writeBuffer();
+ }
+ } catch (Throwable throwable) {
+ readThrowable = throwable;
+ LOG.info("FlushFileReader read failed, caused by ", throwable);
+ stop = true;
+ } finally {
+ lock.unlock();
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(
+ "statistics: load buffer available is {}, process is {}",
available, process);
+ }
+ }
+ } catch (InterruptedException e) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("FlushFileReader for {} have been interrupted.",
dataFile);
+ }
+ }
+ }
+ }
+ }
+
+ class Buffer {
+
+ private byte[] bytes = new byte[BUFFER_SIZE];
+ private int cap = BUFFER_SIZE;
+ private int pos = cap;
+
+ public int get() {
+ return this.bytes[pos++] & 0xFF;
+ }
+
+ public int get(byte[] bs, int off, int len) {
+ int r = Math.min(cap - pos, len);
+ System.arraycopy(bytes, pos, bs, off, r);
+ pos += r;
+ return r;
+ }
+
+ public boolean readable() {
+ return pos < cap;
+ }
+
+ public void writeBuffer(int length) throws IOException {
+ dataFileChannel.read(ByteBuffer.wrap(this.bytes, 0, length));
+ this.pos = 0;
+ this.cap = length;
+ }
+ }
+
+ class RingBuffer {
+
+ Buffer[] buffers;
+ // The max of int is 2147483647, the maximum bocksize supported by
RingBuffer is 7.999 TB,
+ // the block can't be that big. so readIndex and writeIndex cannot
overflow, there's no
+ // modulo operator for readIndex and writeIndex.
+ int readIndex = 0;
+ int writeIndex = 0;
+
+ RingBuffer() {
+ this.buffers = new Buffer[ringBufferSize];
+ for (int i = 0; i < ringBufferSize; i++) {
+ this.buffers[i] = new Buffer();
+ }
+ }
+
+ boolean full() {
+ return (writeIndex - readIndex) == ringBufferSize;
+ }
+
+ boolean empty() {
+ return writeIndex == readIndex;
+ }
+
+ int write(int available) throws IOException {
+ int left = available;
+ while (!full() && left > 0) {
+ int size = Math.min(available, BUFFER_SIZE);
+ this.buffers[writeIndex & mask].writeBuffer(size);
+ left -= size;
+ writeIndex++;
+ }
+ return available - left;
+ }
+
+ int read() {
+ int ret = this.buffers[readIndex & mask].get();
+ if (!this.buffers[readIndex & mask].readable()) {
+ readIndex++;
+ }
+ return ret;
+ }
+
+ int read(byte[] bs, int off, int len) {
+ int total = 0;
+ int end = off + len;
+ while (off < end && !this.empty()) {
+ Buffer buffer = this.buffers[readIndex & mask];
+ int r = buffer.get(bs, off, len);
+ if (!this.buffers[readIndex & mask].readable()) {
+ readIndex++;
+ }
+ off += r;
+ len -= r;
+ total += r;
+ }
+ return total;
+ }
+ }
+
+ public class BlockInputStream extends PartialInputStream {
+
+ private long blockId;
+ private RingBuffer ringBuffer;
+ private boolean eof = false;
+ private final int length;
+ private int pos = 0;
+ private int offsetInThisBlock = 0;
+
+ public BlockInputStream(long blockId, int length) {
+ this.blockId = blockId;
+ this.length = length;
+ this.ringBuffer = new RingBuffer();
+ }
+
+ @Override
+ public int available() throws IOException {
+ return length - pos;
+ }
+
+ @Override
+ public long getStart() {
+ return 0;
+ }
+
+ @Override
+ public long getEnd() {
+ return length;
+ }
+
+ public long getOffsetInThisBlock() {
+ return this.offsetInThisBlock;
+ }
+
+ @Override
+ public void close() throws IOException {
+ try {
+ inputStreamMap.remove(blockId);
+ indexSegments.remove(blockId);
+ if (inputStreamMap.size() == 0) {
+ BlockFlushFileReader.this.close();
+ }
+ } catch (InterruptedException e) {
+ throw new IOException(e);
+ }
+ }
+
+ public boolean isBufferFull() {
+ return ringBuffer.full();
+ }
+
+ public void writeBuffer() throws IOException {
+ int size = this.ringBuffer.write(length - offsetInThisBlock);
+ this.offsetInThisBlock += size;
+ }
+
+ public int read(byte[] bs, int off, int len) throws IOException {
+ if (stop) {
+ throw new IOException("Block flush file reader is closed, caused by "
+ readThrowable);
+ }
+ if (bs == null) {
+ throw new NullPointerException();
+ } else if (off < 0 || len < 0 || len > bs.length - off) {
+ throw new IndexOutOfBoundsException();
+ } else if (len == 0) {
+ return 0;
+ }
+ if (eof) {
+ return -1;
+ }
+ while (ringBuffer.empty() && !stop) {
+ if (lock.isHeldByCurrentThread()) {
+ lock.unlock();
+ }
+ try {
+ lock.lockInterruptibly();
+ } catch (InterruptedException e) {
+ throw new IOException(e);
+ }
+ }
+ int c = this.ringBuffer.read(bs, off, len);
+ pos += c;
+ if (pos >= length) {
+ eof = true;
+ }
+ return c;
+ }
+
+ @Override
+ public int read() throws IOException {
+ if (stop) {
+ throw new IOException("Block flush file reader is closed, caused by "
+ readThrowable);
+ }
+ if (eof) {
+ return -1;
+ }
+ while (ringBuffer.empty() && !stop) {
+ if (lock.isHeldByCurrentThread()) {
+ lock.unlock();
+ }
+ try {
+ lock.lockInterruptibly();
+ } catch (InterruptedException e) {
+ throw new IOException(e);
+ }
+ }
+ int c = this.ringBuffer.read();
+ pos++;
+ if (pos >= length) {
+ eof = true;
+ }
+ return c;
+ }
+ }
+}
diff --git
a/server/src/main/java/org/apache/uniffle/server/merge/DefaultMergeEventHandler.java
b/server/src/main/java/org/apache/uniffle/server/merge/DefaultMergeEventHandler.java
new file mode 100644
index 000000000..05c9a3723
--- /dev/null
+++
b/server/src/main/java/org/apache/uniffle/server/merge/DefaultMergeEventHandler.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.merge;
+
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Executor;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+
+import com.google.common.collect.Queues;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.util.ThreadUtils;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.server.ShuffleServerMetrics;
+
+import static
org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_THREAD_ALIVE_TIME;
+import static
org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_THREAD_POOL_QUEUE_SIZE;
+import static
org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_THREAD_POOL_SIZE;
+
+public class DefaultMergeEventHandler implements MergeEventHandler {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(DefaultMergeEventHandler.class);
+
+ private Executor threadPoolExecutor;
+ protected final BlockingQueue<MergeEvent> queue =
Queues.newLinkedBlockingQueue();
+ private Consumer<MergeEvent> eventConsumer;
+ private volatile boolean stopped = false;
+
+ public DefaultMergeEventHandler(
+ ShuffleServerConf serverConf, Consumer<MergeEvent> eventConsumer) {
+ this.eventConsumer = eventConsumer;
+ int poolSize = serverConf.get(SERVER_MERGE_THREAD_POOL_SIZE);
+ int queueSize = serverConf.get(SERVER_MERGE_THREAD_POOL_QUEUE_SIZE);
+ int keepAliveTime = serverConf.get(SERVER_MERGE_THREAD_ALIVE_TIME);
+ BlockingQueue<Runnable> waitQueue =
Queues.newLinkedBlockingQueue(queueSize);
+ threadPoolExecutor =
+ new ThreadPoolExecutor(
+ poolSize,
+ poolSize,
+ keepAliveTime,
+ TimeUnit.SECONDS,
+ waitQueue,
+ ThreadUtils.getThreadFactory("DefaultMergeEventHandler"));
+ startEventProcessor();
+ }
+
+ private void startEventProcessor() {
+ Thread processEventThread = new Thread(this::eventLoop);
+ processEventThread.setName("ProcessEventThread");
+ processEventThread.setDaemon(true);
+ processEventThread.start();
+ }
+
+ protected void eventLoop() {
+ while (!stopped && !Thread.currentThread().isInterrupted()) {
+ processNextEvent();
+ }
+ }
+
+ protected void processNextEvent() {
+ try {
+ MergeEvent event = queue.take();
+ threadPoolExecutor.execute(() -> handleEventAndUpdateMetrics(event));
+ } catch (Exception e) {
+ LOG.error("Exception happened when process event.", e);
+ }
+ }
+
+ private void handleEventAndUpdateMetrics(MergeEvent event) {
+ try {
+ eventConsumer.accept(event);
+ } finally {
+ ShuffleServerMetrics.gaugeMergeEventQueueSize.dec();
+ }
+ }
+
+ @Override
+ public void handle(MergeEvent event) {
+ if (queue.offer(event)) {
+ ShuffleServerMetrics.gaugeMergeEventQueueSize.inc();
+ }
+ }
+
+ @Override
+ public int getEventNumInMerge() {
+ return queue.size();
+ }
+
+ @Override
+ public void stop() {
+ stopped = true;
+ }
+}
diff --git
a/server/src/main/java/org/apache/uniffle/server/merge/MergeEvent.java
b/server/src/main/java/org/apache/uniffle/server/merge/MergeEvent.java
new file mode 100644
index 000000000..250e9a1ee
--- /dev/null
+++ b/server/src/main/java/org/apache/uniffle/server/merge/MergeEvent.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.merge;
+
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+public class MergeEvent {
+
+ private final String appId;
+ private final int shuffleId;
+ private final int partitionId;
+ private final Class kClass;
+ private final Class vClass;
+ private Roaring64NavigableMap expectedBlockIdMap;
+
+ public MergeEvent(
+ String appId,
+ int shuffleId,
+ int partitionId,
+ Class kClass,
+ Class vClass,
+ Roaring64NavigableMap expectedBlockIdMap) {
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.partitionId = partitionId;
+ this.kClass = kClass;
+ this.vClass = vClass;
+ this.expectedBlockIdMap = expectedBlockIdMap;
+ }
+
+ public String getAppId() {
+ return appId;
+ }
+
+ public int getShuffleId() {
+ return shuffleId;
+ }
+
+ public int getPartitionId() {
+ return partitionId;
+ }
+
+ public Roaring64NavigableMap getExpectedBlockIdMap() {
+ return expectedBlockIdMap;
+ }
+
+ public Class getKeyClass() {
+ return kClass;
+ }
+
+ public Class getValueClass() {
+ return vClass;
+ }
+
+ @Override
+ public String toString() {
+ return "MergeEvent{"
+ + "appId='"
+ + appId
+ + '\''
+ + ", shuffleId="
+ + shuffleId
+ + ", partitionId="
+ + partitionId
+ + ", kClass="
+ + kClass
+ + ", vClass="
+ + vClass
+ + ", expectedBlockIdMap="
+ + expectedBlockIdMap
+ + '}';
+ }
+}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
b/server/src/main/java/org/apache/uniffle/server/merge/MergeEventHandler.java
similarity index 56%
copy from
internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
copy to
server/src/main/java/org/apache/uniffle/server/merge/MergeEventHandler.java
index 37a31652e..c4a248e3a 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
+++
b/server/src/main/java/org/apache/uniffle/server/merge/MergeEventHandler.java
@@ -15,21 +15,13 @@
* limitations under the License.
*/
-package org.apache.uniffle.client.response;
+package org.apache.uniffle.server.merge;
-import org.apache.uniffle.common.ShuffleIndexResult;
-import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
-import org.apache.uniffle.common.rpc.StatusCode;
+public interface MergeEventHandler {
-public class RssGetShuffleIndexResponse extends ClientResponse {
- private final ShuffleIndexResult shuffleIndexResult;
+ void handle(MergeEvent event);
- public RssGetShuffleIndexResponse(StatusCode statusCode, ManagedBuffer data,
long dataFileLen) {
- super(statusCode);
- this.shuffleIndexResult = new ShuffleIndexResult(data, dataFileLen);
- }
+ int getEventNumInMerge();
- public ShuffleIndexResult getShuffleIndexResult() {
- return shuffleIndexResult;
- }
+ void stop();
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
b/server/src/main/java/org/apache/uniffle/server/merge/MergeStatus.java
similarity index 57%
copy from
internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
copy to server/src/main/java/org/apache/uniffle/server/merge/MergeStatus.java
index 37a31652e..b2263aa34 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
+++ b/server/src/main/java/org/apache/uniffle/server/merge/MergeStatus.java
@@ -15,21 +15,25 @@
* limitations under the License.
*/
-package org.apache.uniffle.client.response;
+package org.apache.uniffle.server.merge;
-import org.apache.uniffle.common.ShuffleIndexResult;
-import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
-import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.merger.MergeState;
-public class RssGetShuffleIndexResponse extends ClientResponse {
- private final ShuffleIndexResult shuffleIndexResult;
+public class MergeStatus {
- public RssGetShuffleIndexResponse(StatusCode statusCode, ManagedBuffer data,
long dataFileLen) {
- super(statusCode);
- this.shuffleIndexResult = new ShuffleIndexResult(data, dataFileLen);
+ private MergeState state;
+ private long size;
+
+ public MergeStatus(MergeState state, long size) {
+ this.state = state;
+ this.size = size;
+ }
+
+ public MergeState getState() {
+ return state;
}
- public ShuffleIndexResult getShuffleIndexResult() {
- return shuffleIndexResult;
+ public long getSize() {
+ return size;
}
}
diff --git
a/server/src/main/java/org/apache/uniffle/server/merge/MergedResult.java
b/server/src/main/java/org/apache/uniffle/server/merge/MergedResult.java
new file mode 100644
index 000000000..6c7ce056f
--- /dev/null
+++ b/server/src/main/java/org/apache/uniffle/server/merge/MergedResult.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.merge;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.merger.Recordable;
+
+import static
org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE;
+
+public class MergedResult {
+
+ private final RssConf rssConf;
+ private final long mergedBlockSize;
+ // raw offset by blockId
+ private final List<Long> offsets = new ArrayList<>();
+ private final CacheMergedBlockFuntion cachedMergedBlock;
+
+ public MergedResult(
+ RssConf rssConf, CacheMergedBlockFuntion cachedMergedBlock, int
mergedBlockSize) {
+ this.rssConf = rssConf;
+ this.cachedMergedBlock = cachedMergedBlock;
+ this.mergedBlockSize =
+ mergedBlockSize > 0
+ ? mergedBlockSize
+ : this.rssConf.getSizeAsBytes(
+ SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE.key(),
+ SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE.defaultValue());
+ offsets.add(0L);
+ }
+
+ public OutputStream getOutputStream() {
+ return new MergedSegmentOutputStream();
+ }
+
+ public boolean isOutOfBound(long blockId) {
+ return blockId >= offsets.size();
+ }
+
+ public long getBlockSize(long blockId) {
+ return offsets.get((int) blockId) - offsets.get((int) (blockId - 1));
+ }
+
+ @FunctionalInterface
+ public interface CacheMergedBlockFuntion {
+ void cache(byte[] buffer, long blockId, int length);
+ }
+
+ class MergedSegmentOutputStream extends OutputStream implements Recordable {
+
+ ByteArrayOutputStream current;
+
+ MergedSegmentOutputStream() {
+ current = new ByteArrayOutputStream((int) mergedBlockSize);
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ current.write(b);
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (current != null) {
+ current.close();
+ current = null;
+ }
+ }
+
+ @Override
+ public boolean record(long written, Flushable flushable, boolean force)
throws IOException {
+ assert written >= 0;
+ long currentOffsetInThisBlock = written - offsets.get(offsets.size() -
1);
+ if (currentOffsetInThisBlock >= mergedBlockSize ||
(currentOffsetInThisBlock > 0 && force)) {
+ if (flushable != null) {
+ flushable.flush();
+ }
+ cachedMergedBlock.cache(
+ current.toByteArray(), offsets.size(), (int)
(currentOffsetInThisBlock));
+ offsets.add(written);
+ if (!force) {
+ current = new ByteArrayOutputStream((int) mergedBlockSize);
+ }
+ return true;
+ }
+ return false;
+ }
+ }
+}
diff --git
a/server/src/main/java/org/apache/uniffle/server/merge/Partition.java
b/server/src/main/java/org/apache/uniffle/server/merge/Partition.java
new file mode 100644
index 000000000..05e28ddaa
--- /dev/null
+++ b/server/src/main/java/org/apache/uniffle/server/merge/Partition.java
@@ -0,0 +1,478 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.merge;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import io.netty.buffer.ByteBuf;
+import org.apache.hadoop.io.RawComparator;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ShuffleDataResult;
+import org.apache.uniffle.common.ShuffleIndexResult;
+import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.FileNotFoundException;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.merger.MergeState;
+import org.apache.uniffle.common.merger.Merger;
+import org.apache.uniffle.common.merger.Segment;
+import org.apache.uniffle.common.merger.StreamedSegment;
+import org.apache.uniffle.common.netty.buffer.FileSegmentManagedBuffer;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.server.ShuffleDataReadEvent;
+import org.apache.uniffle.storage.common.Storage;
+import org.apache.uniffle.storage.handler.impl.LocalFileServerReadHandler;
+import org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest;
+import org.apache.uniffle.storage.util.StorageType;
+
+import static org.apache.uniffle.common.merger.MergeState.DONE;
+import static org.apache.uniffle.common.merger.MergeState.INITED;
+import static org.apache.uniffle.common.merger.MergeState.INTERNAL_ERROR;
+import static org.apache.uniffle.common.merger.MergeState.MERGING;
+import static
org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_BLOCK_RING_BUFFER_SIZE;
+import static
org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_CACHE_MERGED_BLOCK_INIT_SLEEP_MS;
+import static
org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_CACHE_MERGED_BLOCK_MAX_SLEEP_MS;
+import static
org.apache.uniffle.server.merge.ShuffleMergeManager.MERGE_APP_SUFFIX;
+
+public class Partition<K, V> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(Partition.class);
+
+ private final Shuffle shuffle;
+ private final int partitionId;
+ // Inserting or deleting ShuffleBuffer::blocks while traversing blocks may
cause an
+ // ConcurrentModificationException.
+ // So cache the block here. When we use the cached block, we should check
refCnt so that we can
+ // make sure the ByteBuf
+ // is not released.
+ Map<Long, ShufflePartitionedBlock> cachedblockMap =
JavaUtils.newConcurrentMap();
+ Map<Long, ShufflePartitionedBlock> mergedBlockMap =
JavaUtils.newConcurrentMap();
+
+ private MergeState state = MergeState.INITED;
+ private MergedResult result;
+ private ShuffleMeta shuffleMeta = new ShuffleMeta();
+
+ // These variable should be moved to ShuffleMergeManager, it is
+ // not necessary to use partition granularity
+ private final long initSleepTime;
+ private final long maxSleepTime;
+ private long sleepTime;
+ private int ringBufferSize;
+ private BlockFlushFileReader reader = null;
+
+ public Partition(Shuffle shuffle, int partitionId) throws IOException {
+ this.shuffle = shuffle;
+ this.partitionId = partitionId;
+ this.result =
+ new MergedResult(shuffle.serverConf, this::cachedMergedBlock,
shuffle.mergedBlockSize);
+ this.initSleepTime =
shuffle.serverConf.get(SERVER_MERGE_CACHE_MERGED_BLOCK_INIT_SLEEP_MS);
+ this.maxSleepTime =
shuffle.serverConf.get(SERVER_MERGE_CACHE_MERGED_BLOCK_MAX_SLEEP_MS);
+ int tmpRingBufferSize =
shuffle.serverConf.get(SERVER_MERGE_BLOCK_RING_BUFFER_SIZE);
+ this.ringBufferSize =
+ Integer.highestOneBit((Math.min(32, Math.max(2, tmpRingBufferSize)) -
1) << 1);
+ if (tmpRingBufferSize != this.ringBufferSize) {
+ LOG.info(
+ "The ring buffer size will transient from {} to {}",
+ tmpRingBufferSize,
+ this.ringBufferSize);
+ }
+ }
+
+ // startSortMerge is used to trigger to merger
+ synchronized void startSortMerge(Roaring64NavigableMap expectedBlockIdMap)
throws IOException {
+ if (getState() != INITED) {
+ LOG.warn("Partition is already merging, so ignore duplicate reports,
partition is {}", this);
+ } else {
+ if (!expectedBlockIdMap.isEmpty()) {
+ setState(MERGING);
+ MergeEvent event =
+ new MergeEvent(
+ shuffle.appId,
+ shuffle.shuffleId,
+ partitionId,
+ shuffle.kClass,
+ shuffle.vClass,
+ expectedBlockIdMap);
+ shuffle.eventHandler.handle(event);
+ } else {
+ setState(DONE);
+ }
+ }
+ }
+
+ // getSegments is used to get segments from original shuffle blocks
+ public List<Segment> getSegments(
+ RssConf rssConf, Iterator<Long> blockIds, Class keyClass, Class
valueClass)
+ throws IOException {
+ List<Segment> segments = new ArrayList<>();
+ Set<Long> blocksFlushed = new HashSet<>();
+ while (blockIds.hasNext()) {
+ long blockId = blockIds.next();
+ ByteBuf buf = null;
+ if (cachedblockMap.containsKey(blockId)) {
+ buf = cachedblockMap.get(blockId).getData();
+ }
+ if (buf != null && buf.refCnt() > 0) {
+ try {
+ StreamedSegment segment =
+ new StreamedSegment(
+ rssConf,
+ buf,
+ blockId,
+ keyClass,
+ valueClass,
+ (shuffle.comparator instanceof RawComparator));
+ segments.add(segment);
+ } catch (Exception e) {
+ // If ByteBuf is released by flush cleanup before we retain in
Segment,
+ // will throw ConcurrentModificationException. So we need get block
buffer
+ // from file
+ LOG.warn("construct segment failed, caused by ", e);
+ blocksFlushed.add(blockId);
+ }
+ } else {
+ blocksFlushed.add(blockId);
+ }
+ }
+ if (blocksFlushed.isEmpty()) {
+ return segments;
+ }
+ try {
+ LocalFileServerReadHandler handler =
getLocalFileServerReadHandler(rssConf, shuffle.appId);
+ this.reader =
+ new BlockFlushFileReader(
+ handler.getDataFileName(), handler.getIndexFileName(),
ringBufferSize);
+ for (Long blockId : blocksFlushed) {
+ BlockFlushFileReader.BlockInputStream inputStream =
+ reader.registerBlockInputStream(blockId);
+ if (inputStream == null) {
+ throw new IOException("Can not find any buffer or file for block " +
blockId);
+ }
+ segments.add(
+ new StreamedSegment(
+ rssConf,
+ inputStream,
+ blockId,
+ keyClass,
+ valueClass,
+ (shuffle.comparator instanceof RawComparator)));
+ }
+ return segments;
+ } catch (Throwable throwable) {
+ throw new IOException(throwable);
+ }
+ }
+
+ void merge(List<Segment> segments) throws IOException {
+ try {
+ OutputStream outputStream = result.getOutputStream();
+ Merger.merge(
+ shuffle.serverConf,
+ outputStream,
+ segments,
+ shuffle.kClass,
+ shuffle.vClass,
+ shuffle.comparator,
+ (shuffle.comparator instanceof RawComparator));
+ setState(DONE);
+ } catch (Exception e) {
+ // TODO: should retry!!!
+ LOG.error("Partition {} remote merge failed, caused by {}", this, e);
+ setState(INTERNAL_ERROR);
+ throw new IOException(e);
+ }
+ }
+
+ public void setState(MergeState state) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Partition is {}, transient from {} to {}.", this,
this.state.name(), state.name());
+ }
+ this.state = state;
+ }
+
+ public MergeState getState() {
+ return state;
+ }
+
+ // Input: The first value is state, the second value is fetch block size
+ // Output: left is the state, right is the blocks size that you can fetch
+ public MergeStatus tryGetBlock(long blockId) {
+ long size = -1L;
+ MergeState currentState = state;
+ if ((currentState == MERGING || currentState == DONE) &&
!result.isOutOfBound(blockId)) {
+ size = result.getBlockSize(blockId);
+ }
+ return new MergeStatus(currentState, size);
+ }
+
+ public void cacheBlock(ShufflePartitionedBlock spb) {
+ cachedblockMap.put(spb.getBlockId(), spb);
+ }
+
+ // When we merge data, we will divide the merge results into blocks
according to the specified
+ // block size.
+ // The merged block in a new appId field (${appd} + MERGE_APP_SUFFIX). We
will process the merged
+ // blocks in the
+ // original way, cache them first, and flush them to disk when necessary.
+ private void cachedMergedBlock(byte[] buffer, long blockId, int length) {
+ String appId = shuffle.appId + MERGE_APP_SUFFIX;
+ ShufflePartitionedBlock spb =
+ new ShufflePartitionedBlock(length, length, -1, blockId, -1, buffer);
+ ShufflePartitionedData spd =
+ new ShufflePartitionedData(partitionId, new ShufflePartitionedBlock[]
{spb});
+ while (true) {
+ StatusCode ret =
+ shuffle
+ .shuffleServer
+ .getShuffleTaskManager()
+ .cacheShuffleData(appId, shuffle.shuffleId, false, spd);
+ if (ret == StatusCode.SUCCESS) {
+ mergedBlockMap.put(blockId, spb);
+ shuffle
+ .shuffleServer
+ .getShuffleTaskManager()
+ .updateCachedBlockIds(
+ appId, shuffle.shuffleId, spd.getPartitionId(),
spd.getBlockList());
+ sleepTime = initSleepTime;
+ break;
+ } else if (ret == StatusCode.NO_BUFFER) {
+ try {
+ LOG.info(
+ "Can not allocate enough memory for "
+ + this
+ + ", then will sleep "
+ + sleepTime
+ + "ms");
+ Thread.sleep(sleepTime);
+ sleepTime = Math.min(maxSleepTime, sleepTime * 2);
+ } catch (InterruptedException ex) {
+ throw new RssException(ex);
+ }
+ } else {
+ String shuffleDataInfo =
+ "appId["
+ + appId
+ + "], shuffleId["
+ + shuffle.shuffleId
+ + "], partitionId["
+ + spd.getPartitionId()
+ + "]";
+ throw new RssException(
+ "Error happened when shuffleEngine.write for "
+ + shuffleDataInfo
+ + ", statusCode="
+ + ret);
+ }
+ }
+ }
+
+ // get merged block
+ public ShuffleDataResult getShuffleData(long blockId) throws IOException {
+ // 1 Get result in memory
+ // For merged block, we read and merge at the same time. Blocks may be
added during the
+ // traversal of blocks,
+ // then may throw ConcurrentModificationException. So use cache block in
Partition.
+ ManagedBuffer managedBuffer = this.getMergedBlockBufferInMemory(blockId);
+ if (managedBuffer != null) {
+ return new ShuffleDataResult(managedBuffer);
+ }
+
+ // 2 Get result in flush file if we can't find block in memory.
+ managedBuffer = this.getMergedBlockBufferInFile(shuffle.serverConf,
blockId);
+ return new ShuffleDataResult(managedBuffer);
+ }
+
+ private NettyManagedBuffer getMergedBlockBufferInMemory(long blockId) {
+ try {
+ ShufflePartitionedBlock block = this.mergedBlockMap.get(blockId);
+ // We must make sure refCnt > 0, it means the ByteBuf is not released by
flush cleanup
+ if (block != null && block.getData().refCnt() > 0) {
+ return new NettyManagedBuffer(block.getData().retain());
+ }
+ return null;
+ } catch (Exception e) {
+ // If release that is triggered by flush cleanup before we retain, may
throw
+ // IllegalReferenceCountException.
+ // It means ByteBuf is not available, we must get the block buffer from
file.
+ LOG.warn("Get ByteBuf from memory failed, cased by", e);
+ return null;
+ }
+ }
+
+ private synchronized ManagedBuffer getMergedBlockBufferInFile(RssConf
rssConf, long blockId) {
+ String appId = shuffle.appId + MERGE_APP_SUFFIX;
+ if (!shuffleMeta.getSegments().containsKey(blockId)) {
+ reloadShuffleMeta(rssConf, appId);
+ }
+ ShuffleMeta.Segment segment = shuffleMeta.getSegments().get(blockId);
+ if (segment != null) {
+ return new FileSegmentManagedBuffer(
+ new File(shuffleMeta.getDataFileName()), segment.getOffset(),
segment.getLength());
+ }
+ throw new RssException("Can not find block for blockId " + blockId);
+ }
+
+ // The index file is constantly growing and needs to be reloaded when
necessary.
+ private synchronized void reloadShuffleMeta(RssConf rssConf, String appId) {
+ ShuffleIndexResult indexResult = loadShuffleIndexResult(rssConf, appId);
+ shuffleMeta.setDataFileName(indexResult.getDataFileName());
+ ByteBuffer indexData = indexResult.getIndexData();
+ Map<Long, ShuffleMeta.Segment> segments = new HashMap<>();
+ while (indexData.hasRemaining()) {
+ long offset = indexData.getLong();
+ int length = indexData.getInt();
+ int uncompressLength = indexData.getInt();
+ long crc = indexData.getLong();
+ long blockId = indexData.getLong();
+ long taskAttemptId = indexData.getLong();
+ segments.put(blockId, new ShuffleMeta.Segment(offset, length));
+ }
+ shuffleMeta.getSegments().clear();
+ shuffleMeta.getSegments().putAll(segments);
+ }
+
+ private ShuffleIndexResult loadShuffleIndexResult(RssConf rssConf, String
appId) {
+ CreateShuffleReadHandlerRequest request = new
CreateShuffleReadHandlerRequest();
+ request.setAppId(appId);
+ request.setShuffleId(shuffle.shuffleId);
+ request.setPartitionId(partitionId);
+ request.setPartitionNumPerRange(1);
+ request.setPartitionNum(Integer.MAX_VALUE); // ignore check partition
number
+ request.setStorageType(StorageType.LOCALFILE.name());
+ request.setRssBaseConf((RssBaseConf) rssConf);
+ Storage storage =
+ shuffle
+ .shuffleServer
+ .getStorageManager()
+ .selectStorage(
+ new ShuffleDataReadEvent(appId, shuffle.shuffleId,
partitionId, partitionId));
+ if (storage == null) {
+ throw new FileNotFoundException("No such data in current storage
manager.");
+ }
+ ShuffleIndexResult index =
storage.getOrCreateReadHandler(request).getShuffleIndex();
+ return index;
+ }
+
+ private LocalFileServerReadHandler getLocalFileServerReadHandler(RssConf
rssConf, String appId) {
+ CreateShuffleReadHandlerRequest request = new
CreateShuffleReadHandlerRequest();
+ request.setAppId(appId);
+ request.setShuffleId(shuffle.shuffleId);
+ request.setPartitionId(partitionId);
+ request.setPartitionNumPerRange(1);
+ request.setPartitionNum(Integer.MAX_VALUE); // ignore check partition
number
+ request.setStorageType(StorageType.LOCALFILE.name());
+ request.setRssBaseConf((RssBaseConf) rssConf);
+ Storage storage =
+ shuffle
+ .shuffleServer
+ .getStorageManager()
+ .selectStorage(
+ new ShuffleDataReadEvent(appId, shuffle.shuffleId,
partitionId, partitionId));
+ if (storage == null) {
+ throw new FileNotFoundException("No such data in current storage
manager.");
+ }
+ return (LocalFileServerReadHandler)
storage.getOrCreateReadHandler(request);
+ }
+
+ void cleanup() {
+ try {
+ if (reader != null) {
+ reader.close();
+ }
+ cachedblockMap.clear();
+ mergedBlockMap.clear();
+ shuffleMeta.clear();
+ } catch (Exception e) {
+ LOG.warn("Partition {} clean up failed, caused by {}", this, e);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "Partition{"
+ + "appId="
+ + shuffle.appId
+ + ", shuffle="
+ + shuffle.shuffleId
+ + ", partitionId="
+ + partitionId
+ + ", state="
+ + state
+ + '}';
+ }
+
+ public static class ShuffleMeta {
+
+ public static class Segment {
+ private long offset;
+ private int length;
+
+ public Segment(long offset, int length) {
+ this.offset = offset;
+ this.length = length;
+ }
+
+ public long getOffset() {
+ return offset;
+ }
+
+ public int getLength() {
+ return length;
+ }
+ }
+
+ private String dataFileName;
+ private Map<Long, Segment> segments = new HashMap();
+
+ public ShuffleMeta() {}
+
+ public void setDataFileName(String dataFileName) {
+ this.dataFileName = dataFileName;
+ }
+
+ public String getDataFileName() {
+ return dataFileName;
+ }
+
+ public Map<Long, Segment> getSegments() {
+ return segments;
+ }
+
+ public void clear() {
+ this.segments.clear();
+ }
+ }
+}
diff --git a/server/src/main/java/org/apache/uniffle/server/merge/Shuffle.java
b/server/src/main/java/org/apache/uniffle/server/merge/Shuffle.java
new file mode 100644
index 000000000..92097eff6
--- /dev/null
+++ b/server/src/main/java/org/apache/uniffle/server/merge/Shuffle.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.merge;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.Map;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.server.ShuffleServer;
+
+public class Shuffle<K, V> {
+
+ final RssConf serverConf;
+ final String appId;
+ final int shuffleId;
+ final Class<K> kClass;
+ final Class<V> vClass;
+ final Comparator<K> comparator;
+ final MergeEventHandler eventHandler;
+ final ShuffleServer shuffleServer;
+ // partition id --> Partition
+ private final Map<Integer, Partition<K, V>> partitions =
JavaUtils.newConcurrentMap();
+ final int mergedBlockSize;
+ final ClassLoader classLoader;
+
+ public Shuffle(
+ RssConf rssConf,
+ MergeEventHandler eventHandler,
+ ShuffleServer shuffleServer,
+ String appId,
+ int shuffleId,
+ Class<K> kClass,
+ Class<V> vClass,
+ Comparator<K> comparator,
+ int mergedBlockSize,
+ ClassLoader classLoader) {
+ this.serverConf = rssConf;
+ this.eventHandler = eventHandler;
+ this.shuffleServer = shuffleServer;
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.kClass = kClass;
+ this.vClass = vClass;
+ this.comparator = comparator;
+ this.mergedBlockSize = mergedBlockSize;
+ this.classLoader = classLoader;
+ }
+
+ public void startSortMerge(int partitionId, Roaring64NavigableMap
expectedBlockIdMap)
+ throws IOException {
+ this.partitions.putIfAbsent(partitionId, new Partition<K, V>(this,
partitionId));
+ this.partitions.get(partitionId).startSortMerge(expectedBlockIdMap);
+ }
+
+ void cleanup() {
+ for (Partition partition : this.partitions.values()) {
+ partition.cleanup();
+ }
+ this.partitions.clear();
+ }
+
+ public void cacheBlock(ShufflePartitionedData spd) throws IOException {
+ int partitionId = spd.getPartitionId();
+ this.partitions.putIfAbsent(partitionId, new Partition<K, V>(this,
partitionId));
+ for (ShufflePartitionedBlock block : spd.getBlockList()) {
+ this.partitions.get(partitionId).cacheBlock(block);
+ }
+ }
+
+ public ClassLoader getClassLoader() {
+ return classLoader;
+ }
+
+ @VisibleForTesting
+ Partition getPartition(int partition) {
+ return this.partitions.get(partition);
+ }
+}
diff --git
a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
new file mode 100644
index 000000000..2024459d5
--- /dev/null
+++
b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.merge;
+
+import java.io.File;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.net.URL;
+import java.net.URLClassLoader;
+import java.security.AccessController;
+import java.security.PrivilegedExceptionAction;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.commons.lang3.ClassUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ShuffleDataResult;
+import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.merger.Segment;
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.server.ShuffleServer;
+import org.apache.uniffle.server.ShuffleServerConf;
+
+import static
org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_CLASS_LOADER_JARS_PATH;
+
+public class ShuffleMergeManager {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(ShuffleMergeManager.class);
+ public static final String MERGE_APP_SUFFIX = "@RemoteMerge";
+
+ private ShuffleServerConf serverConf;
+ private final ShuffleServer shuffleServer;
+ // appId -> shuffleid -> Shuffle
+ private final Map<String, Map<Integer, Shuffle>> shuffles =
JavaUtils.newConcurrentMap();
+ private final MergeEventHandler eventHandler;
+ private final Map<String, ClassLoader> cachedClassLoader = new HashMap<>();
+
+ // If comparator is not set, will use hashCode to compare. It is used for
shuffle that does not
+ // require
+ // sort but require combine.
+ private Comparator defaultComparator =
+ new Comparator() {
+ @Override
+ public int compare(Object o1, Object o2) {
+ int h1 = (o1 == null) ? 0 : o1.hashCode();
+ int h2 = (o2 == null) ? 0 : o2.hashCode();
+ return h1 < h2 ? -1 : h1 == h2 ? 0 : 1;
+ }
+ };
+
+ public ShuffleMergeManager(ShuffleServerConf serverConf, ShuffleServer
shuffleServer)
+ throws Exception {
+ this.serverConf = serverConf;
+ this.shuffleServer = shuffleServer;
+ this.eventHandler = new DefaultMergeEventHandler(this.serverConf,
this::processEvent);
+ initCacheClassLoader();
+ }
+
+ public void initCacheClassLoader() throws Exception {
+ addCacheClassLoader("",
serverConf.getString(SERVER_MERGE_CLASS_LOADER_JARS_PATH));
+ Map<String, Object> props =
+
serverConf.getPropsWithPrefix(SERVER_MERGE_CLASS_LOADER_JARS_PATH.key() + ".");
+ for (Map.Entry<String, Object> prop : props.entrySet()) {
+ addCacheClassLoader(prop.getKey(), (String) prop.getValue());
+ }
+ }
+
+ public void addCacheClassLoader(String label, String jarsPath) throws
Exception {
+ if (StringUtils.isNotBlank(jarsPath)) {
+ File jarsPathFile = new File(jarsPath);
+ if (jarsPathFile.exists()) {
+ if (jarsPathFile.isFile()) {
+ URLClassLoader urlClassLoader =
+ AccessController.doPrivileged(
+ new PrivilegedExceptionAction<URLClassLoader>() {
+ @Override
+ public URLClassLoader run() throws Exception {
+ return new URLClassLoader(
+ new URL[] {new URL("file://" + jarsPath)},
+ Thread.currentThread().getContextClassLoader());
+ }
+ });
+ cachedClassLoader.put(label, urlClassLoader);
+ } else if (jarsPathFile.isDirectory()) {
+ File[] files = jarsPathFile.listFiles();
+ List<URL> urlList = new ArrayList<>();
+ if (files != null) {
+ for (File file : files) {
+ if (file.getName().endsWith(".jar")) {
+ urlList.add(new URL("file://" + file.getAbsolutePath()));
+ }
+ }
+ }
+ URLClassLoader urlClassLoader =
+ AccessController.doPrivileged(
+ new PrivilegedExceptionAction<URLClassLoader>() {
+ @Override
+ public URLClassLoader run() throws Exception {
+ return new URLClassLoader(
+ urlList.toArray(new URL[urlList.size()]),
+ Thread.currentThread().getContextClassLoader());
+ }
+ });
+ cachedClassLoader.put(label, urlClassLoader);
+ } else {
+ // If not set, will use current thread classloader
+ cachedClassLoader.put(label,
Thread.currentThread().getContextClassLoader());
+ }
+ }
+ } else {
+ // If not set, will use current thread classloader
+ cachedClassLoader.put(label,
Thread.currentThread().getContextClassLoader());
+ }
+ }
+
+ public ClassLoader getClassLoader(String label) {
+ if (StringUtils.isBlank(label)) {
+ return cachedClassLoader.get("");
+ }
+ return cachedClassLoader.getOrDefault(label, cachedClassLoader.get(""));
+ }
+
+ public StatusCode registerShuffle(
+ String appId,
+ int shuffleId,
+ String keyClassName,
+ String valueClassName,
+ String comparatorClassName,
+ int mergedBlockSize,
+ String classLoaderLabel) {
+ try {
+ ClassLoader classLoader = getClassLoader(classLoaderLabel);
+ Class kClass = ClassUtils.getClass(classLoader, keyClassName);
+ Class vClass = ClassUtils.getClass(classLoader, valueClassName);
+ Comparator comparator;
+ if (StringUtils.isNotBlank(comparatorClassName)) {
+ Constructor constructor =
+ ClassUtils.getClass(classLoader,
comparatorClassName).getDeclaredConstructor();
+ constructor.setAccessible(true);
+ comparator = (Comparator) constructor.newInstance();
+ } else {
+ comparator = defaultComparator;
+ }
+ this.shuffles.putIfAbsent(appId, JavaUtils.newConcurrentMap());
+ this.shuffles
+ .get(appId)
+ .putIfAbsent(
+ shuffleId,
+ new Shuffle(
+ serverConf,
+ eventHandler,
+ shuffleServer,
+ appId,
+ shuffleId,
+ kClass,
+ vClass,
+ comparator,
+ mergedBlockSize,
+ classLoader));
+ } catch (ClassNotFoundException
+ | InstantiationException
+ | IllegalAccessException
+ | NoSuchMethodException
+ | InvocationTargetException e) {
+ LOG.info("Cant register shuffle, caused by ", e);
+ removeBuffer(appId, shuffleId);
+ return StatusCode.INTERNAL_ERROR;
+ }
+ return StatusCode.SUCCESS;
+ }
+
+ public void removeBuffer(String appId) {
+ if (this.shuffles.containsKey(appId)) {
+ for (Integer shuffleId : this.shuffles.get(appId).keySet()) {
+ removeBuffer(appId, shuffleId);
+ }
+ }
+ }
+
+ public void removeBuffer(String appId, List<Integer> shuffleIds) {
+ if (this.shuffles.containsKey(appId)) {
+ for (Integer shuffleId : shuffleIds) {
+ removeBuffer(appId, shuffleId);
+ }
+ }
+ }
+
+ public void removeBuffer(String appId, int shuffleId) {
+ if (this.shuffles.containsKey(appId)) {
+ if (this.shuffles.get(appId).containsKey(shuffleId)) {
+ this.shuffles.get(appId).get(shuffleId).cleanup();
+ this.shuffles.get(appId).remove(shuffleId);
+ }
+ if (this.shuffles.get(appId).size() == 0) {
+ this.shuffles.remove(appId);
+ }
+ }
+ }
+
+ public void startSortMerge(
+ String appId, int shuffleId, int partitionId, Roaring64NavigableMap
expectedBlockIdMap)
+ throws IOException {
+ Map<Integer, Shuffle> shuffleMap = this.shuffles.get(appId);
+ if (shuffleMap != null) {
+ Shuffle shuffle = shuffleMap.get(shuffleId);
+ if (shuffle != null) {
+ shuffle.startSortMerge(partitionId, expectedBlockIdMap);
+ }
+ }
+ }
+
+ public void processEvent(MergeEvent event) {
+ try {
+ ClassLoader original = Thread.currentThread().getContextClassLoader();
+ Thread.currentThread()
+ .setContextClassLoader(
+ this.getShuffle(event.getAppId(),
event.getShuffleId()).getClassLoader());
+ List<Segment> segments =
+ this.getPartition(event.getAppId(), event.getShuffleId(),
event.getPartitionId())
+ .getSegments(
+ serverConf,
+ event.getExpectedBlockIdMap().iterator(),
+ event.getKeyClass(),
+ event.getValueClass());
+ this.getPartition(event.getAppId(), event.getShuffleId(),
event.getPartitionId())
+ .merge(segments);
+ Thread.currentThread().setContextClassLoader(original);
+ } catch (Exception e) {
+ LOG.info("Found exception when merge, caused by ", e);
+ throw new RssException(e);
+ }
+ }
+
+ public ShuffleDataResult getShuffleData(
+ String appId, int shuffleId, int partitionId, long blockId) throws
IOException {
+ return this.getPartition(appId, shuffleId,
partitionId).getShuffleData(blockId);
+ }
+
+ public void cacheBlock(String appId, int shuffleId, ShufflePartitionedData
spd)
+ throws IOException {
+ if (this.shuffles.containsKey(appId) &&
this.shuffles.get(appId).containsKey(shuffleId)) {
+ this.getShuffle(appId, shuffleId).cacheBlock(spd);
+ }
+ }
+
+ public MergeStatus tryGetBlock(String appId, int shuffleId, int partitionId,
long blockId) {
+ return this.getPartition(appId, shuffleId,
partitionId).tryGetBlock(blockId);
+ }
+
+ @VisibleForTesting
+ MergeEventHandler getEventHandler() {
+ return eventHandler;
+ }
+
+ Shuffle getShuffle(String appId, int shuffleId) {
+ return this.shuffles.get(appId).get(shuffleId);
+ }
+
+ @VisibleForTesting
+ Partition getPartition(String appId, int shuffleId, int partitionId) {
+ return this.shuffles.get(appId).get(shuffleId).getPartition(partitionId);
+ }
+
+ public void refreshAppId(String appId) {
+ shuffleServer.getShuffleTaskManager().refreshAppId(appId +
MERGE_APP_SUFFIX);
+ }
+}
diff --git
a/server/src/test/java/org/apache/uniffle/server/merge/BlockFlushFileReaderTest.java
b/server/src/test/java/org/apache/uniffle/server/merge/BlockFlushFileReaderTest.java
new file mode 100644
index 000000000..534c74b7d
--- /dev/null
+++
b/server/src/test/java/org/apache/uniffle/server/merge/BlockFlushFileReaderTest.java
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.merge;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.RawComparator;
+import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.api.io.TempDir;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.merger.Merger;
+import org.apache.uniffle.common.merger.Segment;
+import org.apache.uniffle.common.merger.StreamedSegment;
+import org.apache.uniffle.common.records.RecordsReader;
+import org.apache.uniffle.common.serializer.PartialInputStream;
+import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
+import org.apache.uniffle.common.serializer.SerializerUtils;
+import org.apache.uniffle.common.util.BlockIdLayout;
+import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler;
+import org.apache.uniffle.storage.handler.impl.LocalFileServerReadHandler;
+import org.apache.uniffle.storage.handler.impl.LocalFileWriteHandler;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class BlockFlushFileReaderTest {
+
+ private static AtomicInteger ATOMIC_INT = new AtomicInteger(0);
+
+ @ParameterizedTest
+ @ValueSource(
+ strings = {
+ "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2",
+ "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,4",
+ "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,32",
+ })
+ public void writeTestWithMerge(String classes, @TempDir File tmpDir) throws
Exception {
+ final String[] classArray = classes.split(",");
+ final Class keyClass = SerializerUtils.getClassByName(classArray[0]);
+ final Class valueClass = SerializerUtils.getClassByName(classArray[1]);
+ final Comparator comparator = SerializerUtils.getComparator(keyClass);
+ final int ringBufferSize = Integer.parseInt(classArray[2]);
+
+ final File dataOutput = new File(tmpDir, "dataOutput");
+ final File dataDir = new File(tmpDir, "data");
+ final String[] basePaths = new String[] {dataDir.getAbsolutePath()};
+ final LocalFileWriteHandler writeHandler1 =
+ new LocalFileWriteHandler("appId", 0, 1, 1, basePaths[0], "pre");
+
+ RssBaseConf conf = new RssBaseConf();
+ conf.setString("rss.storage.basePath", dataDir.getAbsolutePath());
+ final Set<Long> expectedBlockIds = new HashSet<>();
+ for (int i = 0; i < 10; i++) {
+ writeTestData(
+ generateBlocks(conf, keyClass, valueClass, i, 10, 10090),
+ writeHandler1,
+ expectedBlockIds);
+ }
+
+ LocalFileServerReadHandler readHandler =
+ new LocalFileServerReadHandler("appId", 0, 1, 1, 10,
dataDir.getAbsolutePath());
+ String dataFileName = readHandler.getDataFileName();
+ String indexFileName = readHandler.getIndexFileName();
+
+ BlockFlushFileReader blockFlushFileReader =
+ new BlockFlushFileReader(dataFileName, indexFileName, ringBufferSize);
+
+ List<Segment> segments = new ArrayList<>();
+ for (Long blockId : expectedBlockIds) {
+ PartialInputStream partialInputStream =
+ blockFlushFileReader.registerBlockInputStream(blockId);
+ segments.add(
+ new StreamedSegment(
+ conf,
+ partialInputStream,
+ blockId,
+ keyClass,
+ valueClass,
+ comparator instanceof RawComparator));
+ }
+ FileOutputStream outputStream = new FileOutputStream(dataOutput);
+ Merger.merge(
+ conf,
+ outputStream,
+ segments,
+ keyClass,
+ valueClass,
+ comparator,
+ comparator instanceof RawComparator);
+ outputStream.close();
+
+ int index = 0;
+ RecordsReader reader =
+ new RecordsReader(
+ conf,
+ PartialInputStreamImpl.newInputStream(dataOutput, 0,
dataOutput.length()),
+ keyClass,
+ valueClass,
+ false);
+ while (reader.next()) {
+ assertEquals(SerializerUtils.genData(keyClass, index),
reader.getCurrentKey());
+ assertEquals(SerializerUtils.genData(valueClass, index),
reader.getCurrentValue());
+ index++;
+ }
+ assertEquals(100900, index);
+ }
+
+ public static void writeTestData(
+ List<ShufflePartitionedBlock> blocks, ShuffleWriteHandler handler,
Set<Long> expectedBlockIds)
+ throws Exception {
+ blocks.forEach(block -> block.getData().retain());
+ handler.write(blocks);
+ blocks.forEach(block -> expectedBlockIds.add(block.getBlockId()));
+ blocks.forEach(block -> block.getData().release());
+ }
+
+ public static List<ShufflePartitionedBlock> generateBlocks(
+ RssConf rssConf, Class keyClass, Class valueClass, int start, int
interval, int length)
+ throws IOException {
+ BlockIdLayout layout = BlockIdLayout.DEFAULT;
+ List<ShufflePartitionedBlock> blocks = Lists.newArrayList();
+ byte[] bytes =
+ SerializerUtils.genSortedRecordBytes(
+ rssConf, keyClass, valueClass, start, interval, length, 1);
+ long blockId = layout.getBlockId(ATOMIC_INT.incrementAndGet(), 0, 100);
+ blocks.add(new ShufflePartitionedBlock(bytes.length, bytes.length, 0,
blockId, 100, bytes));
+ return blocks;
+ }
+
+ @Timeout(20)
+ @ParameterizedTest
+ @ValueSource(
+ strings = {
+ "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2",
+ })
+ public void writeTestWithMergeWhenInterrupted(String classes, @TempDir File
tmpDir)
+ throws Exception {
+ String[] classArray = classes.split(",");
+ Class keyClass = SerializerUtils.getClassByName(classArray[0]);
+ Class valueClass = SerializerUtils.getClassByName(classArray[1]);
+ Comparator comparator = SerializerUtils.getComparator(keyClass);
+ int ringBufferSize = Integer.parseInt(classArray[2]);
+
+ File dataDir = new File(tmpDir, "data");
+ String[] basePaths = new String[] {dataDir.getAbsolutePath()};
+ final LocalFileWriteHandler writeHandler1 =
+ new LocalFileWriteHandler("appId", 0, 1, 1, basePaths[0], "pre");
+
+ RssBaseConf conf = new RssBaseConf();
+ conf.setString("rss.storage.basePath", dataDir.getAbsolutePath());
+ final Set<Long> expectedBlockIds = new HashSet<>();
+ for (int i = 0; i < 10; i++) {
+ writeTestData(
+ generateBlocks(conf, keyClass, valueClass, i, 10, 10090),
+ writeHandler1,
+ expectedBlockIds);
+ }
+
+ File dataOutput = new File(tmpDir, "dataOutput");
+ LocalFileServerReadHandler readHandler =
+ new LocalFileServerReadHandler("appId", 0, 1, 1, 10,
dataDir.getAbsolutePath());
+ String dataFileName = readHandler.getDataFileName();
+ String indexFileName = readHandler.getIndexFileName();
+
+ BlockFlushFileReader blockFlushFileReader =
+ new BlockFlushFileReader(dataFileName, indexFileName, ringBufferSize);
+
+ List<Segment> segments = new ArrayList<>();
+ for (Long blockId : expectedBlockIds) {
+ PartialInputStream partialInputStream =
+ blockFlushFileReader.registerBlockInputStream(blockId);
+ segments.add(
+ new MockedStreamedSegment(
+ conf,
+ partialInputStream,
+ blockId,
+ keyClass,
+ valueClass,
+ comparator instanceof RawComparator,
+ blockFlushFileReader));
+ }
+
+ FileOutputStream outputStream = new FileOutputStream(dataOutput);
+ assertThrows(
+ Exception.class,
+ () -> {
+ Merger.merge(
+ conf,
+ outputStream,
+ segments,
+ keyClass,
+ valueClass,
+ comparator,
+ comparator instanceof RawComparator);
+ });
+ outputStream.close();
+ }
+
+ class MockedStreamedSegment extends StreamedSegment {
+
+ BlockFlushFileReader reader;
+ int count;
+
+ MockedStreamedSegment(
+ RssConf rssConf,
+ PartialInputStream inputStream,
+ long blockId,
+ Class keyClass,
+ Class valueClass,
+ boolean raw,
+ BlockFlushFileReader reader) {
+ super(rssConf, inputStream, blockId, keyClass, valueClass, raw);
+ this.reader = reader;
+ }
+
+ public boolean next() throws IOException {
+ boolean ret = super.next();
+ if (this.count++ > 200) {
+ try {
+ this.reader.close();
+ } catch (InterruptedException e) {
+ throw new IOException(e);
+ }
+ }
+ return ret;
+ }
+ }
+}
diff --git
a/server/src/test/java/org/apache/uniffle/server/merge/MergedResultTest.java
b/server/src/test/java/org/apache/uniffle/server/merge/MergedResultTest.java
new file mode 100644
index 000000000..e8b09a5fd
--- /dev/null
+++ b/server/src/test/java/org/apache/uniffle/server/merge/MergedResultTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.merge;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.hadoop.io.RawComparator;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.merger.Merger;
+import org.apache.uniffle.common.merger.Recordable;
+import org.apache.uniffle.common.merger.Segment;
+import org.apache.uniffle.common.records.RecordsReader;
+import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
+import org.apache.uniffle.common.serializer.SerializerUtils;
+
+import static
org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class MergedResultTest {
+
+ private static final int BYTES_LEN = 10240;
+ private static final int RECORDS = 1009;
+ private static final int SEGMENTS = 4;
+
+ @Test
+ public void testMergedResult() throws IOException {
+ // 1 Construct cache
+ List<Pair<Integer, byte[]>> blocks = new ArrayList<>();
+ MergedResult.CacheMergedBlockFuntion cache =
+ (byte[] buffer, long blockId, int length) -> {
+ assertEquals(blockId - 1, blocks.size());
+ blocks.add(Pair.of(length, buffer));
+ };
+
+ // 2 Write to merged result
+ RssConf rssConf = new RssConf();
+ rssConf.set(SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE,
String.valueOf(BYTES_LEN / 10));
+ MergedResult result = new MergedResult(rssConf, cache, -1);
+ OutputStream output = result.getOutputStream();
+ for (int i = 0; i < BYTES_LEN; i++) {
+ output.write((byte) (i & 0x7F));
+ if (output instanceof Recordable) {
+ ((Recordable) output).record(i + 1, null, false);
+ }
+ }
+ output.close();
+
+ // 3 check blocks number
+ // Max merged block is 1024, every record have 2 bytes, so will result to
10 block
+ assertEquals(10, blocks.size());
+
+ // 4 check the blocks
+ int index = 0;
+ for (int i = 0; i < blocks.size(); i++) {
+ int length = blocks.get(i).getLeft();
+ byte[] buffer = blocks.get(i).getRight();
+ assertTrue(buffer.length >= length);
+ for (int j = 0; j < length; j++) {
+ assertEquals(index & 0x7F, buffer[j]);
+ index++;
+ }
+ }
+ assertEquals(BYTES_LEN, index);
+ }
+
+ @ParameterizedTest
+ @ValueSource(
+ strings = {
+ "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable",
+ })
+ public void testMergeSegmentToMergeResult(String classes, @TempDir File
tmpDir) throws Exception {
+ // 1 Parse arguments
+ String[] classArray = classes.split(",");
+ Class keyClass = SerializerUtils.getClassByName(classArray[0]);
+ Class valueClass = SerializerUtils.getClassByName(classArray[1]);
+
+ // 2 Construct cache
+ List<Pair<Integer, byte[]>> blocks = new ArrayList<>();
+ MergedResult.CacheMergedBlockFuntion cache =
+ (byte[] buffer, long blockId, int length) -> {
+ assertEquals(blockId - 1, blocks.size());
+ blocks.add(Pair.of(length, buffer));
+ };
+
+ // 3 Construct segments, then merge
+ RssConf rssConf = new RssConf();
+ List<Segment> segments = new ArrayList<>();
+ Comparator comparator = SerializerUtils.getComparator(keyClass);
+ for (int i = 0; i < SEGMENTS; i++) {
+ if (i % 2 == 0) {
+ segments.add(
+ SerializerUtils.genMemorySegment(
+ rssConf,
+ keyClass,
+ valueClass,
+ i,
+ i,
+ SEGMENTS,
+ RECORDS,
+ comparator instanceof RawComparator));
+ } else {
+ segments.add(
+ SerializerUtils.genFileSegment(
+ rssConf,
+ keyClass,
+ valueClass,
+ i,
+ i,
+ SEGMENTS,
+ RECORDS,
+ tmpDir,
+ comparator instanceof RawComparator));
+ }
+ }
+ MergedResult result = new MergedResult(rssConf, cache, -1);
+ OutputStream mergedOutputStream = result.getOutputStream();
+ Merger.merge(
+ rssConf,
+ mergedOutputStream,
+ segments,
+ keyClass,
+ valueClass,
+ comparator,
+ comparator instanceof RawComparator);
+ mergedOutputStream.flush();
+ mergedOutputStream.close();
+
+ // 4 check merged blocks
+ int index = 0;
+ for (int i = 0; i < blocks.size(); i++) {
+ int length = blocks.get(i).getLeft();
+ byte[] buffer = blocks.get(i).getRight();
+ RecordsReader reader =
+ new RecordsReader(
+ rssConf,
+ PartialInputStreamImpl.newInputStream(buffer, 0, length),
+ keyClass,
+ valueClass,
+ false);
+ while (reader.next()) {
+ assertEquals(SerializerUtils.genData(keyClass, index),
reader.getCurrentKey());
+ assertEquals(SerializerUtils.genData(valueClass, index),
reader.getCurrentValue());
+ index++;
+ }
+ reader.close();
+ }
+ assertEquals(RECORDS * SEGMENTS, index);
+ }
+}
diff --git
a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java
b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java
new file mode 100644
index 000000000..4b545d3a7
--- /dev/null
+++
b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.merge;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.ImmutableMap;
+import org.awaitility.Awaitility;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.api.io.TempDir;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleDataResult;
+import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.merger.MergeState;
+import org.apache.uniffle.common.records.RecordsReader;
+import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
+import org.apache.uniffle.common.serializer.SerializerUtils;
+import org.apache.uniffle.common.serializer.writable.WritableSerializer;
+import org.apache.uniffle.common.util.BlockIdLayout;
+import org.apache.uniffle.server.ShuffleServer;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.server.ShuffleServerMetrics;
+import org.apache.uniffle.server.ShuffleTaskManager;
+import org.apache.uniffle.storage.util.StorageType;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
+
+public class ShuffleMergeManagerTest {
+
+ private static final String APP_ID = "app1";
+ private static final int SHUFFLE_ID = 1;
+ private static final int PARTITION_ID = 2;
+ private static final int RECORDS_NUMBER = 1009;
+ private static final String USER = "testUser";
+
+ private ShuffleServer shuffleServer;
+ ShuffleServerConf serverConf;
+
+ @TempDir File tempDir1;
+ @TempDir File tempDir2;
+
+ @BeforeEach
+ public void beforeEach() {
+ String confFile = ClassLoader.getSystemResource("server.conf").getFile();
+ serverConf = new ShuffleServerConf(confFile);
+ serverConf.setString(
+ ShuffleServerConf.RSS_STORAGE_TYPE.key(),
StorageType.MEMORY_LOCALFILE.name());
+ serverConf.setString(
+ ShuffleServerConf.RSS_STORAGE_BASE_PATH.key(),
+ tempDir1.getAbsolutePath() + "," + tempDir2.getAbsolutePath());
+ serverConf.setLong(ShuffleServerConf.SERVER_APP_EXPIRED_WITHOUT_HEARTBEAT,
60L * 1000L * 60L);
+ serverConf.set(ShuffleServerConf.SERVER_MERGE_ENABLE, true);
+ ShuffleServerMetrics.clear();
+ ShuffleServerMetrics.register();
+ assertTrue(this.tempDir1.isDirectory());
+ assertTrue(this.tempDir2.isDirectory());
+ }
+
+ @AfterEach
+ public void afterEach() throws Exception {
+ serverConf = null;
+ if (shuffleServer != null) {
+ shuffleServer.stopServer();
+ shuffleServer = null;
+ }
+ }
+
+ @Timeout(10)
+ @ParameterizedTest
+ @ValueSource(
+ strings = {
+ "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable",
+ })
+ public void testMergerManager(String classes, @TempDir File tmpDir) throws
Exception {
+ // 1 Construct serializer and comparator
+ final String[] classArray = classes.split(",");
+ final String keyClassName = classArray[0];
+ final String valueClassName = classArray[1];
+ final Class keyClass = SerializerUtils.getClassByName(keyClassName);
+ final Class valueClass = SerializerUtils.getClassByName(valueClassName);
+ final Comparator comparator = SerializerUtils.getComparator(keyClass);
+ final String comparatorClassName = comparator.getClass().getName();
+ final WritableSerializer serializer = new WritableSerializer(new
RssConf());
+
+ // 2 Construct shuffle task manager and merge manager
+ shuffleServer = new ShuffleServer(serverConf);
+ final ShuffleTaskManager shuffleTaskManager =
shuffleServer.getShuffleTaskManager();
+ final ShuffleMergeManager mergeManager =
shuffleServer.getShuffleMergeManager();
+
+ // 3 register shuffle
+ List<PartitionRange> partitionRanges = new ArrayList<>();
+ partitionRanges.add(new PartitionRange(PARTITION_ID, PARTITION_ID));
+ shuffleTaskManager.registerShuffle(
+ APP_ID, SHUFFLE_ID, partitionRanges, new RemoteStorageInfo(""), USER);
+ shuffleTaskManager.registerShuffle(
+ APP_ID + ShuffleMergeManager.MERGE_APP_SUFFIX,
+ SHUFFLE_ID,
+ partitionRanges,
+ new RemoteStorageInfo(""),
+ USER);
+ mergeManager.registerShuffle(
+ APP_ID, SHUFFLE_ID, keyClassName, valueClassName, comparatorClassName,
-1, "");
+
+ // 4 report blocks
+ // 4.1 send shuffle data
+ // Upstream have 2 task, each task generate 2 blocks
+ BlockIdLayout blockIdLayout = BlockIdLayout.from(serverConf);
+ long[] blocks = new long[4];
+ blocks[0] = blockIdLayout.getBlockId(0, PARTITION_ID, 0);
+ blocks[1] = blockIdLayout.getBlockId(1, PARTITION_ID, 0);
+ blocks[2] = blockIdLayout.getBlockId(0, PARTITION_ID, 1);
+ blocks[3] = blockIdLayout.getBlockId(1, PARTITION_ID, 1);
+ ShufflePartitionedBlock[] shufflePartitionedBlocks = new
ShufflePartitionedBlock[4];
+ for (int i = 0; i < 4; i++) {
+ byte[] buffer =
+ SerializerUtils.genSortedRecordBytes(
+ serverConf, keyClass, valueClass, i, 4, RECORDS_NUMBER, 1);
+ shufflePartitionedBlocks[i] =
+ new ShufflePartitionedBlock(
+ buffer.length,
+ buffer.length,
+ 0,
+ blocks[i],
+ blockIdLayout.getTaskAttemptId(blocks[i]),
+ buffer);
+ }
+ ShufflePartitionedData spd = new ShufflePartitionedData(PARTITION_ID,
shufflePartitionedBlocks);
+ shuffleTaskManager.cacheShuffleData(APP_ID, SHUFFLE_ID, false, spd);
+ mergeManager.cacheBlock(APP_ID, SHUFFLE_ID, spd);
+ // 4.2 report shuffle result
+ shuffleTaskManager.addFinishedBlockIds(
+ APP_ID, SHUFFLE_ID, ImmutableMap.of(PARTITION_ID, blocks), 1);
+ // 4.3 report unique blockIds
+ Roaring64NavigableMap blockIdMap = Roaring64NavigableMap.bitmapOf();
+ blockIdMap.add(blocks);
+ mergeManager.startSortMerge(APP_ID, SHUFFLE_ID, PARTITION_ID, blockIdMap);
+
+ // 4 wait for drain event
+ Awaitility.await()
+ .atMost(10, TimeUnit.SECONDS)
+ .until(() -> mergeManager.getEventHandler().getEventNumInMerge() == 0);
+ Awaitility.await()
+ .atMost(10, TimeUnit.SECONDS)
+ .until(
+ () ->
+ mergeManager.getPartition(APP_ID, SHUFFLE_ID,
PARTITION_ID).getState()
+ == MergeState.DONE);
+
+ // 5 read and check result
+ int blockId = 1;
+ int index = 0;
+ boolean finish = false;
+ while (!finish) {
+ MergeStatus mergeStatus = mergeManager.tryGetBlock(APP_ID, SHUFFLE_ID,
PARTITION_ID, blockId);
+ MergeState mergeState = mergeStatus.getState();
+ long blockSize = mergeStatus.getSize();
+ switch (mergeState) {
+ case INITED:
+ case MERGING:
+ case INTERNAL_ERROR:
+ fail("Find wrong merge state!");
+ break;
+ case DONE:
+ if (blockSize != -1) {
+ ShuffleDataResult shuffleDataResult =
+ mergeManager.getShuffleData(APP_ID, SHUFFLE_ID, PARTITION_ID,
blockId);
+ PartialInputStreamImpl inputStream =
+ PartialInputStreamImpl.newInputStream(
+ shuffleDataResult.getData(), 0,
shuffleDataResult.getDataLength());
+ RecordsReader reader =
+ new RecordsReader(serverConf, inputStream, keyClass,
valueClass, false);
+ while (reader.next()) {
+ assertEquals(SerializerUtils.genData(keyClass, index),
reader.getCurrentKey());
+ assertEquals(SerializerUtils.genData(valueClass, index),
reader.getCurrentValue());
+ index++;
+ }
+ shuffleDataResult.release();
+ blockId++;
+ break;
+ } else {
+ finish = true;
+ break;
+ }
+ default:
+ fail("Find invalid merge state!");
+ }
+ }
+ assertEquals(RECORDS_NUMBER * 4, index);
+
+ // 8 cleanup
+ mergeManager.removeBuffer(APP_ID, SHUFFLE_ID);
+ }
+}
diff --git a/server/src/test/resources/log4j2.xml
b/server/src/test/resources/log4j2.xml
new file mode 100644
index 000000000..d26fb6958
--- /dev/null
+++ b/server/src/test/resources/log4j2.xml
@@ -0,0 +1,29 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+-->
+<Configuration status="WARN" monitorInterval="30">
+ <Appenders>
+ <Console name="Console" target="SYSTEM_OUT">
+ <PatternLayout pattern="[%d{yyyy-MM-dd HH:mm:ss.SSS}] [%t] [%p] %c{1}.%M
- %m%n%ex"/>
+ </Console>
+ </Appenders>
+ <Loggers>
+ <Root level="info">
+ <AppenderRef ref="Console"/>
+ </Root>
+ </Loggers>
+</Configuration>
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java
index f688a18bc..4eb7e2d52 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java
@@ -149,6 +149,15 @@ public class LocalFileServerReadHandler implements
ServerReadHandler {
}
// get dataFileSize for read segment generation in
DataSkippableReadHandler#readShuffleData
long dataFileSize = new File(dataFileName).length();
- return new ShuffleIndexResult(new FileSegmentManagedBuffer(indexFile, 0,
len), dataFileSize);
+ return new ShuffleIndexResult(
+ new FileSegmentManagedBuffer(indexFile, 0, len), dataFileSize,
dataFileName);
+ }
+
+ public String getDataFileName() {
+ return dataFileName;
+ }
+
+ public String getIndexFileName() {
+ return indexFileName;
}
}