This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new a2c2d056c [#2697] refactor(spark): Involve related writer stats info 
into ShuffleWriteTaskStats (#2698)
a2c2d056c is described below

commit a2c2d056ce8fd49af025eaa48785b94d21515557
Author: Junfan Zhang <[email protected]>
AuthorDate: Thu Dec 18 11:40:46 2025 +0800

    [#2697] refactor(spark): Involve related writer stats info into 
ShuffleWriteTaskStats (#2698)
    
    ### What changes were proposed in this pull request?
    
    Involve related writer stats info into ShuffleWriteTaskStats for further 
integrity validation block checksum implementation
    
    ### Why are the changes needed?
    
    for #2697
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing unit tests
    
    ---------
    
    Co-authored-by: Copilot <[email protected]>
---
 .../org/apache/uniffle/shuffle/BlockStats.java     |  57 +++++++++
 .../uniffle/shuffle/ShuffleWriteTaskStats.java     | 112 +++++++++++++++---
 .../org/apache/uniffle/shuffle/BlockStatsTest.java |  83 +++++++++++++
 .../spark/shuffle/writer/RssShuffleWriter.java     | 131 +++++++++++----------
 4 files changed, 300 insertions(+), 83 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockStats.java 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockStats.java
new file mode 100644
index 000000000..9dd69f746
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockStats.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.HashSet;
+import java.util.Set;
+
+/** Multi blocks stats */
+public class BlockStats {
+  private long recordNumber;
+  private Set<Long> blockIds;
+
+  public BlockStats() {
+    this.recordNumber = 0;
+    this.blockIds = new HashSet<>();
+  }
+
+  public BlockStats(long recordNumber, long blockId) {
+    this.recordNumber = recordNumber;
+    this.blockIds = new HashSet<>(1);
+    this.blockIds.add(blockId);
+  }
+
+  public long getRecordNumber() {
+    return recordNumber;
+  }
+
+  public Set<Long> getBlockIds() {
+    return blockIds;
+  }
+
+  public void merge(BlockStats other) {
+    recordNumber += other.recordNumber;
+    blockIds.addAll(other.blockIds);
+  }
+
+  public void remove(BlockStats other) {
+    if (blockIds.removeAll(other.blockIds)) {
+      recordNumber -= other.recordNumber;
+    }
+  }
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
index 4f226569d..e4e9785dd 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
@@ -19,11 +19,14 @@ package org.apache.uniffle.shuffle;
 
 import java.nio.ByteBuffer;
 import java.util.Arrays;
+import java.util.Map;
 import java.util.Optional;
 
+import com.google.common.collect.Maps;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
@@ -39,19 +42,28 @@ import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_DATA_INTEGRITY_VALIDAT
 public class ShuffleWriteTaskStats {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(ShuffleWriteTaskStats.class);
 
+  private RssConf rssConf;
+  private boolean blockNumberCheckEnabled;
+
   // the unique task id across all stages
   private long taskId;
   // this is only unique for one stage and defined in uniffle side instead of 
spark
   private long taskAttemptId;
+  // partition number
   private int partitions;
-  private long[] partitionRecordsWritten;
-  private long[] partitionBlocksWritten;
-  private boolean blockNumberCheckEnabled;
-  private RssConf rssConf;
+
+  private long[] partitionRecords;
+  private long[] partitionBlocks;
+
+  // server -> partitionId -> block-stats
+  private Map<ShuffleServerInfo, Map<Integer, BlockStats>> 
serverToPartitionToBlockStats;
+
+  // partitions data length
+  private long[] partitionLengths;
 
   public ShuffleWriteTaskStats(RssConf rssConf, int partitions, long 
taskAttemptId, long taskId) {
-    this.partitionRecordsWritten = new long[partitions];
-    Arrays.fill(this.partitionRecordsWritten, 0L);
+    this.partitionRecords = new long[partitions];
+    Arrays.fill(this.partitionRecords, 0L);
 
     this.partitions = partitions;
     this.taskAttemptId = taskAttemptId;
@@ -61,9 +73,14 @@ public class ShuffleWriteTaskStats {
         rssConf.get(RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED);
 
     if (blockNumberCheckEnabled) {
-      this.partitionBlocksWritten = new long[partitions];
-      Arrays.fill(this.partitionBlocksWritten, 0L);
+      this.partitionBlocks = new long[partitions];
+      Arrays.fill(this.partitionBlocks, 0L);
     }
+
+    this.serverToPartitionToBlockStats = Maps.newConcurrentMap();
+
+    this.partitionLengths = new long[partitions];
+    Arrays.fill(this.partitionLengths, 0L);
   }
 
   public ShuffleWriteTaskStats(int partitions, long taskAttemptId, long 
taskId) {
@@ -71,22 +88,28 @@ public class ShuffleWriteTaskStats {
   }
 
   public long getRecordsWritten(int partitionId) {
-    return partitionRecordsWritten[partitionId];
+    return partitionRecords[partitionId];
   }
 
   public void incPartitionRecord(int partitionId) {
-    partitionRecordsWritten[partitionId]++;
+    partitionRecords[partitionId]++;
   }
 
   public void incPartitionBlock(int partitionId) {
     if (blockNumberCheckEnabled) {
-      partitionBlocksWritten[partitionId]++;
+      partitionBlocks[partitionId]++;
+    }
+  }
+
+  public void decPartitionBlock(int partitionId) {
+    if (blockNumberCheckEnabled) {
+      partitionBlocks[partitionId]--;
     }
   }
 
   public long getBlocksWritten(int partitionId) {
     if (blockNumberCheckEnabled) {
-      return partitionBlocksWritten[partitionId];
+      return partitionBlocks[partitionId];
     }
     return -1L;
   }
@@ -97,7 +120,7 @@ public class ShuffleWriteTaskStats {
 
   public String encode() {
     final long start = System.currentTimeMillis();
-    int partitions = partitionRecordsWritten.length;
+    int partitions = partitionRecords.length;
     int capacity = 2 * Long.BYTES + Integer.BYTES + partitions * Long.BYTES;
     if (blockNumberCheckEnabled) {
       capacity += partitions * Long.BYTES;
@@ -106,11 +129,11 @@ public class ShuffleWriteTaskStats {
     buffer.putLong(taskId);
     buffer.putLong(taskAttemptId);
     buffer.putInt(partitions);
-    for (long records : partitionRecordsWritten) {
+    for (long records : partitionRecords) {
       buffer.putLong(records);
     }
     if (blockNumberCheckEnabled) {
-      for (long blocks : partitionBlocksWritten) {
+      for (long blocks : partitionBlocks) {
         buffer.putLong(blocks);
       }
     }
@@ -156,11 +179,11 @@ public class ShuffleWriteTaskStats {
     ShuffleWriteTaskStats stats =
         new ShuffleWriteTaskStats(rssConf, partitions, taskAttemptId, taskId);
     for (int i = 0; i < partitions; i++) {
-      stats.partitionRecordsWritten[i] = outBuffer.getLong();
+      stats.partitionRecords[i] = outBuffer.getLong();
     }
     if (rssConf.get(RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED)) 
{
       for (int i = 0; i < partitions; i++) {
-        stats.partitionBlocksWritten[i] = outBuffer.getLong();
+        stats.partitionBlocks[i] = outBuffer.getLong();
       }
     }
     return stats;
@@ -172,7 +195,7 @@ public class ShuffleWriteTaskStats {
 
   public void log() {
     StringBuilder infoBuilder = new StringBuilder();
-    int partitions = partitionRecordsWritten.length;
+    int partitions = partitionRecords.length;
     for (int i = 0; i < partitions; i++) {
       long records = getRecordsWritten(i);
       long blocks = getBlocksWritten(i);
@@ -182,7 +205,10 @@ public class ShuffleWriteTaskStats {
         "Partition records/blocks written for taskId[{}]: {}", taskId, 
infoBuilder.toString());
   }
 
-  public void check(long[] partitionLens) {
+  /** Internal check */
+  public void check() {
+    // 1. partition length check
+    final long[] partitionLens = partitionLengths;
     for (int idx = 0; idx < partitions; idx++) {
       long records = getRecordsWritten(idx);
       long blocks = getBlocksWritten(idx);
@@ -199,5 +225,53 @@ public class ShuffleWriteTaskStats {
                 + length);
       }
     }
+
+    // 2. blockIds check
+    if (blockNumberCheckEnabled) {
+      long expected = 0L;
+      for (long partitionBlockNumber : partitionBlocks) {
+        expected += partitionBlockNumber;
+      }
+      long actual =
+          serverToPartitionToBlockStats.entrySet().stream()
+              .flatMap(x -> x.getValue().entrySet().stream())
+              .map(x -> x.getValue().getBlockIds().size())
+              .reduce(Integer::sum)
+              .orElse(0);
+      if (expected != actual) {
+        throw new RssException(
+            "Illegal block number. Expected: " + expected + ", actual: " + 
actual);
+      }
+    }
+  }
+
+  public void incPartitionLength(int partitionId, long length) {
+    partitionLengths[partitionId] += length;
+  }
+
+  public long[] getPartitionLengths() {
+    return partitionLengths;
+  }
+
+  public void mergeBlockStats(
+      ShuffleServerInfo serverInfo, int partitionId, BlockStats blockStats) {
+    BlockStats existing =
+        serverToPartitionToBlockStats
+            .computeIfAbsent(serverInfo, x -> Maps.newConcurrentMap())
+            .computeIfAbsent(partitionId, x -> new BlockStats());
+    existing.merge(blockStats);
+  }
+
+  public void removeBlockStats(
+      ShuffleServerInfo serverInfo, int partitionId, BlockStats blockStats) {
+    BlockStats existing =
+        serverToPartitionToBlockStats
+            .computeIfAbsent(serverInfo, x -> Maps.newConcurrentMap())
+            .computeIfAbsent(partitionId, x -> new BlockStats());
+    existing.remove(blockStats);
+  }
+
+  public Map<ShuffleServerInfo, Map<Integer, BlockStats>> getAllBlockStats() {
+    return serverToPartitionToBlockStats;
   }
 }
diff --git 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/BlockStatsTest.java
 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/BlockStatsTest.java
new file mode 100644
index 000000000..bea1d69db
--- /dev/null
+++ 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/BlockStatsTest.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.HashSet;
+import java.util.Set;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class BlockStatsTest {
+
+  private static Set<Long> setsOf(long... vars) {
+    Set<Long> set = new HashSet<>();
+    for (long var : vars) {
+      set.add(var);
+    }
+    return set;
+  }
+
+  @Test
+  public void testConstructor() {
+    BlockStats stats = new BlockStats();
+    assertEquals(0, stats.getRecordNumber());
+    assertTrue(stats.getBlockIds().isEmpty());
+
+    BlockStats stats2 = new BlockStats(10, 100L);
+    assertEquals(10, stats2.getRecordNumber());
+    assertEquals(setsOf(100L), stats2.getBlockIds());
+  }
+
+  @Test
+  public void testMerge() {
+    BlockStats s1 = new BlockStats(10, 1L);
+    BlockStats s2 = new BlockStats(5, 2L);
+
+    s1.merge(s2);
+
+    assertEquals(15, s1.getRecordNumber());
+    assertEquals(setsOf(1L, 2L), s1.getBlockIds());
+  }
+
+  @Test
+  public void testRemove() {
+    BlockStats s1 = new BlockStats(20, 1L);
+    BlockStats s2 = new BlockStats(5, 1L);
+
+    s1.remove(s2);
+
+    assertEquals(15, s1.getRecordNumber());
+    assertTrue(s1.getBlockIds().isEmpty());
+  }
+
+  @Test
+  public void testRemoveMultipleBlocks() {
+    BlockStats base = new BlockStats(30, 1L);
+    base.getBlockIds().add(2L);
+
+    BlockStats toRemove = new BlockStats(10, 1L);
+
+    base.remove(toRemove);
+
+    assertEquals(20, base.getRecordNumber());
+    assertEquals(setsOf(2L), base.getBlockIds());
+  }
+}
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index ae61cb1fe..fc9bfe53a 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -46,7 +46,6 @@ import scala.collection.Iterator;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import com.google.common.util.concurrent.Uninterruptibles;
 import org.apache.commons.collections4.CollectionUtils;
@@ -89,6 +88,7 @@ import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssSendFailedException;
 import org.apache.uniffle.common.exception.RssWaitFailedException;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.shuffle.BlockStats;
 import org.apache.uniffle.shuffle.ShuffleWriteTaskStats;
 import org.apache.uniffle.storage.util.StorageType;
 
@@ -117,13 +117,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final long sendCheckTimeout;
   private final long sendCheckInterval;
   private final int bitmapSplitNum;
-  // server -> partitionId -> blockIds
-  private Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
serverToPartitionToBlockIds;
-  // server -> partitionId -> recordNumbers
-  private Map<ShuffleServerInfo, Map<Integer, Long>> 
serverToPartitionToRecordNumbers;
   private final ShuffleWriteClient shuffleWriteClient;
   private final Set<ShuffleServerInfo> shuffleServersForData;
-  private final PartitionLengthStatistic partitionLengthStatistic;
   // Gluten needs this variable
   protected final boolean isMemoryShuffleEnabled;
   private final Function<String, Boolean> taskFailureCallback;
@@ -158,7 +153,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private Optional<String> shuffleWriteFailureReason = Optional.empty();
 
   // Visible for the Gluten
-  protected Optional<ShuffleWriteTaskStats> shuffleTaskStats = 
Optional.empty();
+  protected ShuffleWriteTaskStats shuffleTaskStats;
+
+  private boolean isIntegrityValidationClientManagementEnabled = false;
 
   // Only for tests
   @VisibleForTesting
@@ -227,11 +224,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.sendCheckTimeout = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
     this.sendCheckInterval = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
     this.bitmapSplitNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
-    this.serverToPartitionToBlockIds = Maps.newHashMap();
-    this.serverToPartitionToRecordNumbers = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
     this.shuffleServersForData = shuffleHandleInfo.getServers();
-    this.partitionLengthStatistic = new 
PartitionLengthStatistic(partitioner.numPartitions());
     this.isMemoryShuffleEnabled =
         
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
     this.taskFailureCallback = taskFailureCallback;
@@ -248,16 +242,11 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.blockFailSentRetryMaxTimes = 
rssConf.get(RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES);
     this.enableWriteFailureRetry = 
rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
     this.recordReportFailedShuffleservers = Sets.newConcurrentHashSet();
-
-    if 
(RssShuffleManager.isIntegrityValidationClientManagementEnabled(rssConf)) {
-      this.shuffleTaskStats =
-          Optional.of(
-              new ShuffleWriteTaskStats(
-                  rssConf,
-                  partitioner.numPartitions(),
-                  taskAttemptId,
-                  taskContext.taskAttemptId()));
-    }
+    this.isIntegrityValidationClientManagementEnabled =
+        
RssShuffleManager.isIntegrityValidationClientManagementEnabled(rssConf);
+    this.shuffleTaskStats =
+        new ShuffleWriteTaskStats(
+            rssConf, partitioner.numPartitions(), taskAttemptId, 
taskContext.taskAttemptId());
   }
 
   // Gluten needs this method
@@ -389,7 +378,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
         processShuffleBlockInfos(shuffleBlockInfos);
       }
-      shuffleTaskStats.ifPresent(x -> x.incPartitionRecord(partition));
+      shuffleTaskStats.incPartitionRecord(partition);
     }
     final long start = System.currentTimeMillis();
     shuffleBlockInfos = bufferManager.clear(1.0);
@@ -444,14 +433,11 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     long expected = blockIds.size();
     long bufferManagerTracked = bufferManager.getBlockCount();
 
-    if (serverToPartitionToBlockIds == null) {
-      throw new RssException("serverToPartitionToBlockIds should not be null");
-    }
-
     // to filter the multiple replica's duplicate blockIds
     Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
-    for (Map<Integer, Set<Long>> partitionBlockIds : 
serverToPartitionToBlockIds.values()) {
-      partitionBlockIds.values().forEach(x -> 
x.forEach(blockIdBitmap::addLong));
+    for (Map<Integer, BlockStats> partitionBlockStats :
+        shuffleTaskStats.getAllBlockStats().values()) {
+      partitionBlockStats.values().forEach(x -> 
x.getBlockIds().forEach(blockIdBitmap::addLong));
     }
     long serverTracked = blockIdBitmap.getLongCardinality();
     if (expected != serverTracked || expected != bufferManagerTracked) {
@@ -485,22 +471,13 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             blockIds.add(blockId);
             int partitionId = sbi.getPartitionId();
             // record blocks number for per-partition
-            shuffleTaskStats.ifPresent(x -> x.incPartitionBlock(partitionId));
+            shuffleTaskStats.incPartitionBlock(partitionId);
             // update [partition, blockIds], it will be sent to shuffle server
             sbi.getShuffleServerInfos()
                 .forEach(
-                    shuffleServerInfo -> {
-                      Map<Integer, Set<Long>> pToBlockIds =
-                          serverToPartitionToBlockIds.computeIfAbsent(
-                              shuffleServerInfo, k -> Maps.newHashMap());
-                      pToBlockIds.computeIfAbsent(partitionId, v -> 
Sets.newHashSet()).add(blockId);
-
-                      // update the [partition, recordNumber]
-                      serverToPartitionToRecordNumbers
-                          .computeIfAbsent(shuffleServerInfo, k -> 
Maps.newHashMap())
-                          .compute(
-                              partitionId, (k, v) -> v == null ? recordNumber 
: v + recordNumber);
-                    });
+                    s ->
+                        shuffleTaskStats.mergeBlockStats(
+                            s, partitionId, new BlockStats(recordNumber, 
blockId)));
           });
       return postBlockEvent(shuffleBlockInfoList);
     }
@@ -519,7 +496,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
               // Otherwise, the block is released immediately once completed.
               if (!blockFailSentRetryEnabled || isSuccessful) {
                 bufferManager.releaseBlockResource(b);
-                partitionLengthStatistic.inc(b);
+                shuffleTaskStats.incPartitionLength(b.getPartitionId(), 
b.getLength());
               }
             });
       }
@@ -886,19 +863,14 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private void clearFailedBlockState(ShuffleBlockInfo block) {
     
shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(block.getBlockId());
+    shuffleTaskStats.decPartitionBlock(block.getPartitionId());
     block.getShuffleServerInfos().stream()
         .forEach(
-            s -> {
-              serverToPartitionToBlockIds
-                  .get(s)
-                  .get(block.getPartitionId())
-                  .remove(block.getBlockId());
-              serverToPartitionToRecordNumbers
-                  .get(s)
-                  .compute(
-                      block.getPartitionId(),
-                      (pid, recordNumber) -> recordNumber - 
block.getRecordNumber());
-            });
+            s ->
+                shuffleTaskStats.removeBlockStats(
+                    s,
+                    block.getPartitionId(),
+                    new BlockStats(block.getRecordNumber(), 
block.getBlockId())));
     blockIds.remove(block.getBlockId());
   }
 
@@ -949,14 +921,16 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       if (success) {
         long start = System.currentTimeMillis();
         shuffleWriteClient.reportShuffleResult(
-            serverToPartitionToBlockIds,
+            getServerToPartitionToBlockIds(),
             appId,
             shuffleId,
             taskAttemptId,
             bitmapSplitNum,
             recordReportFailedShuffleservers,
             enableWriteFailureRetry,
-            serverToPartitionToRecordNumbers);
+            isIntegrityValidationClientManagementEnabled
+                ? null
+                : getServerToPartitionToRecordNumbers());
         long reportDuration = System.currentTimeMillis() - start;
         LOG.info(
             "Reported all shuffle result for shuffleId[{}] task[{}] with 
bitmapNum[{}] cost {} ms",
@@ -967,14 +941,10 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         
shuffleWriteMetrics.incWriteTime(TimeUnit.MILLISECONDS.toNanos(reportDuration));
 
         if 
(RssShuffleManager.isIntegrationValidationFailureAnalysisEnabled(rssConf)) {
-          shuffleTaskStats.ifPresent(x -> x.log());
+          shuffleTaskStats.log();
         }
 
-        long[] partitionLens = partitionLengthStatistic.toArray();
-
-        if (shuffleTaskStats.isPresent()) {
-          shuffleTaskStats.get().check(partitionLens);
-        }
+        shuffleTaskStats.check();
 
         // todo: we can replace the dummy host and port with the real shuffle 
server which we prefer
         // to read
@@ -984,10 +954,11 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
                 DUMMY_HOST,
                 DUMMY_PORT,
                 Option.apply(
-                    shuffleTaskStats.isPresent()
-                        ? shuffleTaskStats.get().encode()
+                    isIntegrityValidationClientManagementEnabled
+                        ? shuffleTaskStats.encode()
                         : Long.toString(taskAttemptId)));
-        MapStatus mapStatus = MapStatus.apply(blockManagerId, partitionLens, 
taskAttemptId);
+        MapStatus mapStatus =
+            MapStatus.apply(blockManagerId, 
shuffleTaskStats.getPartitionLengths(), taskAttemptId);
         return Option.apply(mapStatus);
       } else {
         return Option.empty();
@@ -1062,7 +1033,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @VisibleForTesting
   Map<Integer, Set<Long>> getPartitionToBlockIds() {
-    return serverToPartitionToBlockIds.values().stream()
+    return getServerToPartitionToBlockIds().values().stream()
         .flatMap(s -> s.entrySet().stream())
         .collect(
             Collectors.toMap(
@@ -1127,9 +1098,41 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @VisibleForTesting
   protected Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
getServerToPartitionToBlockIds() {
+    Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
serverToPartitionToBlockIds = new HashMap<>();
+    Map<ShuffleServerInfo, Map<Integer, BlockStats>> allBlockStats =
+        shuffleTaskStats.getAllBlockStats();
+    for (Map.Entry<ShuffleServerInfo, Map<Integer, BlockStats>> entry : 
allBlockStats.entrySet()) {
+      ShuffleServerInfo server = entry.getKey();
+      for (Map.Entry<Integer, BlockStats> childEntry : 
entry.getValue().entrySet()) {
+        int partitionId = childEntry.getKey();
+        BlockStats stats = childEntry.getValue();
+        serverToPartitionToBlockIds
+            .computeIfAbsent(server, k -> new HashMap<>())
+            .computeIfAbsent(partitionId, z -> new HashSet<>())
+            .addAll(stats.getBlockIds());
+      }
+    }
     return serverToPartitionToBlockIds;
   }
 
+  @VisibleForTesting
+  protected Map<ShuffleServerInfo, Map<Integer, Long>> 
getServerToPartitionToRecordNumbers() {
+    Map<ShuffleServerInfo, Map<Integer, Long>> 
serverToPartitionToRecordNumbers = new HashMap<>();
+    Map<ShuffleServerInfo, Map<Integer, BlockStats>> allBlockStats =
+        shuffleTaskStats.getAllBlockStats();
+    for (Map.Entry<ShuffleServerInfo, Map<Integer, BlockStats>> entry : 
allBlockStats.entrySet()) {
+      ShuffleServerInfo server = entry.getKey();
+      for (Map.Entry<Integer, BlockStats> childEntry : 
entry.getValue().entrySet()) {
+        int partitionId = childEntry.getKey();
+        BlockStats stats = childEntry.getValue();
+        serverToPartitionToRecordNumbers
+            .computeIfAbsent(server, k -> new HashMap<>())
+            .merge(partitionId, stats.getRecordNumber(), Long::sum);
+      }
+    }
+    return serverToPartitionToRecordNumbers;
+  }
+
   @VisibleForTesting
   protected RssShuffleManager getShuffleManager() {
     return shuffleManager;

Reply via email to