This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new a37936f2b [#2652] feat(spark): Add compression for task write stats
(#2666)
a37936f2b is described below
commit a37936f2bc34f6c38082ea85559ea70ebfdc0f26
Author: Junfan Zhang <[email protected]>
AuthorDate: Fri Nov 7 13:56:36 2025 +0800
[#2652] feat(spark): Add compression for task write stats (#2666)
### What changes were proposed in this pull request?
1. Add compression for task write stats
2. Optional blocks number check mechanism (disabled by default)
### Why are the changes needed?
To reduce the task write stats size
### Does this PR introduce _any_ user-facing change?
`spark.rss.client.integrityValidation.blockNumberCheckEnabled=false`
### How was this patch tested?
Unit tests.
---
.../org/apache/spark/shuffle/RssSparkConfig.java | 15 +++
.../uniffle/shuffle/ShuffleReadTaskStats.java | 4 +-
.../uniffle/shuffle/ShuffleWriteTaskStats.java | 136 +++++++++++++++------
.../uniffle/shuffle/ShuffleWriteTaskStatsTest.java | 32 +++--
.../apache/spark/shuffle/RssShuffleManager.java | 4 +-
.../spark/shuffle/writer/RssShuffleWriter.java | 5 +-
.../apache/uniffle/common/compression/Codec.java | 4 +
7 files changed, 150 insertions(+), 50 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 5eef97bfa..e83bc437f 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -31,6 +31,7 @@ import org.apache.spark.internal.config.ConfigEntry;
import org.apache.spark.internal.config.TypedConfigBuilder;
import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.ConfigOption;
import org.apache.uniffle.common.config.ConfigOptions;
import org.apache.uniffle.common.config.ConfigUtils;
@@ -52,6 +53,20 @@ public class RssSparkConfig {
.defaultValue(false)
.withDescription("Whether or not to enable validation failure
analysis");
+ public static final ConfigOption<Boolean>
+ RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED =
+
ConfigOptions.key("rss.client.integrityValidation.blockNumberCheckEnabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription("Whether or not to enable validation block
number check");
+
+ public static final ConfigOption<Codec.Type>
+ RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE =
+
ConfigOptions.key("rss.client.integrityValidation.statsCompressionType")
+ .enumType(Codec.Type.class)
+ .defaultValue(Codec.Type.ZSTD)
+ .withDescription("stats compression type");
+
public static final ConfigOption<Boolean>
RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED =
ConfigOptions.key("rss.client.read.shuffleHandleCacheEnabled")
.booleanType()
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
index d78f24a2f..0898648a5 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleReadTaskStats.java
@@ -70,8 +70,10 @@ public class ShuffleReadTaskStats {
continue;
}
long recordsUpstream = stats.getRecordsWritten(idx);
+ // If blocksUpstream is less than 0, it indicates that the block
number check is disabled.
long blocksUpstream = stats.getBlocksWritten(idx);
- if (recordsRead != recordsUpstream || blocksRead != blocksUpstream) {
+ if (recordsRead != recordsUpstream
+ || (blocksUpstream >= 0 && blocksRead != blocksUpstream)) {
infoBuilder.append(idx);
infoBuilder.append("/");
infoBuilder.append(stats.getTaskId());
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
index e86da494b..4f226569d 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStats.java
@@ -19,13 +19,18 @@ package org.apache.uniffle.shuffle;
import java.nio.ByteBuffer;
import java.util.Arrays;
+import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import static java.nio.charset.StandardCharsets.ISO_8859_1;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED;
/**
* ShuffleWriteTaskStats stores statistics for a shuffle write task attempt,
including the task
@@ -38,17 +43,31 @@ public class ShuffleWriteTaskStats {
private long taskId;
// this is only unique for one stage and defined in uniffle side instead of
spark
private long taskAttemptId;
+ private int partitions;
private long[] partitionRecordsWritten;
private long[] partitionBlocksWritten;
+ private boolean blockNumberCheckEnabled;
+ private RssConf rssConf;
- public ShuffleWriteTaskStats(int partitions, long taskAttemptId, long
taskId) {
+ public ShuffleWriteTaskStats(RssConf rssConf, int partitions, long
taskAttemptId, long taskId) {
this.partitionRecordsWritten = new long[partitions];
- this.partitionBlocksWritten = new long[partitions];
+ Arrays.fill(this.partitionRecordsWritten, 0L);
+
+ this.partitions = partitions;
this.taskAttemptId = taskAttemptId;
this.taskId = taskId;
+ this.rssConf = rssConf;
+ this.blockNumberCheckEnabled =
+ rssConf.get(RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED);
- Arrays.fill(this.partitionRecordsWritten, 0L);
- Arrays.fill(this.partitionBlocksWritten, 0L);
+ if (blockNumberCheckEnabled) {
+ this.partitionBlocksWritten = new long[partitions];
+ Arrays.fill(this.partitionBlocksWritten, 0L);
+ }
+ }
+
+ public ShuffleWriteTaskStats(int partitions, long taskAttemptId, long
taskId) {
+ this(new RssConf(), partitions, taskAttemptId, taskId);
}
public long getRecordsWritten(int partitionId) {
@@ -60,11 +79,16 @@ public class ShuffleWriteTaskStats {
}
public void incPartitionBlock(int partitionId) {
- partitionBlocksWritten[partitionId]++;
+ if (blockNumberCheckEnabled) {
+ partitionBlocksWritten[partitionId]++;
+ }
}
public long getBlocksWritten(int partitionId) {
- return partitionBlocksWritten[partitionId];
+ if (blockNumberCheckEnabled) {
+ return partitionBlocksWritten[partitionId];
+ }
+ return -1L;
}
public long getTaskAttemptId() {
@@ -72,33 +96,72 @@ public class ShuffleWriteTaskStats {
}
public String encode() {
+ final long start = System.currentTimeMillis();
int partitions = partitionRecordsWritten.length;
- ByteBuffer buffer =
- ByteBuffer.allocate(2 * Long.BYTES + Integer.BYTES + partitions *
Long.BYTES * 2);
+ int capacity = 2 * Long.BYTES + Integer.BYTES + partitions * Long.BYTES;
+ if (blockNumberCheckEnabled) {
+ capacity += partitions * Long.BYTES;
+ }
+ ByteBuffer buffer = ByteBuffer.allocate(capacity);
buffer.putLong(taskId);
buffer.putLong(taskAttemptId);
buffer.putInt(partitions);
for (long records : partitionRecordsWritten) {
buffer.putLong(records);
}
- for (long blocks : partitionBlocksWritten) {
- buffer.putLong(blocks);
+ if (blockNumberCheckEnabled) {
+ for (long blocks : partitionBlocksWritten) {
+ buffer.putLong(blocks);
+ }
+ }
+ Optional<Codec> optionalCodec = getCodec(rssConf);
+ if (optionalCodec.isPresent()) {
+ Codec codec = optionalCodec.get();
+ byte[] compressed = codec.compress(buffer.array());
+ ByteBuffer compositedBuffer = ByteBuffer.allocate(Integer.BYTES +
compressed.length);
+ compositedBuffer.putInt(capacity);
+ compositedBuffer.put(compressed);
+ LOGGER.info(
+ "Encoded task stats for {} partitions with {} bytes (original: {}
bytes) in {} ms",
+ partitions,
+ compositedBuffer.capacity(),
+ capacity,
+ System.currentTimeMillis() - start);
+ return new String(compositedBuffer.array(), ISO_8859_1);
+ } else {
+ return new String(buffer.array(), ISO_8859_1);
}
- return new String(buffer.array(), ISO_8859_1);
}
- public static ShuffleWriteTaskStats decode(String raw) {
- byte[] bytes = raw.getBytes(ISO_8859_1);
- ByteBuffer buffer = ByteBuffer.wrap(bytes);
- long taskId = buffer.getLong();
- long taskAttemptId = buffer.getLong();
- int partitions = buffer.getInt();
- ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(partitions,
taskAttemptId, taskId);
- for (int i = 0; i < partitions; i++) {
- stats.partitionRecordsWritten[i] = buffer.getLong();
+ private static Optional<Codec> getCodec(RssConf rssConf) {
+ return Codec.newInstance(
+ rssConf.get(RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE),
rssConf);
+ }
+
+ public static ShuffleWriteTaskStats decode(RssConf rssConf, String raw) {
+ byte[] rawBytes = raw.getBytes(ISO_8859_1);
+ ByteBuffer outBuffer = ByteBuffer.wrap(rawBytes);
+
+ Optional<Codec> optionalCodec = getCodec(rssConf);
+ if (optionalCodec.isPresent()) {
+ ByteBuffer inBuffer = ByteBuffer.wrap(rawBytes);
+ int capacity = inBuffer.getInt();
+ outBuffer = ByteBuffer.allocate(capacity);
+ optionalCodec.get().decompress(inBuffer, capacity, outBuffer, 0);
}
+
+ long taskId = outBuffer.getLong();
+ long taskAttemptId = outBuffer.getLong();
+ int partitions = outBuffer.getInt();
+ ShuffleWriteTaskStats stats =
+ new ShuffleWriteTaskStats(rssConf, partitions, taskAttemptId, taskId);
for (int i = 0; i < partitions; i++) {
- stats.partitionBlocksWritten[i] = buffer.getLong();
+ stats.partitionRecordsWritten[i] = outBuffer.getLong();
+ }
+ if (rssConf.get(RSS_DATA_INTEGRITY_VALIDATION_BLOCK_NUMBER_CHECK_ENABLED))
{
+ for (int i = 0; i < partitions; i++) {
+ stats.partitionBlocksWritten[i] = outBuffer.getLong();
+ }
}
return stats;
}
@@ -111,8 +174,8 @@ public class ShuffleWriteTaskStats {
StringBuilder infoBuilder = new StringBuilder();
int partitions = partitionRecordsWritten.length;
for (int i = 0; i < partitions; i++) {
- long records = partitionRecordsWritten[i];
- long blocks = partitionBlocksWritten[i];
+ long records = getRecordsWritten(i);
+ long blocks = getBlocksWritten(i);
infoBuilder.append(i).append("/").append(records).append("/").append(blocks).append(",");
}
LOGGER.info(
@@ -120,23 +183,20 @@ public class ShuffleWriteTaskStats {
}
public void check(long[] partitionLens) {
- int partitions = partitionRecordsWritten.length;
for (int idx = 0; idx < partitions; idx++) {
- long records = partitionRecordsWritten[idx];
- long blocks = partitionBlocksWritten[idx];
+ long records = getRecordsWritten(idx);
+ long blocks = getBlocksWritten(idx);
long length = partitionLens[idx];
- if (records > 0) {
- if (blocks <= 0 || length <= 0) {
- throw new RssException(
- "Illegal partition:"
- + idx
- + " stats. records/blocks/length: "
- + records
- + "/"
- + blocks
- + "/"
- + length);
- }
+ if (records > 0 && length <= 0) {
+ throw new RssException(
+ "Illegal partition:"
+ + idx
+ + " stats. records/blocks/length: "
+ + records
+ + "/"
+ + blocks
+ + "/"
+ + length);
}
}
}
diff --git
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
index a0440ac6f..94cfa2cf5 100644
---
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
+++
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ShuffleWriteTaskStatsTest.java
@@ -17,17 +17,35 @@
package org.apache.uniffle.shuffle;
-import org.junit.jupiter.api.Test;
+import java.util.stream.Stream;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
+
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class ShuffleWriteTaskStatsTest {
- @Test
- public void testValidValidationInfo() {
+ private static Stream<Arguments> codecType() {
+ RssConf conf1 = new RssConf();
+ conf1.set(RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE,
Codec.Type.LZ4);
+ RssConf conf2 = new RssConf();
+ RssConf conf3 = new RssConf();
+ conf3.set(RSS_CLIENT_INTEGRITY_VALIDATION_STATS_COMPRESSION_TYPE,
Codec.Type.NONE);
+ return Stream.of(Arguments.of(conf1), Arguments.of(conf2),
Arguments.of(conf3));
+ }
+
+ @ParameterizedTest
+ @MethodSource("codecType")
+ public void testValidValidationInfo(RssConf rssConf) {
long taskId = 10;
long taskAttemptId = 12345L;
- ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(2, taskAttemptId,
taskId);
+ ShuffleWriteTaskStats stats = new ShuffleWriteTaskStats(rssConf, 2,
taskAttemptId, taskId);
stats.incPartitionRecord(0);
stats.incPartitionRecord(1);
@@ -35,15 +53,13 @@ public class ShuffleWriteTaskStatsTest {
stats.incPartitionBlock(1);
String encoded = stats.encode();
- ShuffleWriteTaskStats decoded = ShuffleWriteTaskStats.decode(encoded);
+ System.out.println("Encoded length: " + encoded.length());
+ ShuffleWriteTaskStats decoded = ShuffleWriteTaskStats.decode(rssConf,
encoded);
assertEquals(10, stats.getTaskId());
assertEquals(taskAttemptId, decoded.getTaskAttemptId());
assertEquals(1, decoded.getRecordsWritten(0));
assertEquals(1, decoded.getRecordsWritten(1));
-
- assertEquals(1, decoded.getBlocksWritten(0));
- assertEquals(1, decoded.getBlocksWritten(1));
}
}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index cc1959939..9a5cec0ac 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -582,7 +582,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
while (iter.hasNext()) {
BlockManagerId blockManagerId = iter.next();
ShuffleWriteTaskStats shuffleWriteTaskStats =
- ShuffleWriteTaskStats.decode(blockManagerId.topologyInfo().get());
+ ShuffleWriteTaskStats.decode(rssConf,
blockManagerId.topologyInfo().get());
upstreamStats.put(shuffleWriteTaskStats.getTaskAttemptId(),
shuffleWriteTaskStats);
}
return upstreamStats;
@@ -603,7 +603,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
String raw = blockManagerId.topologyInfo().get();
if (isIntegrityValidationEnabled(rssConf)) {
- ShuffleWriteTaskStats shuffleWriteTaskStats =
ShuffleWriteTaskStats.decode(raw);
+ ShuffleWriteTaskStats shuffleWriteTaskStats =
ShuffleWriteTaskStats.decode(rssConf, raw);
taskIdBitmap.add(shuffleWriteTaskStats.getTaskAttemptId());
for (int i = startPartition; i < endPartition; i++) {
expectedRecordsRead += shuffleWriteTaskStats.getRecordsWritten(i);
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index fd0ba9382..066f9368e 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -248,7 +248,10 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.shuffleTaskStats =
Optional.of(
new ShuffleWriteTaskStats(
- partitioner.numPartitions(), taskAttemptId,
taskContext.taskAttemptId()));
+ rssConf,
+ partitioner.numPartitions(),
+ taskAttemptId,
+ taskContext.taskAttemptId()));
}
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
index eb3b2d770..440e13785 100644
--- a/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
+++ b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
@@ -37,6 +37,10 @@ public abstract class Codec {
public static Optional<Codec> newInstance(RssConf rssConf) {
Type type = rssConf.get(COMPRESSION_TYPE);
+ return newInstance(type, rssConf);
+ }
+
+ public static Optional<Codec> newInstance(Codec.Type type, RssConf rssConf) {
switch (type) {
case NONE:
return Optional.empty();