This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 74a6a8d58 feat: Move shuffle block decompression and decoding to 
native code and add LZ4 & Snappy support (#1192)
74a6a8d58 is described below

commit 74a6a8d58f931e1cedaaac0ef59616d5b6b27fa2
Author: Andy Grove <[email protected]>
AuthorDate: Mon Jan 6 17:47:16 2025 -0700

    feat: Move shuffle block decompression and decoding to native code and add 
LZ4 & Snappy support (#1192)
    
    * Implement native decoding and decompression
    
    * revert some variable renaming for smaller diff
    
    * fix oom issues?
    
    * make NativeBatchDecoderIterator more consistent with ArrowReaderIterator
    
    * fix oom and prep for review
    
    * format
    
    * Add LZ4 support
    
    * clippy, new benchmark
    
    * rename metrics, clean up lz4 code
    
    * update test
    
    * Add support for snappy
    
    * format
    
    * change default back to lz4
    
    * make metrics more accurate
    
    * format
    
    * clippy
    
    * use faster unsafe version of lz4_flex
    
    * Make compression codec configurable for columnar shuffle
    
    * clippy
    
    * fix bench
    
    * fmt
    
    * address feedback
    
    * address feedback
    
    * address feedback
    
    * minor code simplification
    
    * cargo fmt
    
    * overflow check
    
    * rename compression level config
    
    * address feedback
    
    * address feedback
    
    * rename constant
---
 .../main/scala/org/apache/comet/CometConf.scala    |  23 +--
 docs/source/user-guide/configs.md                  |   4 +-
 native/Cargo.lock                                  |  48 +++---
 native/core/Cargo.toml                             |   3 +
 native/core/benches/row_columnar.rs                |   2 +
 native/core/benches/shuffle_writer.rs              |  41 ++++-
 native/core/src/execution/jni_api.rs               |  47 +++++-
 native/core/src/execution/planner.rs               |   2 +
 native/core/src/execution/shuffle/mod.rs           |   4 +-
 native/core/src/execution/shuffle/row.rs           |   5 +-
 .../core/src/execution/shuffle/shuffle_writer.rs   | 132 ++++++++++++---
 native/proto/src/proto/operator.proto              |   2 +
 .../shuffle/sort/CometShuffleExternalSorter.java   |  17 +-
 .../execution/shuffle/CometDiskBlockWriter.java    |  14 +-
 .../sql/comet/execution/shuffle/SpillWriter.java   |   8 +-
 spark/src/main/scala/org/apache/comet/Native.scala |  30 +++-
 .../apache/spark/sql/comet/CometMetricNode.scala   |   5 +-
 .../shuffle/CometBlockStoreShuffleReader.scala     |  34 ++--
 .../execution/shuffle/CometShuffleDependency.scala |   4 +-
 .../shuffle/CometShuffleExchangeExec.scala         |  13 +-
 .../execution/shuffle/IpcInputStreamIterator.scala | 126 --------------
 .../shuffle/NativeBatchDecoderIterator.scala       | 188 +++++++++++++++++++++
 .../comet/exec/CometColumnarShuffleSuite.scala     |   3 +-
 23 files changed, 524 insertions(+), 231 deletions(-)

diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala 
b/common/src/main/scala/org/apache/comet/CometConf.scala
index 4d63de75a..4115fa9ab 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -272,18 +272,19 @@ object CometConf extends ShimCometConf {
       .booleanConf
       .createWithDefault(false)
 
-  val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = conf(
-    s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec")
-    .doc(
-      "The codec of Comet native shuffle used to compress shuffle data. Only 
zstd is supported. " +
-        "Compression can be disabled by setting spark.shuffle.compress=false.")
-    .stringConf
-    .checkValues(Set("zstd"))
-    .createWithDefault("zstd")
+  val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] =
+    conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec")
+      .doc(
+        "The codec of Comet native shuffle used to compress shuffle data. lz4, 
zstd, and " +
+          "snappy are supported. Compression can be disabled by setting " +
+          "spark.shuffle.compress=false.")
+      .stringConf
+      .checkValues(Set("zstd", "lz4", "snappy"))
+      .createWithDefault("lz4")
 
-  val COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL: ConfigEntry[Int] =
-    conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.level")
-      .doc("The compression level to use when compression shuffle files.")
+  val COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL: ConfigEntry[Int] =
+    conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.zstd.level")
+      .doc("The compression level to use when compressing shuffle files with 
zstd.")
       .intConf
       .createWithDefault(1)
 
diff --git a/docs/source/user-guide/configs.md 
b/docs/source/user-guide/configs.md
index 20923b93a..d78e6111d 100644
--- a/docs/source/user-guide/configs.md
+++ b/docs/source/user-guide/configs.md
@@ -50,8 +50,8 @@ Comet provides the following configuration settings.
 | spark.comet.exec.memoryPool | The type of memory pool to be used for Comet 
native execution. Available memory pool types are 'greedy', 'fair_spill', 
'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 
'fair_spill_global', By default, this config is 'greedy_task_shared'. | 
greedy_task_shared |
 | spark.comet.exec.project.enabled | Whether to enable project by default. | 
true |
 | spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark 
to replace SortMergeJoin with ShuffledHashJoin for improved performance. This 
feature is not stable yet. For more information, refer to the Comet Tuning 
Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false |
-| spark.comet.exec.shuffle.compression.codec | The codec of Comet native 
shuffle used to compress shuffle data. Only zstd is supported. Compression can 
be disabled by setting spark.shuffle.compress=false. | zstd |
-| spark.comet.exec.shuffle.compression.level | The compression level to use 
when compression shuffle files. | 1 |
+| spark.comet.exec.shuffle.compression.codec | The codec of Comet native 
shuffle used to compress shuffle data. lz4, zstd, and snappy are supported. 
Compression can be disabled by setting spark.shuffle.compress=false. | lz4 |
+| spark.comet.exec.shuffle.compression.zstd.level | The compression level to 
use when compressing shuffle files with zstd. | 1 |
 | spark.comet.exec.shuffle.enabled | Whether to enable Comet native shuffle. 
Note that this requires setting 'spark.shuffle.manager' to 
'org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager'. 
'spark.shuffle.manager' must be set before starting the Spark application and 
cannot be changed during the application. | true |
 | spark.comet.exec.sort.enabled | Whether to enable sort by default. | true |
 | spark.comet.exec.sortMergeJoin.enabled | Whether to enable sortMergeJoin by 
default. | true |
diff --git a/native/Cargo.lock b/native/Cargo.lock
index bbc0ff97a..1c44e3cc5 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -1,6 +1,6 @@
 # This file is automatically @generated by Cargo.
 # It is not intended for manual editing.
-version = 4
+version = 3
 
 [[package]]
 name = "addr2line"
@@ -346,7 +346,7 @@ checksum = 
"721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -903,6 +903,7 @@ dependencies = [
  "lazy_static",
  "log",
  "log4rs",
+ "lz4_flex",
  "mimalloc",
  "num",
  "once_cell",
@@ -914,6 +915,7 @@ dependencies = [
  "regex",
  "serde",
  "simd-adler32",
+ "snap",
  "tempfile",
  "thiserror",
  "tokio",
@@ -1168,7 +1170,7 @@ version = "44.0.0"
 source = 
"git+https://github.com/apache/datafusion.git?rev=44.0.0-rc2#3cc3fca31e6edc2d953e663bfd7f856bcb70d8c4";
 dependencies = [
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -1333,7 +1335,7 @@ checksum = 
"97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -1473,7 +1475,7 @@ checksum = 
"162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -1746,7 +1748,7 @@ checksum = 
"1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -2556,7 +2558,7 @@ dependencies = [
  "itertools 0.12.1",
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -2778,7 +2780,7 @@ checksum = 
"5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -2868,7 +2870,7 @@ dependencies = [
  "heck 0.5.0",
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -2895,7 +2897,7 @@ checksum = 
"da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -2932,7 +2934,7 @@ dependencies = [
  "proc-macro2",
  "quote",
  "rustversion",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -2977,9 +2979,9 @@ dependencies = [
 
 [[package]]
 name = "syn"
-version = "2.0.92"
+version = "2.0.93"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "70ae51629bf965c5c098cc9e87908a3df5301051a9e087d6f9bef5c9771ed126"
+checksum = "9c786062daee0d6db1132800e623df74274a0a87322d8e183338e01b3d98d058"
 dependencies = [
  "proc-macro2",
  "quote",
@@ -2994,7 +2996,7 @@ checksum = 
"c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -3027,7 +3029,7 @@ checksum = 
"4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -3100,7 +3102,7 @@ checksum = 
"693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -3122,7 +3124,7 @@ checksum = 
"395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -3276,7 +3278,7 @@ dependencies = [
  "log",
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
  "wasm-bindgen-shared",
 ]
 
@@ -3298,7 +3300,7 @@ checksum = 
"30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
  "wasm-bindgen-backend",
  "wasm-bindgen-shared",
 ]
@@ -3561,7 +3563,7 @@ checksum = 
"2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
  "synstructure",
 ]
 
@@ -3583,7 +3585,7 @@ checksum = 
"fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
@@ -3603,7 +3605,7 @@ checksum = 
"595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
  "synstructure",
 ]
 
@@ -3626,7 +3628,7 @@ checksum = 
"6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6"
 dependencies = [
  "proc-macro2",
  "quote",
- "syn 2.0.92",
+ "syn 2.0.93",
 ]
 
 [[package]]
diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml
index 5089e67a0..8937236dd 100644
--- a/native/core/Cargo.toml
+++ b/native/core/Cargo.toml
@@ -52,6 +52,9 @@ serde = { version = "1", features = ["derive"] }
 lazy_static = "1.4.0"
 prost = "0.12.1"
 jni = "0.21"
+snap = "1.1"
+# we disable default features in lz4_flex to force the use of the faster 
unsafe encoding and decoding implementation
+lz4_flex = { version = "0.11.3", default-features = false }
 zstd = "0.11"
 rand = { workspace = true}
 num = { workspace = true }
diff --git a/native/core/benches/row_columnar.rs 
b/native/core/benches/row_columnar.rs
index 60b41330e..a62574111 100644
--- a/native/core/benches/row_columnar.rs
+++ b/native/core/benches/row_columnar.rs
@@ -19,6 +19,7 @@ use arrow::datatypes::DataType as ArrowDataType;
 use comet::execution::shuffle::row::{
     process_sorted_row_partition, SparkUnsafeObject, SparkUnsafeRow,
 };
+use comet::execution::shuffle::CompressionCodec;
 use criterion::{criterion_group, criterion_main, Criterion};
 use tempfile::Builder;
 
@@ -77,6 +78,7 @@ fn benchmark(c: &mut Criterion) {
                 false,
                 0,
                 None,
+                &CompressionCodec::Zstd(1),
             )
             .unwrap();
         });
diff --git a/native/core/benches/shuffle_writer.rs 
b/native/core/benches/shuffle_writer.rs
index 865ca73b4..0d22c62cc 100644
--- a/native/core/benches/shuffle_writer.rs
+++ b/native/core/benches/shuffle_writer.rs
@@ -35,23 +35,52 @@ fn criterion_benchmark(c: &mut Criterion) {
     group.bench_function("shuffle_writer: encode (no compression))", |b| {
         let batch = create_batch(8192, true);
         let mut buffer = vec![];
-        let mut cursor = Cursor::new(&mut buffer);
         let ipc_time = Time::default();
-        b.iter(|| write_ipc_compressed(&batch, &mut cursor, 
&CompressionCodec::None, &ipc_time));
+        b.iter(|| {
+            buffer.clear();
+            let mut cursor = Cursor::new(&mut buffer);
+            write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::None, 
&ipc_time)
+        });
+    });
+    group.bench_function("shuffle_writer: encode and compress (snappy)", |b| {
+        let batch = create_batch(8192, true);
+        let mut buffer = vec![];
+        let ipc_time = Time::default();
+        b.iter(|| {
+            buffer.clear();
+            let mut cursor = Cursor::new(&mut buffer);
+            write_ipc_compressed(&batch, &mut cursor, 
&CompressionCodec::Snappy, &ipc_time)
+        });
+    });
+    group.bench_function("shuffle_writer: encode and compress (lz4)", |b| {
+        let batch = create_batch(8192, true);
+        let mut buffer = vec![];
+        let ipc_time = Time::default();
+        b.iter(|| {
+            buffer.clear();
+            let mut cursor = Cursor::new(&mut buffer);
+            write_ipc_compressed(&batch, &mut cursor, 
&CompressionCodec::Lz4Frame, &ipc_time)
+        });
     });
     group.bench_function("shuffle_writer: encode and compress (zstd level 1)", 
|b| {
         let batch = create_batch(8192, true);
         let mut buffer = vec![];
-        let mut cursor = Cursor::new(&mut buffer);
         let ipc_time = Time::default();
-        b.iter(|| write_ipc_compressed(&batch, &mut cursor, 
&CompressionCodec::Zstd(1), &ipc_time));
+        b.iter(|| {
+            buffer.clear();
+            let mut cursor = Cursor::new(&mut buffer);
+            write_ipc_compressed(&batch, &mut cursor, 
&CompressionCodec::Zstd(1), &ipc_time)
+        });
     });
     group.bench_function("shuffle_writer: encode and compress (zstd level 6)", 
|b| {
         let batch = create_batch(8192, true);
         let mut buffer = vec![];
-        let mut cursor = Cursor::new(&mut buffer);
         let ipc_time = Time::default();
-        b.iter(|| write_ipc_compressed(&batch, &mut cursor, 
&CompressionCodec::Zstd(6), &ipc_time));
+        b.iter(|| {
+            buffer.clear();
+            let mut cursor = Cursor::new(&mut buffer);
+            write_ipc_compressed(&batch, &mut cursor, 
&CompressionCodec::Zstd(6), &ipc_time)
+        });
     });
     group.bench_function("shuffle_writer: end to end", |b| {
         let ctx = SessionContext::new();
diff --git a/native/core/src/execution/jni_api.rs 
b/native/core/src/execution/jni_api.rs
index 7d8d577fe..aaac7ec8c 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -17,6 +17,7 @@
 
 //! Define JNI APIs which can be called from Java/Scala.
 
+use super::{serde, utils::SparkArrowConvert, CometMemoryPool};
 use arrow::datatypes::DataType as ArrowDataType;
 use arrow_array::RecordBatch;
 use datafusion::{
@@ -40,8 +41,6 @@ use jni::{
 use std::time::{Duration, Instant};
 use std::{collections::HashMap, sync::Arc, task::Poll};
 
-use super::{serde, utils::SparkArrowConvert, CometMemoryPool};
-
 use crate::{
     errors::{try_unwrap_or_throw, CometError, CometResult},
     execution::{
@@ -54,6 +53,7 @@ use datafusion_comet_proto::spark_operator::Operator;
 use datafusion_common::ScalarValue;
 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
 use futures::stream::StreamExt;
+use jni::objects::JByteBuffer;
 use jni::sys::JNI_FALSE;
 use jni::{
     objects::GlobalRef,
@@ -64,6 +64,7 @@ use std::sync::Mutex;
 use tokio::runtime::Runtime;
 
 use crate::execution::operators::ScanExec;
+use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec};
 use crate::execution::spark_plan::SparkPlan;
 use log::info;
 use once_cell::sync::{Lazy, OnceCell};
@@ -147,7 +148,7 @@ impl PerTaskMemoryPool {
 
 /// Accept serialized query plan and return the address of the native query 
plan.
 /// # Safety
-/// This function is inheritly unsafe since it deals with raw pointers passed 
from JNI.
+/// This function is inherently unsafe since it deals with raw pointers passed 
from JNI.
 #[no_mangle]
 pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
     e: JNIEnv,
@@ -444,7 +445,7 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) 
-> Result<(), CometEr
 /// Accept serialized query plan and the addresses of Arrow Arrays from Spark,
 /// then execute the query. Return addresses of arrow vector.
 /// # Safety
-/// This function is inheritly unsafe since it deals with raw pointers passed 
from JNI.
+/// This function is inherently unsafe since it deals with raw pointers passed 
from JNI.
 #[no_mangle]
 pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
     e: JNIEnv,
@@ -618,7 +619,7 @@ fn get_execution_context<'a>(id: i64) -> &'a mut 
ExecutionContext {
 
 /// Used by Comet shuffle external sorter to write sorted records to disk.
 /// # Safety
-/// This function is inheritly unsafe since it deals with raw pointers passed 
from JNI.
+/// This function is inherently unsafe since it deals with raw pointers passed 
from JNI.
 #[no_mangle]
 pub unsafe extern "system" fn 
Java_org_apache_comet_Native_writeSortedFileNative(
     e: JNIEnv,
@@ -632,6 +633,8 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_Native_writeSortedFileNative
     checksum_enabled: jboolean,
     checksum_algo: jint,
     current_checksum: jlong,
+    compression_codec: jstring,
+    compression_level: jint,
 ) -> jlongArray {
     try_unwrap_or_throw(&e, |mut env| unsafe {
         let data_types = convert_datatype_arrays(&mut env, 
serialized_datatypes)?;
@@ -659,6 +662,18 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_Native_writeSortedFileNative
             Some(current_checksum as u32)
         };
 
+        let compression_codec: String = env
+            .get_string(&JString::from_raw(compression_codec))
+            .unwrap()
+            .into();
+
+        let compression_codec = match compression_codec.as_str() {
+            "zstd" => CompressionCodec::Zstd(compression_level),
+            "lz4" => CompressionCodec::Lz4Frame,
+            "snappy" => CompressionCodec::Snappy,
+            _ => CompressionCodec::Lz4Frame,
+        };
+
         let (written_bytes, checksum) = process_sorted_row_partition(
             row_num,
             batch_size as usize,
@@ -670,6 +685,7 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_Native_writeSortedFileNative
             checksum_enabled,
             checksum_algo,
             current_checksum,
+            &compression_codec,
         )?;
 
         let checksum = if let Some(checksum) = checksum {
@@ -703,3 +719,24 @@ pub extern "system" fn 
Java_org_apache_comet_Native_sortRowPartitionsNative(
         Ok(())
     })
 }
+
+#[no_mangle]
+/// Used by Comet native shuffle reader
+/// # Safety
+/// This function is inherently unsafe since it deals with raw pointers passed 
from JNI.
+pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
+    e: JNIEnv,
+    _class: JClass,
+    byte_buffer: JByteBuffer,
+    length: jint,
+    array_addrs: jlongArray,
+    schema_addrs: jlongArray,
+) -> jlong {
+    try_unwrap_or_throw(&e, |mut env| {
+        let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?;
+        let length = length as usize;
+        let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, 
length) };
+        let batch = read_ipc_compressed(slice)?;
+        prepare_output(&mut env, array_addrs, schema_addrs, batch, false)
+    })
+}
diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index da452c2f1..294922f2f 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -1067,9 +1067,11 @@ impl PhysicalPlanner {
 
                 let codec = match writer.codec.try_into() {
                     Ok(SparkCompressionCodec::None) => 
Ok(CompressionCodec::None),
+                    Ok(SparkCompressionCodec::Snappy) => 
Ok(CompressionCodec::Snappy),
                     Ok(SparkCompressionCodec::Zstd) => {
                         Ok(CompressionCodec::Zstd(writer.compression_level))
                     }
+                    Ok(SparkCompressionCodec::Lz4) => 
Ok(CompressionCodec::Lz4Frame),
                     _ => Err(ExecutionError::GeneralError(format!(
                         "Unsupported shuffle compression codec: {:?}",
                         writer.codec
diff --git a/native/core/src/execution/shuffle/mod.rs 
b/native/core/src/execution/shuffle/mod.rs
index 8111f5eed..178aff1fa 100644
--- a/native/core/src/execution/shuffle/mod.rs
+++ b/native/core/src/execution/shuffle/mod.rs
@@ -19,4 +19,6 @@ mod list;
 mod map;
 pub mod row;
 mod shuffle_writer;
-pub use shuffle_writer::{write_ipc_compressed, CompressionCodec, 
ShuffleWriterExec};
+pub use shuffle_writer::{
+    read_ipc_compressed, write_ipc_compressed, CompressionCodec, 
ShuffleWriterExec,
+};
diff --git a/native/core/src/execution/shuffle/row.rs 
b/native/core/src/execution/shuffle/row.rs
index 405f64216..9037bd794 100644
--- a/native/core/src/execution/shuffle/row.rs
+++ b/native/core/src/execution/shuffle/row.rs
@@ -3297,6 +3297,7 @@ pub fn process_sorted_row_partition(
     // this is the initial checksum for this method, as it also gets updated 
iteratively
     // inside the loop within the method across batches.
     initial_checksum: Option<u32>,
+    codec: &CompressionCodec,
 ) -> Result<(i64, Option<u32>), CometError> {
     // TODO: We can tune this parameter automatically based on row size and 
cache size.
     let row_step = 10;
@@ -3359,9 +3360,7 @@ pub fn process_sorted_row_partition(
 
         // we do not collect metrics in Native_writeSortedFileNative
         let ipc_time = Time::default();
-        // compression codec is not configurable for 
CometBypassMergeSortShuffleWriter
-        let codec = CompressionCodec::Zstd(1);
-        written += write_ipc_compressed(&batch, &mut cursor, &codec, 
&ipc_time)?;
+        written += write_ipc_compressed(&batch, &mut cursor, codec, 
&ipc_time)?;
 
         if let Some(checksum) = &mut current_checksum {
             checksum.update(&mut cursor)?;
diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs 
b/native/core/src/execution/shuffle/shuffle_writer.rs
index f3fa685b8..e6679d13d 100644
--- a/native/core/src/execution/shuffle/shuffle_writer.rs
+++ b/native/core/src/execution/shuffle/shuffle_writer.rs
@@ -21,6 +21,7 @@ use crate::{
     common::bit::ceil,
     errors::{CometError, CometResult},
 };
+use arrow::ipc::reader::StreamReader;
 use arrow::{datatypes::*, ipc::writer::StreamWriter};
 use async_trait::async_trait;
 use bytes::Buf;
@@ -312,7 +313,7 @@ impl PartitionBuffer {
             repart_timer.stop();
 
             if self.num_active_rows >= self.batch_size {
-                let flush = self.flush(&metrics.ipc_time);
+                let flush = self.flush(metrics);
                 if let Err(e) = flush {
                     return AppendRowStatus::MemDiff(Err(e));
                 }
@@ -330,7 +331,7 @@ impl PartitionBuffer {
     }
 
     /// flush active data into frozen bytes
-    fn flush(&mut self, ipc_time: &Time) -> Result<isize> {
+    fn flush(&mut self, metrics: &ShuffleRepartitionerMetrics) -> 
Result<isize> {
         if self.num_active_rows == 0 {
             return Ok(0);
         }
@@ -340,14 +341,24 @@ impl PartitionBuffer {
         let active = std::mem::take(&mut self.active);
         let num_rows = self.num_active_rows;
         self.num_active_rows = 0;
+
+        let mut mempool_timer = metrics.mempool_time.timer();
         self.reservation.try_shrink(self.active_slots_mem_size)?;
+        mempool_timer.stop();
 
+        let mut repart_timer = metrics.repart_time.timer();
         let frozen_batch = make_batch(Arc::clone(&self.schema), active, 
num_rows)?;
+        repart_timer.stop();
 
         let frozen_capacity_old = self.frozen.capacity();
         let mut cursor = Cursor::new(&mut self.frozen);
         cursor.seek(SeekFrom::End(0))?;
-        write_ipc_compressed(&frozen_batch, &mut cursor, &self.codec, 
ipc_time)?;
+        write_ipc_compressed(
+            &frozen_batch,
+            &mut cursor,
+            &self.codec,
+            &metrics.encode_time,
+        )?;
 
         mem_diff += (self.frozen.capacity() - frozen_capacity_old) as isize;
         Ok(mem_diff)
@@ -652,7 +663,7 @@ struct ShuffleRepartitionerMetrics {
     mempool_time: Time,
 
     /// Time encoding batches to IPC format
-    ipc_time: Time,
+    encode_time: Time,
 
     /// Time spent writing to disk. Maps to "shuffleWriteTime" in Spark SQL 
Metrics.
     write_time: Time,
@@ -676,7 +687,7 @@ impl ShuffleRepartitionerMetrics {
             baseline: BaselineMetrics::new(metrics, partition),
             repart_time: 
MetricBuilder::new(metrics).subset_time("repart_time", partition),
             mempool_time: 
MetricBuilder::new(metrics).subset_time("mempool_time", partition),
-            ipc_time: MetricBuilder::new(metrics).subset_time("ipc_time", 
partition),
+            encode_time: 
MetricBuilder::new(metrics).subset_time("encode_time", partition),
             write_time: MetricBuilder::new(metrics).subset_time("write_time", 
partition),
             input_batches: 
MetricBuilder::new(metrics).counter("input_batches", partition),
             spill_count: MetricBuilder::new(metrics).spill_count(partition),
@@ -790,6 +801,8 @@ impl ShuffleRepartitioner {
             Partitioning::Hash(exprs, _) => {
                 let (partition_starts, shuffled_partition_ids): (Vec<usize>, 
Vec<usize>) = {
                     let mut timer = self.metrics.repart_time.timer();
+
+                    // evaluate partition expressions
                     let arrays = exprs
                         .iter()
                         .map(|expr| 
expr.evaluate(&input)?.into_array(input.num_rows()))
@@ -923,7 +936,7 @@ impl ShuffleRepartitioner {
         let mut output_batches: Vec<Vec<u8>> = vec![vec![]; 
num_output_partitions];
         let mut offsets = vec![0; num_output_partitions + 1];
         for i in 0..num_output_partitions {
-            buffered_partitions[i].flush(&self.metrics.ipc_time)?;
+            buffered_partitions[i].flush(&self.metrics)?;
             output_batches[i] = std::mem::take(&mut 
buffered_partitions[i].frozen);
         }
 
@@ -1023,20 +1036,19 @@ impl ShuffleRepartitioner {
         }
 
         let mut timer = self.metrics.write_time.timer();
-
         let spillfile = self
             .runtime
             .disk_manager
             .create_tmp_file("shuffle writer spill")?;
+        timer.stop();
+
         let offsets = spill_into(
             &mut self.buffered_partitions,
             spillfile.path(),
             self.num_output_partitions,
-            &self.metrics.ipc_time,
+            &self.metrics,
         )?;
 
-        timer.stop();
-
         let mut spills = self.spills.lock().await;
         let used = self.reservation.size();
         self.metrics.spill_count.add(1);
@@ -1107,16 +1119,18 @@ fn spill_into(
     buffered_partitions: &mut [PartitionBuffer],
     path: &Path,
     num_output_partitions: usize,
-    ipc_time: &Time,
+    metrics: &ShuffleRepartitionerMetrics,
 ) -> Result<Vec<u64>> {
     let mut output_batches: Vec<Vec<u8>> = vec![vec![]; num_output_partitions];
 
     for i in 0..num_output_partitions {
-        buffered_partitions[i].flush(ipc_time)?;
+        buffered_partitions[i].flush(metrics)?;
         output_batches[i] = std::mem::take(&mut buffered_partitions[i].frozen);
     }
     let path = path.to_owned();
 
+    let mut write_timer = metrics.write_time.timer();
+
     let mut offsets = vec![0; num_output_partitions + 1];
     let mut spill_data = OpenOptions::new()
         .write(true)
@@ -1130,6 +1144,8 @@ fn spill_into(
         spill_data.write_all(&output_batches[i])?;
         output_batches[i].clear();
     }
+    write_timer.stop();
+
     // add one extra offset at last to ease partition length computation
     offsets[num_output_partitions] = spill_data.stream_position()?;
     Ok(offsets)
@@ -1549,7 +1565,9 @@ impl Checksum {
 #[derive(Debug, Clone)]
 pub enum CompressionCodec {
     None,
+    Lz4Frame,
     Zstd(i32),
+    Snappy,
 }
 
 /// Writes given record batch as Arrow IPC bytes into given writer.
@@ -1567,17 +1585,41 @@ pub fn write_ipc_compressed<W: Write + Seek>(
     let mut timer = ipc_time.timer();
     let start_pos = output.stream_position()?;
 
-    // write ipc_length placeholder
-    output.write_all(&[0u8; 8])?;
+    // seek past ipc_length placeholder
+    output.seek_relative(8)?;
+
+    // write number of columns because JVM side needs to know how many 
addresses to allocate
+    let field_count = batch.schema().fields().len();
+    output.write_all(&field_count.to_le_bytes())?;
 
     let output = match codec {
         CompressionCodec::None => {
+            output.write_all(b"NONE")?;
             let mut arrow_writer = StreamWriter::try_new(output, 
&batch.schema())?;
             arrow_writer.write(batch)?;
             arrow_writer.finish()?;
             arrow_writer.into_inner()?
         }
+        CompressionCodec::Snappy => {
+            output.write_all(b"SNAP")?;
+            let mut wtr = snap::write::FrameEncoder::new(output);
+            let mut arrow_writer = StreamWriter::try_new(&mut wtr, 
&batch.schema())?;
+            arrow_writer.write(batch)?;
+            arrow_writer.finish()?;
+            wtr.into_inner()
+                .map_err(|e| DataFusionError::Execution(format!("lz4 
compression error: {}", e)))?
+        }
+        CompressionCodec::Lz4Frame => {
+            output.write_all(b"LZ4_")?;
+            let mut wtr = lz4_flex::frame::FrameEncoder::new(output);
+            let mut arrow_writer = StreamWriter::try_new(&mut wtr, 
&batch.schema())?;
+            arrow_writer.write(batch)?;
+            arrow_writer.finish()?;
+            wtr.finish()
+                .map_err(|e| DataFusionError::Execution(format!("lz4 
compression error: {}", e)))?
+        }
         CompressionCodec::Zstd(level) => {
+            output.write_all(b"ZSTD")?;
             let encoder = zstd::Encoder::new(output, *level)?;
             let mut arrow_writer = StreamWriter::try_new(encoder, 
&batch.schema())?;
             arrow_writer.write(batch)?;
@@ -1590,6 +1632,13 @@ pub fn write_ipc_compressed<W: Write + Seek>(
     // fill ipc length
     let end_pos = output.stream_position()?;
     let ipc_length = end_pos - start_pos - 8;
+    let max_size = i32::MAX as u64;
+    if ipc_length > max_size {
+        return Err(DataFusionError::Execution(format!(
+            "Shuffle block size {ipc_length} exceeds maximum size of 
{max_size}. \
+            Try reducing batch size or increasing compression level"
+        )));
+    }
 
     // fill ipc length
     output.seek(SeekFrom::Start(start_pos))?;
@@ -1601,6 +1650,33 @@ pub fn write_ipc_compressed<W: Write + Seek>(
     Ok((end_pos - start_pos) as usize)
 }
 
+pub fn read_ipc_compressed(bytes: &[u8]) -> Result<RecordBatch> {
+    match &bytes[0..4] {
+        b"SNAP" => {
+            let decoder = snap::read::FrameDecoder::new(&bytes[4..]);
+            let mut reader = StreamReader::try_new(decoder, None)?;
+            reader.next().unwrap().map_err(|e| e.into())
+        }
+        b"LZ4_" => {
+            let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]);
+            let mut reader = StreamReader::try_new(decoder, None)?;
+            reader.next().unwrap().map_err(|e| e.into())
+        }
+        b"ZSTD" => {
+            let decoder = zstd::Decoder::new(&bytes[4..])?;
+            let mut reader = StreamReader::try_new(decoder, None)?;
+            reader.next().unwrap().map_err(|e| e.into())
+        }
+        b"NONE" => {
+            let mut reader = StreamReader::try_new(&bytes[4..], None)?;
+            reader.next().unwrap().map_err(|e| e.into())
+        }
+        _ => Err(DataFusionError::Execution(
+            "Failed to decode batch: invalid compression codec".to_string(),
+        )),
+    }
+}
+
 /// A stream that yields no record batches which represent end of output.
 pub struct EmptyStream {
     /// Schema representing the data
@@ -1650,18 +1726,24 @@ mod test {
 
     #[test]
     #[cfg_attr(miri, ignore)] // miri can't call foreign function 
`ZSTD_createCCtx`
-    fn write_ipc_zstd() {
+    fn roundtrip_ipc() {
         let batch = create_batch(8192);
-        let mut output = vec![];
-        let mut cursor = Cursor::new(&mut output);
-        write_ipc_compressed(
-            &batch,
-            &mut cursor,
-            &CompressionCodec::Zstd(1),
-            &Time::default(),
-        )
-        .unwrap();
-        assert_eq!(40218, output.len());
+        for codec in &[
+            CompressionCodec::None,
+            CompressionCodec::Zstd(1),
+            CompressionCodec::Snappy,
+            CompressionCodec::Lz4Frame,
+        ] {
+            let mut output = vec![];
+            let mut cursor = Cursor::new(&mut output);
+            let length =
+                write_ipc_compressed(&batch, &mut cursor, codec, 
&Time::default()).unwrap();
+            assert_eq!(length, output.len());
+
+            let ipc_without_length_prefix = &output[16..];
+            let batch2 = 
read_ipc_compressed(ipc_without_length_prefix).unwrap();
+            assert_eq!(batch, batch2);
+        }
     }
 
     #[test]
diff --git a/native/proto/src/proto/operator.proto 
b/native/proto/src/proto/operator.proto
index 5cb2802da..a3480086c 100644
--- a/native/proto/src/proto/operator.proto
+++ b/native/proto/src/proto/operator.proto
@@ -85,6 +85,8 @@ message Limit {
 enum CompressionCodec {
   None = 0;
   Zstd = 1;
+  Lz4 = 2;
+  Snappy = 3;
 }
 
 message ShuffleWriter {
diff --git 
a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
 
b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
index cc4495570..1e3762a6c 100644
--- 
a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
+++ 
b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
@@ -107,6 +107,8 @@ public final class CometShuffleExternalSorter implements 
CometShuffleChecksumSup
   private final long[] partitionChecksums;
 
   private final String checksumAlgorithm;
+  private final String compressionCodec;
+  private final int compressionLevel;
 
   // The memory allocator for this sorter. It is used to allocate/free memory 
pages for this sorter.
   // Because we need to allocate off-heap memory regardless of configured 
Spark memory mode
@@ -153,6 +155,9 @@ public final class CometShuffleExternalSorter implements 
CometShuffleChecksumSup
     this.peakMemoryUsedBytes = getMemoryUsage();
     this.partitionChecksums = createPartitionChecksums(numPartitions, conf);
     this.checksumAlgorithm = getChecksumAlgorithm(conf);
+    this.compressionCodec = 
CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC().get();
+    this.compressionLevel =
+        (int) 
CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL().get();
 
     this.initialSize = initialSize;
 
@@ -556,7 +561,9 @@ public final class CometShuffleExternalSorter implements 
CometShuffleChecksumSup
                     spillInfo.file,
                     rowPartition,
                     writeMetricsToUse,
-                    preferDictionaryRatio);
+                    preferDictionaryRatio,
+                    compressionCodec,
+                    compressionLevel);
             spillInfo.partitionLengths[currentPartition] = written;
 
             // Store the checksum for the current partition.
@@ -578,7 +585,13 @@ public final class CometShuffleExternalSorter implements 
CometShuffleChecksumSup
       if (currentPartition != -1) {
         long written =
             doSpilling(
-                dataTypes, spillInfo.file, rowPartition, writeMetricsToUse, 
preferDictionaryRatio);
+                dataTypes,
+                spillInfo.file,
+                rowPartition,
+                writeMetricsToUse,
+                preferDictionaryRatio,
+                compressionCodec,
+                compressionLevel);
         spillInfo.partitionLengths[currentPartition] = written;
 
         synchronized (spills) {
diff --git 
a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java
 
b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java
index dcb9d99d3..006e8ce97 100644
--- 
a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java
+++ 
b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java
@@ -103,6 +103,8 @@ public final class CometDiskBlockWriter {
   private long totalWritten = 0L;
   private boolean initialized = false;
   private final int columnarBatchSize;
+  private final String compressionCodec;
+  private final int compressionLevel;
   private final boolean isAsync;
   private final int asyncThreadNum;
   private final ExecutorService threadPool;
@@ -153,6 +155,9 @@ public final class CometDiskBlockWriter {
     this.threadPool = threadPool;
 
     this.columnarBatchSize = (int) 
CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_BATCH_SIZE().get();
+    this.compressionCodec = 
CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC().get();
+    this.compressionLevel =
+        (int) 
CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL().get();
 
     this.numElementsForSpillThreshold =
         (int) 
CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_SPILL_THRESHOLD().get();
@@ -397,7 +402,14 @@ public final class CometDiskBlockWriter {
       synchronized (file) {
         outputRecords += rowPartition.getNumRows();
         written =
-            doSpilling(dataTypes, file, rowPartition, writeMetricsToUse, 
preferDictionaryRatio);
+            doSpilling(
+                dataTypes,
+                file,
+                rowPartition,
+                writeMetricsToUse,
+                preferDictionaryRatio,
+                compressionCodec,
+                compressionLevel);
       }
 
       // Update metrics
diff --git 
a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java
 
b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java
index 3dc86b05b..a4f09b415 100644
--- 
a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java
+++ 
b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java
@@ -171,7 +171,9 @@ public abstract class SpillWriter {
       File file,
       RowPartition rowPartition,
       ShuffleWriteMetricsReporter writeMetricsToUse,
-      double preferDictionaryRatio) {
+      double preferDictionaryRatio,
+      String compressionCodec,
+      int compressionLevel) {
     long[] addresses = rowPartition.getRowAddresses();
     int[] sizes = rowPartition.getRowSizes();
 
@@ -190,7 +192,9 @@ public abstract class SpillWriter {
             batchSize,
             checksumEnabled,
             checksumAlgo,
-            currentChecksum);
+            currentChecksum,
+            compressionCodec,
+            compressionLevel);
 
     long written = results[0];
     checksum = results[1];
diff --git a/spark/src/main/scala/org/apache/comet/Native.scala 
b/spark/src/main/scala/org/apache/comet/Native.scala
index e5728009e..dbcab15b4 100644
--- a/spark/src/main/scala/org/apache/comet/Native.scala
+++ b/spark/src/main/scala/org/apache/comet/Native.scala
@@ -19,6 +19,8 @@
 
 package org.apache.comet
 
+import java.nio.ByteBuffer
+
 import org.apache.spark.CometTaskMemoryManager
 import org.apache.spark.sql.comet.CometMetricNode
 
@@ -118,9 +120,14 @@ class Native extends NativeBase {
    * @param currentChecksum
    *   the current checksum of the file. As the checksum is computed 
incrementally, this is used
    *   to resume the computation of checksum for previous written data.
+   * @param compressionCodec
+   *   the compression codec
+   * @param compressionLevel
+   *   the compression level
    * @return
    *   [the number of bytes written to disk, the checksum]
    */
+  // scalastyle:off
   @native def writeSortedFileNative(
       addresses: Array[Long],
       rowSizes: Array[Int],
@@ -130,7 +137,10 @@ class Native extends NativeBase {
       batchSize: Int,
       checksumEnabled: Boolean,
       checksumAlgo: Int,
-      currentChecksum: Long): Array[Long]
+      currentChecksum: Long,
+      compressionCodec: String,
+      compressionLevel: Int): Array[Long]
+  // scalastyle:on
 
   /**
    * Sorts partition ids of Spark unsafe rows in place. Used by Comet shuffle 
external sorter.
@@ -141,4 +151,22 @@ class Native extends NativeBase {
    *   the size of the array.
    */
   @native def sortRowPartitionsNative(addr: Long, size: Long): Unit
+
+  /**
+   * Decompress and decode a native shuffle block.
+   * @param shuffleBlock
+   *   the encoded anc compressed shuffle block.
+   * @param length
+   *   the limit of the byte buffer.
+   * @param addr
+   *   the address of the array of compressed and encoded bytes.
+   * @param size
+   *   the size of the array.
+   */
+  @native def decodeShuffleBlock(
+      shuffleBlock: ByteBuffer,
+      length: Int,
+      arrayAddrs: Array[Long],
+      schemaAddrs: Array[Long]): Long
+
 }
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala
index a26fa28c8..53370a03b 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala
@@ -132,10 +132,11 @@ object CometMetricNode {
 
   def shuffleMetrics(sc: SparkContext): Map[String, SQLMetric] = {
     Map(
-      "elapsed_compute" -> SQLMetrics.createNanoTimingMetric(sc, "native 
shuffle time"),
+      "elapsed_compute" -> SQLMetrics.createNanoTimingMetric(sc, "native 
shuffle writer time"),
       "mempool_time" -> SQLMetrics.createNanoTimingMetric(sc, "memory pool 
time"),
       "repart_time" -> SQLMetrics.createNanoTimingMetric(sc, "repartition 
time"),
-      "ipc_time" -> SQLMetrics.createNanoTimingMetric(sc, "encoding and 
compression time"),
+      "encode_time" -> SQLMetrics.createNanoTimingMetric(sc, "encoding and 
compression time"),
+      "decode_time" -> SQLMetrics.createNanoTimingMetric(sc, "decoding and 
decompression time"),
       "spill_count" -> SQLMetrics.createMetric(sc, "number of spills"),
       "spilled_bytes" -> SQLMetrics.createMetric(sc, "spilled bytes"),
       "input_batches" -> SQLMetrics.createMetric(sc, "number of input 
batches"))
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
index 74c655950..1283a745a 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
@@ -25,8 +25,14 @@ import org.apache.spark.{InterruptibleIterator, 
MapOutputTracker, SparkEnv, Task
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.serializer.SerializerManager
-import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader, 
ShuffleReadMetricsReporter}
-import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, 
ShuffleBlockFetcherIterator}
+import org.apache.spark.shuffle.BaseShuffleHandle
+import org.apache.spark.shuffle.ShuffleReader
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.ShuffleBlockFetcherIterator
 import org.apache.spark.util.CompletionIterator
 
 /**
@@ -79,7 +85,7 @@ class CometBlockStoreShuffleReader[K, C](
 
   /** Read the combined key-values for this reduce task */
   override def read(): Iterator[Product2[K, C]] = {
-    var currentReadIterator: ArrowReaderIterator = null
+    var currentReadIterator: NativeBatchDecoderIterator = null
 
     // Closes last read iterator after the task is finished.
     // We need to close read iterator during iterating input streams,
@@ -91,18 +97,16 @@ class CometBlockStoreShuffleReader[K, C](
       }
     }
 
-    val recordIter = fetchIterator
-      .flatMap { case (_, inputStream) =>
-        IpcInputStreamIterator(inputStream, decompressingNeeded = true, 
context)
-          .flatMap { channel =>
-            if (currentReadIterator != null) {
-              // Closes previous read iterator.
-              currentReadIterator.close()
-            }
-            currentReadIterator = new ArrowReaderIterator(channel, 
this.getClass.getSimpleName)
-            currentReadIterator.map((0, _)) // use 0 as key since it's not used
-          }
-      }
+    val recordIter: Iterator[(Int, ColumnarBatch)] = fetchIterator
+      .flatMap(blockIdAndStream => {
+        if (currentReadIterator != null) {
+          currentReadIterator.close()
+        }
+        currentReadIterator =
+          NativeBatchDecoderIterator(blockIdAndStream._2, context, 
dep.decodeTime)
+        currentReadIterator
+      })
+      .map(b => (0, b))
 
     // Update the context task metrics for each record read.
     val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
index 7b1d1f127..8c8aed28e 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
@@ -25,6 +25,7 @@ import org.apache.spark.{Aggregator, Partitioner, 
ShuffleDependency, SparkEnv}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.ShuffleWriteProcessor
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -39,7 +40,8 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
     override val mapSideCombine: Boolean = false,
     override val shuffleWriterProcessor: ShuffleWriteProcessor = new 
ShuffleWriteProcessor,
     val shuffleType: ShuffleType = CometNativeShuffle,
-    val schema: Option[StructType] = None)
+    val schema: Option[StructType] = None,
+    val decodeTime: SQLMetric)
     extends ShuffleDependency[K, V, C](
       _rdd,
       partitioner,
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 3a11b8b28..041411b3f 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -238,7 +238,8 @@ object CometShuffleExchangeExec extends 
ShimCometShuffleExchangeExec {
       partitioner = new Partitioner {
         override def numPartitions: Int = outputPartitioning.numPartitions
         override def getPartition(key: Any): Int = key.asInstanceOf[Int]
-      })
+      },
+      decodeTime = metrics("decode_time"))
     dependency
   }
 
@@ -435,7 +436,8 @@ object CometShuffleExchangeExec extends 
ShimCometShuffleExchangeExec {
         serializer,
         shuffleWriterProcessor = 
ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
         shuffleType = CometColumnarShuffle,
-        schema = Some(fromAttributes(outputAttributes)))
+        schema = Some(fromAttributes(outputAttributes)),
+        decodeTime = writeMetrics("decode_time"))
 
     dependency
   }
@@ -481,7 +483,7 @@ class CometShuffleWriteProcessor(
 
     val detailedMetrics = Seq(
       "elapsed_compute",
-      "ipc_time",
+      "encode_time",
       "repart_time",
       "mempool_time",
       "input_batches",
@@ -557,13 +559,16 @@ class CometShuffleWriteProcessor(
       if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) {
         val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match 
{
           case "zstd" => CompressionCodec.Zstd
+          case "lz4" => CompressionCodec.Lz4
+          case "snappy" => CompressionCodec.Snappy
           case other => throw new UnsupportedOperationException(s"invalid 
codec: $other")
         }
         shuffleWriterBuilder.setCodec(codec)
       } else {
         shuffleWriterBuilder.setCodec(CompressionCodec.None)
       }
-      
shuffleWriterBuilder.setCompressionLevel(CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL.get)
+      shuffleWriterBuilder.setCompressionLevel(
+        CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get)
 
       outputPartitioning match {
         case _: HashPartitioning =>
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala
deleted file mode 100644
index aa4055048..000000000
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.comet.execution.shuffle
-
-import java.io.{EOFException, InputStream}
-import java.nio.{ByteBuffer, ByteOrder}
-import java.nio.channels.{Channels, ReadableByteChannel}
-
-import org.apache.spark.TaskContext
-import org.apache.spark.internal.Logging
-import org.apache.spark.network.util.LimitedInputStream
-
-case class IpcInputStreamIterator(
-    var in: InputStream,
-    decompressingNeeded: Boolean,
-    taskContext: TaskContext)
-    extends Iterator[ReadableByteChannel]
-    with Logging {
-
-  private[execution] val channel: ReadableByteChannel = if (in != null) {
-    Channels.newChannel(in)
-  } else {
-    null
-  }
-
-  private val ipcLengthsBuf = 
ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN)
-
-  // NOTE:
-  // since all ipcs are sharing the same input stream and channel, the second
-  // hasNext() must be called after the first ipc has been completely 
processed.
-
-  private[execution] var consumed = true
-  private var finished = false
-  private var currentIpcLength = 0L
-  private var currentLimitedInputStream: LimitedInputStream = _
-
-  taskContext.addTaskCompletionListener[Unit](_ => {
-    closeInputStream()
-  })
-
-  override def hasNext: Boolean = {
-    if (in == null || finished) {
-      return false
-    }
-
-    // If we've read the length of the next IPC, we don't need to read it 
again.
-    if (!consumed) {
-      return true
-    }
-
-    if (currentLimitedInputStream != null) {
-      currentLimitedInputStream.skip(Int.MaxValue)
-      currentLimitedInputStream = null
-    }
-
-    // Reads the length of IPC bytes
-    ipcLengthsBuf.clear()
-    while (ipcLengthsBuf.hasRemaining && channel.read(ipcLengthsBuf) >= 0) {}
-
-    // If we reach the end of the stream, we are done, or if we read partial 
length
-    // then the stream is corrupted.
-    if (ipcLengthsBuf.hasRemaining) {
-      if (ipcLengthsBuf.position() == 0) {
-        finished = true
-        closeInputStream()
-        return false
-      }
-      throw new EOFException("Data corrupt: unexpected EOF while reading 
compressed ipc lengths")
-    }
-
-    ipcLengthsBuf.flip()
-    currentIpcLength = ipcLengthsBuf.getLong
-
-    // Skips empty IPC
-    if (currentIpcLength == 0) {
-      return hasNext
-    }
-    consumed = false
-    return true
-  }
-
-  override def next(): ReadableByteChannel = {
-    if (!hasNext) {
-      throw new NoSuchElementException
-    }
-    assert(!consumed)
-    consumed = true
-
-    val is = new LimitedInputStream(Channels.newInputStream(channel), 
currentIpcLength, false)
-    currentLimitedInputStream = is
-
-    if (decompressingNeeded) {
-      ShuffleUtils.compressionCodecForShuffling match {
-        case Some(codec) => 
Channels.newChannel(codec.compressedInputStream(is))
-        case _ => Channels.newChannel(is)
-      }
-    } else {
-      Channels.newChannel(is)
-    }
-  }
-
-  private def closeInputStream(): Unit =
-    synchronized {
-      if (in != null) {
-        in.close()
-        in = null
-      }
-    }
-}
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala
new file mode 100644
index 000000000..2839c9bd8
--- /dev/null
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.shuffle
+
+import java.io.{EOFException, InputStream}
+import java.nio.{ByteBuffer, ByteOrder}
+import java.nio.channels.{Channels, ReadableByteChannel}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import org.apache.comet.Native
+import org.apache.comet.vector.NativeUtil
+
+/**
+ * This iterator wraps a Spark input stream that is reading shuffle blocks 
generated by the Comet
+ * native ShuffleWriterExec and then calls native code to decompress and 
decode the shuffle blocks
+ * and use Arrow FFI to return the Arrow record batch.
+ */
+case class NativeBatchDecoderIterator(
+    var in: InputStream,
+    taskContext: TaskContext,
+    decodeTime: SQLMetric)
+    extends Iterator[ColumnarBatch] {
+
+  private var isClosed = false
+  private val longBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN)
+  private val native = new Native()
+  private val nativeUtil = new NativeUtil()
+  private var currentBatch: ColumnarBatch = null
+  private var batch = fetchNext()
+
+  import NativeBatchDecoderIterator.threadLocalDataBuf
+
+  if (taskContext != null) {
+    taskContext.addTaskCompletionListener[Unit](_ => {
+      close()
+    })
+  }
+
+  private val channel: ReadableByteChannel = if (in != null) {
+    Channels.newChannel(in)
+  } else {
+    null
+  }
+
+  def hasNext(): Boolean = {
+    if (channel == null || isClosed) {
+      return false
+    }
+    if (batch.isDefined) {
+      return true
+    }
+
+    // Release the previous batch.
+    if (currentBatch != null) {
+      currentBatch.close()
+      currentBatch = null
+    }
+
+    batch = fetchNext()
+    if (batch.isEmpty) {
+      close()
+      return false
+    }
+    true
+  }
+
+  def next(): ColumnarBatch = {
+    if (!hasNext) {
+      throw new NoSuchElementException
+    }
+
+    val nextBatch = batch.get
+
+    currentBatch = nextBatch
+    batch = None
+    currentBatch
+  }
+
+  private def fetchNext(): Option[ColumnarBatch] = {
+    if (channel == null || isClosed) {
+      return None
+    }
+
+    // read compressed batch size from header
+    try {
+      longBuf.clear()
+      while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {}
+    } catch {
+      case _: EOFException =>
+        close()
+        return None
+    }
+
+    // If we reach the end of the stream, we are done, or if we read partial 
length
+    // then the stream is corrupted.
+    if (longBuf.hasRemaining) {
+      if (longBuf.position() == 0) {
+        close()
+        return None
+      }
+      throw new EOFException("Data corrupt: unexpected EOF while reading 
compressed ipc lengths")
+    }
+
+    // get compressed length (including headers)
+    longBuf.flip()
+    val compressedLength = longBuf.getLong
+
+    // read field count from header
+    longBuf.clear()
+    while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {}
+    if (longBuf.hasRemaining) {
+      throw new EOFException("Data corrupt: unexpected EOF while reading field 
count")
+    }
+    longBuf.flip()
+    val fieldCount = longBuf.getLong.toInt
+
+    // read body
+    val bytesToRead = compressedLength - 8
+    if (bytesToRead > Integer.MAX_VALUE) {
+      // very unlikely that shuffle block will reach 2GB
+      throw new IllegalStateException(
+        s"Native shuffle block size of $bytesToRead exceeds " +
+          s"maximum of ${Integer.MAX_VALUE}. Try reducing shuffle batch size.")
+    }
+    var dataBuf = threadLocalDataBuf.get()
+    if (dataBuf.capacity() < bytesToRead) {
+      val newCapacity = (bytesToRead * 2L).min(Integer.MAX_VALUE).toInt
+      dataBuf = ByteBuffer.allocateDirect(newCapacity)
+      threadLocalDataBuf.set(dataBuf)
+    }
+    dataBuf.clear()
+    dataBuf.limit(bytesToRead.toInt)
+    while (dataBuf.hasRemaining && channel.read(dataBuf) >= 0) {}
+    if (dataBuf.hasRemaining) {
+      throw new EOFException("Data corrupt: unexpected EOF while reading 
compressed batch")
+    }
+
+    // make native call to decode batch
+    val startTime = System.nanoTime()
+    val batch = nativeUtil.getNextBatch(
+      fieldCount,
+      (arrayAddrs, schemaAddrs) => {
+        native.decodeShuffleBlock(dataBuf, bytesToRead.toInt, arrayAddrs, 
schemaAddrs)
+      })
+    decodeTime.add(System.nanoTime() - startTime)
+
+    batch
+  }
+
+  def close(): Unit = {
+    synchronized {
+      if (!isClosed) {
+        if (currentBatch != null) {
+          currentBatch.close()
+          currentBatch = null
+        }
+        in.close()
+        isClosed = true
+      }
+    }
+  }
+}
+
+object NativeBatchDecoderIterator {
+  private val threadLocalDataBuf: ThreadLocal[ByteBuffer] = 
ThreadLocal.withInitial(() => {
+    ByteBuffer.allocateDirect(128 * 1024)
+  })
+}
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
index 6130e4cd5..13344c0ed 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
@@ -1132,7 +1132,8 @@ class CometShuffleManagerSuite extends CometTestBase {
         partitioner = new Partitioner {
           override def numPartitions: Int = 50
           override def getPartition(key: Any): Int = key.asInstanceOf[Int]
-        })
+        },
+        decodeTime = null)
 
       assert(CometShuffleManager.shouldBypassMergeSort(conf, dependency))
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to