This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new d9815c06f [#2652] feat(spark): Add detailed integrity validation 
failure analysis (#2657)
d9815c06f is described below

commit d9815c06fea0f98a4afe2135869a7dc443c0fc93
Author: Junfan Zhang <[email protected]>
AuthorDate: Mon Nov 3 10:06:25 2025 +0800

    [#2652] feat(spark): Add detailed integrity validation failure analysis 
(#2657)
    
    ### What changes were proposed in this pull request?
    
    This PR is to add the detailed integrity validation failure analysis to 
hopefully get the concrate upstream task attempt id for the further dig.
    
    ### Why are the changes needed?
    
    the followup for the #2653
    
    ### Does this PR introduce _any_ user-facing change?
    
    `spark.rss.client.integrityValidation.failureAnalysisEnabled=false`
    
    ### How was this patch tested?
    
    Unit tests.
---
 .../org/apache/spark/shuffle/RssSparkConfig.java   |  9 +-
 .../shuffle/reader/RssShuffleDataIterator.java     | 25 +++++-
 .../apache/spark/shuffle/writer/DataPusher.java    | 11 +--
 .../uniffle/shuffle/ShuffleReadTaskStats.java      | 99 ++++++++++++++++++++++
 .../uniffle/shuffle/ShuffleWriteTaskStats.java     | 78 ++++++++++++++++-
 .../uniffle/shuffle/ShuffleReadTaskStatsTest.java  | 82 ++++++++++++++++++
 .../uniffle/shuffle/ShuffleWriteTaskStatsTest.java | 11 ++-
 .../apache/spark/shuffle/RssShuffleManager.java    | 76 ++++++++++++++---
 .../spark/shuffle/reader/RssShuffleReader.java     | 30 ++++++-
 .../spark/shuffle/writer/RssShuffleWriter.java     | 37 +++++---
 .../uniffle/client/impl/DecompressionWorker.java   |  5 +-
 .../uniffle/client/impl/ShuffleReadClientImpl.java |  3 +-
 .../client/response/CompressedShuffleBlock.java    |  7 +-
 .../client/response/DecompressedShuffleBlock.java  |  6 +-
 .../uniffle/client/response/ShuffleBlock.java      | 15 +++-
 15 files changed, 447 insertions(+), 47 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index fd330b359..a17fa57bd 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -41,9 +41,16 @@ public class RssSparkConfig {
 
   public static final ConfigOption<Boolean> 
RSS_CLIENT_INTEGRITY_VALIDATION_ENABLED =
       ConfigOptions.key("rss.client.integrityValidation.enabled")
+          .booleanType()
+          .defaultValue(true)
+          .withDescription(
+              "Whether or not to enable shuffle data integrity validation 
mechanism (spark version >= 3.5.0)");
+
+  public static final ConfigOption<Boolean> 
RSS_DATA_INTEGRATION_VALIDATION_ANALYSIS_ENABLED =
+      
ConfigOptions.key("rss.client.integrityValidation.failureAnalysisEnabled")
           .booleanType()
           .defaultValue(false)
-          .withDescription("Whether or not to enable shuffle data integrity 
validation mechanism");
+          .withDescription("Whether or not to enable validation failure 
analysis");
 
   public static final ConfigOption<Boolean> 
RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED =
       ConfigOptions.key("rss.client.read.shuffleHandleCacheEnabled")
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
index 646807d2b..15f9ca3c6 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
@@ -45,6 +45,7 @@ import org.apache.uniffle.common.ShuffleReadTimes;
 import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.util.RssUtils;
+import org.apache.uniffle.shuffle.ShuffleReadTaskStats;
 
 public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, 
C>> {
 
@@ -64,6 +65,10 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
   private ByteBuffer uncompressedData;
   private Optional<Codec> codec;
 
+  private final int partitionId;
+  private Optional<ShuffleReadTaskStats> shuffleReadTaskStats;
+  private long currentBlockTaskAttemptId = -1L;
+
   // only for tests
   @VisibleForTesting
   public RssShuffleDataIterator(
@@ -71,7 +76,14 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
       ShuffleReadClient shuffleReadClient,
       ShuffleReadMetrics shuffleReadMetrics,
       RssConf rssConf) {
-    this(serializer, shuffleReadClient, shuffleReadMetrics, rssConf, 
Optional.empty());
+    this(
+        serializer,
+        shuffleReadClient,
+        shuffleReadMetrics,
+        rssConf,
+        Optional.empty(),
+        Optional.empty(),
+        0);
     boolean compress =
         rssConf.getBoolean(
             RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY.substring(
@@ -85,11 +97,15 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
       ShuffleReadClient shuffleReadClient,
       ShuffleReadMetrics shuffleReadMetrics,
       RssConf rssConf,
-      Optional<Codec> codec) {
+      Optional<Codec> codec,
+      Optional<ShuffleReadTaskStats> shuffleReadTaskStats,
+      int partitionId) {
     this.serializerInstance = serializer.newInstance();
     this.shuffleReadClient = shuffleReadClient;
     this.shuffleReadMetrics = shuffleReadMetrics;
     this.codec = codec;
+    this.shuffleReadTaskStats = shuffleReadTaskStats;
+    this.partitionId = partitionId;
   }
 
   public Iterator<Tuple2<Object, Object>> createKVIterator(ByteBuffer data) {
@@ -136,6 +152,9 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
       long getBufferDuration = System.currentTimeMillis() - getBuffer;
 
       if (rawData != null) {
+        this.currentBlockTaskAttemptId = shuffleBlock.getTaskAttemptId();
+        shuffleReadTaskStats.ifPresent(
+            stats -> stats.incPartitionBlock(partitionId, 
shuffleBlock.getTaskAttemptId()));
         // collect metrics from raw data
         long rawDataLength = rawData.limit() - rawData.position();
         totalRawBytesLength += rawDataLength;
@@ -241,6 +260,8 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
   @Override
   public Product2<K, C> next() {
     shuffleReadMetrics.incRecordsRead(1L);
+    shuffleReadTaskStats.ifPresent(
+        x -> x.incPartitionRecord(partitionId, currentBlockTaskAttemptId));
     return (Product2<K, C>) recordsIterator.next();
   }
 
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
index 9c74e328b..05dfce694 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -108,6 +108,12 @@ public class DataPusher implements Closeable {
                         event.getStageAttemptNumber(),
                         validBlocks,
                         () -> !isValidTask(taskId));
+                // completionCallback should be executed before updating 
taskToSuccessBlockIds
+                // structure to avoid side effect
+                Set<Long> succeedBlockIds = getSucceedBlockIds(result);
+                for (ShuffleBlockInfo block : validBlocks) {
+                  
block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
+                }
                 putBlockId(taskToSuccessBlockIds, taskId, 
result.getSuccessBlockIds());
                 putFailedBlockSendTracker(
                     taskToFailedBlockSendTracker, taskId, 
result.getFailedBlockSendTracker());
@@ -119,11 +125,6 @@ public class DataPusher implements Closeable {
                   bufferManager.merge(shuffleServerPushCostTracker);
                 }
 
-                Set<Long> succeedBlockIds = getSucceedBlockIds(result);
-                for (ShuffleBlockInfo block : validBlocks) {
-                  
block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
-                }
-
                 List<Runnable> callbackChain =
                     
Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
                 for (Runnable runnable : callbackChain) {
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
new file mode 100644
index 000000000..d78f24a2f
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ShuffleReadTaskStats {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(ShuffleReadTaskStats.class);
+
+  // partition_id -> upstream_map_id -> records_read
+  private Map<Integer, Map<Long, Long>> partitionRecordsReadPerMap = new 
HashMap<>();
+  // partition_id -> upstream_map_id -> blocks_read
+  private Map<Integer, Map<Long, Long>> partitionBlocksReadPerMap = new 
HashMap<>();
+
+  public void incPartitionRecord(int partitionId, long taskAttemptId) {
+    Map<Long, Long> records =
+        partitionRecordsReadPerMap.computeIfAbsent(partitionId, k -> new 
HashMap<>());
+    records.compute(taskAttemptId, (k, v) -> v == null ? 1 : v + 1);
+  }
+
+  public void incPartitionBlock(int partitionId, long taskAttemptId) {
+    Map<Long, Long> records =
+        partitionBlocksReadPerMap.computeIfAbsent(partitionId, k -> new 
HashMap<>());
+    records.compute(taskAttemptId, (k, v) -> v == null ? 1 : v + 1);
+  }
+
+  public Map<Long, Long> getPartitionRecords(int partitionId) {
+    return partitionRecordsReadPerMap.get(partitionId);
+  }
+
+  public Map<Long, Long> getPartitionBlocks(int partitionId) {
+    return partitionBlocksReadPerMap.get(partitionId);
+  }
+
+  public boolean diff(
+      Map<Long, ShuffleWriteTaskStats> writeStats, int startPartition, int 
endPartition) {
+    StringBuilder infoBuilder = new StringBuilder();
+    for (int idx = startPartition; idx < endPartition; idx++) {
+      for (Map.Entry<Long, Long> recordEntry : 
partitionRecordsReadPerMap.get(idx).entrySet()) {
+        long taskAttemptId = recordEntry.getKey();
+        long recordsRead = recordEntry.getValue();
+        long blocksRead =
+            Optional.ofNullable(partitionBlocksReadPerMap.get(idx))
+                .map(m -> m.getOrDefault(taskAttemptId, 0L))
+                .orElse(0L);
+
+        ShuffleWriteTaskStats stats = writeStats.get(taskAttemptId);
+        if (stats == null) {
+          LOGGER.warn("Should not happen that task attempt {} has no write 
stats", taskAttemptId);
+          continue;
+        }
+        long recordsUpstream = stats.getRecordsWritten(idx);
+        long blocksUpstream = stats.getBlocksWritten(idx);
+        if (recordsRead != recordsUpstream || blocksRead != blocksUpstream) {
+          infoBuilder.append(idx);
+          infoBuilder.append("/");
+          infoBuilder.append(stats.getTaskId());
+          infoBuilder.append("/");
+          infoBuilder.append(recordsRead);
+          infoBuilder.append("-");
+          infoBuilder.append(recordsUpstream);
+          infoBuilder.append("/");
+          infoBuilder.append(blocksRead);
+          infoBuilder.append("-");
+          infoBuilder.append(blocksUpstream);
+          infoBuilder.append(",");
+        }
+      }
+    }
+    if (infoBuilder.length() > 0) {
+      infoBuilder.insert(
+          0,
+          "Errors on integrity validating. 
Details(partitionId/upstreamTaskId/recordsRead-recordsUpstream/blocksRead-blocksUpstream):
 ");
+      LOGGER.warn(infoBuilder.toString());
+      return false;
+    }
+    return true;
+  }
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
index 3625f9654..e86da494b 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
@@ -18,6 +18,12 @@
 package org.apache.uniffle.shuffle;
 
 import java.nio.ByteBuffer;
+import java.util.Arrays;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.exception.RssException;
 
 import static java.nio.charset.StandardCharsets.ISO_8859_1;
 
@@ -26,12 +32,23 @@ import static java.nio.charset.StandardCharsets.ISO_8859_1;
  * attempt ID and the number of records written for each partition.
  */
 public class ShuffleWriteTaskStats {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(ShuffleWriteTaskStats.class);
+
+  // the unique task id across all stages
+  private long taskId;
+  // this is only unique for one stage and defined in uniffle side instead of 
spark
   private long taskAttemptId;
   private long[] partitionRecordsWritten;
+  private long[] partitionBlocksWritten;
 
-  public ShuffleWriteTaskStats(int partitions, long taskAttemptId) {
+  public ShuffleWriteTaskStats(int partitions, long taskAttemptId, long 
taskId) {
     this.partitionRecordsWritten = new long[partitions];
+    this.partitionBlocksWritten = new long[partitions];
     this.taskAttemptId = taskAttemptId;
+    this.taskId = taskId;
+
+    Arrays.fill(this.partitionRecordsWritten, 0L);
+    Arrays.fill(this.partitionBlocksWritten, 0L);
   }
 
   public long getRecordsWritten(int partitionId) {
@@ -42,30 +59,85 @@ public class ShuffleWriteTaskStats {
     partitionRecordsWritten[partitionId]++;
   }
 
+  public void incPartitionBlock(int partitionId) {
+    partitionBlocksWritten[partitionId]++;
+  }
+
+  public long getBlocksWritten(int partitionId) {
+    return partitionBlocksWritten[partitionId];
+  }
+
   public long getTaskAttemptId() {
     return taskAttemptId;
   }
 
   public String encode() {
     int partitions = partitionRecordsWritten.length;
-    ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Integer.BYTES + 
partitions * Long.BYTES);
+    ByteBuffer buffer =
+        ByteBuffer.allocate(2 * Long.BYTES + Integer.BYTES + partitions * 
Long.BYTES * 2);
+    buffer.putLong(taskId);
     buffer.putLong(taskAttemptId);
     buffer.putInt(partitions);
     for (long records : partitionRecordsWritten) {
       buffer.putLong(records);
     }
+    for (long blocks : partitionBlocksWritten) {
+      buffer.putLong(blocks);
+    }
     return new String(buffer.array(), ISO_8859_1);
   }
 
   public static ShuffleWriteTaskStats decode(String raw) {
     byte[] bytes = raw.getBytes(ISO_8859_1);
     ByteBuffer buffer = ByteBuffer.wrap(bytes);
+    long taskId = buffer.getLong();
     long taskAttemptId = buffer.getLong();
     int partitions = buffer.getInt();
-    ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(partitions, 
taskAttemptId);
+    ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(partitions, 
taskAttemptId, taskId);
     for (int i = 0; i < partitions; i++) {
       stats.partitionRecordsWritten[i] = buffer.getLong();
     }
+    for (int i = 0; i < partitions; i++) {
+      stats.partitionBlocksWritten[i] = buffer.getLong();
+    }
     return stats;
   }
+
+  public long getTaskId() {
+    return taskId;
+  }
+
+  public void log() {
+    StringBuilder infoBuilder = new StringBuilder();
+    int partitions = partitionRecordsWritten.length;
+    for (int i = 0; i < partitions; i++) {
+      long records = partitionRecordsWritten[i];
+      long blocks = partitionBlocksWritten[i];
+      
infoBuilder.append(i).append("/").append(records).append("/").append(blocks).append(",");
+    }
+    LOGGER.info(
+        "Partition records/blocks written for taskId[{}]: {}", taskId, 
infoBuilder.toString());
+  }
+
+  public void check(long[] partitionLens) {
+    int partitions = partitionRecordsWritten.length;
+    for (int idx = 0; idx < partitions; idx++) {
+      long records = partitionRecordsWritten[idx];
+      long blocks = partitionBlocksWritten[idx];
+      long length = partitionLens[idx];
+      if (records > 0) {
+        if (blocks <= 0 || length <= 0) {
+          throw new RssException(
+              "Illegal partition:"
+                  + idx
+                  + " stats. records/blocks/length: "
+                  + records
+                  + "/"
+                  + blocks
+                  + "/"
+                  + length);
+        }
+      }
+    }
+  }
 }
diff --git 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleReadTaskStatsTest.java
 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleReadTaskStatsTest.java
new file mode 100644
index 000000000..36a84290e
--- /dev/null
+++ 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleReadTaskStatsTest.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class ShuffleReadTaskStatsTest {
+
+  @Test
+  public void testReadStats() {
+    ShuffleReadTaskStats readTaskStats = new ShuffleReadTaskStats();
+    readTaskStats.incPartitionRecord(0, 1);
+    readTaskStats.incPartitionRecord(0, 2);
+    readTaskStats.incPartitionRecord(0, 3);
+    readTaskStats.incPartitionRecord(0, 3);
+
+    Map<Long, Long> records = readTaskStats.getPartitionRecords(0);
+    assertEquals(1, records.get(1L));
+    assertEquals(1, records.get(2L));
+    assertEquals(2, records.get(3L));
+
+    readTaskStats.incPartitionBlock(0, 1);
+    readTaskStats.incPartitionBlock(0, 2);
+    Map<Long, Long> blocks = readTaskStats.getPartitionBlocks(0);
+    assertEquals(2, blocks.size());
+  }
+
+  @ParameterizedTest
+  @ValueSource(booleans = {true, false})
+  void testDiffWithInconsistentStats(boolean inconsistent) {
+    ShuffleReadTaskStats readTaskStats = new ShuffleReadTaskStats();
+    int partitionId = 0;
+    long taskId = 10;
+    long taskAttemptId = 1001L;
+    int expectedRecords = 10;
+    int expectedBlocks = 2;
+
+    int readRecords = inconsistent ? expectedRecords + 1 : expectedRecords;
+    // 10 records from 1 upstream tasks
+    for (int i = 0; i < readRecords; i++) {
+      readTaskStats.incPartitionRecord(partitionId, taskAttemptId);
+    }
+    // 2 blocks from 1 upstream tasks
+    for (int i = 0; i < expectedBlocks; i++) {
+      readTaskStats.incPartitionBlock(partitionId, taskAttemptId);
+    }
+
+    ShuffleWriteTaskStats writeStat = new ShuffleWriteTaskStats(1, 
taskAttemptId, taskId);
+    for (int i = 0; i < expectedRecords; i++) {
+      writeStat.incPartitionRecord(partitionId);
+    }
+    for (int i = 0; i < expectedBlocks; i++) {
+      writeStat.incPartitionBlock(partitionId);
+    }
+    Map<Long, ShuffleWriteTaskStats> writeStats = new HashMap<>();
+    writeStats.put(taskAttemptId, writeStat);
+    boolean result = readTaskStats.diff(writeStats, 0, 1);
+    assertEquals(!inconsistent, result);
+  }
+}
diff --git 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
index 3a70ae408..a0440ac6f 100644
--- 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
@@ -25,16 +25,25 @@ public class ShuffleWriteTaskStatsTest {
 
   @Test
   public void testValidValidationInfo() {
+    long taskId = 10;
     long taskAttemptId = 12345L;
-    ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(2, taskAttemptId);
+    ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(2, taskAttemptId, 
taskId);
     stats.incPartitionRecord(0);
     stats.incPartitionRecord(1);
 
+    stats.incPartitionBlock(0);
+    stats.incPartitionBlock(1);
+
     String encoded = stats.encode();
     ShuffleWriteTaskStats decoded = ShuffleWriteTaskStats.decode(encoded);
 
+    assertEquals(10, stats.getTaskId());
+
     assertEquals(taskAttemptId, decoded.getTaskAttemptId());
     assertEquals(1, decoded.getRecordsWritten(0));
     assertEquals(1, decoded.getRecordsWritten(1));
+
+    assertEquals(1, decoded.getBlocksWritten(0));
+    assertEquals(1, decoded.getBlocksWritten(1));
   }
 }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 51cdc7218..45cc0c933 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle;
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -77,6 +78,7 @@ import org.apache.uniffle.shuffle.ShuffleWriteTaskStats;
 import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase;
 
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_INTEGRITY_VALIDATION_ENABLED;
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_DATA_INTEGRATION_VALIDATION_ANALYSIS_ENABLED;
 
 public class RssShuffleManager extends RssShuffleManagerBase {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleManager.class);
@@ -85,7 +87,7 @@ public class RssShuffleManager extends RssShuffleManagerBase {
     super(conf, isDriver);
     this.dataDistributionType = getDataDistributionType(sparkConf);
     if (isIntegrityValidationEnabled(rssConf)) {
-      LOG.info("shuffle row-based validation has been enabled.");
+      LOG.info("shuffle integrity validation has been enabled.");
     }
   }
 
@@ -296,7 +298,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       ShuffleReadMetricsReporter metrics) {
     long start = System.currentTimeMillis();
     Pair<Roaring64NavigableMap, Long> info =
-        getExpectedTasksByExecutorId(
+        getExpectedTasksAndRecordsForReader(
             handle.shuffleId(), startPartition, endPartition, startMapIndex, 
endMapIndex);
     Roaring64NavigableMap taskIdBitmap = info.getLeft();
     long expectedRecordsRead = info.getRight();
@@ -480,10 +482,16 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     return rssConf.get(RSS_CLIENT_INTEGRITY_VALIDATION_ENABLED);
   }
 
+  public static boolean isIntegrationValidationFailureAnalysisEnabled(RssConf 
rssConf) {
+    if (!isIntegrityValidationEnabled(rssConf)) {
+      return false;
+    }
+    return rssConf.get(RSS_DATA_INTEGRATION_VALIDATION_ANALYSIS_ENABLED);
+  }
+
   @SuppressFBWarnings("REC_CATCH_EXCEPTION")
-  private Pair<Roaring64NavigableMap, Long> getExpectedTasksByExecutorId(
+  private static Iterator<BlockManagerId> 
getUpstreamBlockManagerIdsForShuffleReader(
       int shuffleId, int startPartition, int endPartition, int startMapIndex, 
int endMapIndex) {
-    Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf();
     Iterator<Tuple2<BlockManagerId, Seq<Tuple3<BlockId, Object, Object>>>> 
mapStatusIter = null;
     try {
       // Since Spark 3.1 refactors the interface of getMapSizesByExecutorId,
@@ -550,25 +558,71 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     } catch (Exception e) {
       throw new RssException(e);
     }
-    long expectedRecords = 0;
-    while (mapStatusIter.hasNext()) {
-      Tuple2<BlockManagerId, Seq<Tuple3<BlockId, Object, Object>>> tuple2 = 
mapStatusIter.next();
-      if (!tuple2._1().topologyInfo().isDefined()) {
+    final Iterator<Tuple2<BlockManagerId, Seq<Tuple3<BlockId, Object, 
Object>>>> immutableIter =
+        mapStatusIter;
+    Iterator<BlockManagerId> iter =
+        new Iterator<BlockManagerId>() {
+          @Override
+          public boolean hasNext() {
+            return immutableIter.hasNext();
+          }
+
+          @Override
+          public BlockManagerId next() {
+            return immutableIter.next()._1();
+          }
+        };
+    return iter;
+  }
+
+  public static Map<Long, ShuffleWriteTaskStats> getUpstreamWriteTaskStats(
+      RssConf rssConf,
+      int shuffleId,
+      int startPartition,
+      int endPartition,
+      int startMapIndex,
+      int endMapIndex) {
+    if (!isIntegrityValidationEnabled(rssConf)) {
+      return Collections.emptyMap();
+    }
+    Iterator<BlockManagerId> iter =
+        getUpstreamBlockManagerIdsForShuffleReader(
+            shuffleId, startPartition, endPartition, startMapIndex, 
endMapIndex);
+    Map<Long, ShuffleWriteTaskStats> upstreamStats = new HashMap<>();
+    while (iter.hasNext()) {
+      BlockManagerId blockManagerId = iter.next();
+      ShuffleWriteTaskStats shuffleWriteTaskStats =
+          ShuffleWriteTaskStats.decode(blockManagerId.topologyInfo().get());
+      upstreamStats.put(shuffleWriteTaskStats.getTaskAttemptId(), 
shuffleWriteTaskStats);
+    }
+    return upstreamStats;
+  }
+
+  private Pair<Roaring64NavigableMap, Long> 
getExpectedTasksAndRecordsForReader(
+      int shuffleId, int startPartition, int endPartition, int startMapIndex, 
int endMapIndex) {
+    Iterator<BlockManagerId> iter =
+        getUpstreamBlockManagerIdsForShuffleReader(
+            shuffleId, startPartition, endPartition, startMapIndex, 
endMapIndex);
+    Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf();
+    long expectedRecordsRead = 0;
+    while (iter.hasNext()) {
+      BlockManagerId blockManagerId = iter.next();
+      if (!blockManagerId.topologyInfo().isDefined()) {
         throw new RssException("Can't get expected taskAttemptId");
       }
 
-      String raw = tuple2._1().topologyInfo().get();
+      String raw = blockManagerId.topologyInfo().get();
       if (isIntegrityValidationEnabled(rssConf)) {
         ShuffleWriteTaskStats shuffleWriteTaskStats = 
ShuffleWriteTaskStats.decode(raw);
         taskIdBitmap.add(shuffleWriteTaskStats.getTaskAttemptId());
         for (int i = startPartition; i < endPartition; i++) {
-          expectedRecords += shuffleWriteTaskStats.getRecordsWritten(i);
+          expectedRecordsRead += shuffleWriteTaskStats.getRecordsWritten(i);
         }
       } else {
         taskIdBitmap.add(Long.parseLong(raw));
       }
     }
-    return Pair.of(taskIdBitmap, expectedRecords);
+    return Pair.of(taskIdBitmap, expectedRecordsRead);
   }
 
   // This API is only used by Spark3.0 and removed since 3.1,
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 015b12120..00b76c46c 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -69,6 +69,8 @@ import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.shuffle.ShuffleReadTaskStats;
+import org.apache.uniffle.shuffle.ShuffleWriteTaskStats;
 import org.apache.uniffle.storage.handler.impl.ShuffleServerReadCostTracker;
 
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_READ_OVERLAPPING_DECOMPRESSION_ENABLED;
@@ -112,6 +114,8 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
   private long expectedRecordsRead = 0L;
   private long actualRecordsRead = 0L;
 
+  private Optional<ShuffleReadTaskStats> shuffleReadTaskStats = 
Optional.empty();
+
   public RssShuffleReader(
       int startPartition,
       int endPartition,
@@ -148,6 +152,9 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
         dataDistributionType,
         allPartitionToServers);
     this.expectedRecordsRead = expectedRecordsRead;
+    if 
(RssShuffleManager.isIntegrationValidationFailureAnalysisEnabled(rssConf)) {
+      this.shuffleReadTaskStats = Optional.of(new ShuffleReadTaskStats());
+    }
   }
 
   public RssShuffleReader(
@@ -376,7 +383,13 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
             
ShuffleClientFactory.getInstance().createShuffleReadClient(builder);
         RssShuffleDataIterator<K, C> iterator =
             new RssShuffleDataIterator<>(
-                shuffleDependency.serializer(), shuffleReadClient, 
readMetrics, rssConf, codec);
+                shuffleDependency.serializer(),
+                shuffleReadClient,
+                readMetrics,
+                rssConf,
+                codec,
+                shuffleReadTaskStats,
+                partition);
         CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>> 
completionIterator =
             CompletionIterator$.MODULE$.apply(
                 iterator,
@@ -438,11 +451,20 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     if (RssShuffleManager.isIntegrityValidationEnabled(rssConf)
         && expectedRecordsRead > 0
         && (expectedRecordsRead != actualRecordsRead)) {
+      // dig to analyze the missing records from the upstream map id
+      if (shuffleReadTaskStats.isPresent()) {
+        ShuffleReadTaskStats readTaskStats = shuffleReadTaskStats.get();
+        Map<Long, ShuffleWriteTaskStats> upstreamWriteTaskStats =
+            RssShuffleManager.getUpstreamWriteTaskStats(
+                rssConf, shuffleId, startPartition, endPartition, 
mapStartIndex, mapEndIndex);
+        readTaskStats.diff(upstreamWriteTaskStats, startPartition, 
endPartition);
+      }
       throw new RssException(
-          "Unexpected read records. expected: "
+          "Inconsistent number of records: "
               + expectedRecordsRead
-              + ", actual: "
-              + actualRecordsRead);
+              + " written, "
+              + actualRecordsRead
+              + " read");
     }
   }
 
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index f26246c1a..10c939baa 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -83,6 +83,7 @@ import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssSendFailedException;
 import org.apache.uniffle.common.exception.RssWaitFailedException;
@@ -93,6 +94,7 @@ import org.apache.uniffle.storage.util.StorageType;
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED;
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES;
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
+import static org.apache.spark.shuffle.RssSparkConfig.toRssConf;
 
 public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
 
@@ -125,6 +127,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final Set<Long> blockIds = Sets.newConcurrentHashSet();
   private TaskContext taskContext;
   private SparkConf sparkConf;
+  private RssConf rssConf;
   private boolean blockFailSentRetryEnabled;
   private int blockFailSentRetryMaxTimes = 1;
 
@@ -229,21 +232,22 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.shuffleHandleInfo = shuffleHandleInfo;
     this.taskContext = context;
     this.sparkConf = sparkConf;
+    this.rssConf = toRssConf(sparkConf);
     this.managerClientSupplier = managerClientSupplier;
     this.blockFailSentRetryEnabled =
         sparkConf.getBoolean(
             RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
                 + RssClientConf.RSS_CLIENT_REASSIGN_ENABLED.key(),
             RssClientConf.RSS_CLIENT_REASSIGN_ENABLED.defaultValue());
-    this.blockFailSentRetryMaxTimes =
-        
RssSparkConfig.toRssConf(sparkConf).get(RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES);
-    this.enableWriteFailureRetry =
-        
RssSparkConfig.toRssConf(sparkConf).get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
+    this.blockFailSentRetryMaxTimes = 
rssConf.get(RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES);
+    this.enableWriteFailureRetry = 
rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
     this.recordReportFailedShuffleservers = Sets.newConcurrentHashSet();
 
-    if 
(RssShuffleManager.isIntegrityValidationEnabled(RssSparkConfig.toRssConf(sparkConf)))
 {
+    if (RssShuffleManager.isIntegrityValidationEnabled(rssConf)) {
       this.shuffleTaskStats =
-          Optional.of(new ShuffleWriteTaskStats(partitioner.numPartitions(), 
taskAttemptId));
+          Optional.of(
+              new ShuffleWriteTaskStats(
+                  partitioner.numPartitions(), taskAttemptId, 
taskContext.taskAttemptId()));
     }
   }
 
@@ -376,9 +380,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
         processShuffleBlockInfos(shuffleBlockInfos);
       }
-      if (shuffleTaskStats.isPresent()) {
-        shuffleTaskStats.get().incPartitionRecord(partition);
-      }
+      shuffleTaskStats.ifPresent(x -> x.incPartitionRecord(partition));
     }
     final long start = System.currentTimeMillis();
     shuffleBlockInfos = bufferManager.clear(1.0);
@@ -477,8 +479,10 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             long blockId = sbi.getBlockId();
             // add blockId to set, check if it is sent later
             blockIds.add(blockId);
-            // update [partition, blockIds], it will be sent to shuffle server
             int partitionId = sbi.getPartitionId();
+            // record blocks number for per-partition
+            shuffleTaskStats.ifPresent(x -> x.incPartitionBlock(partitionId));
+            // update [partition, blockIds], it will be sent to shuffle server
             sbi.getShuffleServerInfos()
                 .forEach(
                     shuffleServerInfo -> {
@@ -945,6 +949,16 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             reportDuration);
         
shuffleWriteMetrics.incWriteTime(TimeUnit.MILLISECONDS.toNanos(reportDuration));
 
+        if 
(RssShuffleManager.isIntegrationValidationFailureAnalysisEnabled(rssConf)) {
+          shuffleTaskStats.ifPresent(x -> x.log());
+        }
+
+        long[] partitionLens = partitionLengthStatistic.toArray();
+
+        if (shuffleTaskStats.isPresent()) {
+          shuffleTaskStats.get().check(partitionLens);
+        }
+
         // todo: we can replace the dummy host and port with the real shuffle 
server which we prefer
         // to read
         final BlockManagerId blockManagerId =
@@ -956,8 +970,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
                     shuffleTaskStats.isPresent()
                         ? shuffleTaskStats.get().encode()
                         : Long.toString(taskAttemptId)));
-        MapStatus mapStatus =
-            MapStatus.apply(blockManagerId, 
partitionLengthStatistic.toArray(), taskAttemptId);
+        MapStatus mapStatus = MapStatus.apply(blockManagerId, partitionLens, 
taskAttemptId);
         return Option.apply(mapStatus);
       } else {
         return Option.empty();
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java 
b/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java
index 41d29511f..db6dc8b48 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java
@@ -101,7 +101,10 @@ public class DecompressionWorker {
           tasks.computeIfAbsent(batchIndex, k -> new ConcurrentHashMap<>());
       blocks.put(
           index++,
-          new DecompressedShuffleBlock(f, waitMillis -> 
this.waitMillis.addAndGet(waitMillis)));
+          new DecompressedShuffleBlock(
+              f,
+              waitMillis -> this.waitMillis.addAndGet(waitMillis),
+              bufferSegment.getTaskAttemptId()));
     }
   }
 
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index 8834eaef2..ba21e2146 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -312,7 +312,8 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
           ByteBuffer compressedBuffer = readBuffer.duplicate();
           compressedBuffer.position(bs.getOffset());
           compressedBuffer.limit(bs.getOffset() + bs.getLength());
-          return new CompressedShuffleBlock(compressedBuffer, 
bs.getUncompressLength());
+          return new CompressedShuffleBlock(
+              compressedBuffer, bs.getUncompressLength(), 
bs.getTaskAttemptId());
         } else {
           DecompressedShuffleBlock block = decompressionWorker.get(batchIndex 
- 1, segmentIndex++);
           if (block == null) {
diff --git 
a/client/src/main/java/org/apache/uniffle/client/response/CompressedShuffleBlock.java
 
b/client/src/main/java/org/apache/uniffle/client/response/CompressedShuffleBlock.java
index 3ae949286..355c024f7 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/response/CompressedShuffleBlock.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/response/CompressedShuffleBlock.java
@@ -19,11 +19,16 @@ package org.apache.uniffle.client.response;
 
 import java.nio.ByteBuffer;
 
-public class CompressedShuffleBlock implements ShuffleBlock {
+public class CompressedShuffleBlock extends ShuffleBlock {
   private ByteBuffer byteBuffer;
   private int uncompressLength;
 
   public CompressedShuffleBlock(ByteBuffer byteBuffer, int uncompressLength) {
+    this(byteBuffer, uncompressLength, -1);
+  }
+
+  public CompressedShuffleBlock(ByteBuffer byteBuffer, int uncompressLength, 
long taskAttemptId) {
+    super(taskAttemptId);
     this.byteBuffer = byteBuffer;
     this.uncompressLength = uncompressLength;
   }
diff --git 
a/client/src/main/java/org/apache/uniffle/client/response/DecompressedShuffleBlock.java
 
b/client/src/main/java/org/apache/uniffle/client/response/DecompressedShuffleBlock.java
index 0462bb768..22df42a2d 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/response/DecompressedShuffleBlock.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/response/DecompressedShuffleBlock.java
@@ -23,11 +23,13 @@ import java.util.function.Consumer;
 
 import org.apache.uniffle.common.exception.RssException;
 
-public class DecompressedShuffleBlock implements ShuffleBlock {
+public class DecompressedShuffleBlock extends ShuffleBlock {
   private CompletableFuture<ByteBuffer> f;
   private Consumer<Long> waitMillisCallback;
 
-  public DecompressedShuffleBlock(CompletableFuture<ByteBuffer> f, 
Consumer<Long> consumer) {
+  public DecompressedShuffleBlock(
+      CompletableFuture<ByteBuffer> f, Consumer<Long> consumer, long 
taskAttemptId) {
+    super(taskAttemptId);
     this.f = f;
     this.waitMillisCallback = consumer;
   }
diff --git 
a/client/src/main/java/org/apache/uniffle/client/response/ShuffleBlock.java 
b/client/src/main/java/org/apache/uniffle/client/response/ShuffleBlock.java
index 0fca1ad64..e240f4f3d 100644
--- a/client/src/main/java/org/apache/uniffle/client/response/ShuffleBlock.java
+++ b/client/src/main/java/org/apache/uniffle/client/response/ShuffleBlock.java
@@ -19,9 +19,18 @@ package org.apache.uniffle.client.response;
 
 import java.nio.ByteBuffer;
 
-public interface ShuffleBlock {
+public abstract class ShuffleBlock {
+  private long taskAttemptId;
 
-  int getUncompressLength();
+  public ShuffleBlock(long taskAttemptId) {
+    this.taskAttemptId = taskAttemptId;
+  }
 
-  ByteBuffer getByteBuffer();
+  public abstract int getUncompressLength();
+
+  public abstract ByteBuffer getByteBuffer();
+
+  public long getTaskAttemptId() {
+    return taskAttemptId;
+  }
 }


Reply via email to