This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new d47f7039 [Improvement] Support skip memory data when use multiple
replicas (#400)
d47f7039 is described below
commit d47f70395a63418f78c1d9c2d28d3a572b27dd69
Author: xianjingfeng <[email protected]>
AuthorDate: Wed Dec 14 19:04:40 2022 +0800
[Improvement] Support skip memory data when use multiple replicas (#400)
### What changes were proposed in this pull request?
Support filter data when use multiple replica
### Why are the changes needed?
Now, filter the data that the client had read from the first replica is not
supported when we use multiple replica. This will cause duplicate data to be
read.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
---
.../org/apache/hadoop/mapreduce/MRIdHelper.java | 2 +-
.../hadoop/mapreduce/task/reduce/RssShuffle.java | 3 +-
.../spark/shuffle/reader/RssShuffleReader.java | 5 +-
.../spark/shuffle/reader/RssShuffleReader.java | 8 +--
.../uniffle/client/impl/ShuffleReadClientImpl.java | 3 +-
.../request/CreateShuffleReadClientRequest.java | 14 ++--
.../uniffle/client/util/DefaultIdHelper.java | 1 +
.../org/apache/uniffle/client/ClientUtilsTest.java | 26 +++++++
.../org/apache/uniffle/common}/util/IdHelper.java | 2 +-
.../org/apache/uniffle/common/util/RssUtils.java | 9 +++
.../ShuffleServerWithLocalOfExceptionTest.java | 4 +-
.../test/ShuffleServerWithMemLocalHdfsTest.java | 5 +-
.../uniffle/test/ShuffleServerWithMemoryTest.java | 80 ++++++++++++++++++++--
.../storage/factory/ShuffleHandlerFactory.java | 20 ++++--
.../handler/impl/ComposedClientReadHandler.java | 45 +++++++++---
.../handler/impl/MemoryClientReadHandler.java | 30 ++------
.../request/CreateShuffleReadHandlerRequest.java | 11 +++
17 files changed, 206 insertions(+), 62 deletions(-)
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/MRIdHelper.java
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/MRIdHelper.java
index 001d2338..e20a3d87 100644
--- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/MRIdHelper.java
+++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/MRIdHelper.java
@@ -17,7 +17,7 @@
package org.apache.hadoop.mapreduce;
-import org.apache.uniffle.client.util.IdHelper;
+import org.apache.uniffle.common.util.IdHelper;
public class MRIdHelper implements IdHelper {
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
index e5af9795..372a265e 100644
---
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
+++
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
@@ -191,10 +191,11 @@ public class RssShuffle<K, V> implements
ShuffleConsumerPlugin<K, V>, ExceptionR
LOG.info("In reduce: " + reduceId
+ ", Rss MR client starts to fetch blocks from RSS server");
JobConf readerJobConf = getRemoteConf();
+ boolean expectedTaskIdsBitmapFilterEnable = serverInfoList.size() > 1;
CreateShuffleReadClientRequest request = new
CreateShuffleReadClientRequest(
appId, 0, reduceId.getTaskID().getId(), storageType, basePath,
indexReadLimit, readBufferSize,
partitionNumPerRange, partitionNum, blockIdBitmap, taskIdBitmap,
serverInfoList,
- readerJobConf, new MRIdHelper());
+ readerJobConf, new MRIdHelper(), expectedTaskIdsBitmapFilterEnable);
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
RssFetcher fetcher = new RssFetcher(mrJobConf, reduceId, taskStatus,
merger, copyPhase, reporter, metrics,
shuffleReadClient, blockIdBitmap.getLongCardinality(),
RssMRConfig.toRssConf(rssJobConf));
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index cc6fe254..ff1eac6e 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -51,6 +51,7 @@ import org.apache.uniffle.common.config.RssConf;
public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleReader.class);
+ private final boolean expectedTaskIdsBitmapFilterEnable;
private String appId;
private int shuffleId;
@@ -107,6 +108,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
this.shuffleServerInfoList =
(List<ShuffleServerInfo>)
(rssShuffleHandle.getPartitionToServers().get(startPartition));
this.rssConf = rssConf;
+ expectedTaskIdsBitmapFilterEnable = shuffleServerInfoList.size() > 1;
}
@Override
@@ -115,7 +117,8 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
CreateShuffleReadClientRequest request = new
CreateShuffleReadClientRequest(
appId, shuffleId, startPartition, storageType, basePath,
indexReadLimit, readBufferSize,
- partitionNumPerRange, partitionNum, blockIdBitmap, taskIdBitmap,
shuffleServerInfoList, hadoopConf);
+ partitionNumPerRange, partitionNum, blockIdBitmap, taskIdBitmap,
+ shuffleServerInfoList, hadoopConf, expectedTaskIdsBitmapFilterEnable);
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
RssShuffleDataIterator rssShuffleDataIterator = new
RssShuffleDataIterator<K, C>(
shuffleDependency.serializer(), shuffleReadClient,
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 89e17664..b810f611 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -78,7 +78,6 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
private ShuffleReadMetrics readMetrics;
private RssConf rssConf;
private ShuffleDataDistributionType dataDistributionType;
- private boolean expectedTaskIdsBitmapFilterEnable;
public RssShuffleReader(
int startPartition,
@@ -120,9 +119,6 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
this.partitionToShuffleServers = rssShuffleHandle.getPartitionToServers();
this.rssConf = rssConf;
this.dataDistributionType = dataDistributionType;
- // This mechanism of expectedTaskIdsBitmap filter is to filter out the
most of data.
- // especially for AQE skew optimization
- this.expectedTaskIdsBitmapFilterEnable = !(mapStartIndex == 0 &&
mapEndIndex == Integer.MAX_VALUE);
}
@Override
@@ -207,6 +203,10 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
continue;
}
List<ShuffleServerInfo> shuffleServerInfoList =
partitionToShuffleServers.get(partition);
+ // This mechanism of expectedTaskIdsBitmap filter is to filter out the
most of data.
+ // especially for AQE skew optimization
+ boolean expectedTaskIdsBitmapFilterEnable = !(mapStartIndex == 0 &&
mapEndIndex == Integer.MAX_VALUE)
+ || shuffleServerInfoList.size() > 1;
CreateShuffleReadClientRequest request = new
CreateShuffleReadClientRequest(
appId, shuffleId, partition, storageType, basePath,
indexReadLimit, readBufferSize,
1, partitionNum, partitionToExpectBlocks.get(partition),
taskIdBitmap, shuffleServerInfoList,
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index 84758c7b..a7ff8ef0 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -32,13 +32,13 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.response.CompressedShuffleBlock;
-import org.apache.uniffle.client.util.IdHelper;
import org.apache.uniffle.common.BufferSegment;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ChecksumUtils;
+import org.apache.uniffle.common.util.IdHelper;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.storage.factory.ShuffleHandlerFactory;
import org.apache.uniffle.storage.handler.api.ClientReadHandler;
@@ -101,6 +101,7 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
request.setExpectBlockIds(blockIdBitmap);
request.setProcessBlockIds(processedBlockIds);
request.setDistributionType(dataDistributionType);
+ request.setIdHelper(idHelper);
request.setExpectTaskIds(taskIdBitmap);
if (expectedTaskIdsBitmapFilterEnable) {
request.useExpectedTaskIdsBitmapFilter();
diff --git
a/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
b/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
index db050304..a4b4a325 100644
---
a/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
+++
b/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
@@ -23,9 +23,9 @@ import org.apache.hadoop.conf.Configuration;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.util.DefaultIdHelper;
-import org.apache.uniffle.client.util.IdHelper;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.util.IdHelper;
public class CreateShuffleReadClientRequest {
@@ -64,9 +64,8 @@ public class CreateShuffleReadClientRequest {
boolean expectedTaskIdsBitmapFilterEnable) {
this(appId, shuffleId, partitionId, storageType, basePath, indexReadLimit,
readBufferSize,
partitionNumPerRange, partitionNum, blockIdBitmap, taskIdBitmap,
shuffleServerInfoList,
- hadoopConf, new DefaultIdHelper());
+ hadoopConf, new DefaultIdHelper(), expectedTaskIdsBitmapFilterEnable);
this.shuffleDataDistributionType = dataDistributionType;
- this.expectedTaskIdsBitmapFilterEnable = expectedTaskIdsBitmapFilterEnable;
}
public CreateShuffleReadClientRequest(
@@ -82,10 +81,11 @@ public class CreateShuffleReadClientRequest {
Roaring64NavigableMap blockIdBitmap,
Roaring64NavigableMap taskIdBitmap,
List<ShuffleServerInfo> shuffleServerInfoList,
- Configuration hadoopConf) {
+ Configuration hadoopConf,
+ boolean expectedTaskIdsBitmapFilterEnable) {
this(appId, shuffleId, partitionId, storageType, basePath, indexReadLimit,
readBufferSize,
partitionNumPerRange, partitionNum, blockIdBitmap, taskIdBitmap,
shuffleServerInfoList,
- hadoopConf, new DefaultIdHelper());
+ hadoopConf, new DefaultIdHelper(), expectedTaskIdsBitmapFilterEnable);
}
public CreateShuffleReadClientRequest(
@@ -102,7 +102,8 @@ public class CreateShuffleReadClientRequest {
Roaring64NavigableMap taskIdBitmap,
List<ShuffleServerInfo> shuffleServerInfoList,
Configuration hadoopConf,
- IdHelper idHelper) {
+ IdHelper idHelper,
+ boolean expectedTaskIdsBitmapFilterEnable) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionId = partitionId;
@@ -117,6 +118,7 @@ public class CreateShuffleReadClientRequest {
this.shuffleServerInfoList = shuffleServerInfoList;
this.hadoopConf = hadoopConf;
this.idHelper = idHelper;
+ this.expectedTaskIdsBitmapFilterEnable = expectedTaskIdsBitmapFilterEnable;
}
public String getAppId() {
diff --git
a/client/src/main/java/org/apache/uniffle/client/util/DefaultIdHelper.java
b/client/src/main/java/org/apache/uniffle/client/util/DefaultIdHelper.java
index 084571a8..97376cc8 100644
--- a/client/src/main/java/org/apache/uniffle/client/util/DefaultIdHelper.java
+++ b/client/src/main/java/org/apache/uniffle/client/util/DefaultIdHelper.java
@@ -18,6 +18,7 @@
package org.apache.uniffle.client.util;
import org.apache.uniffle.common.util.Constants;
+import org.apache.uniffle.common.util.IdHelper;
public class DefaultIdHelper implements IdHelper {
@Override
diff --git
a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
index 5f26ae17..77f9cba5 100644
--- a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
+++ b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
@@ -26,10 +26,14 @@ import java.util.concurrent.TimeUnit;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.LongIterator;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.util.ClientUtils;
+import org.apache.uniffle.client.util.DefaultIdHelper;
+import org.apache.uniffle.common.util.RssUtils;
import static org.apache.uniffle.client.util.ClientUtils.waitUntilDoneOrFail;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -64,6 +68,28 @@ public class ClientUtilsTest {
assertTrue(e3.getMessage().contains("Can't support sequence[262144], the
max value should be 262143"));
}
+ @Test
+ public void testGenerateTaskIdBitMap() {
+ int partitionId = 1;
+ Roaring64NavigableMap blockIdMap = Roaring64NavigableMap.bitmapOf();
+ int taskSize = 10;
+ long[] except = new long[taskSize];
+ for (int i = 0; i < taskSize; i++) {
+ except[i] = i;
+ for (int j = 0; j < 100; j++) {
+ Long blockId = ClientUtils.getBlockId(partitionId, i, j);
+ blockIdMap.addLong(blockId);
+ }
+
+ }
+ Roaring64NavigableMap taskIdBitMap =
RssUtils.generateTaskIdBitMap(blockIdMap, new DefaultIdHelper());
+ assertEquals(taskSize, taskIdBitMap.getLongCardinality());
+ LongIterator longIterator = taskIdBitMap.getLongIterator();
+ for (int i = 0; i < taskSize; i++) {
+ assertEquals(except[i], longIterator.next());
+ }
+ }
+
private List<CompletableFuture<Boolean>> getFutures(boolean fail) {
List<CompletableFuture<Boolean>> futures = new ArrayList<>();
for (int i = 0; i < 3; i++) {
diff --git a/client/src/main/java/org/apache/uniffle/client/util/IdHelper.java
b/common/src/main/java/org/apache/uniffle/common/util/IdHelper.java
similarity index 95%
rename from client/src/main/java/org/apache/uniffle/client/util/IdHelper.java
rename to common/src/main/java/org/apache/uniffle/common/util/IdHelper.java
index 725a62a2..df03376d 100644
--- a/client/src/main/java/org/apache/uniffle/client/util/IdHelper.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/IdHelper.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.uniffle.client.util;
+package org.apache.uniffle.common.util;
public interface IdHelper {
diff --git a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java
index fb05894e..a079753b 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java
@@ -283,4 +283,13 @@ public class RssUtils {
+ " blocks, actual " + cloneBitmap.getLongCardinality() + " blocks");
}
}
+
+ public static Roaring64NavigableMap
generateTaskIdBitMap(Roaring64NavigableMap blockIdBitmap, IdHelper idHelper) {
+ Iterator<Long> iterator = blockIdBitmap.iterator();
+ Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf();
+ while (iterator.hasNext()) {
+ taskIdBitmap.addLong(idHelper.getTaskAttemptId(iterator.next()));
+ }
+ return taskIdBitmap;
+ }
}
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithLocalOfExceptionTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithLocalOfExceptionTest.java
index fb316a08..4e8aa0d6 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithLocalOfExceptionTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithLocalOfExceptionTest.java
@@ -24,6 +24,7 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.impl.grpc.ShuffleServerGrpcClient;
import org.apache.uniffle.common.exception.RssException;
@@ -75,7 +76,8 @@ public class ShuffleServerWithLocalOfExceptionTest extends
ShuffleReadWriteBase
int partitionId = 0;
MemoryClientReadHandler memoryClientReadHandler = new
MemoryClientReadHandler(
- testAppId, shuffleId, partitionId, 150, shuffleServerClient);
+ testAppId, shuffleId, partitionId, 150, shuffleServerClient,
+ Roaring64NavigableMap.bitmapOf());
shuffleServers.get(0).stopServer();
try {
memoryClientReadHandler.readShuffleData();
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
index e3e45aa6..4dc9c706 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
@@ -123,10 +123,11 @@ public class ShuffleServerWithMemLocalHdfsTest extends
ShuffleReadWriteBase {
testAppId, 3, 1000, shuffleToBlocks);
shuffleServerClient.sendShuffleData(rssdr);
+ Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+ Roaring64NavigableMap exceptTaskIds = Roaring64NavigableMap.bitmapOf(0);
// read the 1-th segment from memory
MemoryClientReadHandler memoryClientReadHandler = new
MemoryClientReadHandler(
- testAppId, shuffleId, partitionId, 150, shuffleServerClient);
- Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+ testAppId, shuffleId, partitionId, 150, shuffleServerClient,
exceptTaskIds);
LocalFileClientReadHandler localFileClientReadHandler = new
LocalFileClientReadHandler(
testAppId, shuffleId, partitionId, 0, 1, 3,
75, expectBlockIds, processBlockIds, shuffleServerClient);
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
index accc280f..81a7e31f 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
@@ -113,9 +113,11 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
// data is cached
assertEquals(3, shuffleServers.get(0).getShuffleBufferManager()
.getShuffleBuffer(testAppId, shuffleId, 0).getBlocks().size());
+
+ Roaring64NavigableMap exceptTaskIds = Roaring64NavigableMap.bitmapOf(0);
// create memory handler to read data,
MemoryClientReadHandler memoryClientReadHandler = new
MemoryClientReadHandler(
- testAppId, shuffleId, partitionId, 20, shuffleServerClient);
+ testAppId, shuffleId, partitionId, 20, shuffleServerClient,
exceptTaskIds);
// start to read data, one block data for every call
ShuffleDataResult sdr = memoryClientReadHandler.readShuffleData();
Map<Long, byte[]> expectedData = Maps.newHashMap();
@@ -137,9 +139,11 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
assertEquals(0, sdr.getBufferSegments().size());
// case: read with ComposedClientReadHandler
- Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
memoryClientReadHandler = new MemoryClientReadHandler(
- testAppId, shuffleId, partitionId, 50, shuffleServerClient);
+ testAppId, shuffleId, partitionId, 50, shuffleServerClient,
+ exceptTaskIds);
+
+ Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
LocalFileClientReadHandler localFileQuorumClientReadHandler = new
LocalFileClientReadHandler(
testAppId, shuffleId, partitionId, 0, 1, 3,
50, expectBlockIds, processBlockIds, shuffleServerClient);
@@ -210,6 +214,73 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
assertNull(sdr);
}
+
+ @Test
+ public void memoryWriteReadWithMultiReplicaTest() throws Exception {
+ String testAppId = "memoryWriteReadWithMultiReplicaTest";
+ int shuffleId = 0;
+ int partitionId = 0;
+ RssRegisterShuffleRequest rrsr = new RssRegisterShuffleRequest(testAppId,
0,
+ Lists.newArrayList(new PartitionRange(0, 0)), "");
+ shuffleServerClient.registerShuffle(rrsr);
+ Roaring64NavigableMap expectBlockIds = Roaring64NavigableMap.bitmapOf();
+ Map<Long, byte[]> dataMap = Maps.newHashMap();
+ Roaring64NavigableMap[] bitmaps = new Roaring64NavigableMap[1];
+ bitmaps[0] = Roaring64NavigableMap.bitmapOf();
+ // create blocks which belong to different tasks
+ List<ShuffleBlockInfo> blocks = Lists.newArrayList();
+ for (int i = 0; i < 3; i++) {
+ blocks.addAll(createShuffleBlockList(
+ shuffleId, partitionId, i, 1, 25,
+ expectBlockIds, dataMap, mockSSI));
+ }
+ Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks = Maps.newHashMap();
+ partitionToBlocks.put(partitionId, blocks);
+ Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleToBlocks =
Maps.newHashMap();
+ shuffleToBlocks.put(shuffleId, partitionToBlocks);
+
+ // send data to shuffle server
+ RssSendShuffleDataRequest rssdr = new RssSendShuffleDataRequest(
+ testAppId, 3, 1000, shuffleToBlocks);
+ shuffleServerClient.sendShuffleData(rssdr);
+
+ // data is cached
+ assertEquals(3, shuffleServers.get(0).getShuffleBufferManager()
+ .getShuffleBuffer(testAppId, shuffleId, 0).getBlocks().size());
+
+ Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+ Roaring64NavigableMap exceptTaskIds = Roaring64NavigableMap.bitmapOf(0, 1,
2);
+ // create memory handler to read data,
+ MemoryClientReadHandler memoryClientReadHandler = new
MemoryClientReadHandler(
+ testAppId, shuffleId, partitionId, 20, shuffleServerClient,
exceptTaskIds);
+ // start to read data, one block data for every call
+ ShuffleDataResult sdr = memoryClientReadHandler.readShuffleData();
+ Map<Long, byte[]> expectedData = Maps.newHashMap();
+ expectedData.put(blocks.get(0).getBlockId(), blocks.get(0).getData());
+ validateResult(expectedData, sdr);
+ // read by different reader, the first block should be skipped.
+ exceptTaskIds.removeLong(blocks.get(0).getTaskAttemptId());
+ MemoryClientReadHandler memoryClientReadHandler2 = new
MemoryClientReadHandler(
+ testAppId, shuffleId, partitionId, 20, shuffleServerClient,
exceptTaskIds);
+ sdr = memoryClientReadHandler2.readShuffleData();
+ expectedData.clear();
+ expectedData.put(blocks.get(1).getBlockId(), blocks.get(1).getData());
+ validateResult(expectedData, sdr);
+
+ sdr = memoryClientReadHandler.readShuffleData();
+ expectedData.clear();
+ expectedData.put(blocks.get(1).getBlockId(), blocks.get(1).getData());
+ validateResult(expectedData, sdr);
+
+ sdr = memoryClientReadHandler2.readShuffleData();
+ expectedData.clear();
+ expectedData.put(blocks.get(2).getBlockId(), blocks.get(2).getData());
+ validateResult(expectedData, sdr);
+ // no data in cache, empty return
+ sdr = memoryClientReadHandler2.readShuffleData();
+ assertEquals(0, sdr.getBufferSegments().size());
+ }
+
@Test
public void memoryAndLocalFileReadWithFilterTest() throws Exception {
String testAppId = "memoryAndLocalFileReadWithFilterTest";
@@ -237,8 +308,9 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
// read the 1-th segment from memory
Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+ Roaring64NavigableMap exceptTaskIds = Roaring64NavigableMap.bitmapOf(0);
MemoryClientReadHandler memoryClientReadHandler = new
MemoryClientReadHandler(
- testAppId, shuffleId, partitionId, 150, shuffleServerClient);
+ testAppId, shuffleId, partitionId, 150, shuffleServerClient,
exceptTaskIds);
LocalFileClientReadHandler localFileClientReadHandler = new
LocalFileClientReadHandler(
testAppId, shuffleId, partitionId, 0, 1, 3,
75, expectBlockIds, processBlockIds, shuffleServerClient);
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index e8fc4046..91253482 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -19,15 +19,18 @@ package org.apache.uniffle.storage.factory;
import java.util.ArrayList;
import java.util.List;
+import java.util.concurrent.Callable;
import com.google.common.collect.Lists;
import org.apache.commons.collections.CollectionUtils;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.api.ShuffleServerClient;
import org.apache.uniffle.client.factory.ShuffleServerClientFactory;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.storage.handler.api.ClientReadHandler;
import org.apache.uniffle.storage.handler.api.ShuffleDeleteHandler;
import org.apache.uniffle.storage.handler.impl.ComposedClientReadHandler;
@@ -90,20 +93,20 @@ public class ShuffleHandlerFactory {
return getLocalfileClientReaderHandler(request, serverInfo);
}
- List<ClientReadHandler> handlers = new ArrayList<>();
+ List<Callable<ClientReadHandler>> handlers = new ArrayList<>();
if (StorageType.withMemory(type)) {
handlers.add(
- getMemoryClientReadHandler(request, serverInfo)
+ () -> getMemoryClientReadHandler(request, serverInfo)
);
}
if (StorageType.withLocalfile(type)) {
handlers.add(
- getLocalfileClientReaderHandler(request, serverInfo)
+ () -> getLocalfileClientReaderHandler(request, serverInfo)
);
}
if (StorageType.withHDFS(type)) {
handlers.add(
- getHdfsClientReadHandler(request, serverInfo)
+ () -> getHdfsClientReadHandler(request, serverInfo)
);
}
if (handlers.isEmpty()) {
@@ -116,14 +119,19 @@ public class ShuffleHandlerFactory {
private ClientReadHandler
getMemoryClientReadHandler(CreateShuffleReadHandlerRequest request,
ShuffleServerInfo ssi) {
ShuffleServerClient shuffleServerClient =
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
ClientType.GRPC.name(), ssi);
+ Roaring64NavigableMap expectTaskIds = null;
+ if (request.isExpectedTaskIdsBitmapFilterEnable()) {
+ Roaring64NavigableMap realExceptBlockIds =
RssUtils.cloneBitMap(request.getExpectBlockIds());
+ realExceptBlockIds.xor(request.getProcessBlockIds());
+ expectTaskIds = RssUtils.generateTaskIdBitMap(realExceptBlockIds,
request.getIdHelper());
+ }
ClientReadHandler memoryClientReadHandler = new MemoryClientReadHandler(
request.getAppId(),
request.getShuffleId(),
request.getPartitionId(),
request.getReadBufferSize(),
shuffleServerClient,
- request.getExpectTaskIds(),
- request.isExpectedTaskIdsBitmapFilterEnable()
+ expectTaskIds
);
return memoryClientReadHandler;
}
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
index ea3d0be3..f3baeebe 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
@@ -18,9 +18,9 @@
package org.apache.uniffle.storage.handler.impl;
import java.util.List;
+import java.util.concurrent.Callable;
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -41,6 +41,10 @@ public class ComposedClientReadHandler extends
AbstractClientReadHandler {
private static final Logger LOG =
LoggerFactory.getLogger(ComposedClientReadHandler.class);
private final ShuffleServerInfo serverInfo;
+ private Callable<ClientReadHandler> hotHandlerCreator;
+ private Callable<ClientReadHandler> warmHandlerCreator;
+ private Callable<ClientReadHandler> coldHandlerCreator;
+ private Callable<ClientReadHandler> frozenHandlerCreator;
private ClientReadHandler hotDataReadHandler;
private ClientReadHandler warmDataReadHandler;
private ClientReadHandler coldDataReadHandler;
@@ -58,23 +62,36 @@ public class ComposedClientReadHandler extends
AbstractClientReadHandler {
private ClientReadHandlerMetric frozenHandlerMetric = new
ClientReadHandlerMetric();
public ComposedClientReadHandler(ShuffleServerInfo serverInfo,
ClientReadHandler... handlers) {
- this(serverInfo, Lists.newArrayList(handlers));
+ this.serverInfo = serverInfo;
+ topLevelOfHandler = handlers.length;
+ if (topLevelOfHandler > 0) {
+ this.hotDataReadHandler = handlers[0];
+ }
+ if (topLevelOfHandler > 1) {
+ this.warmDataReadHandler = handlers[1];
+ }
+ if (topLevelOfHandler > 2) {
+ this.coldDataReadHandler = handlers[2];
+ }
+ if (topLevelOfHandler > 3) {
+ this.frozenDataReadHandler = handlers[3];
+ }
}
- public ComposedClientReadHandler(ShuffleServerInfo serverInfo,
List<ClientReadHandler> handlers) {
+ public ComposedClientReadHandler(ShuffleServerInfo serverInfo,
List<Callable<ClientReadHandler>> callables) {
this.serverInfo = serverInfo;
- topLevelOfHandler = handlers.size();
+ topLevelOfHandler = callables.size();
if (topLevelOfHandler > 0) {
- this.hotDataReadHandler = handlers.get(0);
+ this.hotHandlerCreator = callables.get(0);
}
if (topLevelOfHandler > 1) {
- this.warmDataReadHandler = handlers.get(1);
+ this.warmHandlerCreator = callables.get(1);
}
if (topLevelOfHandler > 2) {
- this.coldDataReadHandler = handlers.get(2);
+ this.coldHandlerCreator = callables.get(2);
}
if (topLevelOfHandler > 3) {
- this.frozenDataReadHandler = handlers.get(3);
+ this.frozenHandlerCreator = callables.get(3);
}
}
@@ -84,15 +101,27 @@ public class ComposedClientReadHandler extends
AbstractClientReadHandler {
try {
switch (currentHandler) {
case HOT:
+ if (hotDataReadHandler == null) {
+ hotDataReadHandler = hotHandlerCreator.call();
+ }
shuffleDataResult = hotDataReadHandler.readShuffleData();
break;
case WARM:
+ if (warmDataReadHandler == null) {
+ warmDataReadHandler = warmHandlerCreator.call();
+ }
shuffleDataResult = warmDataReadHandler.readShuffleData();
break;
case COLD:
+ if (coldDataReadHandler == null) {
+ coldDataReadHandler = coldHandlerCreator.call();
+ }
shuffleDataResult = coldDataReadHandler.readShuffleData();
break;
case FROZEN:
+ if (frozenDataReadHandler == null) {
+ frozenDataReadHandler = frozenHandlerCreator.call();
+ }
shuffleDataResult = frozenDataReadHandler.readShuffleData();
break;
default:
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
index 3cc2e6ba..49eaa0b0 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
@@ -19,7 +19,6 @@ package org.apache.uniffle.storage.handler.impl;
import java.util.List;
-import com.google.common.annotations.VisibleForTesting;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -32,48 +31,27 @@ import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.Constants;
+
public class MemoryClientReadHandler extends AbstractClientReadHandler {
private static final Logger LOG =
LoggerFactory.getLogger(MemoryClientReadHandler.class);
private long lastBlockId = Constants.INVALID_BLOCK_ID;
private ShuffleServerClient shuffleServerClient;
private Roaring64NavigableMap expectTaskIds;
- private boolean expectedTaskIdsBitmapFilterEnable;
-
- // Only for tests
- @VisibleForTesting
- public MemoryClientReadHandler(
- String appId,
- int shuffleId,
- int partitionId,
- int readBufferSize,
- ShuffleServerClient shuffleServerClient) {
- this(
- appId,
- shuffleId,
- partitionId,
- readBufferSize,
- shuffleServerClient,
- null,
- false
- );
- }
-
+
public MemoryClientReadHandler(
String appId,
int shuffleId,
int partitionId,
int readBufferSize,
ShuffleServerClient shuffleServerClient,
- Roaring64NavigableMap expectTaskIds,
- boolean expectedTaskIdsBitmapFilterEnable) {
+ Roaring64NavigableMap expectTaskIds) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionId = partitionId;
this.readBufferSize = readBufferSize;
this.shuffleServerClient = shuffleServerClient;
this.expectTaskIds = expectTaskIds;
- this.expectedTaskIdsBitmapFilterEnable = expectedTaskIdsBitmapFilterEnable;
}
@Override
@@ -86,7 +64,7 @@ public class MemoryClientReadHandler extends
AbstractClientReadHandler {
partitionId,
lastBlockId,
readBufferSize,
- expectedTaskIdsBitmapFilterEnable ? expectTaskIds : null
+ expectTaskIds
);
try {
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
index 58729485..0801893a 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
@@ -25,6 +25,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.common.util.IdHelper;
public class CreateShuffleReadHandlerRequest {
@@ -46,6 +47,8 @@ public class CreateShuffleReadHandlerRequest {
private Roaring64NavigableMap expectTaskIds;
private boolean expectedTaskIdsBitmapFilterEnable;
+ private IdHelper idHelper;
+
public CreateShuffleReadHandlerRequest() {
}
@@ -184,4 +187,12 @@ public class CreateShuffleReadHandlerRequest {
public void useExpectedTaskIdsBitmapFilter() {
this.expectedTaskIdsBitmapFilterEnable = true;
}
+
+ public IdHelper getIdHelper() {
+ return idHelper;
+ }
+
+ public void setIdHelper(IdHelper idHelper) {
+ this.idHelper = idHelper;
+ }
}