This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 0495fbca9 [#2056] feat(client): Add a NONE type to bypass 
de/compression for gluten (#2057)
0495fbca9 is described below

commit 0495fbca9509e74d07eb370aae4a9a4acd7a6e23
Author: Zhen Wang <[email protected]>
AuthorDate: Tue Aug 20 10:48:11 2024 +0800

    [#2056] feat(client): Add a NONE type to bypass de/compression for gluten 
(#2057)
    
    ### What changes were proposed in this pull request?
    
    add a NONE type to compression.codec
    
    ### Why are the changes needed?
    
    Allow disabling rss client compression when spark.shuffle.compress is 
enabled
    
    Fix: #2056
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes. add a NONE compression type
    
    ### How was this patch tested?
    
    existing unit tests
---
 .../hadoop/mapred/SortWriteBufferManager.java      |  5 ++--
 .../hadoop/mapreduce/task/reduce/RssFetcher.java   | 25 ++++++++++------
 .../shuffle/reader/RssShuffleDataIterator.java     | 13 ++++----
 .../spark/shuffle/writer/WriteBufferManager.java   |  9 +++---
 .../shuffle/reader/AbstractRssReaderTest.java      |  6 ++--
 .../shuffle/reader/RssShuffleDataIteratorTest.java |  9 ++++--
 .../shuffle/writer/WriteBufferManagerTest.java     |  8 +++--
 .../library/common/shuffle/impl/RssTezFetcher.java | 25 ++++++++++------
 .../orderedgrouped/RssTezShuffleDataFetcher.java   | 25 ++++++++++------
 .../common/sort/buffer/WriteBufferManager.java     |  5 ++--
 .../apache/uniffle/common/compression/Codec.java   | 14 +++++----
 .../uniffle/common/util/ByteBufferUtils.java       | 35 ++++++++++++++++++++++
 .../common/compression/CompressionTest.java        |  2 +-
 13 files changed, 126 insertions(+), 55 deletions(-)

diff --git 
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
 
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
index 1860fe856..b31766652 100644
--- 
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
+++ 
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
@@ -23,6 +23,7 @@ import java.util.Comparator;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -96,7 +97,7 @@ public class SortWriteBufferManager<K, V> {
   private final long maxBufferSize;
   private final ExecutorService sendExecutorService;
   private final RssConf rssConf;
-  private final Codec codec;
+  private final Optional<Codec> codec;
   private final Task.CombinerRunner<K, V> combinerRunner;
 
   public SortWriteBufferManager(
@@ -383,7 +384,7 @@ public class SortWriteBufferManager<K, V> {
     int partitionId = wb.getPartitionId();
     final int uncompressLength = data.length;
     long start = System.currentTimeMillis();
-    final byte[] compressed = codec.compress(data);
+    final byte[] compressed = codec.map(c -> c.compress(data)).orElse(data);
     final long crc32 = ChecksumUtils.getCrc32(compressed);
     compressTime += System.currentTimeMillis() - start;
     final long blockId =
diff --git 
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
 
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
index b07581a2e..0e41490f8 100644
--- 
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
+++ 
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
@@ -20,6 +20,7 @@ package org.apache.hadoop.mapreduce.task.reduce;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.text.DecimalFormat;
+import java.util.Optional;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.apache.hadoop.mapred.Counters;
@@ -38,6 +39,7 @@ import 
org.apache.uniffle.client.response.CompressedShuffleBlock;
 import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ByteBufferUtils;
 import org.apache.uniffle.common.util.ByteUnit;
 
 public class RssFetcher<K, V> {
@@ -90,7 +92,7 @@ public class RssFetcher<K, V> {
   private int waitCount = 0;
   private byte[] uncompressedData = null;
   private RssConf rssConf;
-  private Codec codec;
+  private Optional<Codec> codec;
 
   RssFetcher(
       JobConf job,
@@ -161,14 +163,19 @@ public class RssFetcher<K, V> {
 
     // uncompress the block
     if (!hasPendingData && compressedData != null) {
-      final long startDecompress = System.currentTimeMillis();
-      int uncompressedLen = compressedBlock.getUncompressLength();
-      ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
-      codec.decompress(compressedData, uncompressedLen, decompressedBuffer, 0);
-      uncompressedData = decompressedBuffer.array();
-      unCompressionLength += compressedBlock.getUncompressLength();
-      long decompressDuration = System.currentTimeMillis() - startDecompress;
-      decompressTime += decompressDuration;
+      if (codec.isPresent()) {
+        final long startDecompress = System.currentTimeMillis();
+        int uncompressedLen = compressedBlock.getUncompressLength();
+        ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+        codec.get().decompress(compressedData, uncompressedLen, 
decompressedBuffer, 0);
+        uncompressedData = decompressedBuffer.array();
+        unCompressionLength += compressedBlock.getUncompressLength();
+        long decompressDuration = System.currentTimeMillis() - startDecompress;
+        decompressTime += decompressDuration;
+      } else {
+        uncompressedData = ByteBufferUtils.bufferToArray(compressedData);
+        unCompressionLength += uncompressedData.length;
+      }
     }
 
     if (uncompressedData != null) {
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
index 88b2d22d8..4f9900ce7 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.reader;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.Optional;
 
 import scala.Product2;
 import scala.Tuple2;
@@ -59,7 +60,7 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
   private long totalRawBytesLength = 0;
   private long unCompressedBytesLength = 0;
   private ByteBuffer uncompressedData;
-  private Codec codec;
+  private Optional<Codec> codec;
 
   public RssShuffleDataIterator(
       Serializer serializer,
@@ -74,7 +75,7 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
             RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY.substring(
                 RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
             RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
-    this.codec = compress ? Codec.newInstance(rssConf) : null;
+    this.codec = compress ? Codec.newInstance(rssConf) : Optional.empty();
   }
 
   public Iterator<Tuple2<Object, Object>> createKVIterator(ByteBuffer data) {
@@ -131,7 +132,7 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
         shuffleReadClient.checkProcessedBlockIds();
         shuffleReadClient.logStatics();
         String decInfo =
-            codec == null
+            !codec.isPresent()
                 ? "."
                 : (", "
                     + decompressTime
@@ -160,7 +161,7 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
     shuffleReadMetrics.incRemoteBytesRead(rawDataLength);
 
     int uncompressedLen = rawBlock.getUncompressLength();
-    if (codec != null) {
+    if (codec.isPresent()) {
       if (uncompressedData == null
           || uncompressedData.capacity() < uncompressedLen
           || !isSameMemoryType(uncompressedData, rawData)) {
@@ -185,7 +186,7 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
       }
       uncompressedData.clear();
       long startDecompress = System.currentTimeMillis();
-      codec.decompress(rawData, uncompressedLen, uncompressedData, 0);
+      codec.get().decompress(rawData, uncompressedLen, uncompressedData, 0);
       unCompressedBytesLength += uncompressedLen;
       long decompressDuration = System.currentTimeMillis() - startDecompress;
       decompressTime += decompressDuration;
@@ -210,7 +211,7 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
     // Uncompressed data is released in this class, Compressed data is release 
in the class
     // ShuffleReadClientImpl
     // So if codec is null, we don't release the data when the stream is closed
-    if (codec != null) {
+    if (codec.isPresent()) {
       RssUtils.releaseByteBuffer(uncompressedData);
     }
     if (shuffleReadClient != null) {
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index bfd929777..08eec1c2e 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -22,6 +22,7 @@ import java.util.Collections;
 import java.util.Comparator;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
@@ -91,7 +92,7 @@ public class WriteBufferManager extends MemoryConsumer {
   private long uncompressedDataLen = 0;
   private long requireMemoryInterval;
   private int requireMemoryRetryMax;
-  private Codec codec;
+  private Optional<Codec> codec;
   private Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> 
spillFunc;
   private long sendSizeLimit;
   private boolean memorySpillEnabled;
@@ -159,7 +160,7 @@ public class WriteBufferManager extends MemoryConsumer {
             RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY.substring(
                 RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
             RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
-    this.codec = compress ? Codec.newInstance(rssConf) : null;
+    this.codec = compress ? Codec.newInstance(rssConf) : Optional.empty();
     this.spillFunc = spillFunc;
     this.sendSizeLimit = 
rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION);
     this.memorySpillTimeoutSec = 
rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
@@ -384,9 +385,9 @@ public class WriteBufferManager extends MemoryConsumer {
     byte[] data = wb.getData();
     final int uncompressLength = data.length;
     byte[] compressed = data;
-    if (codec != null) {
+    if (codec.isPresent()) {
       long start = System.currentTimeMillis();
-      compressed = codec.compress(data);
+      compressed = codec.get().compress(data);
       compressTime += System.currentTimeMillis() - start;
     }
     final long crc32 = ChecksumUtils.getCrc32(compressed);
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
index f761c6ea6..7099fd9eb 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.reader;
 
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -171,8 +172,9 @@ public abstract class AbstractRssReaderTest extends 
HadoopTestBase {
   protected ShufflePartitionedBlock createShuffleBlock(
       byte[] data, long blockId, boolean compress) {
     byte[] compressData = data;
-    if (compress) {
-      compressData = Codec.newInstance(new RssConf()).compress(data);
+    Optional<Codec> codec = Codec.newInstance(new RssConf());
+    if (compress && codec.isPresent()) {
+      compressData = codec.get().compress(data);
     }
     long crc = ChecksumUtils.getCrc32(compressData);
     return new ShufflePartitionedBlock(
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
index 3f6993c82..5550f67c0 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.reader;
 import java.nio.ByteBuffer;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.stream.Stream;
 
 import com.google.common.collect.Lists;
@@ -46,6 +47,7 @@ import org.apache.uniffle.client.factory.ShuffleClientFactory;
 import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
 import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.ChecksumUtils;
@@ -321,11 +323,12 @@ public class RssShuffleDataIteratorTest extends 
AbstractRssReaderTest {
     RssShuffleDataIterator rssShuffleDataIterator =
         getDataIterator(
             basePath, blockIdBitmap, taskIdBitmap, Lists.newArrayList(ssi1, 
ssi2), compress);
-    Object codec = FieldUtils.readField(rssShuffleDataIterator, "codec", true);
+    Optional<Codec> codec =
+        (Optional<Codec>) FieldUtils.readField(rssShuffleDataIterator, 
"codec", true);
     if (compress) {
-      Assertions.assertNotNull(codec);
+      Assertions.assertTrue(codec.isPresent());
     } else {
-      Assertions.assertNull(codec);
+      Assertions.assertFalse(codec.isPresent());
     }
 
     validateResult(rssShuffleDataIterator, expectedData, 20);
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 49ebeef25..19c9f6d10 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -21,6 +21,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
@@ -46,6 +47,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.util.BlockIdLayout;
@@ -122,11 +124,11 @@ public class WriteBufferManagerTest {
       conf.set(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY, 
String.valueOf(false));
     }
     WriteBufferManager wbm = createManager(conf);
-    Object codec = FieldUtils.readField(wbm, "codec", true);
+    Optional<Codec> codec = (Optional<Codec>) FieldUtils.readField(wbm, 
"codec", true);
     if (compress) {
-      Assertions.assertNotNull(codec);
+      Assertions.assertTrue(codec.isPresent());
     } else {
-      Assertions.assertNull(codec);
+      Assertions.assertFalse(codec.isPresent());
     }
     wbm.setShuffleWriteMetrics(new ShuffleWriteMetrics());
     String testKey = "Key";
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
index 5ff38333c..7edb4f443 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
@@ -20,6 +20,7 @@ package org.apache.tez.runtime.library.common.shuffle.impl;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.Map;
+import java.util.Optional;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
@@ -36,6 +37,7 @@ import 
org.apache.uniffle.client.response.CompressedShuffleBlock;
 import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ByteBufferUtils;
 
 public class RssTezFetcher {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssTezFetcher.class);
@@ -62,7 +64,7 @@ public class RssTezFetcher {
   private long startWait;
   private int waitCount = 0;
   private byte[] uncompressedData = null;
-  private Codec codec;
+  private Optional<Codec> codec;
 
   RssTezFetcher(
       FetcherCallback fetcherCallback,
@@ -109,14 +111,19 @@ public class RssTezFetcher {
 
     // uncompress the block
     if (!hasPendingData && compressedData != null) {
-      final long startDecompress = System.currentTimeMillis();
-      int uncompressedLen = compressedBlock.getUncompressLength();
-      ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
-      codec.decompress(compressedData, uncompressedLen, decompressedBuffer, 0);
-      uncompressedData = decompressedBuffer.array();
-      unCompressionLength += compressedBlock.getUncompressLength();
-      long decompressDuration = System.currentTimeMillis() - startDecompress;
-      decompressTime += decompressDuration;
+      if (codec.isPresent()) {
+        final long startDecompress = System.currentTimeMillis();
+        int uncompressedLen = compressedBlock.getUncompressLength();
+        ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+        codec.get().decompress(compressedData, uncompressedLen, 
decompressedBuffer, 0);
+        uncompressedData = decompressedBuffer.array();
+        unCompressionLength += compressedBlock.getUncompressLength();
+        long decompressDuration = System.currentTimeMillis() - startDecompress;
+        decompressTime += decompressDuration;
+      } else {
+        uncompressedData = ByteBufferUtils.bufferToArray(compressedData);
+        unCompressionLength += uncompressedData.length;
+      }
     }
 
     if (uncompressedData != null) {
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
index 992f509d7..06de81013 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
@@ -19,6 +19,7 @@ package 
org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.Optional;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -34,6 +35,7 @@ import 
org.apache.uniffle.client.response.CompressedShuffleBlock;
 import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ByteBufferUtils;
 
 public class RssTezShuffleDataFetcher extends CallableWithNdc<Void> {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssTezShuffleDataFetcher.class);
@@ -70,7 +72,7 @@ public class RssTezShuffleDataFetcher extends 
CallableWithNdc<Void> {
   private long startWait;
   private int waitCount = 0;
   private byte[] uncompressedData = null;
-  private final Codec rssCodec;
+  private final Optional<Codec> rssCodec;
   private Integer partitionId;
   private final ExceptionReporter exceptionReporter;
 
@@ -151,14 +153,19 @@ public class RssTezShuffleDataFetcher extends 
CallableWithNdc<Void> {
 
     // uncompress the block
     if (!hasPendingData && compressedData != null) {
-      final long startDecompress = System.currentTimeMillis();
-      int uncompressedLen = compressedBlock.getUncompressLength();
-      ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
-      rssCodec.decompress(compressedData, uncompressedLen, decompressedBuffer, 
0);
-      uncompressedData = decompressedBuffer.array();
-      unCompressionLength += compressedBlock.getUncompressLength();
-      long decompressDuration = System.currentTimeMillis() - startDecompress;
-      decompressTime += decompressDuration;
+      if (rssCodec.isPresent()) {
+        final long startDecompress = System.currentTimeMillis();
+        int uncompressedLen = compressedBlock.getUncompressLength();
+        ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+        rssCodec.get().decompress(compressedData, uncompressedLen, 
decompressedBuffer, 0);
+        uncompressedData = decompressedBuffer.array();
+        unCompressionLength += compressedBlock.getUncompressLength();
+        long decompressDuration = System.currentTimeMillis() - startDecompress;
+        decompressTime += decompressDuration;
+      } else {
+        uncompressedData = ByteBufferUtils.bufferToArray(compressedData);
+        unCompressionLength += uncompressedData.length;
+      }
     }
 
     if (uncompressedData != null) {
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
index 53cfeba45..93735efa4 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
@@ -23,6 +23,7 @@ import java.util.Comparator;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -84,7 +85,7 @@ public class WriteBufferManager<K, V> {
   private final double memoryThreshold;
   private final double sendThreshold;
   private final int batch;
-  private final Codec codec;
+  private final Optional<Codec> codec;
   private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
   private final Set<Long> allBlockIds = Sets.newConcurrentHashSet();
   // server -> partitionId -> blockIds
@@ -370,7 +371,7 @@ public class WriteBufferManager<K, V> {
     final int uncompressLength = data.length;
     long start = System.currentTimeMillis();
 
-    final byte[] compressed = codec.compress(data);
+    final byte[] compressed = codec.map(c -> c.compress(data)).orElse(data);
     final long crc32 = ChecksumUtils.getCrc32(compressed);
     compressTime += System.currentTimeMillis() - start;
     final long blockId =
diff --git 
a/common/src/main/java/org/apache/uniffle/common/compression/Codec.java 
b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
index b2ac5f0bb..72c69dc06 100644
--- a/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
+++ b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.common.compression;
 
 import java.nio.ByteBuffer;
+import java.util.Optional;
 
 import org.apache.uniffle.common.config.RssConf;
 
@@ -26,18 +27,20 @@ import static 
org.apache.uniffle.common.config.RssClientConf.ZSTD_COMPRESSION_LE
 
 public abstract class Codec {
 
-  public static Codec newInstance(RssConf rssConf) {
+  public static Optional<Codec> newInstance(RssConf rssConf) {
     Type type = rssConf.get(COMPRESSION_TYPE);
     switch (type) {
+      case NONE:
+        return Optional.empty();
       case ZSTD:
-        return ZstdCodec.getInstance(rssConf.get(ZSTD_COMPRESSION_LEVEL));
+        return 
Optional.of(ZstdCodec.getInstance(rssConf.get(ZSTD_COMPRESSION_LEVEL)));
       case SNAPPY:
-        return SnappyCodec.getInstance();
+        return Optional.of(SnappyCodec.getInstance());
       case NOOP:
-        return NoOpCodec.getInstance();
+        return Optional.of(NoOpCodec.getInstance());
       case LZ4:
       default:
-        return Lz4Codec.getInstance();
+        return Optional.of(Lz4Codec.getInstance());
     }
   }
 
@@ -72,5 +75,6 @@ public abstract class Codec {
     ZSTD,
     NOOP,
     SNAPPY,
+    NONE,
   }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/util/ByteBufferUtils.java 
b/common/src/main/java/org/apache/uniffle/common/util/ByteBufferUtils.java
new file mode 100644
index 000000000..f32f3d5a8
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/util/ByteBufferUtils.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.nio.ByteBuffer;
+
+public class ByteBufferUtils {
+
+  public static byte[] bufferToArray(ByteBuffer buffer) {
+    if (buffer.hasArray()
+        && buffer.arrayOffset() == 0
+        && buffer.array().length == buffer.remaining()) {
+      return buffer.array();
+    } else {
+      byte[] bytes = new byte[buffer.remaining()];
+      buffer.get(bytes);
+      return bytes;
+    }
+  }
+}
diff --git 
a/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
 
b/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
index 629ad4728..ac5af5aa7 100644
--- 
a/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
+++ 
b/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
@@ -56,7 +56,7 @@ public class CompressionTest {
     conf.set(COMPRESSION_TYPE, type);
 
     // case1: heap bytebuffer
-    Codec codec = Codec.newInstance(conf);
+    Codec codec = Codec.newInstance(conf).get();
     byte[] compressed = codec.compress(data);
 
     ByteBuffer dest = ByteBuffer.allocate(size);

Reply via email to