This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new ea6d2051 feat: Make native shuffle compression configurable and 
respect `spark.shuffle.compress` (#1185)
ea6d2051 is described below

commit ea6d20511e813a2698c47a964b3a0739e9543add
Author: Andy Grove <[email protected]>
AuthorDate: Fri Dec 20 11:11:14 2024 -0700

    feat: Make native shuffle compression configurable and respect 
`spark.shuffle.compress` (#1185)
    
    * Make shuffle compression codec and level configurable
    
    * remove lz4 references
    
    * docs
    
    * update comment
    
    * clippy
    
    * fix benches
    
    * clippy
    
    * clippy
    
    * disable test for miri
    
    * remove lz4 reference from proto
---
 .../main/scala/org/apache/comet/CometConf.scala    | 14 +++-
 .../execution/shuffle/IpcInputStreamIterator.scala |  6 +-
 .../sql/comet/execution/shuffle/ShuffleUtils.scala | 31 +++++---
 docs/source/user-guide/configs.md                  |  3 +-
 docs/source/user-guide/tuning.md                   |  6 ++
 native/core/benches/shuffle_writer.rs              | 87 +++++++++++++++------
 native/core/src/execution/planner.rs               | 17 +++-
 native/core/src/execution/shuffle/mod.rs           |  2 +-
 native/core/src/execution/shuffle/row.rs           |  5 +-
 .../core/src/execution/shuffle/shuffle_writer.rs   | 90 ++++++++++++++++++----
 native/proto/src/proto/operator.proto              |  7 ++
 .../shuffle/CometShuffleExchangeExec.scala         | 14 +++-
 .../execution/shuffle/CometShuffleManager.scala    |  2 +-
 13 files changed, 221 insertions(+), 63 deletions(-)

diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala 
b/common/src/main/scala/org/apache/comet/CometConf.scala
index b602d7cf..8815ac4e 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -272,13 +272,21 @@ object CometConf extends ShimCometConf {
       .booleanConf
       .createWithDefault(false)
 
-  val COMET_EXEC_SHUFFLE_CODEC: ConfigEntry[String] = conf(
-    s"$COMET_EXEC_CONFIG_PREFIX.shuffle.codec")
+  val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = conf(
+    s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec")
     .doc(
-      "The codec of Comet native shuffle used to compress shuffle data. Only 
zstd is supported.")
+      "The codec of Comet native shuffle used to compress shuffle data. Only 
zstd is supported. " +
+        "Compression can be disabled by setting spark.shuffle.compress=false.")
     .stringConf
+    .checkValues(Set("zstd"))
     .createWithDefault("zstd")
 
+  val COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL: ConfigEntry[Int] =
+    conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.level")
+      .doc("The compression level to use when compression shuffle files.")
+      .intConf
+      .createWithDefault(1)
+
   val COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED: ConfigEntry[Boolean] =
     conf("spark.comet.columnar.shuffle.async.enabled")
       .doc("Whether to enable asynchronous shuffle for Arrow-based shuffle.")
diff --git 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala
 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala
index 281c4810..d1d5af35 100644
--- 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala
+++ 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala
@@ -110,8 +110,10 @@ case class IpcInputStreamIterator(
     currentLimitedInputStream = is
 
     if (decompressingNeeded) {
-      val zs = 
ShuffleUtils.compressionCodecForShuffling.compressedInputStream(is)
-      Channels.newChannel(zs)
+      ShuffleUtils.compressionCodecForShuffling match {
+        case Some(codec) => 
Channels.newChannel(codec.compressedInputStream(is))
+        case _ => Channels.newChannel(is)
+      }
     } else {
       Channels.newChannel(is)
     }
diff --git 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala
 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala
index eea134ab..23b4a5ec 100644
--- 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala
+++ 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala
@@ -21,22 +21,33 @@ package org.apache.spark.sql.comet.execution.shuffle
 
 import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config.IO_COMPRESSION_CODEC
+import org.apache.spark.internal.config.{IO_COMPRESSION_CODEC, 
SHUFFLE_COMPRESS}
 import org.apache.spark.io.CompressionCodec
-import org.apache.spark.sql.internal.SQLConf
 
 import org.apache.comet.CometConf
 
 private[spark] object ShuffleUtils extends Logging {
-  lazy val compressionCodecForShuffling: CompressionCodec = {
+  // optional compression codec to use when compressing shuffle files
+  lazy val compressionCodecForShuffling: Option[CompressionCodec] = {
     val sparkConf = SparkEnv.get.conf
-    val codecName = CometConf.COMET_EXEC_SHUFFLE_CODEC.get(SQLConf.get)
-
-    // only zstd compression is supported at the moment
-    if (codecName != "zstd") {
-      logWarning(
-        s"Overriding config ${IO_COMPRESSION_CODEC}=${codecName} in shuffling, 
force using zstd")
+    val shuffleCompressionEnabled = sparkConf.getBoolean(SHUFFLE_COMPRESS.key, 
true)
+    val sparkShuffleCodec = sparkConf.get(IO_COMPRESSION_CODEC.key, "lz4")
+    val cometShuffleCodec = 
CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get()
+    if (shuffleCompressionEnabled) {
+      if (sparkShuffleCodec != cometShuffleCodec) {
+        logWarning(
+          s"Overriding config $IO_COMPRESSION_CODEC=$sparkShuffleCodec in 
shuffling, " +
+            s"force using $cometShuffleCodec")
+      }
+      cometShuffleCodec match {
+        case "zstd" =>
+          Some(CompressionCodec.createCodec(sparkConf, "zstd"))
+        case other =>
+          throw new UnsupportedOperationException(
+            s"Unsupported shuffle compression codec: $other")
+      }
+    } else {
+      None
     }
-    CompressionCodec.createCodec(sparkConf, "zstd")
   }
 }
diff --git a/docs/source/user-guide/configs.md 
b/docs/source/user-guide/configs.md
index 69da7922..7881f076 100644
--- a/docs/source/user-guide/configs.md
+++ b/docs/source/user-guide/configs.md
@@ -50,7 +50,8 @@ Comet provides the following configuration settings.
 | spark.comet.exec.memoryFraction | The fraction of memory from Comet memory 
overhead that the native memory manager can use for execution. The purpose of 
this config is to set aside memory for untracked data structures, as well as 
imprecise size estimation during memory acquisition. | 0.7 |
 | spark.comet.exec.project.enabled | Whether to enable project by default. | 
true |
 | spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark 
to replace SortMergeJoin with ShuffledHashJoin for improved performance. This 
feature is not stable yet. For more information, refer to the Comet Tuning 
Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false |
-| spark.comet.exec.shuffle.codec | The codec of Comet native shuffle used to 
compress shuffle data. Only zstd is supported. | zstd |
+| spark.comet.exec.shuffle.compression.codec | The codec of Comet native 
shuffle used to compress shuffle data. Only zstd is supported. Compression can 
be disabled by setting spark.shuffle.compress=false. | zstd |
+| spark.comet.exec.shuffle.compression.level | The compression level to use 
when compression shuffle files. | 1 |
 | spark.comet.exec.shuffle.enabled | Whether to enable Comet native shuffle. 
Note that this requires setting 'spark.shuffle.manager' to 
'org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager'. 
'spark.shuffle.manager' must be set before starting the Spark application and 
cannot be changed during the application. | true |
 | spark.comet.exec.sort.enabled | Whether to enable sort by default. | true |
 | spark.comet.exec.sortMergeJoin.enabled | Whether to enable sortMergeJoin by 
default. | true |
diff --git a/docs/source/user-guide/tuning.md b/docs/source/user-guide/tuning.md
index d68481d1..e04e750b 100644
--- a/docs/source/user-guide/tuning.md
+++ b/docs/source/user-guide/tuning.md
@@ -103,6 +103,12 @@ native shuffle currently only supports `HashPartitioning` 
and `SinglePartitionin
 To enable native shuffle, set `spark.comet.exec.shuffle.mode` to `native`. If 
this mode is explicitly set,
 then any shuffle operations that cannot be supported in this mode will fall 
back to Spark.
 
+### Shuffle Compression
+
+By default, Spark compresses shuffle files using LZ4 compression. Comet 
overrides this behavior with ZSTD compression.
+Compression can be disabled by setting `spark.shuffle.compress=false`, which 
may result in faster shuffle times in 
+certain environments, such as single-node setups with fast NVMe drives, at the 
expense of increased disk space usage.
+
 ## Explain Plan
 ### Extended Explain
 With Spark 4.0.0 and newer, Comet can provide extended explain plan 
information in the Spark UI. Currently this lists
diff --git a/native/core/benches/shuffle_writer.rs 
b/native/core/benches/shuffle_writer.rs
index 27288723..865ca73b 100644
--- a/native/core/benches/shuffle_writer.rs
+++ b/native/core/benches/shuffle_writer.rs
@@ -15,36 +15,47 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use arrow_array::builder::Int32Builder;
 use arrow_array::{builder::StringBuilder, RecordBatch};
 use arrow_schema::{DataType, Field, Schema};
-use comet::execution::shuffle::ShuffleWriterExec;
+use comet::execution::shuffle::{write_ipc_compressed, CompressionCodec, 
ShuffleWriterExec};
 use criterion::{criterion_group, criterion_main, Criterion};
+use datafusion::physical_plan::metrics::Time;
 use datafusion::{
     physical_plan::{common::collect, memory::MemoryExec, ExecutionPlan},
     prelude::SessionContext,
 };
 use datafusion_physical_expr::{expressions::Column, Partitioning};
+use std::io::Cursor;
 use std::sync::Arc;
 use tokio::runtime::Runtime;
 
 fn criterion_benchmark(c: &mut Criterion) {
-    let batch = create_batch();
-    let mut batches = Vec::new();
-    for _ in 0..10 {
-        batches.push(batch.clone());
-    }
-    let partitions = &[batches];
-    let exec = ShuffleWriterExec::try_new(
-        Arc::new(MemoryExec::try_new(partitions, batch.schema(), 
None).unwrap()),
-        Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 16),
-        "/tmp/data.out".to_string(),
-        "/tmp/index.out".to_string(),
-    )
-    .unwrap();
-
     let mut group = c.benchmark_group("shuffle_writer");
-    group.bench_function("shuffle_writer", |b| {
+    group.bench_function("shuffle_writer: encode (no compression))", |b| {
+        let batch = create_batch(8192, true);
+        let mut buffer = vec![];
+        let mut cursor = Cursor::new(&mut buffer);
+        let ipc_time = Time::default();
+        b.iter(|| write_ipc_compressed(&batch, &mut cursor, 
&CompressionCodec::None, &ipc_time));
+    });
+    group.bench_function("shuffle_writer: encode and compress (zstd level 1)", 
|b| {
+        let batch = create_batch(8192, true);
+        let mut buffer = vec![];
+        let mut cursor = Cursor::new(&mut buffer);
+        let ipc_time = Time::default();
+        b.iter(|| write_ipc_compressed(&batch, &mut cursor, 
&CompressionCodec::Zstd(1), &ipc_time));
+    });
+    group.bench_function("shuffle_writer: encode and compress (zstd level 6)", 
|b| {
+        let batch = create_batch(8192, true);
+        let mut buffer = vec![];
+        let mut cursor = Cursor::new(&mut buffer);
+        let ipc_time = Time::default();
+        b.iter(|| write_ipc_compressed(&batch, &mut cursor, 
&CompressionCodec::Zstd(6), &ipc_time));
+    });
+    group.bench_function("shuffle_writer: end to end", |b| {
         let ctx = SessionContext::new();
+        let exec = create_shuffle_writer_exec(CompressionCodec::Zstd(1));
         b.iter(|| {
             let task_ctx = ctx.task_ctx();
             let stream = exec.execute(0, task_ctx).unwrap();
@@ -54,19 +65,47 @@ fn criterion_benchmark(c: &mut Criterion) {
     });
 }
 
-fn create_batch() -> RecordBatch {
-    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, 
true)]));
+fn create_shuffle_writer_exec(compression_codec: CompressionCodec) -> 
ShuffleWriterExec {
+    let batches = create_batches(8192, 10);
+    let schema = batches[0].schema();
+    let partitions = &[batches];
+    ShuffleWriterExec::try_new(
+        Arc::new(MemoryExec::try_new(partitions, schema, None).unwrap()),
+        Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 16),
+        compression_codec,
+        "/tmp/data.out".to_string(),
+        "/tmp/index.out".to_string(),
+    )
+    .unwrap()
+}
+
+fn create_batches(size: usize, count: usize) -> Vec<RecordBatch> {
+    let batch = create_batch(size, true);
+    let mut batches = Vec::new();
+    for _ in 0..count {
+        batches.push(batch.clone());
+    }
+    batches
+}
+
+fn create_batch(num_rows: usize, allow_nulls: bool) -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c0", DataType::Int32, true),
+        Field::new("c1", DataType::Utf8, true),
+    ]));
+    let mut a = Int32Builder::new();
     let mut b = StringBuilder::new();
-    for i in 0..8192 {
-        if i % 10 == 0 {
+    for i in 0..num_rows {
+        a.append_value(i as i32);
+        if allow_nulls && i % 10 == 0 {
             b.append_null();
         } else {
-            b.append_value(format!("{i}"));
+            b.append_value(format!("this is string number {i}"));
         }
     }
-    let array = b.finish();
-
-    RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap()
+    let a = a.finish();
+    let b = b.finish();
+    RecordBatch::try_new(schema.clone(), vec![Arc::new(a), 
Arc::new(b)]).unwrap()
 }
 
 fn config() -> Criterion {
diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index 3ac830c0..0a749335 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -68,6 +68,7 @@ use datafusion_comet_spark_expr::{create_comet_physical_fun, 
create_negate_expr}
 use datafusion_functions_nested::concat::ArrayAppend;
 use datafusion_physical_expr::aggregate::{AggregateExprBuilder, 
AggregateFunctionExpr};
 
+use crate::execution::shuffle::CompressionCodec;
 use crate::execution::spark_plan::SparkPlan;
 use datafusion_comet_proto::{
     spark_expression::{
@@ -76,8 +77,8 @@ use datafusion_comet_proto::{
     },
     spark_operator::{
         self, lower_window_frame_bound::LowerFrameBoundStruct, 
operator::OpStruct,
-        upper_window_frame_bound::UpperFrameBoundStruct, BuildSide, JoinType, 
Operator,
-        WindowFrameType,
+        upper_window_frame_bound::UpperFrameBoundStruct, BuildSide,
+        CompressionCodec as SparkCompressionCodec, JoinType, Operator, 
WindowFrameType,
     },
     spark_partitioning::{partitioning::PartitioningStruct, Partitioning as 
SparkPartitioning},
 };
@@ -1049,9 +1050,21 @@ impl PhysicalPlanner {
                 let partitioning = self
                     
.create_partitioning(writer.partitioning.as_ref().unwrap(), child.schema())?;
 
+                let codec = match writer.codec.try_into() {
+                    Ok(SparkCompressionCodec::None) => 
Ok(CompressionCodec::None),
+                    Ok(SparkCompressionCodec::Zstd) => {
+                        Ok(CompressionCodec::Zstd(writer.compression_level))
+                    }
+                    _ => Err(ExecutionError::GeneralError(format!(
+                        "Unsupported shuffle compression codec: {:?}",
+                        writer.codec
+                    ))),
+                }?;
+
                 let shuffle_writer = Arc::new(ShuffleWriterExec::try_new(
                     Arc::clone(&child.native_plan),
                     partitioning,
+                    codec,
                     writer.output_data_file.clone(),
                     writer.output_index_file.clone(),
                 )?);
diff --git a/native/core/src/execution/shuffle/mod.rs 
b/native/core/src/execution/shuffle/mod.rs
index 8721ead7..8111f5ee 100644
--- a/native/core/src/execution/shuffle/mod.rs
+++ b/native/core/src/execution/shuffle/mod.rs
@@ -19,4 +19,4 @@ mod list;
 mod map;
 pub mod row;
 mod shuffle_writer;
-pub use shuffle_writer::ShuffleWriterExec;
+pub use shuffle_writer::{write_ipc_compressed, CompressionCodec, 
ShuffleWriterExec};
diff --git a/native/core/src/execution/shuffle/row.rs 
b/native/core/src/execution/shuffle/row.rs
index ecab77d9..405f6421 100644
--- a/native/core/src/execution/shuffle/row.rs
+++ b/native/core/src/execution/shuffle/row.rs
@@ -292,6 +292,7 @@ macro_rules! downcast_builder_ref {
 }
 
 // Expose the macro for other modules.
+use crate::execution::shuffle::shuffle_writer::CompressionCodec;
 pub(crate) use downcast_builder_ref;
 
 /// Appends field of row to the given struct builder. `dt` is the data type of 
the field.
@@ -3358,7 +3359,9 @@ pub fn process_sorted_row_partition(
 
         // we do not collect metrics in Native_writeSortedFileNative
         let ipc_time = Time::default();
-        written += write_ipc_compressed(&batch, &mut cursor, &ipc_time)?;
+        // compression codec is not configurable for 
CometBypassMergeSortShuffleWriter
+        let codec = CompressionCodec::Zstd(1);
+        written += write_ipc_compressed(&batch, &mut cursor, &codec, 
&ipc_time)?;
 
         if let Some(checksum) = &mut current_checksum {
             checksum.update(&mut cursor)?;
diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs 
b/native/core/src/execution/shuffle/shuffle_writer.rs
index fcc8c51f..01117199 100644
--- a/native/core/src/execution/shuffle/shuffle_writer.rs
+++ b/native/core/src/execution/shuffle/shuffle_writer.rs
@@ -90,6 +90,7 @@ pub struct ShuffleWriterExec {
     /// Metrics
     metrics: ExecutionPlanMetricsSet,
     cache: PlanProperties,
+    codec: CompressionCodec,
 }
 
 impl DisplayAs for ShuffleWriterExec {
@@ -126,6 +127,7 @@ impl ExecutionPlan for ShuffleWriterExec {
             1 => Ok(Arc::new(ShuffleWriterExec::try_new(
                 Arc::clone(&children[0]),
                 self.partitioning.clone(),
+                self.codec.clone(),
                 self.output_data_file.clone(),
                 self.output_index_file.clone(),
             )?)),
@@ -152,6 +154,7 @@ impl ExecutionPlan for ShuffleWriterExec {
                     self.partitioning.clone(),
                     metrics,
                     context,
+                    self.codec.clone(),
                 )
                 .map_err(|e| ArrowError::ExternalError(Box::new(e))),
             )
@@ -181,6 +184,7 @@ impl ShuffleWriterExec {
     pub fn try_new(
         input: Arc<dyn ExecutionPlan>,
         partitioning: Partitioning,
+        codec: CompressionCodec,
         output_data_file: String,
         output_index_file: String,
     ) -> Result<Self> {
@@ -197,6 +201,7 @@ impl ShuffleWriterExec {
             output_data_file,
             output_index_file,
             cache,
+            codec,
         })
     }
 }
@@ -217,6 +222,7 @@ struct PartitionBuffer {
     batch_size: usize,
     /// Memory reservation for this partition buffer.
     reservation: MemoryReservation,
+    codec: CompressionCodec,
 }
 
 impl PartitionBuffer {
@@ -225,6 +231,7 @@ impl PartitionBuffer {
         batch_size: usize,
         partition_id: usize,
         runtime: &Arc<RuntimeEnv>,
+        codec: CompressionCodec,
     ) -> Self {
         let reservation = MemoryConsumer::new(format!("PartitionBuffer[{}]", 
partition_id))
             .with_can_spill(true)
@@ -238,6 +245,7 @@ impl PartitionBuffer {
             num_active_rows: 0,
             batch_size,
             reservation,
+            codec,
         }
     }
 
@@ -337,7 +345,7 @@ impl PartitionBuffer {
         let frozen_capacity_old = self.frozen.capacity();
         let mut cursor = Cursor::new(&mut self.frozen);
         cursor.seek(SeekFrom::End(0))?;
-        write_ipc_compressed(&frozen_batch, &mut cursor, ipc_time)?;
+        write_ipc_compressed(&frozen_batch, &mut cursor, &self.codec, 
ipc_time)?;
 
         mem_diff += (self.frozen.capacity() - frozen_capacity_old) as isize;
         Ok(mem_diff)
@@ -687,6 +695,7 @@ impl ShuffleRepartitioner {
         metrics: ShuffleRepartitionerMetrics,
         runtime: Arc<RuntimeEnv>,
         batch_size: usize,
+        codec: CompressionCodec,
     ) -> Self {
         let num_output_partitions = partitioning.partition_count();
         let reservation = 
MemoryConsumer::new(format!("ShuffleRepartitioner[{}]", partition_id))
@@ -709,7 +718,13 @@ impl ShuffleRepartitioner {
             schema: Arc::clone(&schema),
             buffered_partitions: (0..num_output_partitions)
                 .map(|partition_id| {
-                    PartitionBuffer::new(Arc::clone(&schema), batch_size, 
partition_id, &runtime)
+                    PartitionBuffer::new(
+                        Arc::clone(&schema),
+                        batch_size,
+                        partition_id,
+                        &runtime,
+                        codec.clone(),
+                    )
                 })
                 .collect::<Vec<_>>(),
             spills: Mutex::new(vec![]),
@@ -1129,6 +1144,7 @@ impl Debug for ShuffleRepartitioner {
     }
 }
 
+#[allow(clippy::too_many_arguments)]
 async fn external_shuffle(
     mut input: SendableRecordBatchStream,
     partition_id: usize,
@@ -1137,6 +1153,7 @@ async fn external_shuffle(
     partitioning: Partitioning,
     metrics: ShuffleRepartitionerMetrics,
     context: Arc<TaskContext>,
+    codec: CompressionCodec,
 ) -> Result<SendableRecordBatchStream> {
     let schema = input.schema();
     let mut repartitioner = ShuffleRepartitioner::new(
@@ -1148,6 +1165,7 @@ async fn external_shuffle(
         metrics,
         context.runtime_env(),
         context.session_config().batch_size(),
+        codec,
     );
 
     while let Some(batch) = input.next().await {
@@ -1526,11 +1544,18 @@ impl Checksum {
     }
 }
 
+#[derive(Debug, Clone)]
+pub enum CompressionCodec {
+    None,
+    Zstd(i32),
+}
+
 /// Writes given record batch as Arrow IPC bytes into given writer.
 /// Returns number of bytes written.
-pub(crate) fn write_ipc_compressed<W: Write + Seek>(
+pub fn write_ipc_compressed<W: Write + Seek>(
     batch: &RecordBatch,
     output: &mut W,
+    codec: &CompressionCodec,
     ipc_time: &Time,
 ) -> Result<usize> {
     if batch.num_rows() == 0 {
@@ -1543,14 +1568,24 @@ pub(crate) fn write_ipc_compressed<W: Write + Seek>(
     // write ipc_length placeholder
     output.write_all(&[0u8; 8])?;
 
-    // write ipc data
-    // TODO: make compression level configurable
-    let mut arrow_writer = StreamWriter::try_new(zstd::Encoder::new(output, 
1)?, &batch.schema())?;
-    arrow_writer.write(batch)?;
-    arrow_writer.finish()?;
+    let output = match codec {
+        CompressionCodec::None => {
+            let mut arrow_writer = StreamWriter::try_new(output, 
&batch.schema())?;
+            arrow_writer.write(batch)?;
+            arrow_writer.finish()?;
+            arrow_writer.into_inner()?
+        }
+        CompressionCodec::Zstd(level) => {
+            let encoder = zstd::Encoder::new(output, *level)?;
+            let mut arrow_writer = StreamWriter::try_new(encoder, 
&batch.schema())?;
+            arrow_writer.write(batch)?;
+            arrow_writer.finish()?;
+            let zstd_encoder = arrow_writer.into_inner()?;
+            zstd_encoder.finish()?
+        }
+    };
 
-    let zwriter = arrow_writer.into_inner()?;
-    let output = zwriter.finish()?;
+    // fill ipc length
     let end_pos = output.stream_position()?;
     let ipc_length = end_pos - start_pos - 8;
 
@@ -1611,6 +1646,22 @@ mod test {
     use datafusion_physical_expr::expressions::Column;
     use tokio::runtime::Runtime;
 
+    #[test]
+    #[cfg_attr(miri, ignore)] // miri can't call foreign function 
`ZSTD_createCCtx`
+    fn write_ipc_zstd() {
+        let batch = create_batch(8192);
+        let mut output = vec![];
+        let mut cursor = Cursor::new(&mut output);
+        write_ipc_compressed(
+            &batch,
+            &mut cursor,
+            &CompressionCodec::Zstd(1),
+            &Time::default(),
+        )
+        .unwrap();
+        assert_eq!(40218, output.len());
+    }
+
     #[test]
     fn test_slot_size() {
         let batch_size = 1usize;
@@ -1673,13 +1724,7 @@ mod test {
         num_partitions: usize,
         memory_limit: Option<usize>,
     ) {
-        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, 
true)]));
-        let mut b = StringBuilder::new();
-        for i in 0..batch_size {
-            b.append_value(format!("{i}"));
-        }
-        let array = b.finish();
-        let batch = RecordBatch::try_new(Arc::clone(&schema), 
vec![Arc::new(array)]).unwrap();
+        let batch = create_batch(batch_size);
 
         let batches = (0..num_batches).map(|_| 
batch.clone()).collect::<Vec<_>>();
 
@@ -1687,6 +1732,7 @@ mod test {
         let exec = ShuffleWriterExec::try_new(
             Arc::new(MemoryExec::try_new(partitions, batch.schema(), 
None).unwrap()),
             Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 
num_partitions),
+            CompressionCodec::Zstd(1),
             "/tmp/data.out".to_string(),
             "/tmp/index.out".to_string(),
         )
@@ -1707,6 +1753,16 @@ mod test {
         rt.block_on(collect(stream)).unwrap();
     }
 
+    fn create_batch(batch_size: usize) -> RecordBatch {
+        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, 
true)]));
+        let mut b = StringBuilder::new();
+        for i in 0..batch_size {
+            b.append_value(format!("{i}"));
+        }
+        let array = b.finish();
+        RecordBatch::try_new(Arc::clone(&schema), 
vec![Arc::new(array)]).unwrap()
+    }
+
     #[test]
     fn test_pmod() {
         let i: Vec<u32> = vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 
0xcd1e64fb];
diff --git a/native/proto/src/proto/operator.proto 
b/native/proto/src/proto/operator.proto
index 74ec80cb..5cb2802d 100644
--- a/native/proto/src/proto/operator.proto
+++ b/native/proto/src/proto/operator.proto
@@ -82,10 +82,17 @@ message Limit {
   int32 offset = 2;
 }
 
+enum CompressionCodec {
+  None = 0;
+  Zstd = 1;
+}
+
 message ShuffleWriter {
   spark.spark_partitioning.Partitioning partitioning = 1;
   string output_data_file = 3;
   string output_index_file = 4;
+  CompressionCodec codec = 5;
+  int32 compression_level = 6;
 }
 
 enum AggregateMode {
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 0cd8a9ce..3a11b8b2 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -52,8 +52,9 @@ import org.apache.spark.util.random.XORShiftRandom
 
 import com.google.common.base.Objects
 
+import org.apache.comet.CometConf
 import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, 
QueryPlanSerde}
-import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator}
 import org.apache.comet.serde.QueryPlanSerde.serializeDataType
 import org.apache.comet.shims.ShimCometShuffleExchangeExec
 
@@ -553,6 +554,17 @@ class CometShuffleWriteProcessor(
       shuffleWriterBuilder.setOutputDataFile(dataFile)
       shuffleWriterBuilder.setOutputIndexFile(indexFile)
 
+      if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) {
+        val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match 
{
+          case "zstd" => CompressionCodec.Zstd
+          case other => throw new UnsupportedOperationException(s"invalid 
codec: $other")
+        }
+        shuffleWriterBuilder.setCodec(codec)
+      } else {
+        shuffleWriterBuilder.setCodec(CompressionCodec.None)
+      }
+      
shuffleWriterBuilder.setCompressionLevel(CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL.get)
+
       outputPartitioning match {
         case _: HashPartitioning =>
           val hashPartitioning = 
outputPartitioning.asInstanceOf[HashPartitioning]
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
index ef67167c..b2cc2c2b 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
@@ -243,7 +243,7 @@ object CometShuffleManager extends Logging {
 
   lazy val compressionCodecForShuffling: CompressionCodec = {
     val sparkConf = SparkEnv.get.conf
-    val codecName = CometConf.COMET_EXEC_SHUFFLE_CODEC.get(SQLConf.get)
+    val codecName = 
CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get(SQLConf.get)
 
     // only zstd compression is supported at the moment
     if (codecName != "zstd") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to