This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new de55bd90e [#2673] feat(spark)(part-1): Add client-side support for 
storing partition stats on shuffle servers (#2669)
de55bd90e is described below

commit de55bd90eecab19301c688a105f85a9b7873774a
Author: Junfan Zhang <[email protected]>
AuthorDate: Fri Nov 14 14:06:33 2025 +0800

    [#2673] feat(spark)(part-1): Add client-side support for storing partition 
stats on shuffle servers (#2669)
    
    ### What changes were proposed in this pull request?
    
    This is the part-1 PR only with uniffle client changes of making the 
partition stats stored in the shuffle-server side to make the integrity 
validation mechanism more stable. BTW, the shuffle-servers side changes will be 
implemented in the further PRs, and this PR is also compatible with  the legacy 
shuffle-server protocol.
    
    ### Why are the changes needed?
    
    the subtask for the issue #2673.
    
    By leveraging the PR #2653 , we could end-to-end ensure the data 
consistency. But, the partition stats stored in the spark driver side, for the 
normal spark stages, this design runs well. But with the 100000 tasks with 
10000 partitions, this will make the Spark driver overload. From the point of 
cluster spark jobs, some huge jobs will hang when getting the blockManagerIds, 
that will cost almost 20mins for one reader task, that is unacceptable.
    
    And so, this PR implements the server side store the partition stats like 
the blockID store did.
    
    ### Does this PR introduce _any_ user-facing change?
    
    `spark.rss.client.integrityValidation.serverManagementEnabled=false`
    
    ### How was this patch tested?
    
    Internal job tests.
---
 .../org/apache/spark/shuffle/RssSparkConfig.java   |  8 +++
 .../spark/shuffle/writer/WriteBufferManager.java   |  7 ++-
 .../apache/spark/shuffle/RssShuffleManager.java    | 52 +++++++++++++----
 .../spark/shuffle/writer/RssShuffleWriter.java     | 15 ++++-
 .../apache/uniffle/client/api/ShuffleResult.java   | 43 ++++++++++++++
 .../uniffle/client/api/ShuffleWriteClient.java     | 24 ++++++++
 .../client/impl/ShuffleWriteClientImpl.java        | 66 ++++++++++++++++++++--
 .../uniffle/common/DeferredCompressedBlock.java    |  6 +-
 .../apache/uniffle/common/ShuffleBlockInfo.java    | 38 ++++++++++++-
 .../common/DeferredCompressedBlockTest.java        |  3 +-
 .../client/impl/grpc/ShuffleServerGrpcClient.java  | 26 ++-------
 .../request/RssReportShuffleResultRequest.java     | 34 ++++++++++-
 .../response/RssGetShuffleResultResponse.java      | 30 +++++++++-
 proto/src/main/proto/Rss.proto                     | 12 ++++
 14 files changed, 316 insertions(+), 48 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index e83bc437f..be336ef33 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -67,6 +67,14 @@ public class RssSparkConfig {
               .defaultValue(Codec.Type.ZSTD)
               .withDescription("stats compression type");
 
+  public static final ConfigOption<Boolean>
+      RSS_DATA_INTEGRITY_VALIDATION_SERVER_MANAGEMENT_ENABLED =
+          
ConfigOptions.key("rss.client.integrityValidation.serverManagementEnabled")
+              .booleanType()
+              .defaultValue(false)
+              .withDescription(
+                  "Whether or not to enable validation management by 
shuffle-server rather than client side");
+
   public static final ConfigOption<Boolean> 
RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED =
       ConfigOptions.key("rss.client.read.shuffleHandleCacheEnabled")
           .booleanType()
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index e76201d04..dbb5e6943 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -428,6 +428,7 @@ public class WriteBufferManager extends MemoryConsumer {
     byte[] data = writerBuffer.getData();
     final int uncompressLength = data.length;
     final int memoryUsed = writerBuffer.getMemoryUsed();
+    final long records = writerBuffer.getRecordCount();
 
     this.blockCounter.incrementAndGet();
     this.uncompressedDataLen += uncompressLength;
@@ -467,7 +468,8 @@ public class WriteBufferManager extends MemoryConsumer {
         taskAttemptId,
         partitionAssignmentRetrieveFunc,
         rebuildFunction,
-        estimatedCompressedSize);
+        estimatedCompressedSize,
+        records);
   }
 
   // transform records to shuffleBlock
@@ -504,7 +506,8 @@ public class WriteBufferManager extends MemoryConsumer {
         uncompressLength,
         wb.getMemoryUsed(),
         taskAttemptId,
-        partitionAssignmentRetrieveFunc);
+        partitionAssignmentRetrieveFunc,
+        wb.getRecordCount());
   }
 
   // it's run in single thread, and is not thread safe
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 650ece88e..85786881a 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -62,6 +62,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
+import org.apache.uniffle.client.api.ShuffleResult;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.client.util.ClientUtils;
@@ -77,9 +78,6 @@ import org.apache.uniffle.shuffle.RssShuffleClientFactory;
 import org.apache.uniffle.shuffle.ShuffleWriteTaskStats;
 import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase;
 
-import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_INTEGRITY_VALIDATION_ENABLED;
-import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_DATA_INTEGRATION_VALIDATION_ANALYSIS_ENABLED;
-
 public class RssShuffleManager extends RssShuffleManagerBase {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleManager.class);
 
@@ -385,7 +383,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     Map<ShuffleServerInfo, Set<Integer>> serverToPartitions =
         getPartitionDataServers(shuffleHandleInfo, startPartition, 
endPartition);
     long start = System.currentTimeMillis();
-    Roaring64NavigableMap blockIdBitmap =
+    ShuffleResult shuffleResult =
         getShuffleResultForMultiPart(
             clientType,
             serverToPartitions,
@@ -393,7 +391,22 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
             shuffleId,
             context.stageAttemptNumber(),
             shuffleHandleInfo.createPartitionReplicaTracking());
-
+    Roaring64NavigableMap blockIdBitmap = shuffleResult.getBlockIds();
+    if (isIntegrityValidationServerManagementEnabled(rssConf)) {
+      Map<Integer, Map<Long, Long>> partitionToTaskAttemptIdToRecordNumbers =
+          shuffleResult.getPartitionToTaskAttemptIdToRecordNumbers();
+      if (partitionToTaskAttemptIdToRecordNumbers != null) {
+        long total =
+            partitionToTaskAttemptIdToRecordNumbers.values().stream()
+                .flatMap(x -> x.entrySet().stream())
+                .filter(x -> taskIdBitmap.contains(x.getKey()))
+                .mapToLong(Map.Entry::getValue)
+                .sum();
+        if (total > 0) {
+          expectedRecordsRead = total;
+        }
+      }
+    }
     LOG.info(
         "Retrieved {} upstream task ids in {} ms and {} block IDs from {} 
shuffle-servers in {} ms for shuffleId[{}], partitionId[{},{}]",
         taskIdBitmap.getLongCardinality(),
@@ -457,14 +470,29 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     if (!Spark3VersionUtils.isSparkVersionAtLeast("3.5.0")) {
       return false;
     }
-    return rssConf.get(RSS_CLIENT_INTEGRITY_VALIDATION_ENABLED);
+    return rssConf.get(RssSparkConfig.RSS_CLIENT_INTEGRITY_VALIDATION_ENABLED);
   }
 
-  public static boolean isIntegrationValidationFailureAnalysisEnabled(RssConf 
rssConf) {
+  public static boolean isIntegrityValidationServerManagementEnabled(RssConf 
rssConf) {
     if (!isIntegrityValidationEnabled(rssConf)) {
       return false;
     }
-    return rssConf.get(RSS_DATA_INTEGRATION_VALIDATION_ANALYSIS_ENABLED);
+    return 
rssConf.get(RssSparkConfig.RSS_DATA_INTEGRITY_VALIDATION_SERVER_MANAGEMENT_ENABLED);
+  }
+
+  public static boolean isIntegrityValidationClientManagementEnabled(RssConf 
rssConf) {
+    if (!isIntegrityValidationEnabled(rssConf)) {
+      return false;
+    }
+    return !isIntegrityValidationServerManagementEnabled(rssConf);
+  }
+
+  public static boolean isIntegrationValidationFailureAnalysisEnabled(RssConf 
rssConf) {
+    // todo: enable the validation failure analysis when the server management 
is enabled
+    if (isIntegrityValidationServerManagementEnabled(rssConf)) {
+      return false;
+    }
+    return 
rssConf.get(RssSparkConfig.RSS_DATA_INTEGRATION_VALIDATION_ANALYSIS_ENABLED);
   }
 
   @SuppressFBWarnings("REC_CATCH_EXCEPTION")
@@ -560,7 +588,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       int endPartition,
       int startMapIndex,
       int endMapIndex) {
-    if (!isIntegrityValidationEnabled(rssConf)) {
+    if (isIntegrityValidationServerManagementEnabled(rssConf)) {
       return Collections.emptyMap();
     }
     Iterator<BlockManagerId> iter =
@@ -590,7 +618,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       }
 
       String raw = blockManagerId.topologyInfo().get();
-      if (isIntegrityValidationEnabled(rssConf)) {
+      if (isIntegrityValidationClientManagementEnabled(rssConf)) {
         ShuffleWriteTaskStats shuffleWriteTaskStats = 
ShuffleWriteTaskStats.decode(rssConf, raw);
         taskIdBitmap.add(shuffleWriteTaskStats.getTaskAttemptId());
         for (int i = startPartition; i < endPartition; i++) {
@@ -736,7 +764,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     this.id = new AtomicReference<>(appId);
   }
 
-  private Roaring64NavigableMap getShuffleResultForMultiPart(
+  private ShuffleResult getShuffleResultForMultiPart(
       String clientType,
       Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
       String appId,
@@ -745,7 +773,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       PartitionDataReplicaRequirementTracking replicaRequirementTracking) {
     Set<Integer> failedPartitions = Sets.newHashSet();
     try {
-      return shuffleWriteClient.getShuffleResultForMultiPart(
+      return shuffleWriteClient.getShuffleResultForMultiPartV2(
           clientType,
           serverToPartitions,
           appId,
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 066f9368e..48e2f2ab3 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -118,6 +118,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final int bitmapSplitNum;
   // server -> partitionId -> blockIds
   private Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
serverToPartitionToBlockIds;
+  // server -> partitionId -> recordNumbers
+  private Map<ShuffleServerInfo, Map<Integer, Long>> 
serverToPartitionToRecordNumbers;
   private final ShuffleWriteClient shuffleWriteClient;
   private final Set<ShuffleServerInfo> shuffleServersForData;
   private final PartitionLengthStatistic partitionLengthStatistic;
@@ -224,6 +226,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.sendCheckInterval = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
     this.bitmapSplitNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
     this.serverToPartitionToBlockIds = Maps.newHashMap();
+    this.serverToPartitionToRecordNumbers = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
     this.shuffleServersForData = shuffleHandleInfo.getServers();
     this.partitionLengthStatistic = new 
PartitionLengthStatistic(partitioner.numPartitions());
@@ -244,7 +247,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.enableWriteFailureRetry = 
rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
     this.recordReportFailedShuffleservers = Sets.newConcurrentHashSet();
 
-    if (RssShuffleManager.isIntegrityValidationEnabled(rssConf)) {
+    if 
(RssShuffleManager.isIntegrityValidationClientManagementEnabled(rssConf)) {
       this.shuffleTaskStats =
           Optional.of(
               new ShuffleWriteTaskStats(
@@ -475,6 +478,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       shuffleBlockInfoList.forEach(
           sbi -> {
             long blockId = sbi.getBlockId();
+            long recordNumber = sbi.getRecordNumber();
             // add blockId to set, check if it is sent later
             blockIds.add(blockId);
             int partitionId = sbi.getPartitionId();
@@ -488,6 +492,12 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
                           serverToPartitionToBlockIds.computeIfAbsent(
                               shuffleServerInfo, k -> Maps.newHashMap());
                       pToBlockIds.computeIfAbsent(partitionId, v -> 
Sets.newHashSet()).add(blockId);
+
+                      // update the [partition, recordNumber]
+                      serverToPartitionToRecordNumbers
+                          .computeIfAbsent(shuffleServerInfo, k -> 
Maps.newHashMap())
+                          .compute(
+                              partitionId, (k, v) -> v == null ? recordNumber 
: v + recordNumber);
                     });
           });
       return postBlockEvent(shuffleBlockInfoList);
@@ -937,7 +947,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             taskAttemptId,
             bitmapSplitNum,
             recordReportFailedShuffleservers,
-            enableWriteFailureRetry);
+            enableWriteFailureRetry,
+            serverToPartitionToRecordNumbers);
         long reportDuration = System.currentTimeMillis() - start;
         LOG.info(
             "Reported all shuffle result for shuffleId[{}] task[{}] with 
bitmapNum[{}] cost {} ms",
diff --git 
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleResult.java 
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleResult.java
new file mode 100644
index 000000000..5cae357d9
--- /dev/null
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleResult.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.api;
+
+import java.util.Map;
+
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+public class ShuffleResult {
+  private Roaring64NavigableMap blockIds;
+  // partitionId -> taskAttemptId -> recordNumber
+  private Map<Integer, Map<Long, Long>> 
partitionToTaskAttemptIdToRecordNumbers;
+
+  public ShuffleResult(
+      Roaring64NavigableMap blockIds,
+      Map<Integer, Map<Long, Long>> partitionToTaskAttemptIdToRecordNumbers) {
+    this.blockIds = blockIds;
+    this.partitionToTaskAttemptIdToRecordNumbers = 
partitionToTaskAttemptIdToRecordNumbers;
+  }
+
+  public Roaring64NavigableMap getBlockIds() {
+    return blockIds;
+  }
+
+  public Map<Integer, Map<Long, Long>> 
getPartitionToTaskAttemptIdToRecordNumbers() {
+    return partitionToTaskAttemptIdToRecordNumbers;
+  }
+}
diff --git 
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java 
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index a0df5a5e8..63d1bbcc8 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -150,6 +150,20 @@ public interface ShuffleWriteClient {
       long taskAttemptId,
       int bitmapNum);
 
+  default void reportShuffleResult(
+      Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
serverToPartitionToBlockIds,
+      String appId,
+      int shuffleId,
+      long taskAttemptId,
+      int bitmapNum,
+      Set<ShuffleServerInfo> reportFailureServers,
+      boolean enableWriteFailureRetry,
+      Map<ShuffleServerInfo, Map<Integer, Long>> 
serverToPartitionToRecordNumbers) {
+    throw new UnsupportedOperationException(
+        this.getClass().getName()
+            + " doesn't implement reportShuffleResult with integrity 
validation mechanism");
+  }
+
   default void reportShuffleResult(
       Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
serverToPartitionToBlockIds,
       String appId,
@@ -227,6 +241,16 @@ public interface ShuffleWriteClient {
       Set<Integer> failedPartitions,
       PartitionDataReplicaRequirementTracking replicaRequirementTracking);
 
+  default ShuffleResult getShuffleResultForMultiPartV2(
+      String clientType,
+      Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
+      String appId,
+      int shuffleId,
+      Set<Integer> failedPartitions,
+      PartitionDataReplicaRequirementTracking replicaRequirementTracking) {
+    throw new UnsupportedOperationException();
+  }
+
   void close();
 
   void unregisterShuffle(String appId, int shuffleId);
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 25d17205a..36d49a125 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -24,6 +24,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
@@ -46,6 +47,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
+import org.apache.uniffle.client.api.ShuffleResult;
 import org.apache.uniffle.client.api.ShuffleServerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
@@ -755,15 +757,25 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       long taskAttemptId,
       int bitmapNum,
       Set<ShuffleServerInfo> reportFailureServers,
-      boolean enableWriteFailureRetry) {
+      boolean enableWriteFailureRetry,
+      Map<ShuffleServerInfo, Map<Integer, Long>> 
serverToPartitionToRecordNumbers) {
     // record blockId count for quora check,but this is not a good realization.
     Map<Long, Integer> blockReportTracker = 
createBlockReportTracker(serverToPartitionToBlockIds);
     for (Map.Entry<ShuffleServerInfo, Map<Integer, Set<Long>>> entry :
         serverToPartitionToBlockIds.entrySet()) {
+      ShuffleServerInfo serverInfo = entry.getKey();
       Map<Integer, Set<Long>> requestBlockIds = entry.getValue();
       if (requestBlockIds.isEmpty()) {
         continue;
       }
+      Map<Integer, Long> partitionToRecordNumbers = null;
+      if (serverToPartitionToRecordNumbers != null) {
+        partitionToRecordNumbers = 
serverToPartitionToRecordNumbers.get(serverInfo);
+        if (partitionToRecordNumbers == null) {
+          throw new RssException(
+              "Should not happen that partitionToRecordNumbers is null but 
blockIds is not empty!");
+        }
+      }
       RssReportShuffleResultRequest request =
           new RssReportShuffleResultRequest(
               appId,
@@ -771,7 +783,8 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
               taskAttemptId,
               requestBlockIds.entrySet().stream()
                   .collect(Collectors.toMap(Map.Entry::getKey, e -> new 
ArrayList<>(e.getValue()))),
-              bitmapNum);
+              bitmapNum,
+              partitionToRecordNumbers);
       ShuffleServerInfo ssi = entry.getKey();
       try {
         long start = System.currentTimeMillis();
@@ -830,6 +843,26 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
     }
   }
 
+  @Override
+  public void reportShuffleResult(
+      Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
serverToPartitionToBlockIds,
+      String appId,
+      int shuffleId,
+      long taskAttemptId,
+      int bitmapNum,
+      Set<ShuffleServerInfo> reportFailureServers,
+      boolean enableWriteFailureRetry) {
+    reportShuffleResult(
+        serverToPartitionToBlockIds,
+        appId,
+        shuffleId,
+        taskAttemptId,
+        bitmapNum,
+        reportFailureServers,
+        enableWriteFailureRetry,
+        null);
+  }
+
   private void recordFailedBlockIds(
       Map<Long, Integer> blockReportTracker, Map<Integer, Set<Long>> 
requestBlockIds) {
     requestBlockIds.values().stream()
@@ -895,7 +928,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
   }
 
   @Override
-  public Roaring64NavigableMap getShuffleResultForMultiPart(
+  public ShuffleResult getShuffleResultForMultiPartV2(
       String clientType,
       Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
       String appId,
@@ -903,6 +936,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       Set<Integer> failedPartitions,
       PartitionDataReplicaRequirementTracking replicaRequirementTracking) {
     Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
+    Map<Integer, Map<Long, Long>> partitionToTaskAttemptIdToRecordNumbers = 
new HashMap<>();
     Set<Integer> allRequestedPartitionIds = new HashSet<>();
     for (Map.Entry<ShuffleServerInfo, Set<Integer>> entry : 
serverToPartitions.entrySet()) {
       ShuffleServerInfo shuffleServerInfo = entry.getKey();
@@ -923,6 +957,12 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
           // merge into blockIds from multiple servers.
           Roaring64NavigableMap blockIdBitmapOfServer = 
response.getBlockIdBitmap();
           blockIdBitmap.or(blockIdBitmapOfServer);
+
+          // todo: should be more careful to handle this under the multi 
replicas.
+          //  Now, this integrity validation is not supported for multi 
replicas
+          
Optional.ofNullable(response.getPartitionToTaskAttemptIdToRecordNumbers())
+              .ifPresent(x -> 
partitionToTaskAttemptIdToRecordNumbers.putAll(x));
+
           for (Integer partitionId : requestPartitions) {
             replicaRequirementTracking.markPartitionOfServerSuccessful(
                 partitionId, shuffleServerInfo);
@@ -950,7 +990,25 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       throw new RssFetchFailedException(
           "Get shuffle result is failed for appId[" + appId + "], shuffleId[" 
+ shuffleId + "]");
     }
-    return blockIdBitmap;
+    return new ShuffleResult(blockIdBitmap, 
partitionToTaskAttemptIdToRecordNumbers);
+  }
+
+  @Override
+  public Roaring64NavigableMap getShuffleResultForMultiPart(
+      String clientType,
+      Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
+      String appId,
+      int shuffleId,
+      Set<Integer> failedPartitions,
+      PartitionDataReplicaRequirementTracking replicaRequirementTracking) {
+    return getShuffleResultForMultiPartV2(
+            clientType,
+            serverToPartitions,
+            appId,
+            shuffleId,
+            failedPartitions,
+            replicaRequirementTracking)
+        .getBlockIds();
   }
 
   @Override
diff --git 
a/common/src/main/java/org/apache/uniffle/common/DeferredCompressedBlock.java 
b/common/src/main/java/org/apache/uniffle/common/DeferredCompressedBlock.java
index 4fcbd11f9..d2cb4076e 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/DeferredCompressedBlock.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/DeferredCompressedBlock.java
@@ -43,7 +43,8 @@ public class DeferredCompressedBlock extends ShuffleBlockInfo 
{
       long taskAttemptId,
       Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc,
       Function<DeferredCompressedBlock, DeferredCompressedBlock> 
rebuildFunction,
-      int estimatedCompressedSize) {
+      int estimatedCompressedSize,
+      long records) {
     super(
         shuffleId,
         partitionId,
@@ -52,7 +53,8 @@ public class DeferredCompressedBlock extends ShuffleBlockInfo 
{
         uncompressLength,
         freeMemory,
         taskAttemptId,
-        partitionAssignmentRetrieveFunc);
+        partitionAssignmentRetrieveFunc,
+        records);
     this.rebuildFunction = rebuildFunction;
     this.estimatedCompressedSize = estimatedCompressedSize;
   }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
index c429ea7a6..1169b2931 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
@@ -41,6 +41,36 @@ public class ShuffleBlockInfo {
   protected long crc;
   protected ByteBuf data;
 
+  protected long recordNumber;
+
+  public ShuffleBlockInfo(
+      int shuffleId,
+      int partitionId,
+      long blockId,
+      int length,
+      long crc,
+      byte[] data,
+      List<ShuffleServerInfo> shuffleServerInfos,
+      int uncompressLength,
+      long freeMemory,
+      long taskAttemptId,
+      Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc,
+      long records) {
+    this(
+        shuffleId,
+        partitionId,
+        blockId,
+        length,
+        crc,
+        data,
+        shuffleServerInfos,
+        uncompressLength,
+        freeMemory,
+        taskAttemptId);
+    this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
+    this.recordNumber = records;
+  }
+
   public ShuffleBlockInfo(
       int shuffleId,
       int partitionId,
@@ -75,7 +105,8 @@ public class ShuffleBlockInfo {
       int uncompressLength,
       long freeMemory,
       long taskAttemptId,
-      Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc) {
+      Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc,
+      long recordNumber) {
     this.shuffleId = shuffleId;
     this.partitionId = partitionId;
     this.blockId = blockId;
@@ -84,6 +115,7 @@ public class ShuffleBlockInfo {
     this.freeMemory = freeMemory;
     this.taskAttemptId = taskAttemptId;
     this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
+    this.recordNumber = recordNumber;
   }
 
   public ShuffleBlockInfo(
@@ -247,4 +279,8 @@ public class ShuffleBlockInfo {
     }
     return false;
   }
+
+  public long getRecordNumber() {
+    return recordNumber;
+  }
 }
diff --git 
a/common/src/test/java/org/apache/uniffle/common/DeferredCompressedBlockTest.java
 
b/common/src/test/java/org/apache/uniffle/common/DeferredCompressedBlockTest.java
index 5ec2e851d..71e70c18a 100644
--- 
a/common/src/test/java/org/apache/uniffle/common/DeferredCompressedBlockTest.java
+++ 
b/common/src/test/java/org/apache/uniffle/common/DeferredCompressedBlockTest.java
@@ -44,7 +44,8 @@ public class DeferredCompressedBlockTest {
               deferredCompressedBlock.reset(new byte[10], 10, 10);
               return deferredCompressedBlock;
             },
-            10);
+            10,
+            -1L);
 
     // case1: some params accessing won't trigger initialization
     block.getBlockId();
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 00c25e672..366866d06 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -105,7 +105,6 @@ import 
org.apache.uniffle.proto.RssProtos.GetShuffleResultForMultiPartResponse;
 import org.apache.uniffle.proto.RssProtos.GetShuffleResultRequest;
 import org.apache.uniffle.proto.RssProtos.GetShuffleResultResponse;
 import org.apache.uniffle.proto.RssProtos.MergeContext;
-import org.apache.uniffle.proto.RssProtos.PartitionToBlockIds;
 import org.apache.uniffle.proto.RssProtos.RemoteStorage;
 import org.apache.uniffle.proto.RssProtos.RemoteStorageConfItem;
 import org.apache.uniffle.proto.RssProtos.ReportShuffleResultRequest;
@@ -788,26 +787,7 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
 
   @Override
   public RssReportShuffleResultResponse 
reportShuffleResult(RssReportShuffleResultRequest request) {
-    List<PartitionToBlockIds> partitionToBlockIds = Lists.newArrayList();
-    for (Map.Entry<Integer, List<Long>> entry : 
request.getPartitionToBlockIds().entrySet()) {
-      List<Long> blockIds = entry.getValue();
-      if (blockIds != null && !blockIds.isEmpty()) {
-        partitionToBlockIds.add(
-            PartitionToBlockIds.newBuilder()
-                .setPartitionId(entry.getKey())
-                .addAllBlockIds(entry.getValue())
-                .build());
-      }
-    }
-
-    ReportShuffleResultRequest recRequest =
-        ReportShuffleResultRequest.newBuilder()
-            .setAppId(request.getAppId())
-            .setShuffleId(request.getShuffleId())
-            .setTaskAttemptId(request.getTaskAttemptId())
-            .setBitmapNum(request.getBitmapNum())
-            .addAllPartitionToBlockIds(partitionToBlockIds)
-            .build();
+    ReportShuffleResultRequest recRequest = request.toProto();
     ReportShuffleResultResponse rpcResponse = 
doReportShuffleResult(recRequest);
 
     RssProtos.StatusCode statusCode = rpcResponse.getStatus();
@@ -927,7 +907,9 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
         try {
           response =
               new RssGetShuffleResultResponse(
-                  StatusCode.SUCCESS, 
rpcResponse.getSerializedBitmap().toByteArray());
+                  StatusCode.SUCCESS,
+                  rpcResponse.getSerializedBitmap().toByteArray(),
+                  rpcResponse.getPartitionStatsList());
         } catch (Exception e) {
           throw new RssException(e);
         }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleResultRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleResultRequest.java
index 3a4f9fb22..5667b0abf 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleResultRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleResultRequest.java
@@ -17,6 +17,7 @@
 
 package org.apache.uniffle.client.request;
 
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 
@@ -31,18 +32,30 @@ public class RssReportShuffleResultRequest {
   private long taskAttemptId;
   private int bitmapNum;
   private Map<Integer, List<Long>> partitionToBlockIds;
+  private Map<Integer, Long> partitionToRecordNumbers;
 
   public RssReportShuffleResultRequest(
       String appId,
       int shuffleId,
       long taskAttemptId,
       Map<Integer, List<Long>> partitionToBlockIds,
-      int bitmapNum) {
+      int bitmapNum,
+      Map<Integer, Long> partitionToRecordNumbers) {
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.taskAttemptId = taskAttemptId;
     this.bitmapNum = bitmapNum;
     this.partitionToBlockIds = partitionToBlockIds;
+    this.partitionToRecordNumbers = partitionToRecordNumbers;
+  }
+
+  public RssReportShuffleResultRequest(
+      String appId,
+      int shuffleId,
+      long taskAttemptId,
+      Map<Integer, List<Long>> partitionToBlockIds,
+      int bitmapNum) {
+    this(appId, shuffleId, taskAttemptId, partitionToBlockIds, bitmapNum, 
null);
   }
 
   public String getAppId() {
@@ -79,6 +92,24 @@ public class RssReportShuffleResultRequest {
       }
     }
 
+    List<RssProtos.PartitionStats> partitionStats = Lists.newArrayList();
+    if (partitionToRecordNumbers != null) {
+      for (Map.Entry<Integer, Long> entry : 
partitionToRecordNumbers.entrySet()) {
+        int partitionId = entry.getKey();
+        long recordNumber = entry.getValue();
+        RssProtos.TaskAttemptIdToRecords taskAttemptIdToRecords =
+            RssProtos.TaskAttemptIdToRecords.newBuilder()
+                .setTaskAttemptId(taskAttemptId)
+                .setRecordNumber(recordNumber)
+                .build();
+        partitionStats.add(
+            RssProtos.PartitionStats.newBuilder()
+                .setPartitionId(partitionId)
+                
.addAllTaskAttemptIdToRecords(Arrays.asList(taskAttemptIdToRecords))
+                .build());
+      }
+    }
+
     RssProtos.ReportShuffleResultRequest rpcRequest =
         RssProtos.ReportShuffleResultRequest.newBuilder()
             .setAppId(request.getAppId())
@@ -86,6 +117,7 @@ public class RssReportShuffleResultRequest {
             .setTaskAttemptId(request.getTaskAttemptId())
             .setBitmapNum(request.getBitmapNum())
             .addAllPartitionToBlockIds(partitionToBlockIds)
+            .addAllPartitionStats(partitionStats)
             .build();
     return rpcRequest;
   }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleResultResponse.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleResultResponse.java
index aca33aaed..2a5451220 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleResultResponse.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleResultResponse.java
@@ -18,6 +18,9 @@
 package org.apache.uniffle.client.response;
 
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
@@ -29,6 +32,26 @@ import org.apache.uniffle.proto.RssProtos;
 public class RssGetShuffleResultResponse extends ClientResponse {
 
   private Roaring64NavigableMap blockIdBitmap;
+  // partitionId -> taskAttemptId -> recordNumber
+  private Map<Integer, Map<Long, Long>> 
partitionToTaskAttemptIdToRecordNumbers = new HashMap<>();
+
+  public RssGetShuffleResultResponse(
+      StatusCode statusCode,
+      byte[] serializedBitmap,
+      List<RssProtos.PartitionStats> partitionStatsList)
+      throws IOException {
+    super(statusCode);
+    blockIdBitmap = RssUtils.deserializeBitMap(serializedBitmap);
+    for (RssProtos.PartitionStats partitionStats : partitionStatsList) {
+      int partitionId = partitionStats.getPartitionId();
+      for (RssProtos.TaskAttemptIdToRecords record :
+          partitionStats.getTaskAttemptIdToRecordsList()) {
+        partitionToTaskAttemptIdToRecordNumbers
+            .computeIfAbsent(partitionId, k -> new HashMap<>())
+            .put(record.getTaskAttemptId(), record.getRecordNumber());
+      }
+    }
+  }
 
   public RssGetShuffleResultResponse(StatusCode statusCode, byte[] 
serializedBitmap)
       throws IOException {
@@ -40,6 +63,10 @@ public class RssGetShuffleResultResponse extends 
ClientResponse {
     return blockIdBitmap;
   }
 
+  public Map<Integer, Map<Long, Long>> 
getPartitionToTaskAttemptIdToRecordNumbers() {
+    return partitionToTaskAttemptIdToRecordNumbers;
+  }
+
   public static RssGetShuffleResultResponse fromProto(
       RssProtos.GetShuffleResultResponse rpcResponse) {
     try {
@@ -56,7 +83,8 @@ public class RssGetShuffleResultResponse extends 
ClientResponse {
     try {
       return new RssGetShuffleResultResponse(
           StatusCode.fromProto(rpcResponse.getStatus()),
-          rpcResponse.getSerializedBitmap().toByteArray());
+          rpcResponse.getSerializedBitmap().toByteArray(),
+          rpcResponse.getPartitionStatsList());
     } catch (Exception e) {
       throw new RssException(e);
     }
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index c48fe5c38..89845fe18 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -140,6 +140,7 @@ message ReportShuffleResultRequest {
   int64 taskAttemptId = 3;
   int32 bitmapNum = 4;
   repeated PartitionToBlockIds partitionToBlockIds = 5;
+  repeated PartitionStats partitionStats = 6;
 }
 
 message PartitionToBlockIds {
@@ -182,6 +183,17 @@ message GetShuffleResultForMultiPartResponse {
   StatusCode status = 1;
   string retMsg = 2;
   bytes serializedBitmap = 3;
+  repeated PartitionStats partitionStats = 4;
+}
+
+message PartitionStats {
+  int32 partitionId = 1;
+  repeated TaskAttemptIdToRecords taskAttemptIdToRecords = 2;
+}
+
+message TaskAttemptIdToRecords {
+  int64 taskAttemptId = 1;
+  int64 recordNumber = 2;
 }
 
 message ShufflePartitionRange {


Reply via email to