This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new a2c2d056c [#2697] refactor(spark): Involve related writer stats info
into ShuffleWriteTaskStats (#2698)
a2c2d056c is described below
commit a2c2d056ce8fd49af025eaa48785b94d21515557
Author: Junfan Zhang <[email protected]>
AuthorDate: Thu Dec 18 11:40:46 2025 +0800
[#2697] refactor(spark): Involve related writer stats info into
ShuffleWriteTaskStats (#2698)
### What changes were proposed in this pull request?
Involve related writer stats info into ShuffleWriteTaskStats for further
integrity validation block checksum implementation
### Why are the changes needed?
for #2697
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing unit tests
---------
Co-authored-by: Copilot <[email protected]>
---
.../org/apache/uniffle/shuffle/BlockStats.java | 57 +++++++++
.../uniffle/shuffle/ShuffleWriteTaskStats.java | 112 +++++++++++++++---
.../org/apache/uniffle/shuffle/BlockStatsTest.java | 83 +++++++++++++
.../spark/shuffle/writer/RssShuffleWriter.java | 131 +++++++++++----------
4 files changed, 300 insertions(+), 83 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockStats.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockStats.java
new file mode 100644
index 000000000..9dd69f746
--- /dev/null
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockStats.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.HashSet;
+import java.util.Set;
+
+/** Multi blocks stats */
+public class BlockStats {
+ private long recordNumber;
+ private Set<Long> blockIds;
+
+ public BlockStats() {
+ this.recordNumber = 0;
+ this.blockIds = new HashSet<>();
+ }
+
+ public BlockStats(long recordNumber, long blockId) {
+ this.recordNumber = recordNumber;
+ this.blockIds = new HashSet<>(1);
+ this.blockIds.add(blockId);
+ }
+
+ public long getRecordNumber() {
+ return recordNumber;
+ }
+
+ public Set<Long> getBlockIds() {
+ return blockIds;
+ }
+
+ public void merge(BlockStats other) {
+ recordNumber += other.recordNumber;
+ blockIds.addAll(other.blockIds);
+ }
+
+ public void remove(BlockStats other) {
+ if (blockIds.removeAll(other.blockIds)) {
+ recordNumber -= other.recordNumber;
+ }
+ }
+}
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
index 4f226569d..e4e9785dd 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
@@ -19,11 +19,14 @@ package org.apache.uniffle.shuffle;
import java.nio.ByteBuffer;
import java.util.Arrays;
+import java.util.Map;
import java.util.Optional;
+import com.google.common.collect.Maps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
@@ -39,19 +42,28 @@ import static
org.apache.spark.shuffle.RssSparkConfig.RSS_DATA_INTEGRITY_VALIDAT
public class ShuffleWriteTaskStats {
private static final Logger LOGGER =
LoggerFactory.getLogger(ShuffleWriteTaskStats.class);
+ private RssConf rssConf;
+ private boolean blockNumberCheckEnabled;
+
// the unique task id across all stages
private long taskId;
// this is only unique for one stage and defined in uniffle side instead of
spark
private long taskAttemptId;
+ // partition number
private int partitions;
- private long[] partitionRecordsWritten;
- private long[] partitionBlocksWritten;
- private boolean blockNumberCheckEnabled;
- private RssConf rssConf;
+
+ private long[] partitionRecords;
+ private long[] partitionBlocks;
+
+ // server -> partitionId -> block-stats
+ private Map<ShuffleServerInfo, Map<Integer, BlockStats>>
serverToPartitionToBlockStats;
+
+ // partitions data length
+ private long[] partitionLengths;
public ShuffleWriteTaskStats(RssConf rssConf, int partitions, long
taskAttemptId, long taskId) {
- this.partitionRecordsWritten = new long[partitions];
- Arrays.fill(this.partitionRecordsWritten, 0L);
+ this.partitionRecords = new long[partitions];
+ Arrays.fill(this.partitionRecords, 0L);
this.partitions = partitions;
this.taskAttemptId = taskAttemptId;
@@ -61,9 +73,14 @@ public class ShuffleWriteTaskStats {
rssConf.get(RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED);
if (blockNumberCheckEnabled) {
- this.partitionBlocksWritten = new long[partitions];
- Arrays.fill(this.partitionBlocksWritten, 0L);
+ this.partitionBlocks = new long[partitions];
+ Arrays.fill(this.partitionBlocks, 0L);
}
+
+ this.serverToPartitionToBlockStats = Maps.newConcurrentMap();
+
+ this.partitionLengths = new long[partitions];
+ Arrays.fill(this.partitionLengths, 0L);
}
public ShuffleWriteTaskStats(int partitions, long taskAttemptId, long
taskId) {
@@ -71,22 +88,28 @@ public class ShuffleWriteTaskStats {
}
public long getRecordsWritten(int partitionId) {
- return partitionRecordsWritten[partitionId];
+ return partitionRecords[partitionId];
}
public void incPartitionRecord(int partitionId) {
- partitionRecordsWritten[partitionId]++;
+ partitionRecords[partitionId]++;
}
public void incPartitionBlock(int partitionId) {
if (blockNumberCheckEnabled) {
- partitionBlocksWritten[partitionId]++;
+ partitionBlocks[partitionId]++;
+ }
+ }
+
+ public void decPartitionBlock(int partitionId) {
+ if (blockNumberCheckEnabled) {
+ partitionBlocks[partitionId]--;
}
}
public long getBlocksWritten(int partitionId) {
if (blockNumberCheckEnabled) {
- return partitionBlocksWritten[partitionId];
+ return partitionBlocks[partitionId];
}
return -1L;
}
@@ -97,7 +120,7 @@ public class ShuffleWriteTaskStats {
public String encode() {
final long start = System.currentTimeMillis();
- int partitions = partitionRecordsWritten.length;
+ int partitions = partitionRecords.length;
int capacity = 2 * Long.BYTES + Integer.BYTES + partitions * Long.BYTES;
if (blockNumberCheckEnabled) {
capacity += partitions * Long.BYTES;
@@ -106,11 +129,11 @@ public class ShuffleWriteTaskStats {
buffer.putLong(taskId);
buffer.putLong(taskAttemptId);
buffer.putInt(partitions);
- for (long records : partitionRecordsWritten) {
+ for (long records : partitionRecords) {
buffer.putLong(records);
}
if (blockNumberCheckEnabled) {
- for (long blocks : partitionBlocksWritten) {
+ for (long blocks : partitionBlocks) {
buffer.putLong(blocks);
}
}
@@ -156,11 +179,11 @@ public class ShuffleWriteTaskStats {
ShuffleWriteTaskStats stats =
new ShuffleWriteTaskStats(rssConf, partitions, taskAttemptId, taskId);
for (int i = 0; i < partitions; i++) {
- stats.partitionRecordsWritten[i] = outBuffer.getLong();
+ stats.partitionRecords[i] = outBuffer.getLong();
}
if (rssConf.get(RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED))
{
for (int i = 0; i < partitions; i++) {
- stats.partitionBlocksWritten[i] = outBuffer.getLong();
+ stats.partitionBlocks[i] = outBuffer.getLong();
}
}
return stats;
@@ -172,7 +195,7 @@ public class ShuffleWriteTaskStats {
public void log() {
StringBuilder infoBuilder = new StringBuilder();
- int partitions = partitionRecordsWritten.length;
+ int partitions = partitionRecords.length;
for (int i = 0; i < partitions; i++) {
long records = getRecordsWritten(i);
long blocks = getBlocksWritten(i);
@@ -182,7 +205,10 @@ public class ShuffleWriteTaskStats {
"Partition records/blocks written for taskId[{}]: {}", taskId,
infoBuilder.toString());
}
- public void check(long[] partitionLens) {
+ /** Internal check */
+ public void check() {
+ // 1. partition length check
+ final long[] partitionLens = partitionLengths;
for (int idx = 0; idx < partitions; idx++) {
long records = getRecordsWritten(idx);
long blocks = getBlocksWritten(idx);
@@ -199,5 +225,53 @@ public class ShuffleWriteTaskStats {
+ length);
}
}
+
+ // 2. blockIds check
+ if (blockNumberCheckEnabled) {
+ long expected = 0L;
+ for (long partitionBlockNumber : partitionBlocks) {
+ expected += partitionBlockNumber;
+ }
+ long actual =
+ serverToPartitionToBlockStats.entrySet().stream()
+ .flatMap(x -> x.getValue().entrySet().stream())
+ .map(x -> x.getValue().getBlockIds().size())
+ .reduce(Integer::sum)
+ .orElse(0);
+ if (expected != actual) {
+ throw new RssException(
+ "Illegal block number. Expected: " + expected + ", actual: " +
actual);
+ }
+ }
+ }
+
+ public void incPartitionLength(int partitionId, long length) {
+ partitionLengths[partitionId] += length;
+ }
+
+ public long[] getPartitionLengths() {
+ return partitionLengths;
+ }
+
+ public void mergeBlockStats(
+ ShuffleServerInfo serverInfo, int partitionId, BlockStats blockStats) {
+ BlockStats existing =
+ serverToPartitionToBlockStats
+ .computeIfAbsent(serverInfo, x -> Maps.newConcurrentMap())
+ .computeIfAbsent(partitionId, x -> new BlockStats());
+ existing.merge(blockStats);
+ }
+
+ public void removeBlockStats(
+ ShuffleServerInfo serverInfo, int partitionId, BlockStats blockStats) {
+ BlockStats existing =
+ serverToPartitionToBlockStats
+ .computeIfAbsent(serverInfo, x -> Maps.newConcurrentMap())
+ .computeIfAbsent(partitionId, x -> new BlockStats());
+ existing.remove(blockStats);
+ }
+
+ public Map<ShuffleServerInfo, Map<Integer, BlockStats>> getAllBlockStats() {
+ return serverToPartitionToBlockStats;
}
}
diff --git
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/BlockStatsTest.java
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/BlockStatsTest.java
new file mode 100644
index 000000000..bea1d69db
--- /dev/null
+++
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/BlockStatsTest.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.HashSet;
+import java.util.Set;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class BlockStatsTest {
+
+ private static Set<Long> setsOf(long... vars) {
+ Set<Long> set = new HashSet<>();
+ for (long var : vars) {
+ set.add(var);
+ }
+ return set;
+ }
+
+ @Test
+ public void testConstructor() {
+ BlockStats stats = new BlockStats();
+ assertEquals(0, stats.getRecordNumber());
+ assertTrue(stats.getBlockIds().isEmpty());
+
+ BlockStats stats2 = new BlockStats(10, 100L);
+ assertEquals(10, stats2.getRecordNumber());
+ assertEquals(setsOf(100L), stats2.getBlockIds());
+ }
+
+ @Test
+ public void testMerge() {
+ BlockStats s1 = new BlockStats(10, 1L);
+ BlockStats s2 = new BlockStats(5, 2L);
+
+ s1.merge(s2);
+
+ assertEquals(15, s1.getRecordNumber());
+ assertEquals(setsOf(1L, 2L), s1.getBlockIds());
+ }
+
+ @Test
+ public void testRemove() {
+ BlockStats s1 = new BlockStats(20, 1L);
+ BlockStats s2 = new BlockStats(5, 1L);
+
+ s1.remove(s2);
+
+ assertEquals(15, s1.getRecordNumber());
+ assertTrue(s1.getBlockIds().isEmpty());
+ }
+
+ @Test
+ public void testRemoveMultipleBlocks() {
+ BlockStats base = new BlockStats(30, 1L);
+ base.getBlockIds().add(2L);
+
+ BlockStats toRemove = new BlockStats(10, 1L);
+
+ base.remove(toRemove);
+
+ assertEquals(20, base.getRecordNumber());
+ assertEquals(setsOf(2L), base.getBlockIds());
+ }
+}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index ae61cb1fe..fc9bfe53a 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -46,7 +46,6 @@ import scala.collection.Iterator;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Uninterruptibles;
import org.apache.commons.collections4.CollectionUtils;
@@ -89,6 +88,7 @@ import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssSendFailedException;
import org.apache.uniffle.common.exception.RssWaitFailedException;
import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.shuffle.BlockStats;
import org.apache.uniffle.shuffle.ShuffleWriteTaskStats;
import org.apache.uniffle.storage.util.StorageType;
@@ -117,13 +117,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final long sendCheckTimeout;
private final long sendCheckInterval;
private final int bitmapSplitNum;
- // server -> partitionId -> blockIds
- private Map<ShuffleServerInfo, Map<Integer, Set<Long>>>
serverToPartitionToBlockIds;
- // server -> partitionId -> recordNumbers
- private Map<ShuffleServerInfo, Map<Integer, Long>>
serverToPartitionToRecordNumbers;
private final ShuffleWriteClient shuffleWriteClient;
private final Set<ShuffleServerInfo> shuffleServersForData;
- private final PartitionLengthStatistic partitionLengthStatistic;
// Gluten needs this variable
protected final boolean isMemoryShuffleEnabled;
private final Function<String, Boolean> taskFailureCallback;
@@ -158,7 +153,9 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private Optional<String> shuffleWriteFailureReason = Optional.empty();
// Visible for the Gluten
- protected Optional<ShuffleWriteTaskStats> shuffleTaskStats =
Optional.empty();
+ protected ShuffleWriteTaskStats shuffleTaskStats;
+
+ private boolean isIntegrityValidationClientManagementEnabled = false;
// Only for tests
@VisibleForTesting
@@ -227,11 +224,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.sendCheckTimeout =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
this.sendCheckInterval =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
this.bitmapSplitNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
- this.serverToPartitionToBlockIds = Maps.newHashMap();
- this.serverToPartitionToRecordNumbers = Maps.newHashMap();
this.shuffleWriteClient = shuffleWriteClient;
this.shuffleServersForData = shuffleHandleInfo.getServers();
- this.partitionLengthStatistic = new
PartitionLengthStatistic(partitioner.numPartitions());
this.isMemoryShuffleEnabled =
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
this.taskFailureCallback = taskFailureCallback;
@@ -248,16 +242,11 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.blockFailSentRetryMaxTimes =
rssConf.get(RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES);
this.enableWriteFailureRetry =
rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
this.recordReportFailedShuffleservers = Sets.newConcurrentHashSet();
-
- if
(RssShuffleManager.isIntegrityValidationClientManagementEnabled(rssConf)) {
- this.shuffleTaskStats =
- Optional.of(
- new ShuffleWriteTaskStats(
- rssConf,
- partitioner.numPartitions(),
- taskAttemptId,
- taskContext.taskAttemptId()));
- }
+ this.isIntegrityValidationClientManagementEnabled =
+
RssShuffleManager.isIntegrityValidationClientManagementEnabled(rssConf);
+ this.shuffleTaskStats =
+ new ShuffleWriteTaskStats(
+ rssConf, partitioner.numPartitions(), taskAttemptId,
taskContext.taskAttemptId());
}
// Gluten needs this method
@@ -389,7 +378,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
processShuffleBlockInfos(shuffleBlockInfos);
}
- shuffleTaskStats.ifPresent(x -> x.incPartitionRecord(partition));
+ shuffleTaskStats.incPartitionRecord(partition);
}
final long start = System.currentTimeMillis();
shuffleBlockInfos = bufferManager.clear(1.0);
@@ -444,14 +433,11 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
long expected = blockIds.size();
long bufferManagerTracked = bufferManager.getBlockCount();
- if (serverToPartitionToBlockIds == null) {
- throw new RssException("serverToPartitionToBlockIds should not be null");
- }
-
// to filter the multiple replica's duplicate blockIds
Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
- for (Map<Integer, Set<Long>> partitionBlockIds :
serverToPartitionToBlockIds.values()) {
- partitionBlockIds.values().forEach(x ->
x.forEach(blockIdBitmap::addLong));
+ for (Map<Integer, BlockStats> partitionBlockStats :
+ shuffleTaskStats.getAllBlockStats().values()) {
+ partitionBlockStats.values().forEach(x ->
x.getBlockIds().forEach(blockIdBitmap::addLong));
}
long serverTracked = blockIdBitmap.getLongCardinality();
if (expected != serverTracked || expected != bufferManagerTracked) {
@@ -485,22 +471,13 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
blockIds.add(blockId);
int partitionId = sbi.getPartitionId();
// record blocks number for per-partition
- shuffleTaskStats.ifPresent(x -> x.incPartitionBlock(partitionId));
+ shuffleTaskStats.incPartitionBlock(partitionId);
// update [partition, blockIds], it will be sent to shuffle server
sbi.getShuffleServerInfos()
.forEach(
- shuffleServerInfo -> {
- Map<Integer, Set<Long>> pToBlockIds =
- serverToPartitionToBlockIds.computeIfAbsent(
- shuffleServerInfo, k -> Maps.newHashMap());
- pToBlockIds.computeIfAbsent(partitionId, v ->
Sets.newHashSet()).add(blockId);
-
- // update the [partition, recordNumber]
- serverToPartitionToRecordNumbers
- .computeIfAbsent(shuffleServerInfo, k ->
Maps.newHashMap())
- .compute(
- partitionId, (k, v) -> v == null ? recordNumber
: v + recordNumber);
- });
+ s ->
+ shuffleTaskStats.mergeBlockStats(
+ s, partitionId, new BlockStats(recordNumber,
blockId)));
});
return postBlockEvent(shuffleBlockInfoList);
}
@@ -519,7 +496,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
// Otherwise, the block is released immediately once completed.
if (!blockFailSentRetryEnabled || isSuccessful) {
bufferManager.releaseBlockResource(b);
- partitionLengthStatistic.inc(b);
+ shuffleTaskStats.incPartitionLength(b.getPartitionId(),
b.getLength());
}
});
}
@@ -886,19 +863,14 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private void clearFailedBlockState(ShuffleBlockInfo block) {
shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(block.getBlockId());
+ shuffleTaskStats.decPartitionBlock(block.getPartitionId());
block.getShuffleServerInfos().stream()
.forEach(
- s -> {
- serverToPartitionToBlockIds
- .get(s)
- .get(block.getPartitionId())
- .remove(block.getBlockId());
- serverToPartitionToRecordNumbers
- .get(s)
- .compute(
- block.getPartitionId(),
- (pid, recordNumber) -> recordNumber -
block.getRecordNumber());
- });
+ s ->
+ shuffleTaskStats.removeBlockStats(
+ s,
+ block.getPartitionId(),
+ new BlockStats(block.getRecordNumber(),
block.getBlockId())));
blockIds.remove(block.getBlockId());
}
@@ -949,14 +921,16 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
if (success) {
long start = System.currentTimeMillis();
shuffleWriteClient.reportShuffleResult(
- serverToPartitionToBlockIds,
+ getServerToPartitionToBlockIds(),
appId,
shuffleId,
taskAttemptId,
bitmapSplitNum,
recordReportFailedShuffleservers,
enableWriteFailureRetry,
- serverToPartitionToRecordNumbers);
+ isIntegrityValidationClientManagementEnabled
+ ? null
+ : getServerToPartitionToRecordNumbers());
long reportDuration = System.currentTimeMillis() - start;
LOG.info(
"Reported all shuffle result for shuffleId[{}] task[{}] with
bitmapNum[{}] cost {} ms",
@@ -967,14 +941,10 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleWriteMetrics.incWriteTime(TimeUnit.MILLISECONDS.toNanos(reportDuration));
if
(RssShuffleManager.isIntegrationValidationFailureAnalysisEnabled(rssConf)) {
- shuffleTaskStats.ifPresent(x -> x.log());
+ shuffleTaskStats.log();
}
- long[] partitionLens = partitionLengthStatistic.toArray();
-
- if (shuffleTaskStats.isPresent()) {
- shuffleTaskStats.get().check(partitionLens);
- }
+ shuffleTaskStats.check();
// todo: we can replace the dummy host and port with the real shuffle
server which we prefer
// to read
@@ -984,10 +954,11 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
DUMMY_HOST,
DUMMY_PORT,
Option.apply(
- shuffleTaskStats.isPresent()
- ? shuffleTaskStats.get().encode()
+ isIntegrityValidationClientManagementEnabled
+ ? shuffleTaskStats.encode()
: Long.toString(taskAttemptId)));
- MapStatus mapStatus = MapStatus.apply(blockManagerId, partitionLens,
taskAttemptId);
+ MapStatus mapStatus =
+ MapStatus.apply(blockManagerId,
shuffleTaskStats.getPartitionLengths(), taskAttemptId);
return Option.apply(mapStatus);
} else {
return Option.empty();
@@ -1062,7 +1033,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
@VisibleForTesting
Map<Integer, Set<Long>> getPartitionToBlockIds() {
- return serverToPartitionToBlockIds.values().stream()
+ return getServerToPartitionToBlockIds().values().stream()
.flatMap(s -> s.entrySet().stream())
.collect(
Collectors.toMap(
@@ -1127,9 +1098,41 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
@VisibleForTesting
protected Map<ShuffleServerInfo, Map<Integer, Set<Long>>>
getServerToPartitionToBlockIds() {
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>>
serverToPartitionToBlockIds = new HashMap<>();
+ Map<ShuffleServerInfo, Map<Integer, BlockStats>> allBlockStats =
+ shuffleTaskStats.getAllBlockStats();
+ for (Map.Entry<ShuffleServerInfo, Map<Integer, BlockStats>> entry :
allBlockStats.entrySet()) {
+ ShuffleServerInfo server = entry.getKey();
+ for (Map.Entry<Integer, BlockStats> childEntry :
entry.getValue().entrySet()) {
+ int partitionId = childEntry.getKey();
+ BlockStats stats = childEntry.getValue();
+ serverToPartitionToBlockIds
+ .computeIfAbsent(server, k -> new HashMap<>())
+ .computeIfAbsent(partitionId, z -> new HashSet<>())
+ .addAll(stats.getBlockIds());
+ }
+ }
return serverToPartitionToBlockIds;
}
+ @VisibleForTesting
+ protected Map<ShuffleServerInfo, Map<Integer, Long>>
getServerToPartitionToRecordNumbers() {
+ Map<ShuffleServerInfo, Map<Integer, Long>>
serverToPartitionToRecordNumbers = new HashMap<>();
+ Map<ShuffleServerInfo, Map<Integer, BlockStats>> allBlockStats =
+ shuffleTaskStats.getAllBlockStats();
+ for (Map.Entry<ShuffleServerInfo, Map<Integer, BlockStats>> entry :
allBlockStats.entrySet()) {
+ ShuffleServerInfo server = entry.getKey();
+ for (Map.Entry<Integer, BlockStats> childEntry :
entry.getValue().entrySet()) {
+ int partitionId = childEntry.getKey();
+ BlockStats stats = childEntry.getValue();
+ serverToPartitionToRecordNumbers
+ .computeIfAbsent(server, k -> new HashMap<>())
+ .merge(partitionId, stats.getRecordNumber(), Long::sum);
+ }
+ }
+ return serverToPartitionToRecordNumbers;
+ }
+
@VisibleForTesting
protected RssShuffleManager getShuffleManager() {
return shuffleManager;