This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6038f4cfac Track parquet writer encoding memory usage on MemoryPool 
(#11345)
6038f4cfac is described below

commit 6038f4cfac536dbb54ea2761828f7344a23b94f0
Author: wiedld <[email protected]>
AuthorDate: Wed Jul 10 11:21:01 2024 -0700

    Track parquet writer encoding memory usage on MemoryPool (#11345)
    
    * feat(11344): track memory used for non-parallel writes
    
    * feat(11344): track memory usage during parallel writes
    
    * test(11344): create bounded stream for testing
    
    * test(11344): test ParquetSink memory reservation
    
    * feat(11344): track bytes in file writer
    
    * refactor(11344): tweak the ordering to add col bytes to rg_reservation, 
before selecting shrinking for data bytes flushed
    
    * refactor: move each col_reservation and rg_reservation to match the 
parallelized call stack for col vs rg
    
    * test(11344): add memory_limit enforcement test for parquet sink
    
    * chore: cleanup to remove unnecessary reservation management steps
    
    * fix: fix CI test failure due to file extension rename
---
 .../core/src/datasource/file_format/parquet.rs     | 165 +++++++++++++++++++--
 datafusion/core/src/test_util/mod.rs               |  36 +++++
 datafusion/core/tests/memory_limit/mod.rs          |  25 ++++
 3 files changed, 216 insertions(+), 10 deletions(-)

diff --git a/datafusion/core/src/datasource/file_format/parquet.rs 
b/datafusion/core/src/datasource/file_format/parquet.rs
index 27d783cd89..694c949285 100644
--- a/datafusion/core/src/datasource/file_format/parquet.rs
+++ b/datafusion/core/src/datasource/file_format/parquet.rs
@@ -48,6 +48,7 @@ use datafusion_common::{
     DEFAULT_PARQUET_EXTENSION,
 };
 use datafusion_common_runtime::SpawnedTask;
+use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, 
MemoryReservation};
 use datafusion_execution::TaskContext;
 use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator};
 use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
@@ -749,9 +750,13 @@ impl DataSink for ParquetSink {
                         parquet_props.writer_options().clone(),
                     )
                     .await?;
+                let mut reservation =
+                    MemoryConsumer::new(format!("ParquetSink[{}]", path))
+                        .register(context.memory_pool());
                 file_write_tasks.spawn(async move {
                     while let Some(batch) = rx.recv().await {
                         writer.write(&batch).await?;
+                        reservation.try_resize(writer.memory_size())?;
                     }
                     let file_metadata = writer
                         .close()
@@ -771,6 +776,7 @@ impl DataSink for ParquetSink {
                 let schema = self.get_writer_schema();
                 let props = parquet_props.clone();
                 let parallel_options_clone = parallel_options.clone();
+                let pool = Arc::clone(context.memory_pool());
                 file_write_tasks.spawn(async move {
                     let file_metadata = 
output_single_parquet_file_parallelized(
                         writer,
@@ -778,6 +784,7 @@ impl DataSink for ParquetSink {
                         schema,
                         props.writer_options(),
                         parallel_options_clone,
+                        pool,
                     )
                     .await?;
                     Ok((path, file_metadata))
@@ -818,14 +825,16 @@ impl DataSink for ParquetSink {
 async fn column_serializer_task(
     mut rx: Receiver<ArrowLeafColumn>,
     mut writer: ArrowColumnWriter,
-) -> Result<ArrowColumnWriter> {
+    mut reservation: MemoryReservation,
+) -> Result<(ArrowColumnWriter, MemoryReservation)> {
     while let Some(col) = rx.recv().await {
         writer.write(&col)?;
+        reservation.try_resize(writer.memory_size())?;
     }
-    Ok(writer)
+    Ok((writer, reservation))
 }
 
-type ColumnWriterTask = SpawnedTask<Result<ArrowColumnWriter>>;
+type ColumnWriterTask = SpawnedTask<Result<(ArrowColumnWriter, 
MemoryReservation)>>;
 type ColSender = Sender<ArrowLeafColumn>;
 
 /// Spawns a parallel serialization task for each column
@@ -835,6 +844,7 @@ fn spawn_column_parallel_row_group_writer(
     schema: Arc<Schema>,
     parquet_props: Arc<WriterProperties>,
     max_buffer_size: usize,
+    pool: &Arc<dyn MemoryPool>,
 ) -> Result<(Vec<ColumnWriterTask>, Vec<ColSender>)> {
     let schema_desc = arrow_to_parquet_schema(&schema)?;
     let col_writers = get_column_writers(&schema_desc, &parquet_props, 
&schema)?;
@@ -848,7 +858,13 @@ fn spawn_column_parallel_row_group_writer(
             mpsc::channel::<ArrowLeafColumn>(max_buffer_size);
         col_array_channels.push(send_array);
 
-        let task = SpawnedTask::spawn(column_serializer_task(recieve_array, 
writer));
+        let reservation =
+            
MemoryConsumer::new("ParquetSink(ArrowColumnWriter)").register(pool);
+        let task = SpawnedTask::spawn(column_serializer_task(
+            recieve_array,
+            writer,
+            reservation,
+        ));
         col_writer_tasks.push(task);
     }
 
@@ -864,7 +880,7 @@ struct ParallelParquetWriterOptions {
 
 /// This is the return type of calling [ArrowColumnWriter].close() on each 
column
 /// i.e. the Vec of encoded columns which can be appended to a row group
-type RBStreamSerializeResult = Result<(Vec<ArrowColumnChunk>, usize)>;
+type RBStreamSerializeResult = Result<(Vec<ArrowColumnChunk>, 
MemoryReservation, usize)>;
 
 /// Sends the ArrowArrays in passed [RecordBatch] through the channels to 
their respective
 /// parallel column serializers.
@@ -895,16 +911,22 @@ async fn send_arrays_to_col_writers(
 fn spawn_rg_join_and_finalize_task(
     column_writer_tasks: Vec<ColumnWriterTask>,
     rg_rows: usize,
+    pool: &Arc<dyn MemoryPool>,
 ) -> SpawnedTask<RBStreamSerializeResult> {
+    let mut rg_reservation =
+        
MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool);
+
     SpawnedTask::spawn(async move {
         let num_cols = column_writer_tasks.len();
         let mut finalized_rg = Vec::with_capacity(num_cols);
         for task in column_writer_tasks.into_iter() {
-            let writer = task.join_unwind().await?;
+            let (writer, _col_reservation) = task.join_unwind().await?;
+            let encoded_size = writer.get_estimated_total_bytes();
+            rg_reservation.grow(encoded_size);
             finalized_rg.push(writer.close()?);
         }
 
-        Ok((finalized_rg, rg_rows))
+        Ok((finalized_rg, rg_reservation, rg_rows))
     })
 }
 
@@ -922,6 +944,7 @@ fn spawn_parquet_parallel_serialization_task(
     schema: Arc<Schema>,
     writer_props: Arc<WriterProperties>,
     parallel_options: ParallelParquetWriterOptions,
+    pool: Arc<dyn MemoryPool>,
 ) -> SpawnedTask<Result<(), DataFusionError>> {
     SpawnedTask::spawn(async move {
         let max_buffer_rb = 
parallel_options.max_buffered_record_batches_per_stream;
@@ -931,6 +954,7 @@ fn spawn_parquet_parallel_serialization_task(
                 schema.clone(),
                 writer_props.clone(),
                 max_buffer_rb,
+                &pool,
             )?;
         let mut current_rg_rows = 0;
 
@@ -957,6 +981,7 @@ fn spawn_parquet_parallel_serialization_task(
                     let finalize_rg_task = spawn_rg_join_and_finalize_task(
                         column_writer_handles,
                         max_row_group_rows,
+                        &pool,
                     );
 
                     serialize_tx.send(finalize_rg_task).await.map_err(|_| {
@@ -973,6 +998,7 @@ fn spawn_parquet_parallel_serialization_task(
                             schema.clone(),
                             writer_props.clone(),
                             max_buffer_rb,
+                            &pool,
                         )?;
                 }
             }
@@ -981,8 +1007,11 @@ fn spawn_parquet_parallel_serialization_task(
         drop(col_array_channels);
         // Handle leftover rows as final rowgroup, which may be smaller than 
max_row_group_rows
         if current_rg_rows > 0 {
-            let finalize_rg_task =
-                spawn_rg_join_and_finalize_task(column_writer_handles, 
current_rg_rows);
+            let finalize_rg_task = spawn_rg_join_and_finalize_task(
+                column_writer_handles,
+                current_rg_rows,
+                &pool,
+            );
 
             serialize_tx.send(finalize_rg_task).await.map_err(|_| {
                 DataFusionError::Internal(
@@ -1002,9 +1031,13 @@ async fn concatenate_parallel_row_groups(
     schema: Arc<Schema>,
     writer_props: Arc<WriterProperties>,
     mut object_store_writer: Box<dyn AsyncWrite + Send + Unpin>,
+    pool: Arc<dyn MemoryPool>,
 ) -> Result<FileMetaData> {
     let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES);
 
+    let mut file_reservation =
+        
MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool);
+
     let schema_desc = arrow_to_parquet_schema(schema.as_ref())?;
     let mut parquet_writer = SerializedFileWriter::new(
         merged_buff.clone(),
@@ -1015,15 +1048,20 @@ async fn concatenate_parallel_row_groups(
     while let Some(task) = serialize_rx.recv().await {
         let result = task.join_unwind().await;
         let mut rg_out = parquet_writer.next_row_group()?;
-        let (serialized_columns, _cnt) = result?;
+        let (serialized_columns, mut rg_reservation, _cnt) = result?;
         for chunk in serialized_columns {
             chunk.append_to_row_group(&mut rg_out)?;
+            rg_reservation.free();
+
             let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap();
+            file_reservation.try_resize(buff_to_flush.len())?;
+
             if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
                 object_store_writer
                     .write_all(buff_to_flush.as_slice())
                     .await?;
                 buff_to_flush.clear();
+                file_reservation.try_resize(buff_to_flush.len())?; // will set 
to zero
             }
         }
         rg_out.close()?;
@@ -1034,6 +1072,7 @@ async fn concatenate_parallel_row_groups(
 
     object_store_writer.write_all(final_buff.as_slice()).await?;
     object_store_writer.shutdown().await?;
+    file_reservation.free();
 
     Ok(file_metadata)
 }
@@ -1048,6 +1087,7 @@ async fn output_single_parquet_file_parallelized(
     output_schema: Arc<Schema>,
     parquet_props: &WriterProperties,
     parallel_options: ParallelParquetWriterOptions,
+    pool: Arc<dyn MemoryPool>,
 ) -> Result<FileMetaData> {
     let max_rowgroups = parallel_options.max_parallel_row_groups;
     // Buffer size of this channel limits maximum number of RowGroups being 
worked on in parallel
@@ -1061,12 +1101,14 @@ async fn output_single_parquet_file_parallelized(
         output_schema.clone(),
         arc_props.clone(),
         parallel_options,
+        Arc::clone(&pool),
     );
     let file_metadata = concatenate_parallel_row_groups(
         serialize_rx,
         output_schema.clone(),
         arc_props.clone(),
         object_store_writer,
+        pool,
     )
     .await?;
 
@@ -1158,8 +1200,10 @@ mod tests {
     use super::super::test_util::scan_format;
     use crate::datasource::listing::{ListingTableUrl, PartitionedFile};
     use crate::physical_plan::collect;
+    use crate::test_util::bounded_stream;
     use std::fmt::{Display, Formatter};
     use std::sync::atomic::{AtomicUsize, Ordering};
+    use std::time::Duration;
 
     use super::*;
 
@@ -2177,4 +2221,105 @@ mod tests {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn parquet_sink_write_memory_reservation() -> Result<()> {
+        async fn test_memory_reservation(global: ParquetOptions) -> Result<()> 
{
+            let field_a = Field::new("a", DataType::Utf8, false);
+            let field_b = Field::new("b", DataType::Utf8, false);
+            let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+            let object_store_url = ObjectStoreUrl::local_filesystem();
+
+            let file_sink_config = FileSinkConfig {
+                object_store_url: object_store_url.clone(),
+                file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)],
+                table_paths: vec![ListingTableUrl::parse("file:///")?],
+                output_schema: schema.clone(),
+                table_partition_cols: vec![],
+                overwrite: true,
+                keep_partition_by_columns: false,
+            };
+            let parquet_sink = Arc::new(ParquetSink::new(
+                file_sink_config,
+                TableParquetOptions {
+                    key_value_metadata: std::collections::HashMap::from([
+                        ("my-data".to_string(), Some("stuff".to_string())),
+                        ("my-data-bool-key".to_string(), None),
+                    ]),
+                    global,
+                    ..Default::default()
+                },
+            ));
+
+            // create data
+            let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", 
"bar"]));
+            let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", 
"baz"]));
+            let batch =
+                RecordBatch::try_from_iter(vec![("a", col_a), ("b", 
col_b)]).unwrap();
+
+            // create task context
+            let task_context = build_ctx(object_store_url.as_ref());
+            assert_eq!(
+                task_context.memory_pool().reserved(),
+                0,
+                "no bytes are reserved yet"
+            );
+
+            let mut write_task = parquet_sink.write_all(
+                Box::pin(RecordBatchStreamAdapter::new(
+                    schema,
+                    bounded_stream(batch, 1000),
+                )),
+                &task_context,
+            );
+
+            // incrementally poll and check for memory reservation
+            let mut reserved_bytes = 0;
+            while futures::poll!(&mut write_task).is_pending() {
+                reserved_bytes += task_context.memory_pool().reserved();
+                tokio::time::sleep(Duration::from_micros(1)).await;
+            }
+            assert!(
+                reserved_bytes > 0,
+                "should have bytes reserved during write"
+            );
+            assert_eq!(
+                task_context.memory_pool().reserved(),
+                0,
+                "no leaking byte reservation"
+            );
+
+            Ok(())
+        }
+
+        let write_opts = ParquetOptions {
+            allow_single_file_parallelism: false,
+            ..Default::default()
+        };
+        test_memory_reservation(write_opts)
+            .await
+            .expect("should track for non-parallel writes");
+
+        let row_parallel_write_opts = ParquetOptions {
+            allow_single_file_parallelism: true,
+            maximum_parallel_row_group_writers: 10,
+            maximum_buffered_record_batches_per_stream: 1,
+            ..Default::default()
+        };
+        test_memory_reservation(row_parallel_write_opts)
+            .await
+            .expect("should track for row-parallel writes");
+
+        let col_parallel_write_opts = ParquetOptions {
+            allow_single_file_parallelism: true,
+            maximum_parallel_row_group_writers: 1,
+            maximum_buffered_record_batches_per_stream: 2,
+            ..Default::default()
+        };
+        test_memory_reservation(col_parallel_write_opts)
+            .await
+            .expect("should track for column-parallel writes");
+
+        Ok(())
+    }
 }
diff --git a/datafusion/core/src/test_util/mod.rs 
b/datafusion/core/src/test_util/mod.rs
index 059fa8fc6d..ba0509f3f5 100644
--- a/datafusion/core/src/test_util/mod.rs
+++ b/datafusion/core/src/test_util/mod.rs
@@ -366,3 +366,39 @@ pub fn register_unbounded_file_with_ordering(
     ctx.register_table(table_name, 
Arc::new(StreamTable::new(Arc::new(config))))?;
     Ok(())
 }
+
+struct BoundedStream {
+    limit: usize,
+    count: usize,
+    batch: RecordBatch,
+}
+
+impl Stream for BoundedStream {
+    type Item = Result<RecordBatch>;
+
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        _cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        if self.count >= self.limit {
+            return Poll::Ready(None);
+        }
+        self.count += 1;
+        Poll::Ready(Some(Ok(self.batch.clone())))
+    }
+}
+
+impl RecordBatchStream for BoundedStream {
+    fn schema(&self) -> SchemaRef {
+        self.batch.schema()
+    }
+}
+
+/// Creates an bounded stream for testing purposes.
+pub fn bounded_stream(batch: RecordBatch, limit: usize) -> 
SendableRecordBatchStream {
+    Box::pin(BoundedStream {
+        count: 0,
+        limit,
+        batch,
+    })
+}
diff --git a/datafusion/core/tests/memory_limit/mod.rs 
b/datafusion/core/tests/memory_limit/mod.rs
index f61ee5d9ab..f7402357d1 100644
--- a/datafusion/core/tests/memory_limit/mod.rs
+++ b/datafusion/core/tests/memory_limit/mod.rs
@@ -31,6 +31,7 @@ use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
 use futures::StreamExt;
 use std::any::Any;
 use std::sync::{Arc, OnceLock};
+use tokio::fs::File;
 
 use datafusion::datasource::streaming::StreamingTable;
 use datafusion::datasource::{MemTable, TableProvider};
@@ -323,6 +324,30 @@ async fn oom_recursive_cte() {
         .await
 }
 
+#[tokio::test]
+async fn oom_parquet_sink() {
+    let dir = tempfile::tempdir().unwrap();
+    let path = dir.into_path().join("test.parquet");
+    let _ = File::create(path.clone()).await.unwrap();
+
+    TestCase::new()
+        .with_query(format!(
+            "
+            COPY (select * from t)
+            TO '{}'
+            STORED AS PARQUET OPTIONS (compression 'uncompressed');
+        ",
+            path.to_string_lossy()
+        ))
+        .with_expected_errors(vec![
+            // TODO: update error handling in ParquetSink
+            "Unable to send array to writer!",
+        ])
+        .with_memory_limit(200_000)
+        .run()
+        .await
+}
+
 /// Run the query with the specified memory limit,
 /// and verifies the expected errors are returned
 #[derive(Clone, Debug)]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to