This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 6aef84634 [#2652] feat(spark): Introduce partition records number 
check to ensure data consistency (#2653)
6aef84634 is described below

commit 6aef84634d7abfed1f65a4e8c37cfa95b3fd7591
Author: Junfan Zhang <[email protected]>
AuthorDate: Wed Oct 29 11:34:19 2025 +0800

    [#2652] feat(spark): Introduce partition records number check to ensure 
data consistency (#2653)
    
    ### What changes were proposed in this pull request?
    
    This PR ensures end-to-end data consistency by verifying the record counts 
of each partition.
    In this initial step, ShuffleWriteTaskStats is introduced to store record 
counts for validation.
    In the next phase, this mechanism will be extended to support row-level 
checksums and block count verification.
    
    I have only validated this patch on Spark 3.5.0, and this feature is 
enabled only for Spark versions at least `3.5.0`
    
    ### Why are the changes needed?
    
    for the #2652
    
    ### Does this PR introduce _any_ user-facing change?
    
    `spark.rss.client.integrityValidation.enabled=false` .
    
    ### How was this patch tested?
    
    Unit tests.
---
 .../org/apache/spark/shuffle/RssSparkConfig.java   |  6 ++
 .../uniffle/shuffle/ShuffleWriteTaskStats.java     | 71 ++++++++++++++++++++++
 .../shuffle/ShuffleWriteTaskStatsTest.java}        | 24 ++++++--
 .../apache/spark/shuffle/RssShuffleManager.java    | 48 ++++++++++++---
 .../apache/spark/shuffle/Spark3VersionUtils.java   | 14 +++++
 .../spark/shuffle/reader/RssShuffleReader.java     | 61 ++++++++++++++++++-
 .../spark/shuffle/writer/RssShuffleWriter.java     | 17 +++++-
 .../org/apache/uniffle/test/AQESkewedJoinTest.java |  1 +
 .../test/GetShuffleReportForMultiPartTest.java     |  6 +-
 .../apache/uniffle/test/MapSideCombineTest.java    |  1 +
 10 files changed, 231 insertions(+), 18 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 8dc17b70f..fd330b359 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -39,6 +39,12 @@ import org.apache.uniffle.common.config.RssConf;
 
 public class RssSparkConfig {
 
+  public static final ConfigOption<Boolean> 
RSS_CLIENT_INTEGRITY_VALIDATION_ENABLED =
+      ConfigOptions.key("rss.client.integrityValidation.enabled")
+          .booleanType()
+          .defaultValue(false)
+          .withDescription("Whether or not to enable shuffle data integrity 
validation mechanism");
+
   public static final ConfigOption<Boolean> 
RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED =
       ConfigOptions.key("rss.client.read.shuffleHandleCacheEnabled")
           .booleanType()
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
new file mode 100644
index 000000000..3625f9654
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.nio.ByteBuffer;
+
+import static java.nio.charset.StandardCharsets.ISO_8859_1;
+
+/**
+ * ShuffleWriteTaskStats stores statistics for a shuffle write task attempt, 
including the task
+ * attempt ID and the number of records written for each partition.
+ */
+public class ShuffleWriteTaskStats {
+  private long taskAttemptId;
+  private long[] partitionRecordsWritten;
+
+  public ShuffleWriteTaskStats(int partitions, long taskAttemptId) {
+    this.partitionRecordsWritten = new long[partitions];
+    this.taskAttemptId = taskAttemptId;
+  }
+
+  public long getRecordsWritten(int partitionId) {
+    return partitionRecordsWritten[partitionId];
+  }
+
+  public void incPartitionRecord(int partitionId) {
+    partitionRecordsWritten[partitionId]++;
+  }
+
+  public long getTaskAttemptId() {
+    return taskAttemptId;
+  }
+
+  public String encode() {
+    int partitions = partitionRecordsWritten.length;
+    ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Integer.BYTES + 
partitions * Long.BYTES);
+    buffer.putLong(taskAttemptId);
+    buffer.putInt(partitions);
+    for (long records : partitionRecordsWritten) {
+      buffer.putLong(records);
+    }
+    return new String(buffer.array(), ISO_8859_1);
+  }
+
+  public static ShuffleWriteTaskStats decode(String raw) {
+    byte[] bytes = raw.getBytes(ISO_8859_1);
+    ByteBuffer buffer = ByteBuffer.wrap(bytes);
+    long taskAttemptId = buffer.getLong();
+    int partitions = buffer.getInt();
+    ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(partitions, 
taskAttemptId);
+    for (int i = 0; i < partitions; i++) {
+      stats.partitionRecordsWritten[i] = buffer.getLong();
+    }
+    return stats;
+  }
+}
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/Spark3VersionUtils.java
 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
similarity index 54%
copy from 
client-spark/spark3/src/main/java/org/apache/spark/shuffle/Spark3VersionUtils.java
copy to 
client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
index 76f464eec..3a70ae408 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/Spark3VersionUtils.java
+++ 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
@@ -15,14 +15,26 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle;
+package org.apache.uniffle.shuffle;
 
-import org.apache.spark.package$;
+import org.junit.jupiter.api.Test;
 
-public class Spark3VersionUtils extends SparkVersionUtils {
-  public static final String SPARK_VERSION_SHORT = 
package$.MODULE$.SPARK_VERSION_SHORT();
+import static org.junit.jupiter.api.Assertions.assertEquals;
 
-  public static boolean isSpark320() {
-    return SPARK_VERSION_SHORT.equals("3.2.0");
+public class ShuffleWriteTaskStatsTest {
+
+  @Test
+  public void testValidValidationInfo() {
+    long taskAttemptId = 12345L;
+    ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(2, taskAttemptId);
+    stats.incPartitionRecord(0);
+    stats.incPartitionRecord(1);
+
+    String encoded = stats.encode();
+    ShuffleWriteTaskStats decoded = ShuffleWriteTaskStats.decode(encoded);
+
+    assertEquals(taskAttemptId, decoded.getTaskAttemptId());
+    assertEquals(1, decoded.getRecordsWritten(0));
+    assertEquals(1, decoded.getRecordsWritten(1));
   }
 }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index e4b180af1..3d1a69bb6 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -34,6 +34,7 @@ import scala.collection.Seq;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Sets;
 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.spark.MapOutputTracker;
 import org.apache.spark.ShuffleDependency;
@@ -72,14 +73,20 @@ import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.shuffle.RssShuffleClientFactory;
+import org.apache.uniffle.shuffle.ShuffleWriteTaskStats;
 import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase;
 
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_INTEGRITY_VALIDATION_ENABLED;
+
 public class RssShuffleManager extends RssShuffleManagerBase {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleManager.class);
 
   public RssShuffleManager(SparkConf conf, boolean isDriver) {
     super(conf, isDriver);
     this.dataDistributionType = getDataDistributionType(sparkConf);
+    if (isIntegrityValidationEnabled(rssConf)) {
+      LOG.info("shuffle row-based validation has been enabled.");
+    }
   }
 
   // For testing only
@@ -288,9 +295,11 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       TaskContext context,
       ShuffleReadMetricsReporter metrics) {
     long start = System.currentTimeMillis();
-    Roaring64NavigableMap taskIdBitmap =
+    Pair<Roaring64NavigableMap, Long> info =
         getExpectedTasksByExecutorId(
             handle.shuffleId(), startPartition, endPartition, startMapIndex, 
endMapIndex);
+    Roaring64NavigableMap taskIdBitmap = info.getLeft();
+    long expectedRecordsRead = info.getRight();
     LOG.info(
         "Get taskId cost "
             + (System.currentTimeMillis() - start)
@@ -311,7 +320,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         endPartition,
         context,
         metrics,
-        taskIdBitmap);
+        taskIdBitmap,
+        expectedRecordsRead);
   }
 
   // The interface is used for compatibility with spark 3.0.1
@@ -347,7 +357,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         endPartition,
         context,
         metrics,
-        taskIdBitmap);
+        taskIdBitmap,
+        -1);
   }
 
   public <K, C> ShuffleReader<K, C> getReaderImpl(
@@ -358,7 +369,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       int endPartition,
       TaskContext context,
       ShuffleReadMetricsReporter metrics,
-      Roaring64NavigableMap taskIdBitmap) {
+      Roaring64NavigableMap taskIdBitmap,
+      long expectedRecordsRead) {
     if (!(handle instanceof RssShuffleHandle)) {
       throw new RssException("Unexpected ShuffleHandle:" + 
handle.getClass().getName());
     }
@@ -443,7 +455,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         managerClientSupplier,
         RssSparkConfig.toRssConf(sparkConf),
         dataDistributionType,
-        shuffleHandleInfo.getAllPartitionServersForReader());
+        shuffleHandleInfo.getAllPartitionServersForReader(),
+        expectedRecordsRead);
   }
 
   private Map<ShuffleServerInfo, Set<Integer>> getPartitionDataServers(
@@ -459,8 +472,16 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     return serverToPartitions;
   }
 
+  public static boolean isIntegrityValidationEnabled(RssConf rssConf) {
+    assert rssConf != null;
+    if (!Spark3VersionUtils.isSparkVersionAtLeast("3.5.0")) {
+      return false;
+    }
+    return rssConf.get(RSS_CLIENT_INTEGRITY_VALIDATION_ENABLED);
+  }
+
   @SuppressFBWarnings("REC_CATCH_EXCEPTION")
-  private Roaring64NavigableMap getExpectedTasksByExecutorId(
+  private Pair<Roaring64NavigableMap, Long> getExpectedTasksByExecutorId(
       int shuffleId, int startPartition, int endPartition, int startMapIndex, 
int endMapIndex) {
     Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf();
     Iterator<Tuple2<BlockManagerId, Seq<Tuple3<BlockId, Object, Object>>>> 
mapStatusIter = null;
@@ -529,14 +550,25 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     } catch (Exception e) {
       throw new RssException(e);
     }
+    long expectedRecords = 0;
     while (mapStatusIter.hasNext()) {
       Tuple2<BlockManagerId, Seq<Tuple3<BlockId, Object, Object>>> tuple2 = 
mapStatusIter.next();
       if (!tuple2._1().topologyInfo().isDefined()) {
         throw new RssException("Can't get expected taskAttemptId");
       }
-      taskIdBitmap.add(Long.parseLong(tuple2._1().topologyInfo().get()));
+
+      String raw = tuple2._1().topologyInfo().get();
+      if (isIntegrityValidationEnabled(rssConf)) {
+        ShuffleWriteTaskStats shuffleWriteTaskStats = 
ShuffleWriteTaskStats.decode(raw);
+        taskIdBitmap.add(shuffleWriteTaskStats.getTaskAttemptId());
+        for (int i = startPartition; i < endPartition; i++) {
+          expectedRecords += shuffleWriteTaskStats.getRecordsWritten(i);
+        }
+      } else {
+        taskIdBitmap.add(Long.parseLong(raw));
+      }
     }
-    return taskIdBitmap;
+    return Pair.of(taskIdBitmap, expectedRecords);
   }
 
   // This API is only used by Spark3.0 and removed since 3.1,
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/Spark3VersionUtils.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/Spark3VersionUtils.java
index 76f464eec..7a9a42a0a 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/Spark3VersionUtils.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/Spark3VersionUtils.java
@@ -18,6 +18,7 @@
 package org.apache.spark.shuffle;
 
 import org.apache.spark.package$;
+import org.apache.spark.util.VersionUtils;
 
 public class Spark3VersionUtils extends SparkVersionUtils {
   public static final String SPARK_VERSION_SHORT = 
package$.MODULE$.SPARK_VERSION_SHORT();
@@ -25,4 +26,17 @@ public class Spark3VersionUtils extends SparkVersionUtils {
   public static boolean isSpark320() {
     return SPARK_VERSION_SHORT.equals("3.2.0");
   }
+
+  public static boolean isSparkVersionAtLeast(String target) {
+    int targetMajor = VersionUtils.majorVersion(target);
+    int targetMinor = VersionUtils.minorVersion(target);
+
+    if (MAJOR_VERSION > targetMajor) {
+      return true;
+    } else if (MAJOR_VERSION == targetMajor) {
+      return MINOR_VERSION >= targetMinor;
+    } else {
+      return false;
+    }
+  }
 }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index b0a25efd8..015b12120 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -45,6 +45,7 @@ import org.apache.spark.executor.ShuffleReadMetrics;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.shuffle.FunctionUtils;
 import org.apache.spark.shuffle.RssShuffleHandle;
+import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
 import org.apache.spark.shuffle.ShuffleReader;
 import org.apache.spark.util.CompletionIterator;
@@ -66,6 +67,7 @@ import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.storage.handler.impl.ShuffleServerReadCostTracker;
 
@@ -107,6 +109,47 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
 
   private ShuffleReadTimes shuffleReadTimes = new ShuffleReadTimes();
 
+  private long expectedRecordsRead = 0L;
+  private long actualRecordsRead = 0L;
+
+  public RssShuffleReader(
+      int startPartition,
+      int endPartition,
+      int mapStartIndex,
+      int mapEndIndex,
+      TaskContext context,
+      RssShuffleHandle<K, ?, C> rssShuffleHandle,
+      String basePath,
+      Configuration hadoopConf,
+      int partitionNum,
+      Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks,
+      Roaring64NavigableMap taskIdBitmap,
+      ShuffleReadMetrics readMetrics,
+      Supplier<ShuffleManagerClient> managerClientSupplier,
+      RssConf rssConf,
+      ShuffleDataDistributionType dataDistributionType,
+      Map<Integer, List<ShuffleServerInfo>> allPartitionToServers,
+      long expectedRecordsRead) {
+    this(
+        startPartition,
+        endPartition,
+        mapStartIndex,
+        mapEndIndex,
+        context,
+        rssShuffleHandle,
+        basePath,
+        hadoopConf,
+        partitionNum,
+        partitionToExpectBlocks,
+        taskIdBitmap,
+        readMetrics,
+        managerClientSupplier,
+        rssConf,
+        dataDistributionType,
+        allPartitionToServers);
+    this.expectedRecordsRead = expectedRecordsRead;
+  }
+
   public RssShuffleReader(
       int startPartition,
       int endPartition,
@@ -247,7 +290,9 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
         + mapStartIndex
         + ", "
         + mapEndIndex
-        + ")";
+        + "]"
+        + ", expected records: "
+        + expectedRecordsRead;
   }
 
   @VisibleForTesting
@@ -365,6 +410,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
         }
         while (!dataIterator.hasNext()) {
           if (!iterator.hasNext()) {
+            validate();
             postShuffleReadMetricsToDriver();
             return false;
           }
@@ -383,10 +429,23 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     @Override
     public Product2<K, C> next() {
       Product2<K, C> result = dataIterator.next();
+      actualRecordsRead += 1;
       return result;
     }
   }
 
+  private void validate() {
+    if (RssShuffleManager.isIntegrityValidationEnabled(rssConf)
+        && expectedRecordsRead > 0
+        && (expectedRecordsRead != actualRecordsRead)) {
+      throw new RssException(
+          "Unexpected read records. expected: "
+              + expectedRecordsRead
+              + ", actual: "
+              + actualRecordsRead);
+    }
+  }
+
   private void postShuffleReadMetricsToDriver() {
     if (managerClientSupplier != null) {
       ShuffleManagerClient client = managerClientSupplier.get();
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index e6c48158e..f26246c1a 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -87,6 +87,7 @@ import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssSendFailedException;
 import org.apache.uniffle.common.exception.RssWaitFailedException;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.shuffle.ShuffleWriteTaskStats;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED;
@@ -150,6 +151,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private boolean isShuffleWriteFailed = false;
   private Optional<String> shuffleWriteFailureReason = Optional.empty();
 
+  private Optional<ShuffleWriteTaskStats> shuffleTaskStats = Optional.empty();
+
   // Only for tests
   @VisibleForTesting
   public RssShuffleWriter(
@@ -237,6 +240,11 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.enableWriteFailureRetry =
         
RssSparkConfig.toRssConf(sparkConf).get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
     this.recordReportFailedShuffleservers = Sets.newConcurrentHashSet();
+
+    if 
(RssShuffleManager.isIntegrityValidationEnabled(RssSparkConfig.toRssConf(sparkConf)))
 {
+      this.shuffleTaskStats =
+          Optional.of(new ShuffleWriteTaskStats(partitioner.numPartitions(), 
taskAttemptId));
+    }
   }
 
   // Gluten needs this method
@@ -368,6 +376,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
         processShuffleBlockInfos(shuffleBlockInfos);
       }
+      if (shuffleTaskStats.isPresent()) {
+        shuffleTaskStats.get().incPartitionRecord(partition);
+      }
     }
     final long start = System.currentTimeMillis();
     shuffleBlockInfos = bufferManager.clear(1.0);
@@ -933,6 +944,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             bitmapSplitNum,
             reportDuration);
         
shuffleWriteMetrics.incWriteTime(TimeUnit.MILLISECONDS.toNanos(reportDuration));
+
         // todo: we can replace the dummy host and port with the real shuffle 
server which we prefer
         // to read
         final BlockManagerId blockManagerId =
@@ -940,7 +952,10 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
                 appId + "_" + taskId,
                 DUMMY_HOST,
                 DUMMY_PORT,
-                Option.apply(Long.toString(taskAttemptId)));
+                Option.apply(
+                    shuffleTaskStats.isPresent()
+                        ? shuffleTaskStats.get().encode()
+                        : Long.toString(taskAttemptId)));
         MapStatus mapStatus =
             MapStatus.apply(blockManagerId, 
partitionLengthStatistic.toArray(), taskAttemptId);
         return Option.apply(mapStatus);
diff --git 
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinTest.java
 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinTest.java
index 78e936eb0..3cdf14fba 100644
--- 
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinTest.java
+++ 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinTest.java
@@ -68,6 +68,7 @@ public class AQESkewedJoinTest extends 
SparkIntegrationTestBase {
   public void updateSparkConfCustomer(SparkConf sparkConf) {
     sparkConf.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.MEMORY_LOCALFILE_HDFS.name());
     sparkConf.set(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), HDFS_URI + 
"rss/test");
+    sparkConf.set("spark." + 
RssSparkConfig.RSS_CLIENT_INTEGRITY_VALIDATION_ENABLED.key(), "true");
   }
 
   @Test
diff --git 
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java
 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java
index be7200273..5c1b4d7f2 100644
--- 
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java
+++ 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java
@@ -247,7 +247,8 @@ public class GetShuffleReportForMultiPartTest extends 
SparkIntegrationTestBase {
         int endPartition,
         TaskContext context,
         ShuffleReadMetricsReporter metrics,
-        Roaring64NavigableMap taskIdBitmap) {
+        Roaring64NavigableMap taskIdBitmap,
+        long expectedRecordsRead) {
       int shuffleId = handle.shuffleId();
       RssShuffleHandle<?, ?, ?> rssShuffleHandle = (RssShuffleHandle<?, ?, ?>) 
handle;
       Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
@@ -268,7 +269,8 @@ public class GetShuffleReportForMultiPartTest extends 
SparkIntegrationTestBase {
           endPartition,
           context,
           metrics,
-          taskIdBitmap);
+          taskIdBitmap,
+          expectedRecordsRead);
     }
 
     public Map<Integer, AtomicInteger> getShuffleIdToPartitionNum() {
diff --git 
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/MapSideCombineTest.java
 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/MapSideCombineTest.java
index dcca62a8d..46f96eccf 100644
--- 
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/MapSideCombineTest.java
+++ 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/MapSideCombineTest.java
@@ -57,6 +57,7 @@ public class MapSideCombineTest extends 
SparkIntegrationTestBase {
     sparkConf.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.MEMORY_LOCALFILE_HDFS.name());
     sparkConf.set(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), HDFS_URI + 
"rss/test");
     sparkConf.set("spark." + 
RssSparkConfig.RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED.key(), "true");
+    sparkConf.set("spark." + 
RssSparkConfig.RSS_CLIENT_INTEGRITY_VALIDATION_ENABLED.key(), "true");
   }
 
   @Test

Reply via email to