This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 4edd904ee refactor: reorganize shuffle crate module structure (#3772)
4edd904ee is described below
commit 4edd904eeaca8c985e5dd928486888e4ee7e9f60
Author: Andy Grove <[email protected]>
AuthorDate: Fri Mar 27 11:50:17 2026 -0600
refactor: reorganize shuffle crate module structure (#3772)
---
.../core/src/execution/operators/shuffle_scan.rs | 7 +-
native/shuffle/benches/row_columnar.rs | 5 +-
native/shuffle/src/comet_partitioning.rs | 1 +
native/shuffle/src/ipc.rs | 52 +++++
native/shuffle/src/lib.rs | 5 +-
native/shuffle/src/metrics.rs | 1 +
native/shuffle/src/partitioners/mod.rs | 13 +-
native/shuffle/src/partitioners/multi_partition.rs | 1 +
.../src/partitioners/partitioned_batch_iterator.rs | 1 +
.../shuffle/src/partitioners/{mod.rs => traits.rs} | 8 -
native/shuffle/src/spark_unsafe/list.rs | 7 +-
native/shuffle/src/spark_unsafe/map.rs | 1 +
native/shuffle/src/spark_unsafe/mod.rs | 1 +
native/shuffle/src/spark_unsafe/row.rs | 214 +-------------------
native/shuffle/src/spark_unsafe/unsafe_object.rs | 224 +++++++++++++++++++++
native/shuffle/src/writers/buf_batch_writer.rs | 2 +-
native/shuffle/src/writers/checksum.rs | 81 ++++++++
native/shuffle/src/writers/mod.rs | 8 +-
.../{codec.rs => writers/shuffle_block_writer.rs} | 97 +--------
.../src/writers/{partition_writer.rs => spill.rs} | 4 +-
20 files changed, 395 insertions(+), 338 deletions(-)
diff --git a/native/core/src/execution/operators/shuffle_scan.rs
b/native/core/src/execution/operators/shuffle_scan.rs
index 824965d48..a1ad52310 100644
--- a/native/core/src/execution/operators/shuffle_scan.rs
+++ b/native/core/src/execution/operators/shuffle_scan.rs
@@ -18,8 +18,7 @@
use crate::{
errors::CometError,
execution::{
- operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID,
- shuffle::codec::read_ipc_compressed,
+ operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID,
shuffle::ipc::read_ipc_compressed,
},
jvm_bridge::{jni_call, JVMClasses},
};
@@ -352,7 +351,7 @@ impl RecordBatchStream for ShuffleScanStream {
#[cfg(test)]
mod tests {
- use crate::execution::shuffle::codec::{CompressionCodec,
ShuffleBlockWriter};
+ use crate::execution::shuffle::{CompressionCodec, ShuffleBlockWriter};
use arrow::array::{Int32Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
@@ -360,7 +359,7 @@ mod tests {
use std::io::Cursor;
use std::sync::Arc;
- use crate::execution::shuffle::codec::read_ipc_compressed;
+ use crate::execution::shuffle::ipc::read_ipc_compressed;
#[test]
#[cfg_attr(miri, ignore)] // Miri cannot call FFI functions (zstd)
diff --git a/native/shuffle/benches/row_columnar.rs
b/native/shuffle/benches/row_columnar.rs
index 7d3951b4d..cc98f3fac 100644
--- a/native/shuffle/benches/row_columnar.rs
+++ b/native/shuffle/benches/row_columnar.rs
@@ -23,9 +23,8 @@
use arrow::datatypes::{DataType as ArrowDataType, Field, Fields};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
-use datafusion_comet_shuffle::spark_unsafe::row::{
- process_sorted_row_partition, SparkUnsafeObject, SparkUnsafeRow,
-};
+use
datafusion_comet_shuffle::spark_unsafe::row::{process_sorted_row_partition,
SparkUnsafeRow};
+use datafusion_comet_shuffle::spark_unsafe::unsafe_object::SparkUnsafeObject;
use datafusion_comet_shuffle::CompressionCodec;
use std::sync::Arc;
use tempfile::Builder;
diff --git a/native/shuffle/src/comet_partitioning.rs
b/native/shuffle/src/comet_partitioning.rs
index c269539a6..15912e648 100644
--- a/native/shuffle/src/comet_partitioning.rs
+++ b/native/shuffle/src/comet_partitioning.rs
@@ -19,6 +19,7 @@ use arrow::row::{OwnedRow, RowConverter};
use datafusion::physical_expr::{LexOrdering, PhysicalExpr};
use std::sync::Arc;
+/// Partitioning scheme for distributing rows across shuffle output partitions.
#[derive(Debug, Clone)]
pub enum CometPartitioning {
SinglePartition,
diff --git a/native/shuffle/src/ipc.rs b/native/shuffle/src/ipc.rs
new file mode 100644
index 000000000..81ee41332
--- /dev/null
+++ b/native/shuffle/src/ipc.rs
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::RecordBatch;
+use arrow::ipc::reader::StreamReader;
+use datafusion::common::DataFusionError;
+use datafusion::error::Result;
+
+pub fn read_ipc_compressed(bytes: &[u8]) -> Result<RecordBatch> {
+ match &bytes[0..4] {
+ b"SNAP" => {
+ let decoder = snap::read::FrameDecoder::new(&bytes[4..]);
+ let mut reader =
+ unsafe { StreamReader::try_new(decoder,
None)?.with_skip_validation(true) };
+ reader.next().unwrap().map_err(|e| e.into())
+ }
+ b"LZ4_" => {
+ let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]);
+ let mut reader =
+ unsafe { StreamReader::try_new(decoder,
None)?.with_skip_validation(true) };
+ reader.next().unwrap().map_err(|e| e.into())
+ }
+ b"ZSTD" => {
+ let decoder = zstd::Decoder::new(&bytes[4..])?;
+ let mut reader =
+ unsafe { StreamReader::try_new(decoder,
None)?.with_skip_validation(true) };
+ reader.next().unwrap().map_err(|e| e.into())
+ }
+ b"NONE" => {
+ let mut reader =
+ unsafe { StreamReader::try_new(&bytes[4..],
None)?.with_skip_validation(true) };
+ reader.next().unwrap().map_err(|e| e.into())
+ }
+ other => Err(DataFusionError::Execution(format!(
+ "Failed to decode batch: invalid compression codec: {other:?}"
+ ))),
+ }
+}
diff --git a/native/shuffle/src/lib.rs b/native/shuffle/src/lib.rs
index 7c2fc8403..f29588f2e 100644
--- a/native/shuffle/src/lib.rs
+++ b/native/shuffle/src/lib.rs
@@ -15,14 +15,15 @@
// specific language governing permissions and limitations
// under the License.
-pub mod codec;
pub(crate) mod comet_partitioning;
+pub mod ipc;
pub(crate) mod metrics;
pub(crate) mod partitioners;
mod shuffle_writer;
pub mod spark_unsafe;
pub(crate) mod writers;
-pub use codec::{read_ipc_compressed, CompressionCodec, ShuffleBlockWriter};
pub use comet_partitioning::CometPartitioning;
+pub use ipc::read_ipc_compressed;
pub use shuffle_writer::ShuffleWriterExec;
+pub use writers::{CompressionCodec, ShuffleBlockWriter};
diff --git a/native/shuffle/src/metrics.rs b/native/shuffle/src/metrics.rs
index 1aba4677d..1de751cf4 100644
--- a/native/shuffle/src/metrics.rs
+++ b/native/shuffle/src/metrics.rs
@@ -19,6 +19,7 @@ use datafusion::physical_plan::metrics::{
BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, Time,
};
+/// Execution metrics for a shuffle partition operation.
pub(crate) struct ShufflePartitionerMetrics {
/// metrics
pub(crate) baseline: BaselineMetrics,
diff --git a/native/shuffle/src/partitioners/mod.rs
b/native/shuffle/src/partitioners/mod.rs
index a6d589677..3eedef62c 100644
--- a/native/shuffle/src/partitioners/mod.rs
+++ b/native/shuffle/src/partitioners/mod.rs
@@ -18,18 +18,9 @@
mod multi_partition;
mod partitioned_batch_iterator;
mod single_partition;
-
-use arrow::record_batch::RecordBatch;
-use datafusion::common::Result;
+mod traits;
pub(crate) use multi_partition::MultiPartitionShuffleRepartitioner;
pub(crate) use partitioned_batch_iterator::PartitionedBatchIterator;
pub(crate) use single_partition::SinglePartitionShufflePartitioner;
-
-#[async_trait::async_trait]
-pub(crate) trait ShufflePartitioner: Send + Sync {
- /// Insert a batch into the partitioner
- async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()>;
- /// Write shuffle data and shuffle index file to disk
- fn shuffle_write(&mut self) -> Result<()>;
-}
+pub(crate) use traits::ShufflePartitioner;
diff --git a/native/shuffle/src/partitioners/multi_partition.rs
b/native/shuffle/src/partitioners/multi_partition.rs
index 42290c551..655bee351 100644
--- a/native/shuffle/src/partitioners/multi_partition.rs
+++ b/native/shuffle/src/partitioners/multi_partition.rs
@@ -39,6 +39,7 @@ use std::io::{BufReader, BufWriter, Seek, Write};
use std::sync::Arc;
use tokio::time::Instant;
+/// Reusable scratch buffers for computing row-to-partition assignments.
#[derive(Default)]
struct ScratchSpace {
/// Hashes for each row in the current batch.
diff --git a/native/shuffle/src/partitioners/partitioned_batch_iterator.rs
b/native/shuffle/src/partitioners/partitioned_batch_iterator.rs
index 77010938c..8309a8ed4 100644
--- a/native/shuffle/src/partitioners/partitioned_batch_iterator.rs
+++ b/native/shuffle/src/partitioners/partitioned_batch_iterator.rs
@@ -50,6 +50,7 @@ impl PartitionedBatchesProducer {
}
}
+/// Iterates over the shuffled record batches belonging to a single output
partition.
pub(crate) struct PartitionedBatchIterator<'a> {
record_batches: Vec<&'a RecordBatch>,
batch_size: usize,
diff --git a/native/shuffle/src/partitioners/mod.rs
b/native/shuffle/src/partitioners/traits.rs
similarity index 80%
copy from native/shuffle/src/partitioners/mod.rs
copy to native/shuffle/src/partitioners/traits.rs
index a6d589677..9572b70db 100644
--- a/native/shuffle/src/partitioners/mod.rs
+++ b/native/shuffle/src/partitioners/traits.rs
@@ -15,17 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-mod multi_partition;
-mod partitioned_batch_iterator;
-mod single_partition;
-
use arrow::record_batch::RecordBatch;
use datafusion::common::Result;
-pub(crate) use multi_partition::MultiPartitionShuffleRepartitioner;
-pub(crate) use partitioned_batch_iterator::PartitionedBatchIterator;
-pub(crate) use single_partition::SinglePartitionShufflePartitioner;
-
#[async_trait::async_trait]
pub(crate) trait ShufflePartitioner: Send + Sync {
/// Insert a batch into the partitioner
diff --git a/native/shuffle/src/spark_unsafe/list.rs
b/native/shuffle/src/spark_unsafe/list.rs
index 4eb293895..3fea3fade 100644
--- a/native/shuffle/src/spark_unsafe/list.rs
+++ b/native/shuffle/src/spark_unsafe/list.rs
@@ -17,10 +17,8 @@
use crate::spark_unsafe::{
map::append_map_elements,
- row::{
- append_field, downcast_builder_ref, impl_primitive_accessors,
SparkUnsafeObject,
- SparkUnsafeRow,
- },
+ row::{append_field, downcast_builder_ref, SparkUnsafeRow},
+ unsafe_object::{impl_primitive_accessors, SparkUnsafeObject},
};
use arrow::array::{
builder::{
@@ -86,6 +84,7 @@ macro_rules! impl_append_to_builder {
};
}
+/// A Spark `UnsafeArray` backed by JVM-allocated memory, providing element
access by index.
pub struct SparkUnsafeArray {
row_addr: i64,
num_elements: usize,
diff --git a/native/shuffle/src/spark_unsafe/map.rs
b/native/shuffle/src/spark_unsafe/map.rs
index 57444cee7..026e6f71d 100644
--- a/native/shuffle/src/spark_unsafe/map.rs
+++ b/native/shuffle/src/spark_unsafe/map.rs
@@ -20,6 +20,7 @@ use arrow::array::builder::{ArrayBuilder, MapBuilder,
MapFieldNames};
use arrow::datatypes::{DataType, FieldRef};
use datafusion_comet_jni_bridge::errors::CometError;
+/// A Spark `UnsafeMap` backed by JVM-allocated memory, containing parallel
keys and values arrays.
pub struct SparkUnsafeMap {
pub(crate) keys: SparkUnsafeArray,
pub(crate) values: SparkUnsafeArray,
diff --git a/native/shuffle/src/spark_unsafe/mod.rs
b/native/shuffle/src/spark_unsafe/mod.rs
index 6390a0f23..abda69a08 100644
--- a/native/shuffle/src/spark_unsafe/mod.rs
+++ b/native/shuffle/src/spark_unsafe/mod.rs
@@ -18,3 +18,4 @@
pub mod list;
mod map;
pub mod row;
+pub mod unsafe_object;
diff --git a/native/shuffle/src/spark_unsafe/row.rs
b/native/shuffle/src/spark_unsafe/row.rs
index da980af8f..3c9867719 100644
--- a/native/shuffle/src/spark_unsafe/row.rs
+++ b/native/shuffle/src/spark_unsafe/row.rs
@@ -17,11 +17,13 @@
//! Utils for supporting native sort-based columnar shuffle.
-use crate::codec::{Checksum, ShuffleBlockWriter};
+use crate::spark_unsafe::unsafe_object::{impl_primitive_accessors,
SparkUnsafeObject};
use crate::spark_unsafe::{
- list::{append_list_element, SparkUnsafeArray},
- map::{append_map_elements, get_map_key_value_fields, SparkUnsafeMap},
+ list::append_list_element,
+ map::{append_map_elements, get_map_key_value_fields},
};
+use crate::writers::Checksum;
+use crate::writers::ShuffleBlockWriter;
use arrow::array::{
builder::{
ArrayBuilder, BinaryBuilder, BinaryDictionaryBuilder, BooleanBuilder,
Date32Builder,
@@ -36,219 +38,17 @@ use arrow::compute::cast;
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use arrow::error::ArrowError;
use datafusion::physical_plan::metrics::Time;
-use datafusion_comet_common::bytes_to_i128;
use datafusion_comet_jni_bridge::errors::CometError;
use jni::sys::{jint, jlong};
use std::{
fs::OpenOptions,
io::{Cursor, Write},
- str::from_utf8,
sync::Arc,
};
-const MAX_LONG_DIGITS: u8 = 18;
const NESTED_TYPE_BUILDER_CAPACITY: usize = 100;
-/// A common trait for Spark Unsafe classes that can be used to access the
underlying data,
-/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that
can be used to
-/// access the underlying data with index.
-///
-/// # Safety
-///
-/// Implementations must ensure that:
-/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory
-/// - `get_element_offset()` returns a valid pointer within the row/array data
region
-/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format
-/// - The memory remains valid for the lifetime of the object (guaranteed by
JVM ownership)
-///
-/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer
operations but are
-/// safe to call as long as:
-/// - The index is within bounds (caller's responsibility)
-/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data
-///
-/// # Alignment
-///
-/// Primitive accessor methods are implemented separately for each type
because they have
-/// different alignment guarantees:
-/// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is
a multiple of 8,
-/// and each field slot is 8 bytes), so accessors use aligned `ptr::read()`.
-/// - `SparkUnsafeArray`: The array base address may be unaligned when nested
within a row's
-/// variable-length region, so accessors use `ptr::read_unaligned()`.
-pub trait SparkUnsafeObject {
- /// Returns the address of the row.
- fn get_row_addr(&self) -> i64;
-
- /// Returns the offset of the element at the given index.
- fn get_element_offset(&self, index: usize, element_size: usize) -> *const
u8;
-
- fn get_boolean(&self, index: usize) -> bool;
- fn get_byte(&self, index: usize) -> i8;
- fn get_short(&self, index: usize) -> i16;
- fn get_int(&self, index: usize) -> i32;
- fn get_long(&self, index: usize) -> i64;
- fn get_float(&self, index: usize) -> f32;
- fn get_double(&self, index: usize) -> f64;
- fn get_date(&self, index: usize) -> i32;
- fn get_timestamp(&self, index: usize) -> i64;
-
- /// Returns the offset and length of the element at the given index.
- #[inline]
- fn get_offset_and_len(&self, index: usize) -> (i32, i32) {
- let offset_and_size = self.get_long(index);
- let offset = (offset_and_size >> 32) as i32;
- let len = offset_and_size as i32;
- (offset, len)
- }
-
- /// Returns string value at the given index of the object.
- fn get_string(&self, index: usize) -> &str {
- let (offset, len) = self.get_offset_and_len(index);
- let addr = self.get_row_addr() + offset as i64;
- // SAFETY: addr points to valid UTF-8 string data within the
variable-length region.
- // Offset and length are read from the fixed-length portion of the
row/array.
- debug_assert!(addr != 0, "get_string: null address at index {index}");
- debug_assert!(
- len >= 0,
- "get_string: negative length {len} at index {index}"
- );
- let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const
u8, len as usize) };
-
- from_utf8(slice).unwrap()
- }
-
- /// Returns binary value at the given index of the object.
- fn get_binary(&self, index: usize) -> &[u8] {
- let (offset, len) = self.get_offset_and_len(index);
- let addr = self.get_row_addr() + offset as i64;
- // SAFETY: addr points to valid binary data within the variable-length
region.
- // Offset and length are read from the fixed-length portion of the
row/array.
- debug_assert!(addr != 0, "get_binary: null address at index {index}");
- debug_assert!(
- len >= 0,
- "get_binary: negative length {len} at index {index}"
- );
- unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
- }
-
- /// Returns decimal value at the given index of the object.
- fn get_decimal(&self, index: usize, precision: u8) -> i128 {
- if precision <= MAX_LONG_DIGITS {
- self.get_long(index) as i128
- } else {
- let slice = self.get_binary(index);
- bytes_to_i128(slice)
- }
- }
-
- /// Returns struct value at the given index of the object.
- fn get_struct(&self, index: usize, num_fields: usize) -> SparkUnsafeRow {
- let (offset, len) = self.get_offset_and_len(index);
- let mut row = SparkUnsafeRow::new_with_num_fields(num_fields);
- row.point_to(self.get_row_addr() + offset as i64, len);
-
- row
- }
-
- /// Returns array value at the given index of the object.
- fn get_array(&self, index: usize) -> SparkUnsafeArray {
- let (offset, _) = self.get_offset_and_len(index);
- SparkUnsafeArray::new(self.get_row_addr() + offset as i64)
- }
-
- fn get_map(&self, index: usize) -> SparkUnsafeMap {
- let (offset, len) = self.get_offset_and_len(index);
- SparkUnsafeMap::new(self.get_row_addr() + offset as i64, len)
- }
-}
-
-/// Generates primitive accessor implementations for `SparkUnsafeObject`.
-///
-/// Uses `$read_method` to read typed values from raw pointers:
-/// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte
aligned)
-/// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray)
-macro_rules! impl_primitive_accessors {
- ($read_method:ident) => {
- #[inline]
- fn get_boolean(&self, index: usize) -> bool {
- let addr = self.get_element_offset(index, 1);
- debug_assert!(
- !addr.is_null(),
- "get_boolean: null pointer at index {index}"
- );
- // SAFETY: addr points to valid element data within the row/array
region.
- unsafe { *addr != 0 }
- }
-
- #[inline]
- fn get_byte(&self, index: usize) -> i8 {
- let addr = self.get_element_offset(index, 1);
- debug_assert!(!addr.is_null(), "get_byte: null pointer at index
{index}");
- // SAFETY: addr points to valid element data (1 byte) within the
row/array region.
- unsafe { *(addr as *const i8) }
- }
-
- #[inline]
- fn get_short(&self, index: usize) -> i16 {
- let addr = self.get_element_offset(index, 2) as *const i16;
- debug_assert!(!addr.is_null(), "get_short: null pointer at index
{index}");
- // SAFETY: addr points to valid element data (2 bytes) within the
row/array region.
- unsafe { addr.$read_method() }
- }
-
- #[inline]
- fn get_int(&self, index: usize) -> i32 {
- let addr = self.get_element_offset(index, 4) as *const i32;
- debug_assert!(!addr.is_null(), "get_int: null pointer at index
{index}");
- // SAFETY: addr points to valid element data (4 bytes) within the
row/array region.
- unsafe { addr.$read_method() }
- }
-
- #[inline]
- fn get_long(&self, index: usize) -> i64 {
- let addr = self.get_element_offset(index, 8) as *const i64;
- debug_assert!(!addr.is_null(), "get_long: null pointer at index
{index}");
- // SAFETY: addr points to valid element data (8 bytes) within the
row/array region.
- unsafe { addr.$read_method() }
- }
-
- #[inline]
- fn get_float(&self, index: usize) -> f32 {
- let addr = self.get_element_offset(index, 4) as *const f32;
- debug_assert!(!addr.is_null(), "get_float: null pointer at index
{index}");
- // SAFETY: addr points to valid element data (4 bytes) within the
row/array region.
- unsafe { addr.$read_method() }
- }
-
- #[inline]
- fn get_double(&self, index: usize) -> f64 {
- let addr = self.get_element_offset(index, 8) as *const f64;
- debug_assert!(!addr.is_null(), "get_double: null pointer at index
{index}");
- // SAFETY: addr points to valid element data (8 bytes) within the
row/array region.
- unsafe { addr.$read_method() }
- }
-
- #[inline]
- fn get_date(&self, index: usize) -> i32 {
- let addr = self.get_element_offset(index, 4) as *const i32;
- debug_assert!(!addr.is_null(), "get_date: null pointer at index
{index}");
- // SAFETY: addr points to valid element data (4 bytes) within the
row/array region.
- unsafe { addr.$read_method() }
- }
-
- #[inline]
- fn get_timestamp(&self, index: usize) -> i64 {
- let addr = self.get_element_offset(index, 8) as *const i64;
- debug_assert!(
- !addr.is_null(),
- "get_timestamp: null pointer at index {index}"
- );
- // SAFETY: addr points to valid element data (8 bytes) within the
row/array region.
- unsafe { addr.$read_method() }
- }
- };
-}
-pub(crate) use impl_primitive_accessors;
-
+/// A Spark `UnsafeRow` backed by JVM-allocated memory, providing field access
by index.
pub struct SparkUnsafeRow {
row_addr: i64,
row_size: i32,
@@ -323,7 +123,7 @@ impl SparkUnsafeRow {
}
/// Points the row to the given address with specified row size.
- fn point_to(&mut self, row_addr: i64, row_size: i32) {
+ pub(crate) fn point_to(&mut self, row_addr: i64, row_size: i32) {
self.row_addr = row_addr;
self.row_size = row_size;
}
diff --git a/native/shuffle/src/spark_unsafe/unsafe_object.rs
b/native/shuffle/src/spark_unsafe/unsafe_object.rs
new file mode 100644
index 000000000..f32ea8c23
--- /dev/null
+++ b/native/shuffle/src/spark_unsafe/unsafe_object.rs
@@ -0,0 +1,224 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::list::SparkUnsafeArray;
+use super::map::SparkUnsafeMap;
+use super::row::SparkUnsafeRow;
+use datafusion_comet_common::bytes_to_i128;
+use std::str::from_utf8;
+
+const MAX_LONG_DIGITS: u8 = 18;
+
+/// A common trait for Spark Unsafe classes that can be used to access the
underlying data,
+/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that
can be used to
+/// access the underlying data with index.
+///
+/// # Safety
+///
+/// Implementations must ensure that:
+/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory
+/// - `get_element_offset()` returns a valid pointer within the row/array data
region
+/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format
+/// - The memory remains valid for the lifetime of the object (guaranteed by
JVM ownership)
+///
+/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer
operations but are
+/// safe to call as long as:
+/// - The index is within bounds (caller's responsibility)
+/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data
+///
+/// # Alignment
+///
+/// Primitive accessor methods are implemented separately for each type
because they have
+/// different alignment guarantees:
+/// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is
a multiple of 8,
+/// and each field slot is 8 bytes), so accessors use aligned `ptr::read()`.
+/// - `SparkUnsafeArray`: The array base address may be unaligned when nested
within a row's
+/// variable-length region, so accessors use `ptr::read_unaligned()`.
+pub trait SparkUnsafeObject {
+ /// Returns the address of the row.
+ fn get_row_addr(&self) -> i64;
+
+ /// Returns the offset of the element at the given index.
+ fn get_element_offset(&self, index: usize, element_size: usize) -> *const
u8;
+
+ fn get_boolean(&self, index: usize) -> bool;
+ fn get_byte(&self, index: usize) -> i8;
+ fn get_short(&self, index: usize) -> i16;
+ fn get_int(&self, index: usize) -> i32;
+ fn get_long(&self, index: usize) -> i64;
+ fn get_float(&self, index: usize) -> f32;
+ fn get_double(&self, index: usize) -> f64;
+ fn get_date(&self, index: usize) -> i32;
+ fn get_timestamp(&self, index: usize) -> i64;
+
+ /// Returns the offset and length of the element at the given index.
+ #[inline]
+ fn get_offset_and_len(&self, index: usize) -> (i32, i32) {
+ let offset_and_size = self.get_long(index);
+ let offset = (offset_and_size >> 32) as i32;
+ let len = offset_and_size as i32;
+ (offset, len)
+ }
+
+ /// Returns string value at the given index of the object.
+ fn get_string(&self, index: usize) -> &str {
+ let (offset, len) = self.get_offset_and_len(index);
+ let addr = self.get_row_addr() + offset as i64;
+ // SAFETY: addr points to valid UTF-8 string data within the
variable-length region.
+ // Offset and length are read from the fixed-length portion of the
row/array.
+ debug_assert!(addr != 0, "get_string: null address at index {index}");
+ debug_assert!(
+ len >= 0,
+ "get_string: negative length {len} at index {index}"
+ );
+ let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const
u8, len as usize) };
+
+ from_utf8(slice).unwrap()
+ }
+
+ /// Returns binary value at the given index of the object.
+ fn get_binary(&self, index: usize) -> &[u8] {
+ let (offset, len) = self.get_offset_and_len(index);
+ let addr = self.get_row_addr() + offset as i64;
+ // SAFETY: addr points to valid binary data within the variable-length
region.
+ // Offset and length are read from the fixed-length portion of the
row/array.
+ debug_assert!(addr != 0, "get_binary: null address at index {index}");
+ debug_assert!(
+ len >= 0,
+ "get_binary: negative length {len} at index {index}"
+ );
+ unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
+ }
+
+ /// Returns decimal value at the given index of the object.
+ fn get_decimal(&self, index: usize, precision: u8) -> i128 {
+ if precision <= MAX_LONG_DIGITS {
+ self.get_long(index) as i128
+ } else {
+ let slice = self.get_binary(index);
+ bytes_to_i128(slice)
+ }
+ }
+
+ /// Returns struct value at the given index of the object.
+ fn get_struct(&self, index: usize, num_fields: usize) -> SparkUnsafeRow {
+ let (offset, len) = self.get_offset_and_len(index);
+ let mut row = SparkUnsafeRow::new_with_num_fields(num_fields);
+ row.point_to(self.get_row_addr() + offset as i64, len);
+
+ row
+ }
+
+ /// Returns array value at the given index of the object.
+ fn get_array(&self, index: usize) -> SparkUnsafeArray {
+ let (offset, _) = self.get_offset_and_len(index);
+ SparkUnsafeArray::new(self.get_row_addr() + offset as i64)
+ }
+
+ fn get_map(&self, index: usize) -> SparkUnsafeMap {
+ let (offset, len) = self.get_offset_and_len(index);
+ SparkUnsafeMap::new(self.get_row_addr() + offset as i64, len)
+ }
+}
+
+/// Generates primitive accessor implementations for `SparkUnsafeObject`.
+///
+/// Uses `$read_method` to read typed values from raw pointers:
+/// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte
aligned)
+/// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray)
+macro_rules! impl_primitive_accessors {
+ ($read_method:ident) => {
+ #[inline]
+ fn get_boolean(&self, index: usize) -> bool {
+ let addr = self.get_element_offset(index, 1);
+ debug_assert!(
+ !addr.is_null(),
+ "get_boolean: null pointer at index {index}"
+ );
+ // SAFETY: addr points to valid element data within the row/array
region.
+ unsafe { *addr != 0 }
+ }
+
+ #[inline]
+ fn get_byte(&self, index: usize) -> i8 {
+ let addr = self.get_element_offset(index, 1);
+ debug_assert!(!addr.is_null(), "get_byte: null pointer at index
{index}");
+ // SAFETY: addr points to valid element data (1 byte) within the
row/array region.
+ unsafe { *(addr as *const i8) }
+ }
+
+ #[inline]
+ fn get_short(&self, index: usize) -> i16 {
+ let addr = self.get_element_offset(index, 2) as *const i16;
+ debug_assert!(!addr.is_null(), "get_short: null pointer at index
{index}");
+ // SAFETY: addr points to valid element data (2 bytes) within the
row/array region.
+ unsafe { addr.$read_method() }
+ }
+
+ #[inline]
+ fn get_int(&self, index: usize) -> i32 {
+ let addr = self.get_element_offset(index, 4) as *const i32;
+ debug_assert!(!addr.is_null(), "get_int: null pointer at index
{index}");
+ // SAFETY: addr points to valid element data (4 bytes) within the
row/array region.
+ unsafe { addr.$read_method() }
+ }
+
+ #[inline]
+ fn get_long(&self, index: usize) -> i64 {
+ let addr = self.get_element_offset(index, 8) as *const i64;
+ debug_assert!(!addr.is_null(), "get_long: null pointer at index
{index}");
+ // SAFETY: addr points to valid element data (8 bytes) within the
row/array region.
+ unsafe { addr.$read_method() }
+ }
+
+ #[inline]
+ fn get_float(&self, index: usize) -> f32 {
+ let addr = self.get_element_offset(index, 4) as *const f32;
+ debug_assert!(!addr.is_null(), "get_float: null pointer at index
{index}");
+ // SAFETY: addr points to valid element data (4 bytes) within the
row/array region.
+ unsafe { addr.$read_method() }
+ }
+
+ #[inline]
+ fn get_double(&self, index: usize) -> f64 {
+ let addr = self.get_element_offset(index, 8) as *const f64;
+ debug_assert!(!addr.is_null(), "get_double: null pointer at index
{index}");
+ // SAFETY: addr points to valid element data (8 bytes) within the
row/array region.
+ unsafe { addr.$read_method() }
+ }
+
+ #[inline]
+ fn get_date(&self, index: usize) -> i32 {
+ let addr = self.get_element_offset(index, 4) as *const i32;
+ debug_assert!(!addr.is_null(), "get_date: null pointer at index
{index}");
+ // SAFETY: addr points to valid element data (4 bytes) within the
row/array region.
+ unsafe { addr.$read_method() }
+ }
+
+ #[inline]
+ fn get_timestamp(&self, index: usize) -> i64 {
+ let addr = self.get_element_offset(index, 8) as *const i64;
+ debug_assert!(
+ !addr.is_null(),
+ "get_timestamp: null pointer at index {index}"
+ );
+ // SAFETY: addr points to valid element data (8 bytes) within the
row/array region.
+ unsafe { addr.$read_method() }
+ }
+ };
+}
+pub(crate) use impl_primitive_accessors;
diff --git a/native/shuffle/src/writers/buf_batch_writer.rs
b/native/shuffle/src/writers/buf_batch_writer.rs
index 6344a8e5f..cfddb4653 100644
--- a/native/shuffle/src/writers/buf_batch_writer.rs
+++ b/native/shuffle/src/writers/buf_batch_writer.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use crate::ShuffleBlockWriter;
+use super::ShuffleBlockWriter;
use arrow::array::RecordBatch;
use arrow::compute::kernels::coalesce::BatchCoalescer;
use datafusion::physical_plan::metrics::Time;
diff --git a/native/shuffle/src/writers/checksum.rs
b/native/shuffle/src/writers/checksum.rs
new file mode 100644
index 000000000..b240302e6
--- /dev/null
+++ b/native/shuffle/src/writers/checksum.rs
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use bytes::Buf;
+use crc32fast::Hasher;
+use datafusion_comet_jni_bridge::errors::{CometError, CometResult};
+use simd_adler32::Adler32;
+use std::io::{Cursor, SeekFrom};
+
+/// Checksum algorithms for writing IPC bytes.
+#[derive(Clone)]
+pub(crate) enum Checksum {
+ /// CRC32 checksum algorithm.
+ CRC32(Hasher),
+ /// Adler32 checksum algorithm.
+ Adler32(Adler32),
+}
+
+impl Checksum {
+ pub(crate) fn try_new(algo: i32, initial_opt: Option<u32>) ->
CometResult<Self> {
+ match algo {
+ 0 => {
+ let hasher = if let Some(initial) = initial_opt {
+ Hasher::new_with_initial(initial)
+ } else {
+ Hasher::new()
+ };
+ Ok(Checksum::CRC32(hasher))
+ }
+ 1 => {
+ let hasher = if let Some(initial) = initial_opt {
+ // Note that Adler32 initial state is not zero.
+ // i.e., `Adler32::from_checksum(0)` is not the same as
`Adler32::new()`.
+ Adler32::from_checksum(initial)
+ } else {
+ Adler32::new()
+ };
+ Ok(Checksum::Adler32(hasher))
+ }
+ _ => Err(CometError::Internal(
+ "Unsupported checksum algorithm".to_string(),
+ )),
+ }
+ }
+
+ pub(crate) fn update(&mut self, cursor: &mut Cursor<&mut Vec<u8>>) ->
CometResult<()> {
+ match self {
+ Checksum::CRC32(hasher) => {
+ std::io::Seek::seek(cursor, SeekFrom::Start(0))?;
+ hasher.update(cursor.chunk());
+ Ok(())
+ }
+ Checksum::Adler32(hasher) => {
+ std::io::Seek::seek(cursor, SeekFrom::Start(0))?;
+ hasher.write(cursor.chunk());
+ Ok(())
+ }
+ }
+ }
+
+ pub(crate) fn finalize(self) -> u32 {
+ match self {
+ Checksum::CRC32(hasher) => hasher.finalize(),
+ Checksum::Adler32(hasher) => hasher.finish(),
+ }
+ }
+}
diff --git a/native/shuffle/src/writers/mod.rs
b/native/shuffle/src/writers/mod.rs
index b58989e46..75caf9f3a 100644
--- a/native/shuffle/src/writers/mod.rs
+++ b/native/shuffle/src/writers/mod.rs
@@ -16,7 +16,11 @@
// under the License.
mod buf_batch_writer;
-mod partition_writer;
+mod checksum;
+mod shuffle_block_writer;
+mod spill;
pub(crate) use buf_batch_writer::BufBatchWriter;
-pub(crate) use partition_writer::PartitionWriter;
+pub(crate) use checksum::Checksum;
+pub use shuffle_block_writer::{CompressionCodec, ShuffleBlockWriter};
+pub(crate) use spill::PartitionWriter;
diff --git a/native/shuffle/src/codec.rs
b/native/shuffle/src/writers/shuffle_block_writer.rs
similarity index 60%
rename from native/shuffle/src/codec.rs
rename to native/shuffle/src/writers/shuffle_block_writer.rs
index c8edc2468..5ed5330e3 100644
--- a/native/shuffle/src/codec.rs
+++ b/native/shuffle/src/writers/shuffle_block_writer.rs
@@ -17,17 +17,13 @@
use arrow::array::RecordBatch;
use arrow::datatypes::Schema;
-use arrow::ipc::reader::StreamReader;
use arrow::ipc::writer::StreamWriter;
-use bytes::Buf;
-use crc32fast::Hasher;
use datafusion::common::DataFusionError;
use datafusion::error::Result;
use datafusion::physical_plan::metrics::Time;
-use datafusion_comet_jni_bridge::errors::{CometError, CometResult};
-use simd_adler32::Adler32;
use std::io::{Cursor, Seek, SeekFrom, Write};
+/// Compression algorithm applied to shuffle IPC blocks.
#[derive(Debug, Clone)]
pub enum CompressionCodec {
None,
@@ -36,6 +32,7 @@ pub enum CompressionCodec {
Snappy,
}
+/// Writes a record batch as a length-prefixed, compressed Arrow IPC block.
#[derive(Clone)]
pub struct ShuffleBlockWriter {
codec: CompressionCodec,
@@ -147,93 +144,3 @@ impl ShuffleBlockWriter {
Ok((end_pos - start_pos) as usize)
}
}
-
-pub fn read_ipc_compressed(bytes: &[u8]) -> Result<RecordBatch> {
- match &bytes[0..4] {
- b"SNAP" => {
- let decoder = snap::read::FrameDecoder::new(&bytes[4..]);
- let mut reader =
- unsafe { StreamReader::try_new(decoder,
None)?.with_skip_validation(true) };
- reader.next().unwrap().map_err(|e| e.into())
- }
- b"LZ4_" => {
- let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]);
- let mut reader =
- unsafe { StreamReader::try_new(decoder,
None)?.with_skip_validation(true) };
- reader.next().unwrap().map_err(|e| e.into())
- }
- b"ZSTD" => {
- let decoder = zstd::Decoder::new(&bytes[4..])?;
- let mut reader =
- unsafe { StreamReader::try_new(decoder,
None)?.with_skip_validation(true) };
- reader.next().unwrap().map_err(|e| e.into())
- }
- b"NONE" => {
- let mut reader =
- unsafe { StreamReader::try_new(&bytes[4..],
None)?.with_skip_validation(true) };
- reader.next().unwrap().map_err(|e| e.into())
- }
- other => Err(DataFusionError::Execution(format!(
- "Failed to decode batch: invalid compression codec: {other:?}"
- ))),
- }
-}
-
-/// Checksum algorithms for writing IPC bytes.
-#[derive(Clone)]
-pub(crate) enum Checksum {
- /// CRC32 checksum algorithm.
- CRC32(Hasher),
- /// Adler32 checksum algorithm.
- Adler32(Adler32),
-}
-
-impl Checksum {
- pub(crate) fn try_new(algo: i32, initial_opt: Option<u32>) ->
CometResult<Self> {
- match algo {
- 0 => {
- let hasher = if let Some(initial) = initial_opt {
- Hasher::new_with_initial(initial)
- } else {
- Hasher::new()
- };
- Ok(Checksum::CRC32(hasher))
- }
- 1 => {
- let hasher = if let Some(initial) = initial_opt {
- // Note that Adler32 initial state is not zero.
- // i.e., `Adler32::from_checksum(0)` is not the same as
`Adler32::new()`.
- Adler32::from_checksum(initial)
- } else {
- Adler32::new()
- };
- Ok(Checksum::Adler32(hasher))
- }
- _ => Err(CometError::Internal(
- "Unsupported checksum algorithm".to_string(),
- )),
- }
- }
-
- pub(crate) fn update(&mut self, cursor: &mut Cursor<&mut Vec<u8>>) ->
CometResult<()> {
- match self {
- Checksum::CRC32(hasher) => {
- std::io::Seek::seek(cursor, SeekFrom::Start(0))?;
- hasher.update(cursor.chunk());
- Ok(())
- }
- Checksum::Adler32(hasher) => {
- std::io::Seek::seek(cursor, SeekFrom::Start(0))?;
- hasher.write(cursor.chunk());
- Ok(())
- }
- }
- }
-
- pub(crate) fn finalize(self) -> u32 {
- match self {
- Checksum::CRC32(hasher) => hasher.finalize(),
- Checksum::Adler32(hasher) => hasher.finish(),
- }
- }
-}
diff --git a/native/shuffle/src/writers/partition_writer.rs
b/native/shuffle/src/writers/spill.rs
similarity index 95%
rename from native/shuffle/src/writers/partition_writer.rs
rename to native/shuffle/src/writers/spill.rs
index 48017871d..c16caddbf 100644
--- a/native/shuffle/src/writers/partition_writer.rs
+++ b/native/shuffle/src/writers/spill.rs
@@ -15,20 +15,22 @@
// specific language governing permissions and limitations
// under the License.
+use super::ShuffleBlockWriter;
use crate::metrics::ShufflePartitionerMetrics;
use crate::partitioners::PartitionedBatchIterator;
use crate::writers::buf_batch_writer::BufBatchWriter;
-use crate::ShuffleBlockWriter;
use datafusion::common::DataFusionError;
use datafusion::execution::disk_manager::RefCountedTempFile;
use datafusion::execution::runtime_env::RuntimeEnv;
use std::fs::{File, OpenOptions};
+/// A temporary disk file for spilling a partition's intermediate shuffle data.
struct SpillFile {
temp_file: RefCountedTempFile,
file: File,
}
+/// Manages encoding and optional disk spilling for a single shuffle partition.
pub(crate) struct PartitionWriter {
/// Spill file for intermediate shuffle output for this partition. Each
spill event
/// will append to this file and the contents will be copied to the
shuffle file at
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]