This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new a37936f2b [#2652] feat(spark): Add compression for task write stats 
(#2666)
a37936f2b is described below

commit a37936f2bc34f6c38082ea85559ea70ebfdc0f26
Author: Junfan Zhang <[email protected]>
AuthorDate: Fri Nov 7 13:56:36 2025 +0800

    [#2652] feat(spark): Add compression for task write stats (#2666)
    
    ### What changes were proposed in this pull request?
    
    1. Add compression for task write stats
    2. Optional blocks number check mechanism (disabled by default)
    
    ### Why are the changes needed?
    
    To reduce the task write stats size
    
    ### Does this PR introduce _any_ user-facing change?
    
    `spark.rss.client.integrityValidation.blockNumberCheckEnabled=false`
    
    ### How was this patch tested?
    
    Unit tests.
---
 .../org/apache/spark/shuffle/RssSparkConfig.java   |  15 +++
 .../uniffle/shuffle/ShuffleReadTaskStats.java      |   4 +-
 .../uniffle/shuffle/ShuffleWriteTaskStats.java     | 136 +++++++++++++++------
 .../uniffle/shuffle/ShuffleWriteTaskStatsTest.java |  32 +++--
 .../apache/spark/shuffle/RssShuffleManager.java    |   4 +-
 .../spark/shuffle/writer/RssShuffleWriter.java     |   5 +-
 .../apache/uniffle/common/compression/Codec.java   |   4 +
 7 files changed, 150 insertions(+), 50 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 5eef97bfa..e83bc437f 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -31,6 +31,7 @@ import org.apache.spark.internal.config.ConfigEntry;
 import org.apache.spark.internal.config.TypedConfigBuilder;
 
 import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.common.config.ConfigOption;
 import org.apache.uniffle.common.config.ConfigOptions;
 import org.apache.uniffle.common.config.ConfigUtils;
@@ -52,6 +53,20 @@ public class RssSparkConfig {
           .defaultValue(false)
           .withDescription("Whether or not to enable validation failure 
analysis");
 
+  public static final ConfigOption<Boolean>
+      RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED =
+          
ConfigOptions.key("rss.client.integrityValidation.blockNumberCheckEnabled")
+              .booleanType()
+              .defaultValue(false)
+              .withDescription("Whether or not to enable validation block 
number check");
+
+  public static final ConfigOption<Codec.Type>
+      RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE =
+          
ConfigOptions.key("rss.client.integrityValidation.statsCompressionType")
+              .enumType(Codec.Type.class)
+              .defaultValue(Codec.Type.ZSTD)
+              .withDescription("stats compression type");
+
   public static final ConfigOption<Boolean> 
RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED =
       ConfigOptions.key("rss.client.read.shuffleHandleCacheEnabled")
           .booleanType()
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
index d78f24a2f..0898648a5 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
@@ -70,8 +70,10 @@ public class ShuffleReadTaskStats {
           continue;
         }
         long recordsUpstream = stats.getRecordsWritten(idx);
+        // If blocksUpstream is less than 0, it indicates that the block 
number check is disabled.
         long blocksUpstream = stats.getBlocksWritten(idx);
-        if (recordsRead != recordsUpstream || blocksRead != blocksUpstream) {
+        if (recordsRead != recordsUpstream
+            || (blocksUpstream >= 0 && blocksRead != blocksUpstream)) {
           infoBuilder.append(idx);
           infoBuilder.append("/");
           infoBuilder.append(stats.getTaskId());
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
index e86da494b..4f226569d 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
@@ -19,13 +19,18 @@ package org.apache.uniffle.shuffle;
 
 import java.nio.ByteBuffer;
 import java.util.Arrays;
+import java.util.Optional;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 
 import static java.nio.charset.StandardCharsets.ISO_8859_1;
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE;
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED;
 
 /**
  * ShuffleWriteTaskStats stores statistics for a shuffle write task attempt, 
including the task
@@ -38,17 +43,31 @@ public class ShuffleWriteTaskStats {
   private long taskId;
   // this is only unique for one stage and defined in uniffle side instead of 
spark
   private long taskAttemptId;
+  private int partitions;
   private long[] partitionRecordsWritten;
   private long[] partitionBlocksWritten;
+  private boolean blockNumberCheckEnabled;
+  private RssConf rssConf;
 
-  public ShuffleWriteTaskStats(int partitions, long taskAttemptId, long 
taskId) {
+  public ShuffleWriteTaskStats(RssConf rssConf, int partitions, long 
taskAttemptId, long taskId) {
     this.partitionRecordsWritten = new long[partitions];
-    this.partitionBlocksWritten = new long[partitions];
+    Arrays.fill(this.partitionRecordsWritten, 0L);
+
+    this.partitions = partitions;
     this.taskAttemptId = taskAttemptId;
     this.taskId = taskId;
+    this.rssConf = rssConf;
+    this.blockNumberCheckEnabled =
+        rssConf.get(RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED);
 
-    Arrays.fill(this.partitionRecordsWritten, 0L);
-    Arrays.fill(this.partitionBlocksWritten, 0L);
+    if (blockNumberCheckEnabled) {
+      this.partitionBlocksWritten = new long[partitions];
+      Arrays.fill(this.partitionBlocksWritten, 0L);
+    }
+  }
+
+  public ShuffleWriteTaskStats(int partitions, long taskAttemptId, long 
taskId) {
+    this(new RssConf(), partitions, taskAttemptId, taskId);
   }
 
   public long getRecordsWritten(int partitionId) {
@@ -60,11 +79,16 @@ public class ShuffleWriteTaskStats {
   }
 
   public void incPartitionBlock(int partitionId) {
-    partitionBlocksWritten[partitionId]++;
+    if (blockNumberCheckEnabled) {
+      partitionBlocksWritten[partitionId]++;
+    }
   }
 
   public long getBlocksWritten(int partitionId) {
-    return partitionBlocksWritten[partitionId];
+    if (blockNumberCheckEnabled) {
+      return partitionBlocksWritten[partitionId];
+    }
+    return -1L;
   }
 
   public long getTaskAttemptId() {
@@ -72,33 +96,72 @@ public class ShuffleWriteTaskStats {
   }
 
   public String encode() {
+    final long start = System.currentTimeMillis();
     int partitions = partitionRecordsWritten.length;
-    ByteBuffer buffer =
-        ByteBuffer.allocate(2 * Long.BYTES + Integer.BYTES + partitions * 
Long.BYTES * 2);
+    int capacity = 2 * Long.BYTES + Integer.BYTES + partitions * Long.BYTES;
+    if (blockNumberCheckEnabled) {
+      capacity += partitions * Long.BYTES;
+    }
+    ByteBuffer buffer = ByteBuffer.allocate(capacity);
     buffer.putLong(taskId);
     buffer.putLong(taskAttemptId);
     buffer.putInt(partitions);
     for (long records : partitionRecordsWritten) {
       buffer.putLong(records);
     }
-    for (long blocks : partitionBlocksWritten) {
-      buffer.putLong(blocks);
+    if (blockNumberCheckEnabled) {
+      for (long blocks : partitionBlocksWritten) {
+        buffer.putLong(blocks);
+      }
+    }
+    Optional<Codec> optionalCodec = getCodec(rssConf);
+    if (optionalCodec.isPresent()) {
+      Codec codec = optionalCodec.get();
+      byte[] compressed = codec.compress(buffer.array());
+      ByteBuffer compositedBuffer = ByteBuffer.allocate(Integer.BYTES + 
compressed.length);
+      compositedBuffer.putInt(capacity);
+      compositedBuffer.put(compressed);
+      LOGGER.info(
+          "Encoded task stats for {} partitions with {} bytes (original: {} 
bytes) in {} ms",
+          partitions,
+          compositedBuffer.capacity(),
+          capacity,
+          System.currentTimeMillis() - start);
+      return new String(compositedBuffer.array(), ISO_8859_1);
+    } else {
+      return new String(buffer.array(), ISO_8859_1);
     }
-    return new String(buffer.array(), ISO_8859_1);
   }
 
-  public static ShuffleWriteTaskStats decode(String raw) {
-    byte[] bytes = raw.getBytes(ISO_8859_1);
-    ByteBuffer buffer = ByteBuffer.wrap(bytes);
-    long taskId = buffer.getLong();
-    long taskAttemptId = buffer.getLong();
-    int partitions = buffer.getInt();
-    ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(partitions, 
taskAttemptId, taskId);
-    for (int i = 0; i < partitions; i++) {
-      stats.partitionRecordsWritten[i] = buffer.getLong();
+  private static Optional<Codec> getCodec(RssConf rssConf) {
+    return Codec.newInstance(
+        rssConf.get(RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE), 
rssConf);
+  }
+
+  public static ShuffleWriteTaskStats decode(RssConf rssConf, String raw) {
+    byte[] rawBytes = raw.getBytes(ISO_8859_1);
+    ByteBuffer outBuffer = ByteBuffer.wrap(rawBytes);
+
+    Optional<Codec> optionalCodec = getCodec(rssConf);
+    if (optionalCodec.isPresent()) {
+      ByteBuffer inBuffer = ByteBuffer.wrap(rawBytes);
+      int capacity = inBuffer.getInt();
+      outBuffer = ByteBuffer.allocate(capacity);
+      optionalCodec.get().decompress(inBuffer, capacity, outBuffer, 0);
     }
+
+    long taskId = outBuffer.getLong();
+    long taskAttemptId = outBuffer.getLong();
+    int partitions = outBuffer.getInt();
+    ShuffleWriteTaskStats stats =
+        new ShuffleWriteTaskStats(rssConf, partitions, taskAttemptId, taskId);
     for (int i = 0; i < partitions; i++) {
-      stats.partitionBlocksWritten[i] = buffer.getLong();
+      stats.partitionRecordsWritten[i] = outBuffer.getLong();
+    }
+    if (rssConf.get(RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED)) 
{
+      for (int i = 0; i < partitions; i++) {
+        stats.partitionBlocksWritten[i] = outBuffer.getLong();
+      }
     }
     return stats;
   }
@@ -111,8 +174,8 @@ public class ShuffleWriteTaskStats {
     StringBuilder infoBuilder = new StringBuilder();
     int partitions = partitionRecordsWritten.length;
     for (int i = 0; i < partitions; i++) {
-      long records = partitionRecordsWritten[i];
-      long blocks = partitionBlocksWritten[i];
+      long records = getRecordsWritten(i);
+      long blocks = getBlocksWritten(i);
       
infoBuilder.append(i).append("/").append(records).append("/").append(blocks).append(",");
     }
     LOGGER.info(
@@ -120,23 +183,20 @@ public class ShuffleWriteTaskStats {
   }
 
   public void check(long[] partitionLens) {
-    int partitions = partitionRecordsWritten.length;
     for (int idx = 0; idx < partitions; idx++) {
-      long records = partitionRecordsWritten[idx];
-      long blocks = partitionBlocksWritten[idx];
+      long records = getRecordsWritten(idx);
+      long blocks = getBlocksWritten(idx);
       long length = partitionLens[idx];
-      if (records > 0) {
-        if (blocks <= 0 || length <= 0) {
-          throw new RssException(
-              "Illegal partition:"
-                  + idx
-                  + " stats. records/blocks/length: "
-                  + records
-                  + "/"
-                  + blocks
-                  + "/"
-                  + length);
-        }
+      if (records > 0 && length <= 0) {
+        throw new RssException(
+            "Illegal partition:"
+                + idx
+                + " stats. records/blocks/length: "
+                + records
+                + "/"
+                + blocks
+                + "/"
+                + length);
       }
     }
   }
diff --git 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
index a0440ac6f..94cfa2cf5 100644
--- 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
@@ -17,17 +17,35 @@
 
 package org.apache.uniffle.shuffle;
 
-import org.junit.jupiter.api.Test;
+import java.util.stream.Stream;
 
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
+
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 
 public class ShuffleWriteTaskStatsTest {
 
-  @Test
-  public void testValidValidationInfo() {
+  private static Stream<Arguments> codecType() {
+    RssConf conf1 = new RssConf();
+    conf1.set(RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE, 
Codec.Type.LZ4);
+    RssConf conf2 = new RssConf();
+    RssConf conf3 = new RssConf();
+    conf3.set(RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE, 
Codec.Type.NONE);
+    return Stream.of(Arguments.of(conf1), Arguments.of(conf2), 
Arguments.of(conf3));
+  }
+
+  @ParameterizedTest
+  @MethodSource("codecType")
+  public void testValidValidationInfo(RssConf rssConf) {
     long taskId = 10;
     long taskAttemptId = 12345L;
-    ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(2, taskAttemptId, 
taskId);
+    ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(rssConf, 2, 
taskAttemptId, taskId);
     stats.incPartitionRecord(0);
     stats.incPartitionRecord(1);
 
@@ -35,15 +53,13 @@ public class ShuffleWriteTaskStatsTest {
     stats.incPartitionBlock(1);
 
     String encoded = stats.encode();
-    ShuffleWriteTaskStats decoded = ShuffleWriteTaskStats.decode(encoded);
+    System.out.println("Encoded length: " + encoded.length());
+    ShuffleWriteTaskStats decoded = ShuffleWriteTaskStats.decode(rssConf, 
encoded);
 
     assertEquals(10, stats.getTaskId());
 
     assertEquals(taskAttemptId, decoded.getTaskAttemptId());
     assertEquals(1, decoded.getRecordsWritten(0));
     assertEquals(1, decoded.getRecordsWritten(1));
-
-    assertEquals(1, decoded.getBlocksWritten(0));
-    assertEquals(1, decoded.getBlocksWritten(1));
   }
 }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index cc1959939..9a5cec0ac 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -582,7 +582,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     while (iter.hasNext()) {
       BlockManagerId blockManagerId = iter.next();
       ShuffleWriteTaskStats shuffleWriteTaskStats =
-          ShuffleWriteTaskStats.decode(blockManagerId.topologyInfo().get());
+          ShuffleWriteTaskStats.decode(rssConf, 
blockManagerId.topologyInfo().get());
       upstreamStats.put(shuffleWriteTaskStats.getTaskAttemptId(), 
shuffleWriteTaskStats);
     }
     return upstreamStats;
@@ -603,7 +603,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
       String raw = blockManagerId.topologyInfo().get();
       if (isIntegrityValidationEnabled(rssConf)) {
-        ShuffleWriteTaskStats shuffleWriteTaskStats = 
ShuffleWriteTaskStats.decode(raw);
+        ShuffleWriteTaskStats shuffleWriteTaskStats = 
ShuffleWriteTaskStats.decode(rssConf, raw);
         taskIdBitmap.add(shuffleWriteTaskStats.getTaskAttemptId());
         for (int i = startPartition; i < endPartition; i++) {
           expectedRecordsRead += shuffleWriteTaskStats.getRecordsWritten(i);
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index fd0ba9382..066f9368e 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -248,7 +248,10 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       this.shuffleTaskStats =
           Optional.of(
               new ShuffleWriteTaskStats(
-                  partitioner.numPartitions(), taskAttemptId, 
taskContext.taskAttemptId()));
+                  rssConf,
+                  partitioner.numPartitions(),
+                  taskAttemptId,
+                  taskContext.taskAttemptId()));
     }
   }
 
diff --git 
a/common/src/main/java/org/apache/uniffle/common/compression/Codec.java 
b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
index eb3b2d770..440e13785 100644
--- a/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
+++ b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
@@ -37,6 +37,10 @@ public abstract class Codec {
 
   public static Optional<Codec> newInstance(RssConf rssConf) {
     Type type = rssConf.get(COMPRESSION_TYPE);
+    return newInstance(type, rssConf);
+  }
+
+  public static Optional<Codec> newInstance(Codec.Type type, RssConf rssConf) {
     switch (type) {
       case NONE:
         return Optional.empty();

Reply via email to