This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 74a6a8d58 feat: Move shuffle block decompression and decoding to
native code and add LZ4 & Snappy support (#1192)
74a6a8d58 is described below
commit 74a6a8d58f931e1cedaaac0ef59616d5b6b27fa2
Author: Andy Grove <[email protected]>
AuthorDate: Mon Jan 6 17:47:16 2025 -0700
feat: Move shuffle block decompression and decoding to native code and add
LZ4 & Snappy support (#1192)
* Implement native decoding and decompression
* revert some variable renaming for smaller diff
* fix oom issues?
* make NativeBatchDecoderIterator more consistent with ArrowReaderIterator
* fix oom and prep for review
* format
* Add LZ4 support
* clippy, new benchmark
* rename metrics, clean up lz4 code
* update test
* Add support for snappy
* format
* change default back to lz4
* make metrics more accurate
* format
* clippy
* use faster unsafe version of lz4_flex
* Make compression codec configurable for columnar shuffle
* clippy
* fix bench
* fmt
* address feedback
* address feedback
* address feedback
* minor code simplification
* cargo fmt
* overflow check
* rename compression level config
* address feedback
* address feedback
* rename constant
---
.../main/scala/org/apache/comet/CometConf.scala | 23 +--
docs/source/user-guide/configs.md | 4 +-
native/Cargo.lock | 48 +++---
native/core/Cargo.toml | 3 +
native/core/benches/row_columnar.rs | 2 +
native/core/benches/shuffle_writer.rs | 41 ++++-
native/core/src/execution/jni_api.rs | 47 +++++-
native/core/src/execution/planner.rs | 2 +
native/core/src/execution/shuffle/mod.rs | 4 +-
native/core/src/execution/shuffle/row.rs | 5 +-
.../core/src/execution/shuffle/shuffle_writer.rs | 132 ++++++++++++---
native/proto/src/proto/operator.proto | 2 +
.../shuffle/sort/CometShuffleExternalSorter.java | 17 +-
.../execution/shuffle/CometDiskBlockWriter.java | 14 +-
.../sql/comet/execution/shuffle/SpillWriter.java | 8 +-
spark/src/main/scala/org/apache/comet/Native.scala | 30 +++-
.../apache/spark/sql/comet/CometMetricNode.scala | 5 +-
.../shuffle/CometBlockStoreShuffleReader.scala | 34 ++--
.../execution/shuffle/CometShuffleDependency.scala | 4 +-
.../shuffle/CometShuffleExchangeExec.scala | 13 +-
.../execution/shuffle/IpcInputStreamIterator.scala | 126 --------------
.../shuffle/NativeBatchDecoderIterator.scala | 188 +++++++++++++++++++++
.../comet/exec/CometColumnarShuffleSuite.scala | 3 +-
23 files changed, 524 insertions(+), 231 deletions(-)
diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala
b/common/src/main/scala/org/apache/comet/CometConf.scala
index 4d63de75a..4115fa9ab 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -272,18 +272,19 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(false)
- val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = conf(
- s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec")
- .doc(
- "The codec of Comet native shuffle used to compress shuffle data. Only
zstd is supported. " +
- "Compression can be disabled by setting spark.shuffle.compress=false.")
- .stringConf
- .checkValues(Set("zstd"))
- .createWithDefault("zstd")
+ val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] =
+ conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec")
+ .doc(
+ "The codec of Comet native shuffle used to compress shuffle data. lz4,
zstd, and " +
+ "snappy are supported. Compression can be disabled by setting " +
+ "spark.shuffle.compress=false.")
+ .stringConf
+ .checkValues(Set("zstd", "lz4", "snappy"))
+ .createWithDefault("lz4")
- val COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL: ConfigEntry[Int] =
- conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.level")
- .doc("The compression level to use when compression shuffle files.")
+ val COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL: ConfigEntry[Int] =
+ conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.zstd.level")
+ .doc("The compression level to use when compressing shuffle files with
zstd.")
.intConf
.createWithDefault(1)
diff --git a/docs/source/user-guide/configs.md
b/docs/source/user-guide/configs.md
index 20923b93a..d78e6111d 100644
--- a/docs/source/user-guide/configs.md
+++ b/docs/source/user-guide/configs.md
@@ -50,8 +50,8 @@ Comet provides the following configuration settings.
| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet
native execution. Available memory pool types are 'greedy', 'fair_spill',
'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and
'fair_spill_global', By default, this config is 'greedy_task_shared'. |
greedy_task_shared |
| spark.comet.exec.project.enabled | Whether to enable project by default. |
true |
| spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark
to replace SortMergeJoin with ShuffledHashJoin for improved performance. This
feature is not stable yet. For more information, refer to the Comet Tuning
Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false |
-| spark.comet.exec.shuffle.compression.codec | The codec of Comet native
shuffle used to compress shuffle data. Only zstd is supported. Compression can
be disabled by setting spark.shuffle.compress=false. | zstd |
-| spark.comet.exec.shuffle.compression.level | The compression level to use
when compression shuffle files. | 1 |
+| spark.comet.exec.shuffle.compression.codec | The codec of Comet native
shuffle used to compress shuffle data. lz4, zstd, and snappy are supported.
Compression can be disabled by setting spark.shuffle.compress=false. | lz4 |
+| spark.comet.exec.shuffle.compression.zstd.level | The compression level to
use when compressing shuffle files with zstd. | 1 |
| spark.comet.exec.shuffle.enabled | Whether to enable Comet native shuffle.
Note that this requires setting 'spark.shuffle.manager' to
'org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager'.
'spark.shuffle.manager' must be set before starting the Spark application and
cannot be changed during the application. | true |
| spark.comet.exec.sort.enabled | Whether to enable sort by default. | true |
| spark.comet.exec.sortMergeJoin.enabled | Whether to enable sortMergeJoin by
default. | true |
diff --git a/native/Cargo.lock b/native/Cargo.lock
index bbc0ff97a..1c44e3cc5 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
-version = 4
+version = 3
[[package]]
name = "addr2line"
@@ -346,7 +346,7 @@ checksum =
"721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -903,6 +903,7 @@ dependencies = [
"lazy_static",
"log",
"log4rs",
+ "lz4_flex",
"mimalloc",
"num",
"once_cell",
@@ -914,6 +915,7 @@ dependencies = [
"regex",
"serde",
"simd-adler32",
+ "snap",
"tempfile",
"thiserror",
"tokio",
@@ -1168,7 +1170,7 @@ version = "44.0.0"
source =
"git+https://github.com/apache/datafusion.git?rev=44.0.0-rc2#3cc3fca31e6edc2d953e663bfd7f856bcb70d8c4"
dependencies = [
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -1333,7 +1335,7 @@ checksum =
"97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -1473,7 +1475,7 @@ checksum =
"162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -1746,7 +1748,7 @@ checksum =
"1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -2556,7 +2558,7 @@ dependencies = [
"itertools 0.12.1",
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -2778,7 +2780,7 @@ checksum =
"5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -2868,7 +2870,7 @@ dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -2895,7 +2897,7 @@ checksum =
"da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -2932,7 +2934,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -2977,9 +2979,9 @@ dependencies = [
[[package]]
name = "syn"
-version = "2.0.92"
+version = "2.0.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "70ae51629bf965c5c098cc9e87908a3df5301051a9e087d6f9bef5c9771ed126"
+checksum = "9c786062daee0d6db1132800e623df74274a0a87322d8e183338e01b3d98d058"
dependencies = [
"proc-macro2",
"quote",
@@ -2994,7 +2996,7 @@ checksum =
"c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -3027,7 +3029,7 @@ checksum =
"4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -3100,7 +3102,7 @@ checksum =
"693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -3122,7 +3124,7 @@ checksum =
"395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -3276,7 +3278,7 @@ dependencies = [
"log",
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
"wasm-bindgen-shared",
]
@@ -3298,7 +3300,7 @@ checksum =
"30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@@ -3561,7 +3563,7 @@ checksum =
"2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
"synstructure",
]
@@ -3583,7 +3585,7 @@ checksum =
"fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
@@ -3603,7 +3605,7 @@ checksum =
"595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
"synstructure",
]
@@ -3626,7 +3628,7 @@ checksum =
"6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.92",
+ "syn 2.0.93",
]
[[package]]
diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml
index 5089e67a0..8937236dd 100644
--- a/native/core/Cargo.toml
+++ b/native/core/Cargo.toml
@@ -52,6 +52,9 @@ serde = { version = "1", features = ["derive"] }
lazy_static = "1.4.0"
prost = "0.12.1"
jni = "0.21"
+snap = "1.1"
+# we disable default features in lz4_flex to force the use of the faster
unsafe encoding and decoding implementation
+lz4_flex = { version = "0.11.3", default-features = false }
zstd = "0.11"
rand = { workspace = true}
num = { workspace = true }
diff --git a/native/core/benches/row_columnar.rs
b/native/core/benches/row_columnar.rs
index 60b41330e..a62574111 100644
--- a/native/core/benches/row_columnar.rs
+++ b/native/core/benches/row_columnar.rs
@@ -19,6 +19,7 @@ use arrow::datatypes::DataType as ArrowDataType;
use comet::execution::shuffle::row::{
process_sorted_row_partition, SparkUnsafeObject, SparkUnsafeRow,
};
+use comet::execution::shuffle::CompressionCodec;
use criterion::{criterion_group, criterion_main, Criterion};
use tempfile::Builder;
@@ -77,6 +78,7 @@ fn benchmark(c: &mut Criterion) {
false,
0,
None,
+ &CompressionCodec::Zstd(1),
)
.unwrap();
});
diff --git a/native/core/benches/shuffle_writer.rs
b/native/core/benches/shuffle_writer.rs
index 865ca73b4..0d22c62cc 100644
--- a/native/core/benches/shuffle_writer.rs
+++ b/native/core/benches/shuffle_writer.rs
@@ -35,23 +35,52 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("shuffle_writer: encode (no compression))", |b| {
let batch = create_batch(8192, true);
let mut buffer = vec![];
- let mut cursor = Cursor::new(&mut buffer);
let ipc_time = Time::default();
- b.iter(|| write_ipc_compressed(&batch, &mut cursor,
&CompressionCodec::None, &ipc_time));
+ b.iter(|| {
+ buffer.clear();
+ let mut cursor = Cursor::new(&mut buffer);
+ write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::None,
&ipc_time)
+ });
+ });
+ group.bench_function("shuffle_writer: encode and compress (snappy)", |b| {
+ let batch = create_batch(8192, true);
+ let mut buffer = vec![];
+ let ipc_time = Time::default();
+ b.iter(|| {
+ buffer.clear();
+ let mut cursor = Cursor::new(&mut buffer);
+ write_ipc_compressed(&batch, &mut cursor,
&CompressionCodec::Snappy, &ipc_time)
+ });
+ });
+ group.bench_function("shuffle_writer: encode and compress (lz4)", |b| {
+ let batch = create_batch(8192, true);
+ let mut buffer = vec![];
+ let ipc_time = Time::default();
+ b.iter(|| {
+ buffer.clear();
+ let mut cursor = Cursor::new(&mut buffer);
+ write_ipc_compressed(&batch, &mut cursor,
&CompressionCodec::Lz4Frame, &ipc_time)
+ });
});
group.bench_function("shuffle_writer: encode and compress (zstd level 1)",
|b| {
let batch = create_batch(8192, true);
let mut buffer = vec![];
- let mut cursor = Cursor::new(&mut buffer);
let ipc_time = Time::default();
- b.iter(|| write_ipc_compressed(&batch, &mut cursor,
&CompressionCodec::Zstd(1), &ipc_time));
+ b.iter(|| {
+ buffer.clear();
+ let mut cursor = Cursor::new(&mut buffer);
+ write_ipc_compressed(&batch, &mut cursor,
&CompressionCodec::Zstd(1), &ipc_time)
+ });
});
group.bench_function("shuffle_writer: encode and compress (zstd level 6)",
|b| {
let batch = create_batch(8192, true);
let mut buffer = vec![];
- let mut cursor = Cursor::new(&mut buffer);
let ipc_time = Time::default();
- b.iter(|| write_ipc_compressed(&batch, &mut cursor,
&CompressionCodec::Zstd(6), &ipc_time));
+ b.iter(|| {
+ buffer.clear();
+ let mut cursor = Cursor::new(&mut buffer);
+ write_ipc_compressed(&batch, &mut cursor,
&CompressionCodec::Zstd(6), &ipc_time)
+ });
});
group.bench_function("shuffle_writer: end to end", |b| {
let ctx = SessionContext::new();
diff --git a/native/core/src/execution/jni_api.rs
b/native/core/src/execution/jni_api.rs
index 7d8d577fe..aaac7ec8c 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -17,6 +17,7 @@
//! Define JNI APIs which can be called from Java/Scala.
+use super::{serde, utils::SparkArrowConvert, CometMemoryPool};
use arrow::datatypes::DataType as ArrowDataType;
use arrow_array::RecordBatch;
use datafusion::{
@@ -40,8 +41,6 @@ use jni::{
use std::time::{Duration, Instant};
use std::{collections::HashMap, sync::Arc, task::Poll};
-use super::{serde, utils::SparkArrowConvert, CometMemoryPool};
-
use crate::{
errors::{try_unwrap_or_throw, CometError, CometResult},
execution::{
@@ -54,6 +53,7 @@ use datafusion_comet_proto::spark_operator::Operator;
use datafusion_common::ScalarValue;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use futures::stream::StreamExt;
+use jni::objects::JByteBuffer;
use jni::sys::JNI_FALSE;
use jni::{
objects::GlobalRef,
@@ -64,6 +64,7 @@ use std::sync::Mutex;
use tokio::runtime::Runtime;
use crate::execution::operators::ScanExec;
+use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec};
use crate::execution::spark_plan::SparkPlan;
use log::info;
use once_cell::sync::{Lazy, OnceCell};
@@ -147,7 +148,7 @@ impl PerTaskMemoryPool {
/// Accept serialized query plan and return the address of the native query
plan.
/// # Safety
-/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
+/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
e: JNIEnv,
@@ -444,7 +445,7 @@ fn pull_input_batches(exec_context: &mut ExecutionContext)
-> Result<(), CometEr
/// Accept serialized query plan and the addresses of Arrow Arrays from Spark,
/// then execute the query. Return addresses of arrow vector.
/// # Safety
-/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
+/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
e: JNIEnv,
@@ -618,7 +619,7 @@ fn get_execution_context<'a>(id: i64) -> &'a mut
ExecutionContext {
/// Used by Comet shuffle external sorter to write sorted records to disk.
/// # Safety
-/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
+/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative(
e: JNIEnv,
@@ -632,6 +633,8 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative
checksum_enabled: jboolean,
checksum_algo: jint,
current_checksum: jlong,
+ compression_codec: jstring,
+ compression_level: jint,
) -> jlongArray {
try_unwrap_or_throw(&e, |mut env| unsafe {
let data_types = convert_datatype_arrays(&mut env,
serialized_datatypes)?;
@@ -659,6 +662,18 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative
Some(current_checksum as u32)
};
+ let compression_codec: String = env
+ .get_string(&JString::from_raw(compression_codec))
+ .unwrap()
+ .into();
+
+ let compression_codec = match compression_codec.as_str() {
+ "zstd" => CompressionCodec::Zstd(compression_level),
+ "lz4" => CompressionCodec::Lz4Frame,
+ "snappy" => CompressionCodec::Snappy,
+ _ => CompressionCodec::Lz4Frame,
+ };
+
let (written_bytes, checksum) = process_sorted_row_partition(
row_num,
batch_size as usize,
@@ -670,6 +685,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative
checksum_enabled,
checksum_algo,
current_checksum,
+ &compression_codec,
)?;
let checksum = if let Some(checksum) = checksum {
@@ -703,3 +719,24 @@ pub extern "system" fn
Java_org_apache_comet_Native_sortRowPartitionsNative(
Ok(())
})
}
+
+#[no_mangle]
+/// Used by Comet native shuffle reader
+/// # Safety
+/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
+pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
+ e: JNIEnv,
+ _class: JClass,
+ byte_buffer: JByteBuffer,
+ length: jint,
+ array_addrs: jlongArray,
+ schema_addrs: jlongArray,
+) -> jlong {
+ try_unwrap_or_throw(&e, |mut env| {
+ let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?;
+ let length = length as usize;
+ let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer,
length) };
+ let batch = read_ipc_compressed(slice)?;
+ prepare_output(&mut env, array_addrs, schema_addrs, batch, false)
+ })
+}
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index da452c2f1..294922f2f 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -1067,9 +1067,11 @@ impl PhysicalPlanner {
let codec = match writer.codec.try_into() {
Ok(SparkCompressionCodec::None) =>
Ok(CompressionCodec::None),
+ Ok(SparkCompressionCodec::Snappy) =>
Ok(CompressionCodec::Snappy),
Ok(SparkCompressionCodec::Zstd) => {
Ok(CompressionCodec::Zstd(writer.compression_level))
}
+ Ok(SparkCompressionCodec::Lz4) =>
Ok(CompressionCodec::Lz4Frame),
_ => Err(ExecutionError::GeneralError(format!(
"Unsupported shuffle compression codec: {:?}",
writer.codec
diff --git a/native/core/src/execution/shuffle/mod.rs
b/native/core/src/execution/shuffle/mod.rs
index 8111f5eed..178aff1fa 100644
--- a/native/core/src/execution/shuffle/mod.rs
+++ b/native/core/src/execution/shuffle/mod.rs
@@ -19,4 +19,6 @@ mod list;
mod map;
pub mod row;
mod shuffle_writer;
-pub use shuffle_writer::{write_ipc_compressed, CompressionCodec,
ShuffleWriterExec};
+pub use shuffle_writer::{
+ read_ipc_compressed, write_ipc_compressed, CompressionCodec,
ShuffleWriterExec,
+};
diff --git a/native/core/src/execution/shuffle/row.rs
b/native/core/src/execution/shuffle/row.rs
index 405f64216..9037bd794 100644
--- a/native/core/src/execution/shuffle/row.rs
+++ b/native/core/src/execution/shuffle/row.rs
@@ -3297,6 +3297,7 @@ pub fn process_sorted_row_partition(
// this is the initial checksum for this method, as it also gets updated
iteratively
// inside the loop within the method across batches.
initial_checksum: Option<u32>,
+ codec: &CompressionCodec,
) -> Result<(i64, Option<u32>), CometError> {
// TODO: We can tune this parameter automatically based on row size and
cache size.
let row_step = 10;
@@ -3359,9 +3360,7 @@ pub fn process_sorted_row_partition(
// we do not collect metrics in Native_writeSortedFileNative
let ipc_time = Time::default();
- // compression codec is not configurable for
CometBypassMergeSortShuffleWriter
- let codec = CompressionCodec::Zstd(1);
- written += write_ipc_compressed(&batch, &mut cursor, &codec,
&ipc_time)?;
+ written += write_ipc_compressed(&batch, &mut cursor, codec,
&ipc_time)?;
if let Some(checksum) = &mut current_checksum {
checksum.update(&mut cursor)?;
diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs
b/native/core/src/execution/shuffle/shuffle_writer.rs
index f3fa685b8..e6679d13d 100644
--- a/native/core/src/execution/shuffle/shuffle_writer.rs
+++ b/native/core/src/execution/shuffle/shuffle_writer.rs
@@ -21,6 +21,7 @@ use crate::{
common::bit::ceil,
errors::{CometError, CometResult},
};
+use arrow::ipc::reader::StreamReader;
use arrow::{datatypes::*, ipc::writer::StreamWriter};
use async_trait::async_trait;
use bytes::Buf;
@@ -312,7 +313,7 @@ impl PartitionBuffer {
repart_timer.stop();
if self.num_active_rows >= self.batch_size {
- let flush = self.flush(&metrics.ipc_time);
+ let flush = self.flush(metrics);
if let Err(e) = flush {
return AppendRowStatus::MemDiff(Err(e));
}
@@ -330,7 +331,7 @@ impl PartitionBuffer {
}
/// flush active data into frozen bytes
- fn flush(&mut self, ipc_time: &Time) -> Result<isize> {
+ fn flush(&mut self, metrics: &ShuffleRepartitionerMetrics) ->
Result<isize> {
if self.num_active_rows == 0 {
return Ok(0);
}
@@ -340,14 +341,24 @@ impl PartitionBuffer {
let active = std::mem::take(&mut self.active);
let num_rows = self.num_active_rows;
self.num_active_rows = 0;
+
+ let mut mempool_timer = metrics.mempool_time.timer();
self.reservation.try_shrink(self.active_slots_mem_size)?;
+ mempool_timer.stop();
+ let mut repart_timer = metrics.repart_time.timer();
let frozen_batch = make_batch(Arc::clone(&self.schema), active,
num_rows)?;
+ repart_timer.stop();
let frozen_capacity_old = self.frozen.capacity();
let mut cursor = Cursor::new(&mut self.frozen);
cursor.seek(SeekFrom::End(0))?;
- write_ipc_compressed(&frozen_batch, &mut cursor, &self.codec,
ipc_time)?;
+ write_ipc_compressed(
+ &frozen_batch,
+ &mut cursor,
+ &self.codec,
+ &metrics.encode_time,
+ )?;
mem_diff += (self.frozen.capacity() - frozen_capacity_old) as isize;
Ok(mem_diff)
@@ -652,7 +663,7 @@ struct ShuffleRepartitionerMetrics {
mempool_time: Time,
/// Time encoding batches to IPC format
- ipc_time: Time,
+ encode_time: Time,
/// Time spent writing to disk. Maps to "shuffleWriteTime" in Spark SQL
Metrics.
write_time: Time,
@@ -676,7 +687,7 @@ impl ShuffleRepartitionerMetrics {
baseline: BaselineMetrics::new(metrics, partition),
repart_time:
MetricBuilder::new(metrics).subset_time("repart_time", partition),
mempool_time:
MetricBuilder::new(metrics).subset_time("mempool_time", partition),
- ipc_time: MetricBuilder::new(metrics).subset_time("ipc_time",
partition),
+ encode_time:
MetricBuilder::new(metrics).subset_time("encode_time", partition),
write_time: MetricBuilder::new(metrics).subset_time("write_time",
partition),
input_batches:
MetricBuilder::new(metrics).counter("input_batches", partition),
spill_count: MetricBuilder::new(metrics).spill_count(partition),
@@ -790,6 +801,8 @@ impl ShuffleRepartitioner {
Partitioning::Hash(exprs, _) => {
let (partition_starts, shuffled_partition_ids): (Vec<usize>,
Vec<usize>) = {
let mut timer = self.metrics.repart_time.timer();
+
+ // evaluate partition expressions
let arrays = exprs
.iter()
.map(|expr|
expr.evaluate(&input)?.into_array(input.num_rows()))
@@ -923,7 +936,7 @@ impl ShuffleRepartitioner {
let mut output_batches: Vec<Vec<u8>> = vec![vec![];
num_output_partitions];
let mut offsets = vec![0; num_output_partitions + 1];
for i in 0..num_output_partitions {
- buffered_partitions[i].flush(&self.metrics.ipc_time)?;
+ buffered_partitions[i].flush(&self.metrics)?;
output_batches[i] = std::mem::take(&mut
buffered_partitions[i].frozen);
}
@@ -1023,20 +1036,19 @@ impl ShuffleRepartitioner {
}
let mut timer = self.metrics.write_time.timer();
-
let spillfile = self
.runtime
.disk_manager
.create_tmp_file("shuffle writer spill")?;
+ timer.stop();
+
let offsets = spill_into(
&mut self.buffered_partitions,
spillfile.path(),
self.num_output_partitions,
- &self.metrics.ipc_time,
+ &self.metrics,
)?;
- timer.stop();
-
let mut spills = self.spills.lock().await;
let used = self.reservation.size();
self.metrics.spill_count.add(1);
@@ -1107,16 +1119,18 @@ fn spill_into(
buffered_partitions: &mut [PartitionBuffer],
path: &Path,
num_output_partitions: usize,
- ipc_time: &Time,
+ metrics: &ShuffleRepartitionerMetrics,
) -> Result<Vec<u64>> {
let mut output_batches: Vec<Vec<u8>> = vec![vec![]; num_output_partitions];
for i in 0..num_output_partitions {
- buffered_partitions[i].flush(ipc_time)?;
+ buffered_partitions[i].flush(metrics)?;
output_batches[i] = std::mem::take(&mut buffered_partitions[i].frozen);
}
let path = path.to_owned();
+ let mut write_timer = metrics.write_time.timer();
+
let mut offsets = vec![0; num_output_partitions + 1];
let mut spill_data = OpenOptions::new()
.write(true)
@@ -1130,6 +1144,8 @@ fn spill_into(
spill_data.write_all(&output_batches[i])?;
output_batches[i].clear();
}
+ write_timer.stop();
+
// add one extra offset at last to ease partition length computation
offsets[num_output_partitions] = spill_data.stream_position()?;
Ok(offsets)
@@ -1549,7 +1565,9 @@ impl Checksum {
#[derive(Debug, Clone)]
pub enum CompressionCodec {
None,
+ Lz4Frame,
Zstd(i32),
+ Snappy,
}
/// Writes given record batch as Arrow IPC bytes into given writer.
@@ -1567,17 +1585,41 @@ pub fn write_ipc_compressed<W: Write + Seek>(
let mut timer = ipc_time.timer();
let start_pos = output.stream_position()?;
- // write ipc_length placeholder
- output.write_all(&[0u8; 8])?;
+ // seek past ipc_length placeholder
+ output.seek_relative(8)?;
+
+ // write number of columns because JVM side needs to know how many
addresses to allocate
+ let field_count = batch.schema().fields().len();
+ output.write_all(&field_count.to_le_bytes())?;
let output = match codec {
CompressionCodec::None => {
+ output.write_all(b"NONE")?;
let mut arrow_writer = StreamWriter::try_new(output,
&batch.schema())?;
arrow_writer.write(batch)?;
arrow_writer.finish()?;
arrow_writer.into_inner()?
}
+ CompressionCodec::Snappy => {
+ output.write_all(b"SNAP")?;
+ let mut wtr = snap::write::FrameEncoder::new(output);
+ let mut arrow_writer = StreamWriter::try_new(&mut wtr,
&batch.schema())?;
+ arrow_writer.write(batch)?;
+ arrow_writer.finish()?;
+ wtr.into_inner()
+ .map_err(|e| DataFusionError::Execution(format!("lz4
compression error: {}", e)))?
+ }
+ CompressionCodec::Lz4Frame => {
+ output.write_all(b"LZ4_")?;
+ let mut wtr = lz4_flex::frame::FrameEncoder::new(output);
+ let mut arrow_writer = StreamWriter::try_new(&mut wtr,
&batch.schema())?;
+ arrow_writer.write(batch)?;
+ arrow_writer.finish()?;
+ wtr.finish()
+ .map_err(|e| DataFusionError::Execution(format!("lz4
compression error: {}", e)))?
+ }
CompressionCodec::Zstd(level) => {
+ output.write_all(b"ZSTD")?;
let encoder = zstd::Encoder::new(output, *level)?;
let mut arrow_writer = StreamWriter::try_new(encoder,
&batch.schema())?;
arrow_writer.write(batch)?;
@@ -1590,6 +1632,13 @@ pub fn write_ipc_compressed<W: Write + Seek>(
// fill ipc length
let end_pos = output.stream_position()?;
let ipc_length = end_pos - start_pos - 8;
+ let max_size = i32::MAX as u64;
+ if ipc_length > max_size {
+ return Err(DataFusionError::Execution(format!(
+ "Shuffle block size {ipc_length} exceeds maximum size of
{max_size}. \
+ Try reducing batch size or increasing compression level"
+ )));
+ }
// fill ipc length
output.seek(SeekFrom::Start(start_pos))?;
@@ -1601,6 +1650,33 @@ pub fn write_ipc_compressed<W: Write + Seek>(
Ok((end_pos - start_pos) as usize)
}
+pub fn read_ipc_compressed(bytes: &[u8]) -> Result<RecordBatch> {
+ match &bytes[0..4] {
+ b"SNAP" => {
+ let decoder = snap::read::FrameDecoder::new(&bytes[4..]);
+ let mut reader = StreamReader::try_new(decoder, None)?;
+ reader.next().unwrap().map_err(|e| e.into())
+ }
+ b"LZ4_" => {
+ let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]);
+ let mut reader = StreamReader::try_new(decoder, None)?;
+ reader.next().unwrap().map_err(|e| e.into())
+ }
+ b"ZSTD" => {
+ let decoder = zstd::Decoder::new(&bytes[4..])?;
+ let mut reader = StreamReader::try_new(decoder, None)?;
+ reader.next().unwrap().map_err(|e| e.into())
+ }
+ b"NONE" => {
+ let mut reader = StreamReader::try_new(&bytes[4..], None)?;
+ reader.next().unwrap().map_err(|e| e.into())
+ }
+ _ => Err(DataFusionError::Execution(
+ "Failed to decode batch: invalid compression codec".to_string(),
+ )),
+ }
+}
+
/// A stream that yields no record batches which represent end of output.
pub struct EmptyStream {
/// Schema representing the data
@@ -1650,18 +1726,24 @@ mod test {
#[test]
#[cfg_attr(miri, ignore)] // miri can't call foreign function
`ZSTD_createCCtx`
- fn write_ipc_zstd() {
+ fn roundtrip_ipc() {
let batch = create_batch(8192);
- let mut output = vec![];
- let mut cursor = Cursor::new(&mut output);
- write_ipc_compressed(
- &batch,
- &mut cursor,
- &CompressionCodec::Zstd(1),
- &Time::default(),
- )
- .unwrap();
- assert_eq!(40218, output.len());
+ for codec in &[
+ CompressionCodec::None,
+ CompressionCodec::Zstd(1),
+ CompressionCodec::Snappy,
+ CompressionCodec::Lz4Frame,
+ ] {
+ let mut output = vec![];
+ let mut cursor = Cursor::new(&mut output);
+ let length =
+ write_ipc_compressed(&batch, &mut cursor, codec,
&Time::default()).unwrap();
+ assert_eq!(length, output.len());
+
+ let ipc_without_length_prefix = &output[16..];
+ let batch2 =
read_ipc_compressed(ipc_without_length_prefix).unwrap();
+ assert_eq!(batch, batch2);
+ }
}
#[test]
diff --git a/native/proto/src/proto/operator.proto
b/native/proto/src/proto/operator.proto
index 5cb2802da..a3480086c 100644
--- a/native/proto/src/proto/operator.proto
+++ b/native/proto/src/proto/operator.proto
@@ -85,6 +85,8 @@ message Limit {
enum CompressionCodec {
None = 0;
Zstd = 1;
+ Lz4 = 2;
+ Snappy = 3;
}
message ShuffleWriter {
diff --git
a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
index cc4495570..1e3762a6c 100644
---
a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
+++
b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
@@ -107,6 +107,8 @@ public final class CometShuffleExternalSorter implements
CometShuffleChecksumSup
private final long[] partitionChecksums;
private final String checksumAlgorithm;
+ private final String compressionCodec;
+ private final int compressionLevel;
// The memory allocator for this sorter. It is used to allocate/free memory
pages for this sorter.
// Because we need to allocate off-heap memory regardless of configured
Spark memory mode
@@ -153,6 +155,9 @@ public final class CometShuffleExternalSorter implements
CometShuffleChecksumSup
this.peakMemoryUsedBytes = getMemoryUsage();
this.partitionChecksums = createPartitionChecksums(numPartitions, conf);
this.checksumAlgorithm = getChecksumAlgorithm(conf);
+ this.compressionCodec =
CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC().get();
+ this.compressionLevel =
+ (int)
CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL().get();
this.initialSize = initialSize;
@@ -556,7 +561,9 @@ public final class CometShuffleExternalSorter implements
CometShuffleChecksumSup
spillInfo.file,
rowPartition,
writeMetricsToUse,
- preferDictionaryRatio);
+ preferDictionaryRatio,
+ compressionCodec,
+ compressionLevel);
spillInfo.partitionLengths[currentPartition] = written;
// Store the checksum for the current partition.
@@ -578,7 +585,13 @@ public final class CometShuffleExternalSorter implements
CometShuffleChecksumSup
if (currentPartition != -1) {
long written =
doSpilling(
- dataTypes, spillInfo.file, rowPartition, writeMetricsToUse,
preferDictionaryRatio);
+ dataTypes,
+ spillInfo.file,
+ rowPartition,
+ writeMetricsToUse,
+ preferDictionaryRatio,
+ compressionCodec,
+ compressionLevel);
spillInfo.partitionLengths[currentPartition] = written;
synchronized (spills) {
diff --git
a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java
b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java
index dcb9d99d3..006e8ce97 100644
---
a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java
+++
b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java
@@ -103,6 +103,8 @@ public final class CometDiskBlockWriter {
private long totalWritten = 0L;
private boolean initialized = false;
private final int columnarBatchSize;
+ private final String compressionCodec;
+ private final int compressionLevel;
private final boolean isAsync;
private final int asyncThreadNum;
private final ExecutorService threadPool;
@@ -153,6 +155,9 @@ public final class CometDiskBlockWriter {
this.threadPool = threadPool;
this.columnarBatchSize = (int)
CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_BATCH_SIZE().get();
+ this.compressionCodec =
CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC().get();
+ this.compressionLevel =
+ (int)
CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL().get();
this.numElementsForSpillThreshold =
(int)
CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_SPILL_THRESHOLD().get();
@@ -397,7 +402,14 @@ public final class CometDiskBlockWriter {
synchronized (file) {
outputRecords += rowPartition.getNumRows();
written =
- doSpilling(dataTypes, file, rowPartition, writeMetricsToUse,
preferDictionaryRatio);
+ doSpilling(
+ dataTypes,
+ file,
+ rowPartition,
+ writeMetricsToUse,
+ preferDictionaryRatio,
+ compressionCodec,
+ compressionLevel);
}
// Update metrics
diff --git
a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java
b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java
index 3dc86b05b..a4f09b415 100644
---
a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java
+++
b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java
@@ -171,7 +171,9 @@ public abstract class SpillWriter {
File file,
RowPartition rowPartition,
ShuffleWriteMetricsReporter writeMetricsToUse,
- double preferDictionaryRatio) {
+ double preferDictionaryRatio,
+ String compressionCodec,
+ int compressionLevel) {
long[] addresses = rowPartition.getRowAddresses();
int[] sizes = rowPartition.getRowSizes();
@@ -190,7 +192,9 @@ public abstract class SpillWriter {
batchSize,
checksumEnabled,
checksumAlgo,
- currentChecksum);
+ currentChecksum,
+ compressionCodec,
+ compressionLevel);
long written = results[0];
checksum = results[1];
diff --git a/spark/src/main/scala/org/apache/comet/Native.scala
b/spark/src/main/scala/org/apache/comet/Native.scala
index e5728009e..dbcab15b4 100644
--- a/spark/src/main/scala/org/apache/comet/Native.scala
+++ b/spark/src/main/scala/org/apache/comet/Native.scala
@@ -19,6 +19,8 @@
package org.apache.comet
+import java.nio.ByteBuffer
+
import org.apache.spark.CometTaskMemoryManager
import org.apache.spark.sql.comet.CometMetricNode
@@ -118,9 +120,14 @@ class Native extends NativeBase {
* @param currentChecksum
* the current checksum of the file. As the checksum is computed
incrementally, this is used
* to resume the computation of checksum for previous written data.
+ * @param compressionCodec
+ * the compression codec
+ * @param compressionLevel
+ * the compression level
* @return
* [the number of bytes written to disk, the checksum]
*/
+ // scalastyle:off
@native def writeSortedFileNative(
addresses: Array[Long],
rowSizes: Array[Int],
@@ -130,7 +137,10 @@ class Native extends NativeBase {
batchSize: Int,
checksumEnabled: Boolean,
checksumAlgo: Int,
- currentChecksum: Long): Array[Long]
+ currentChecksum: Long,
+ compressionCodec: String,
+ compressionLevel: Int): Array[Long]
+ // scalastyle:on
/**
* Sorts partition ids of Spark unsafe rows in place. Used by Comet shuffle
external sorter.
@@ -141,4 +151,22 @@ class Native extends NativeBase {
* the size of the array.
*/
@native def sortRowPartitionsNative(addr: Long, size: Long): Unit
+
+ /**
+ * Decompress and decode a native shuffle block.
+ * @param shuffleBlock
+ * the encoded anc compressed shuffle block.
+ * @param length
+ * the limit of the byte buffer.
+ * @param addr
+ * the address of the array of compressed and encoded bytes.
+ * @param size
+ * the size of the array.
+ */
+ @native def decodeShuffleBlock(
+ shuffleBlock: ByteBuffer,
+ length: Int,
+ arrayAddrs: Array[Long],
+ schemaAddrs: Array[Long]): Long
+
}
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala
index a26fa28c8..53370a03b 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala
@@ -132,10 +132,11 @@ object CometMetricNode {
def shuffleMetrics(sc: SparkContext): Map[String, SQLMetric] = {
Map(
- "elapsed_compute" -> SQLMetrics.createNanoTimingMetric(sc, "native
shuffle time"),
+ "elapsed_compute" -> SQLMetrics.createNanoTimingMetric(sc, "native
shuffle writer time"),
"mempool_time" -> SQLMetrics.createNanoTimingMetric(sc, "memory pool
time"),
"repart_time" -> SQLMetrics.createNanoTimingMetric(sc, "repartition
time"),
- "ipc_time" -> SQLMetrics.createNanoTimingMetric(sc, "encoding and
compression time"),
+ "encode_time" -> SQLMetrics.createNanoTimingMetric(sc, "encoding and
compression time"),
+ "decode_time" -> SQLMetrics.createNanoTimingMetric(sc, "decoding and
decompression time"),
"spill_count" -> SQLMetrics.createMetric(sc, "number of spills"),
"spilled_bytes" -> SQLMetrics.createMetric(sc, "spilled bytes"),
"input_batches" -> SQLMetrics.createMetric(sc, "number of input
batches"))
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
index 74c655950..1283a745a 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
@@ -25,8 +25,14 @@ import org.apache.spark.{InterruptibleIterator,
MapOutputTracker, SparkEnv, Task
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.SerializerManager
-import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader,
ShuffleReadMetricsReporter}
-import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId,
ShuffleBlockFetcherIterator}
+import org.apache.spark.shuffle.BaseShuffleHandle
+import org.apache.spark.shuffle.ShuffleReader
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.ShuffleBlockFetcherIterator
import org.apache.spark.util.CompletionIterator
/**
@@ -79,7 +85,7 @@ class CometBlockStoreShuffleReader[K, C](
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
- var currentReadIterator: ArrowReaderIterator = null
+ var currentReadIterator: NativeBatchDecoderIterator = null
// Closes last read iterator after the task is finished.
// We need to close read iterator during iterating input streams,
@@ -91,18 +97,16 @@ class CometBlockStoreShuffleReader[K, C](
}
}
- val recordIter = fetchIterator
- .flatMap { case (_, inputStream) =>
- IpcInputStreamIterator(inputStream, decompressingNeeded = true,
context)
- .flatMap { channel =>
- if (currentReadIterator != null) {
- // Closes previous read iterator.
- currentReadIterator.close()
- }
- currentReadIterator = new ArrowReaderIterator(channel,
this.getClass.getSimpleName)
- currentReadIterator.map((0, _)) // use 0 as key since it's not used
- }
- }
+ val recordIter: Iterator[(Int, ColumnarBatch)] = fetchIterator
+ .flatMap(blockIdAndStream => {
+ if (currentReadIterator != null) {
+ currentReadIterator.close()
+ }
+ currentReadIterator =
+ NativeBatchDecoderIterator(blockIdAndStream._2, context,
dep.decodeTime)
+ currentReadIterator
+ })
+ .map(b => (0, b))
// Update the context task metrics for each record read.
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
index 7b1d1f127..8c8aed28e 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
@@ -25,6 +25,7 @@ import org.apache.spark.{Aggregator, Partitioner,
ShuffleDependency, SparkEnv}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleWriteProcessor
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
/**
@@ -39,7 +40,8 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
override val mapSideCombine: Boolean = false,
override val shuffleWriterProcessor: ShuffleWriteProcessor = new
ShuffleWriteProcessor,
val shuffleType: ShuffleType = CometNativeShuffle,
- val schema: Option[StructType] = None)
+ val schema: Option[StructType] = None,
+ val decodeTime: SQLMetric)
extends ShuffleDependency[K, V, C](
_rdd,
partitioner,
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 3a11b8b28..041411b3f 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -238,7 +238,8 @@ object CometShuffleExchangeExec extends
ShimCometShuffleExchangeExec {
partitioner = new Partitioner {
override def numPartitions: Int = outputPartitioning.numPartitions
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
- })
+ },
+ decodeTime = metrics("decode_time"))
dependency
}
@@ -435,7 +436,8 @@ object CometShuffleExchangeExec extends
ShimCometShuffleExchangeExec {
serializer,
shuffleWriterProcessor =
ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
shuffleType = CometColumnarShuffle,
- schema = Some(fromAttributes(outputAttributes)))
+ schema = Some(fromAttributes(outputAttributes)),
+ decodeTime = writeMetrics("decode_time"))
dependency
}
@@ -481,7 +483,7 @@ class CometShuffleWriteProcessor(
val detailedMetrics = Seq(
"elapsed_compute",
- "ipc_time",
+ "encode_time",
"repart_time",
"mempool_time",
"input_batches",
@@ -557,13 +559,16 @@ class CometShuffleWriteProcessor(
if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) {
val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match
{
case "zstd" => CompressionCodec.Zstd
+ case "lz4" => CompressionCodec.Lz4
+ case "snappy" => CompressionCodec.Snappy
case other => throw new UnsupportedOperationException(s"invalid
codec: $other")
}
shuffleWriterBuilder.setCodec(codec)
} else {
shuffleWriterBuilder.setCodec(CompressionCodec.None)
}
-
shuffleWriterBuilder.setCompressionLevel(CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL.get)
+ shuffleWriterBuilder.setCompressionLevel(
+ CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get)
outputPartitioning match {
case _: HashPartitioning =>
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala
deleted file mode 100644
index aa4055048..000000000
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.comet.execution.shuffle
-
-import java.io.{EOFException, InputStream}
-import java.nio.{ByteBuffer, ByteOrder}
-import java.nio.channels.{Channels, ReadableByteChannel}
-
-import org.apache.spark.TaskContext
-import org.apache.spark.internal.Logging
-import org.apache.spark.network.util.LimitedInputStream
-
-case class IpcInputStreamIterator(
- var in: InputStream,
- decompressingNeeded: Boolean,
- taskContext: TaskContext)
- extends Iterator[ReadableByteChannel]
- with Logging {
-
- private[execution] val channel: ReadableByteChannel = if (in != null) {
- Channels.newChannel(in)
- } else {
- null
- }
-
- private val ipcLengthsBuf =
ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN)
-
- // NOTE:
- // since all ipcs are sharing the same input stream and channel, the second
- // hasNext() must be called after the first ipc has been completely
processed.
-
- private[execution] var consumed = true
- private var finished = false
- private var currentIpcLength = 0L
- private var currentLimitedInputStream: LimitedInputStream = _
-
- taskContext.addTaskCompletionListener[Unit](_ => {
- closeInputStream()
- })
-
- override def hasNext: Boolean = {
- if (in == null || finished) {
- return false
- }
-
- // If we've read the length of the next IPC, we don't need to read it
again.
- if (!consumed) {
- return true
- }
-
- if (currentLimitedInputStream != null) {
- currentLimitedInputStream.skip(Int.MaxValue)
- currentLimitedInputStream = null
- }
-
- // Reads the length of IPC bytes
- ipcLengthsBuf.clear()
- while (ipcLengthsBuf.hasRemaining && channel.read(ipcLengthsBuf) >= 0) {}
-
- // If we reach the end of the stream, we are done, or if we read partial
length
- // then the stream is corrupted.
- if (ipcLengthsBuf.hasRemaining) {
- if (ipcLengthsBuf.position() == 0) {
- finished = true
- closeInputStream()
- return false
- }
- throw new EOFException("Data corrupt: unexpected EOF while reading
compressed ipc lengths")
- }
-
- ipcLengthsBuf.flip()
- currentIpcLength = ipcLengthsBuf.getLong
-
- // Skips empty IPC
- if (currentIpcLength == 0) {
- return hasNext
- }
- consumed = false
- return true
- }
-
- override def next(): ReadableByteChannel = {
- if (!hasNext) {
- throw new NoSuchElementException
- }
- assert(!consumed)
- consumed = true
-
- val is = new LimitedInputStream(Channels.newInputStream(channel),
currentIpcLength, false)
- currentLimitedInputStream = is
-
- if (decompressingNeeded) {
- ShuffleUtils.compressionCodecForShuffling match {
- case Some(codec) =>
Channels.newChannel(codec.compressedInputStream(is))
- case _ => Channels.newChannel(is)
- }
- } else {
- Channels.newChannel(is)
- }
- }
-
- private def closeInputStream(): Unit =
- synchronized {
- if (in != null) {
- in.close()
- in = null
- }
- }
-}
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala
new file mode 100644
index 000000000..2839c9bd8
--- /dev/null
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.shuffle
+
+import java.io.{EOFException, InputStream}
+import java.nio.{ByteBuffer, ByteOrder}
+import java.nio.channels.{Channels, ReadableByteChannel}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import org.apache.comet.Native
+import org.apache.comet.vector.NativeUtil
+
+/**
+ * This iterator wraps a Spark input stream that is reading shuffle blocks
generated by the Comet
+ * native ShuffleWriterExec and then calls native code to decompress and
decode the shuffle blocks
+ * and use Arrow FFI to return the Arrow record batch.
+ */
+case class NativeBatchDecoderIterator(
+ var in: InputStream,
+ taskContext: TaskContext,
+ decodeTime: SQLMetric)
+ extends Iterator[ColumnarBatch] {
+
+ private var isClosed = false
+ private val longBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN)
+ private val native = new Native()
+ private val nativeUtil = new NativeUtil()
+ private var currentBatch: ColumnarBatch = null
+ private var batch = fetchNext()
+
+ import NativeBatchDecoderIterator.threadLocalDataBuf
+
+ if (taskContext != null) {
+ taskContext.addTaskCompletionListener[Unit](_ => {
+ close()
+ })
+ }
+
+ private val channel: ReadableByteChannel = if (in != null) {
+ Channels.newChannel(in)
+ } else {
+ null
+ }
+
+ def hasNext(): Boolean = {
+ if (channel == null || isClosed) {
+ return false
+ }
+ if (batch.isDefined) {
+ return true
+ }
+
+ // Release the previous batch.
+ if (currentBatch != null) {
+ currentBatch.close()
+ currentBatch = null
+ }
+
+ batch = fetchNext()
+ if (batch.isEmpty) {
+ close()
+ return false
+ }
+ true
+ }
+
+ def next(): ColumnarBatch = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+
+ val nextBatch = batch.get
+
+ currentBatch = nextBatch
+ batch = None
+ currentBatch
+ }
+
+ private def fetchNext(): Option[ColumnarBatch] = {
+ if (channel == null || isClosed) {
+ return None
+ }
+
+ // read compressed batch size from header
+ try {
+ longBuf.clear()
+ while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {}
+ } catch {
+ case _: EOFException =>
+ close()
+ return None
+ }
+
+ // If we reach the end of the stream, we are done, or if we read partial
length
+ // then the stream is corrupted.
+ if (longBuf.hasRemaining) {
+ if (longBuf.position() == 0) {
+ close()
+ return None
+ }
+ throw new EOFException("Data corrupt: unexpected EOF while reading
compressed ipc lengths")
+ }
+
+ // get compressed length (including headers)
+ longBuf.flip()
+ val compressedLength = longBuf.getLong
+
+ // read field count from header
+ longBuf.clear()
+ while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {}
+ if (longBuf.hasRemaining) {
+ throw new EOFException("Data corrupt: unexpected EOF while reading field
count")
+ }
+ longBuf.flip()
+ val fieldCount = longBuf.getLong.toInt
+
+ // read body
+ val bytesToRead = compressedLength - 8
+ if (bytesToRead > Integer.MAX_VALUE) {
+ // very unlikely that shuffle block will reach 2GB
+ throw new IllegalStateException(
+ s"Native shuffle block size of $bytesToRead exceeds " +
+ s"maximum of ${Integer.MAX_VALUE}. Try reducing shuffle batch size.")
+ }
+ var dataBuf = threadLocalDataBuf.get()
+ if (dataBuf.capacity() < bytesToRead) {
+ val newCapacity = (bytesToRead * 2L).min(Integer.MAX_VALUE).toInt
+ dataBuf = ByteBuffer.allocateDirect(newCapacity)
+ threadLocalDataBuf.set(dataBuf)
+ }
+ dataBuf.clear()
+ dataBuf.limit(bytesToRead.toInt)
+ while (dataBuf.hasRemaining && channel.read(dataBuf) >= 0) {}
+ if (dataBuf.hasRemaining) {
+ throw new EOFException("Data corrupt: unexpected EOF while reading
compressed batch")
+ }
+
+ // make native call to decode batch
+ val startTime = System.nanoTime()
+ val batch = nativeUtil.getNextBatch(
+ fieldCount,
+ (arrayAddrs, schemaAddrs) => {
+ native.decodeShuffleBlock(dataBuf, bytesToRead.toInt, arrayAddrs,
schemaAddrs)
+ })
+ decodeTime.add(System.nanoTime() - startTime)
+
+ batch
+ }
+
+ def close(): Unit = {
+ synchronized {
+ if (!isClosed) {
+ if (currentBatch != null) {
+ currentBatch.close()
+ currentBatch = null
+ }
+ in.close()
+ isClosed = true
+ }
+ }
+ }
+}
+
+object NativeBatchDecoderIterator {
+ private val threadLocalDataBuf: ThreadLocal[ByteBuffer] =
ThreadLocal.withInitial(() => {
+ ByteBuffer.allocateDirect(128 * 1024)
+ })
+}
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
index 6130e4cd5..13344c0ed 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
@@ -1132,7 +1132,8 @@ class CometShuffleManagerSuite extends CometTestBase {
partitioner = new Partitioner {
override def numPartitions: Int = 50
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
- })
+ },
+ decodeTime = null)
assert(CometShuffleManager.shouldBypassMergeSort(conf, dependency))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]