This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new d6df94ca8 [#2673] feat(spark)(part-2): Merge partition stats for 
partition split on integrity validation (#2681)
d6df94ca8 is described below

commit d6df94ca8ed313ec63986273a744fa3574e78b98
Author: Junfan Zhang <[email protected]>
AuthorDate: Fri Nov 21 11:37:16 2025 +0800

    [#2673] feat(spark)(part-2): Merge partition stats for partition split on 
integrity validation (#2681)
    
    ### What changes were proposed in this pull request?
    
    This PR is to fix the incorrect aggregated expected record numbers when the 
partition split is activate.
    
    ### Why are the changes needed?
    
    If the partition split is activate and the server management is enabled for 
the integrity validation, the records check will fail due to the unmerged 
partition stats.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests.
---
 .../uniffle/shuffle/ShuffleReadTaskStats.java      |  9 +++
 .../apache/spark/shuffle/RssShuffleManager.java    | 27 ++++----
 .../spark/shuffle/writer/RssShuffleWriter.java     | 16 +++--
 .../apache/uniffle/client/api/ShuffleResult.java   | 21 +++----
 .../uniffle/client/impl/MergedPartitionStats.java  | 70 +++++++++++++++++++++
 .../client/impl/ShuffleWriteClientImpl.java        | 11 ++--
 .../client/impl/MergedPartitionStatsTest.java      | 71 ++++++++++++++++++++++
 7 files changed, 190 insertions(+), 35 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
index 0898648a5..7ebee7b0b 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
@@ -52,6 +52,15 @@ public class ShuffleReadTaskStats {
     return partitionBlocksReadPerMap.get(partitionId);
   }
 
+  /**
+   * Compared with the upstream shuffleWriteTaskStats that is re-built from 
the client
+   * mapOutputTracker metadata
+   *
+   * @param writeStats
+   * @param startPartition
+   * @param endPartition
+   * @return
+   */
   public boolean diff(
       Map<Long, ShuffleWriteTaskStats> writeStats, int startPartition, int 
endPartition) {
     StringBuilder infoBuilder = new StringBuilder();
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 85786881a..b81fd479d 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -65,7 +65,9 @@ import 
org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
 import org.apache.uniffle.client.api.ShuffleResult;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
+import org.apache.uniffle.client.impl.MergedPartitionStats;
 import org.apache.uniffle.client.util.ClientUtils;
+import org.apache.uniffle.client.util.RssClientConfig;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
@@ -392,21 +394,17 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
             context.stageAttemptNumber(),
             shuffleHandleInfo.createPartitionReplicaTracking());
     Roaring64NavigableMap blockIdBitmap = shuffleResult.getBlockIds();
+
     if (isIntegrityValidationServerManagementEnabled(rssConf)) {
-      Map<Integer, Map<Long, Long>> partitionToTaskAttemptIdToRecordNumbers =
-          shuffleResult.getPartitionToTaskAttemptIdToRecordNumbers();
-      if (partitionToTaskAttemptIdToRecordNumbers != null) {
-        long total =
-            partitionToTaskAttemptIdToRecordNumbers.values().stream()
-                .flatMap(x -> x.entrySet().stream())
-                .filter(x -> taskIdBitmap.contains(x.getKey()))
-                .mapToLong(Map.Entry::getValue)
-                .sum();
-        if (total > 0) {
-          expectedRecordsRead = total;
+      MergedPartitionStats mergedPartitionStats = 
shuffleResult.getMergedPartitionStats();
+      if (mergedPartitionStats != null) {
+        long records = 
mergedPartitionStats.getExpectedRecordNumberByTaskIds(taskIdBitmap);
+        if (records > 0) {
+          expectedRecordsRead = records;
         }
       }
     }
+
     LOG.info(
         "Retrieved {} upstream task ids in {} ms and {} block IDs from {} 
shuffle-servers in {} ms for shuffleId[{}], partitionId[{},{}]",
         taskIdBitmap.getLongCardinality(),
@@ -467,6 +465,13 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
   public static boolean isIntegrityValidationEnabled(RssConf rssConf) {
     assert rssConf != null;
+    // disable integrity validation when the multi replicas is enabled.
+    if (rssConf.getInteger(
+            RssClientConfig.RSS_DATA_REPLICA, 
RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE)
+        > 1) {
+      return false;
+    }
+    // only enable integrity validation when the spark version >= 3.5.0
     if (!Spark3VersionUtils.isSparkVersionAtLeast("3.5.0")) {
       return false;
     }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 48e2f2ab3..26229a841 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -886,11 +886,17 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     
shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(block.getBlockId());
     block.getShuffleServerInfos().stream()
         .forEach(
-            s ->
-                serverToPartitionToBlockIds
-                    .get(s)
-                    .get(block.getPartitionId())
-                    .remove(block.getBlockId()));
+            s -> {
+              serverToPartitionToBlockIds
+                  .get(s)
+                  .get(block.getPartitionId())
+                  .remove(block.getBlockId());
+              serverToPartitionToRecordNumbers
+                  .get(s)
+                  .compute(
+                      block.getPartitionId(),
+                      (pid, recordNumber) -> recordNumber - 
block.getRecordNumber());
+            });
     blockIds.remove(block.getBlockId());
   }
 
diff --git 
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleResult.java 
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleResult.java
index 5cae357d9..12d006861 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleResult.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleResult.java
@@ -17,27 +17,24 @@
 
 package org.apache.uniffle.client.api;
 
-import java.util.Map;
-
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
+import org.apache.uniffle.client.impl.MergedPartitionStats;
+
 public class ShuffleResult {
   private Roaring64NavigableMap blockIds;
-  // partitionId -> taskAttemptId -> recordNumber
-  private Map<Integer, Map<Long, Long>> 
partitionToTaskAttemptIdToRecordNumbers;
+  private MergedPartitionStats mergedPartitionStats;
 
-  public ShuffleResult(
-      Roaring64NavigableMap blockIds,
-      Map<Integer, Map<Long, Long>> partitionToTaskAttemptIdToRecordNumbers) {
+  public ShuffleResult(Roaring64NavigableMap blockIds, MergedPartitionStats 
mergedPartitionStats) {
     this.blockIds = blockIds;
-    this.partitionToTaskAttemptIdToRecordNumbers = 
partitionToTaskAttemptIdToRecordNumbers;
+    this.mergedPartitionStats = mergedPartitionStats;
   }
 
-  public Roaring64NavigableMap getBlockIds() {
-    return blockIds;
+  public MergedPartitionStats getMergedPartitionStats() {
+    return mergedPartitionStats;
   }
 
-  public Map<Integer, Map<Long, Long>> 
getPartitionToTaskAttemptIdToRecordNumbers() {
-    return partitionToTaskAttemptIdToRecordNumbers;
+  public Roaring64NavigableMap getBlockIds() {
+    return blockIds;
   }
 }
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/MergedPartitionStats.java 
b/client/src/main/java/org/apache/uniffle/client/impl/MergedPartitionStats.java
new file mode 100644
index 000000000..9464d2f2f
--- /dev/null
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/MergedPartitionStats.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.impl;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+/**
+ * Merge the partition stats (now includes the partition records number from 
the different upstream
+ * taskAttemptIds) from the shuffle-servers. The partition records checksum 
also can be supported in
+ * this class.
+ */
+public class MergedPartitionStats {
+  // partitionId -> taskAttemptId -> records
+  private final Map<Integer, Map<Long, Long>> 
partitionToTaskAttemptIdToRecordNumbers;
+
+  public MergedPartitionStats() {
+    this.partitionToTaskAttemptIdToRecordNumbers = new HashMap<>();
+  }
+
+  public void merge(Map<Integer, Map<Long, Long>> 
partitionToTaskAttemptIdToRecordNumbers) {
+    if (partitionToTaskAttemptIdToRecordNumbers == null) {
+      return;
+    }
+
+    for (Map.Entry<Integer, Map<Long, Long>> entry :
+        partitionToTaskAttemptIdToRecordNumbers.entrySet()) {
+      int partitionId = entry.getKey();
+      Map<Long, Long> incomingTaskMap = entry.getValue();
+
+      Map<Long, Long> currentTaskMap =
+          this.partitionToTaskAttemptIdToRecordNumbers.computeIfAbsent(
+              partitionId, k -> new HashMap<>());
+
+      for (Map.Entry<Long, Long> taskEntry : incomingTaskMap.entrySet()) {
+        long taskAttemptId = taskEntry.getKey();
+        long recordCount = taskEntry.getValue();
+        currentTaskMap.merge(taskAttemptId, recordCount, Long::sum);
+      }
+    }
+  }
+
+  // get the expected total record number filter by the upstream taskIds
+  public long getExpectedRecordNumberByTaskIds(Roaring64NavigableMap 
taskIdBitmap) {
+    long total =
+        partitionToTaskAttemptIdToRecordNumbers.values().stream()
+            .flatMap(x -> x.entrySet().stream())
+            .filter(x -> taskIdBitmap.contains(x.getKey()))
+            .mapToLong(Map.Entry::getValue)
+            .sum();
+    return total;
+  }
+}
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 36d49a125..7dc4059ac 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -24,7 +24,6 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
@@ -936,7 +935,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       Set<Integer> failedPartitions,
       PartitionDataReplicaRequirementTracking replicaRequirementTracking) {
     Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
-    Map<Integer, Map<Long, Long>> partitionToTaskAttemptIdToRecordNumbers = 
new HashMap<>();
+    MergedPartitionStats mergedPartitionStats = new MergedPartitionStats();
     Set<Integer> allRequestedPartitionIds = new HashSet<>();
     for (Map.Entry<ShuffleServerInfo, Set<Integer>> entry : 
serverToPartitions.entrySet()) {
       ShuffleServerInfo shuffleServerInfo = entry.getKey();
@@ -958,10 +957,8 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
           Roaring64NavigableMap blockIdBitmapOfServer = 
response.getBlockIdBitmap();
           blockIdBitmap.or(blockIdBitmapOfServer);
 
-          // todo: should be more careful to handle this under the multi 
replicas.
-          //  Now, this integrity validation is not supported for multi 
replicas
-          
Optional.ofNullable(response.getPartitionToTaskAttemptIdToRecordNumbers())
-              .ifPresent(x -> 
partitionToTaskAttemptIdToRecordNumbers.putAll(x));
+          // merge into multi same (partition,taskAttemptId) into one record
+          
mergedPartitionStats.merge(response.getPartitionToTaskAttemptIdToRecordNumbers());
 
           for (Integer partitionId : requestPartitions) {
             replicaRequirementTracking.markPartitionOfServerSuccessful(
@@ -990,7 +987,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       throw new RssFetchFailedException(
           "Get shuffle result is failed for appId[" + appId + "], shuffleId[" 
+ shuffleId + "]");
     }
-    return new ShuffleResult(blockIdBitmap, 
partitionToTaskAttemptIdToRecordNumbers);
+    return new ShuffleResult(blockIdBitmap, mergedPartitionStats);
   }
 
   @Override
diff --git 
a/client/src/test/java/org/apache/uniffle/client/impl/MergedPartitionStatsTest.java
 
b/client/src/test/java/org/apache/uniffle/client/impl/MergedPartitionStatsTest.java
new file mode 100644
index 000000000..458fde32f
--- /dev/null
+++ 
b/client/src/test/java/org/apache/uniffle/client/impl/MergedPartitionStatsTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.impl;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class MergedPartitionStatsTest {
+
+  private Map<Integer, Map<Long, Long>> generateRecords(
+      int partitionId, long taskId, long records) {
+    Map<Integer, Map<Long, Long>> recordsMap = new HashMap<>();
+    recordsMap.computeIfAbsent(partitionId, x -> new HashMap<>()).put(taskId, 
records);
+    return recordsMap;
+  }
+
+  @Test
+  public void testMerge() {
+    MergedPartitionStats stats = new MergedPartitionStats();
+
+    // case1
+    Map<Integer, Map<Long, Long>> r1 = generateRecords(1, 1, 100);
+    stats.merge(r1);
+
+    Map<Integer, Map<Long, Long>> r2 = generateRecords(1, 1, 200);
+    stats.merge(r2);
+
+    Roaring64NavigableMap taskIds = new Roaring64NavigableMap();
+    taskIds.add(1);
+    long records = stats.getExpectedRecordNumberByTaskIds(taskIds);
+    assertEquals(300, records);
+
+    // case2
+    stats.merge(generateRecords(1, 2, 100));
+    taskIds.add(2);
+    assertEquals(400, stats.getExpectedRecordNumberByTaskIds(taskIds));
+
+    // case3: filter out some taskIds
+    taskIds.clear();
+    taskIds.add(2);
+    assertEquals(100, stats.getExpectedRecordNumberByTaskIds(taskIds));
+
+    // case4: different partitionIds
+    stats.merge(generateRecords(2, 3, 100));
+    taskIds.clear();
+    taskIds.add(1);
+    taskIds.add(2);
+    taskIds.add(3);
+    assertEquals(500, stats.getExpectedRecordNumberByTaskIds(taskIds));
+  }
+}

Reply via email to