This is an automated email from the ASF dual-hosted git repository. ytyou pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new 5421825b62 feat: add multi level merge sort that will always fit in memory (#15700) 5421825b62 is described below commit 5421825b62a4bbb977cb91ac45152b3a14ef3e88 Author: Raz Luvaton <16746759+rluva...@users.noreply.github.com> AuthorDate: Tue Jul 29 05:36:52 2025 +0300 feat: add multi level merge sort that will always fit in memory (#15700) * feat: add multi level merge sort that will always fit in memory * test: add fuzz test for aggregate * update * add more tests * fix test * update tests * added more aggregate fuzz * align with add fuzz tests * add sort fuzz * fix lints and formatting * moved spill in memory constrained envs to separate test * rename `StreamExec` to `OnceExec` * added comment on the usize in the `in_progress_spill_file` inside ExternalSorter * rename buffer_size to buffer_len * reuse code in spill fuzz * double the amount of memory needed to sort * add diagram for explaining the overview * update based on code review * fix test based on new memory calculation * remove get_size in favor of get_sliced_size * change to result --- datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs | 23 +- datafusion/core/tests/fuzz_cases/mod.rs | 2 + datafusion/core/tests/fuzz_cases/once_exec.rs | 115 ++++ .../spilling_fuzz_in_memory_constrained_env.rs | 654 +++++++++++++++++++++ .../physical-plan/src/aggregates/row_hash.rs | 32 +- datafusion/physical-plan/src/sorts/mod.rs | 1 + .../physical-plan/src/sorts/multi_level_merge.rs | 449 ++++++++++++++ datafusion/physical-plan/src/sorts/sort.rs | 54 +- .../physical-plan/src/sorts/streaming_merge.rs | 92 ++- .../physical-plan/src/spill/spill_manager.rs | 90 ++- 10 files changed, 1458 insertions(+), 54 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 52a6cc4811..bcf60eb2d7 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -38,21 +38,21 @@ use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::{HashMap, Result}; use datafusion_common_runtime::JoinSet; +use datafusion_functions_aggregate::sum::sum_udaf; +use datafusion_physical_expr::expressions::{col, lit, Column}; +use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_plan::InputOrderMode; +use test_utils::{add_empty_batches, StringBatchGenerator}; + use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; -use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; -use datafusion_physical_expr::expressions::{col, lit, Column}; -use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion_physical_plan::metrics::MetricValue; -use datafusion_physical_plan::InputOrderMode; use datafusion_physical_plan::{collect, displayable, ExecutionPlan}; -use test_utils::{add_empty_batches, StringBatchGenerator}; - use rand::rngs::StdRng; use rand::{random, rng, Rng, SeedableRng}; @@ -632,8 +632,11 @@ fn extract_result_counts(results: Vec<RecordBatch>) -> HashMap<Option<String>, i output } -fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc<AggregateExec>) { - if let Some(metrics_set) = single_aggregate.metrics() { +pub(crate) fn assert_spill_count_metric( + expect_spill: bool, + plan_that_spills: Arc<dyn ExecutionPlan>, +) -> usize { + if let Some(metrics_set) = plan_that_spills.metrics() { let mut spill_count = 0; // Inspect metrics for SpillCount @@ -649,6 +652,8 @@ fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc<Aggregate } else if !expect_spill && spill_count > 0 { panic!("Expected no spill but found SpillCount metric with value greater than 0."); } + + spill_count } else { panic!("No metrics returned from the operator; cannot verify spilling."); } @@ -656,7 +661,7 @@ fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc<Aggregate // Fix for https://github.com/apache/datafusion/issues/15530 #[tokio::test] -async fn test_single_mode_aggregate_with_spill() -> Result<()> { +async fn test_single_mode_aggregate_single_mode_aggregate_with_spill() -> Result<()> { let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::Int64, true), Field::new("col_1", DataType::Utf8, true), diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 9e01621c02..9e2fd170f7 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -33,4 +33,6 @@ mod sort_preserving_repartition_fuzz; mod window_fuzz; // Utility modules +mod once_exec; mod record_batch_generator; +mod spilling_fuzz_in_memory_constrained_env; diff --git a/datafusion/core/tests/fuzz_cases/once_exec.rs b/datafusion/core/tests/fuzz_cases/once_exec.rs new file mode 100644 index 0000000000..ec77c1e64c --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/once_exec.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::SchemaRef; +use datafusion_common::DataFusionError; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, Mutex}; + +/// Execution plan that return the stream on the call to `execute`. further calls to `execute` will +/// return an error +pub struct OnceExec { + /// the results to send back + stream: Mutex<Option<SendableRecordBatchStream>>, + cache: PlanProperties, +} + +impl Debug for OnceExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "OnceExec") + } +} + +impl OnceExec { + pub fn new(stream: SendableRecordBatchStream) -> Self { + let cache = Self::compute_properties(stream.schema()); + Self { + stream: Mutex::new(Some(stream)), + cache, + } + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(schema: SchemaRef) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + } +} + +impl DisplayAs for OnceExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "OnceExec:") + } + DisplayFormatType::TreeRender => { + write!(f, "") + } + } + } +} + +impl ExecutionPlan for OnceExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { + vec![] + } + + fn with_new_children( + self: Arc<Self>, + _: Vec<Arc<dyn ExecutionPlan>>, + ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> { + unimplemented!() + } + + /// Returns a stream which yields data + fn execute( + &self, + partition: usize, + _context: Arc<TaskContext>, + ) -> datafusion_common::Result<SendableRecordBatchStream> { + assert_eq!(partition, 0); + + let stream = self.stream.lock().unwrap().take(); + + stream.ok_or(DataFusionError::Internal( + "Stream already consumed".to_string(), + )) + } +} diff --git a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs new file mode 100644 index 0000000000..6c1bd316cd --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs @@ -0,0 +1,654 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fuzz Test for different operators in memory constrained environment + +use std::pin::Pin; +use std::sync::Arc; + +use crate::fuzz_cases::aggregate_fuzz::assert_spill_count_metric; +use crate::fuzz_cases::once_exec::OnceExec; +use arrow::array::UInt64Array; +use arrow::{array::StringArray, compute::SortOptions, record_batch::RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::common::Result; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::physical_plan::expressions::PhysicalSortExpr; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionConfig; +use datafusion_execution::memory_pool::units::{KB, MB}; +use datafusion_execution::memory_pool::{ + FairSpillPool, MemoryConsumer, MemoryReservation, +}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::{col, Column}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use futures::StreamExt; + +#[tokio::test] +async fn test_sort_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> +{ + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + 16 * KB as usize + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + 16 * KB as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + 16 * KB as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 6), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +struct RunTestWithLimitedMemoryArgs { + pool_size: usize, + task_ctx: Arc<TaskContext>, + number_of_record_batches: usize, + get_size_of_record_batch_to_generate: + Pin<Box<dyn Fn(usize) -> usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_sort_test_with_limited_memory( + mut args: RunTestWithLimitedMemoryArgs, +) -> Result<usize> { + let get_size_of_record_batch_to_generate = std::mem::replace( + &mut args.get_size_of_record_batch_to_generate, + Box::pin(move |_| unreachable!("should not be called after take")), + ); + + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc<dyn ExecutionPlan> = + Arc::new(OnceExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..args.number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::<u64>() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + let sort_exec = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("col_0", &scan_schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]) + .unwrap(), + plan, + )); + + let result = sort_exec.execute(0, Arc::clone(&args.task_ctx))?; + + run_test(args, sort_exec, result).await +} + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result<bool> { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + (16 * KB) as usize + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_record_batch( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 6), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +async fn run_test_aggregate_with_high_cardinality( + mut args: RunTestWithLimitedMemoryArgs, +) -> Result<usize> { + let get_size_of_record_batch_to_generate = std::mem::replace( + &mut args.get_size_of_record_batch_to_generate, + Box::pin(move |_| unreachable!("should not be called after take")), + ); + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(Column::new("col_0", 0)), + "col_0".to_string(), + )]); + + let aggregate_expressions = vec![Arc::new( + AggregateExprBuilder::new( + array_agg_udaf(), + vec![col("col_1", &scan_schema).unwrap()], + ) + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, + )]; + + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc<dyn ExecutionPlan> = + Arc::new(OnceExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..args.number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::<u64>() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + plan, + Arc::clone(&scan_schema), + )?); + let aggregate_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + aggregate_exec, + Arc::clone(&scan_schema), + )?); + + let result = aggregate_final.execute(0, Arc::clone(&args.task_ctx))?; + + run_test(args, aggregate_final, result).await +} + +async fn run_test( + args: RunTestWithLimitedMemoryArgs, + plan: Arc<dyn ExecutionPlan>, + result_stream: SendableRecordBatchStream, +) -> Result<usize> { + let number_of_record_batches = args.number_of_record_batches; + + consume_stream_and_simulate_other_running_memory_consumers(args, result_stream) + .await?; + + let spill_count = assert_spill_count_metric(true, plan); + + assert!( + spill_count > 0, + "Expected spill, but did not, number of record batches: {number_of_record_batches}", + ); + + Ok(spill_count) +} + +/// Consume the stream and change the amount of memory used while consuming it based on the [`MemoryBehavior`] provided +async fn consume_stream_and_simulate_other_running_memory_consumers( + args: RunTestWithLimitedMemoryArgs, + mut result_stream: SendableRecordBatchStream, +) -> Result<()> { + let mut number_of_rows = 0; + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; + + let memory_pool = args.task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + + while let Some(batch) = result_stream.next().await { + match args.memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible( + args.pool_size, + &mut memory_reservation, + )?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + + let batch = batch?; + number_of_rows += batch.num_rows(); + + index += 1; + } + + assert_eq!( + number_of_rows, + args.number_of_record_batches * record_batch_size as usize + ); + + Ok(()) +} diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1d659d7280..6132a8b0ad 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -31,7 +31,7 @@ use crate::aggregates::{ }; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; use crate::sorts::sort::sort_batch; -use crate::sorts::streaming_merge::StreamingMergeBuilder; +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::spill_manager::SpillManager; use crate::stream::RecordBatchStreamAdapter; use crate::{aggregates, metrics, PhysicalExpr}; @@ -40,7 +40,6 @@ use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; use arrow::datatypes::SchemaRef; use datafusion_common::{internal_err, DataFusionError, Result}; -use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; @@ -99,7 +98,7 @@ struct SpillState { // ======================================================================== /// If data has previously been spilled, the locations of the /// spill files (in Arrow IPC format) - spills: Vec<RefCountedTempFile>, + spills: Vec<SortedSpillFile>, /// true when streaming merge is in progress is_stream_merging: bool, @@ -1000,13 +999,21 @@ impl GroupedHashAggregateStream { let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; // Spill sorted state to disk - let spillfile = self.spill_state.spill_manager.spill_record_batch_by_size( - &sorted, - "HashAggSpill", - self.batch_size, - )?; + let spillfile = self + .spill_state + .spill_manager + .spill_record_batch_by_size_and_return_max_batch_memory( + &sorted, + "HashAggSpill", + self.batch_size, + )?; match spillfile { - Some(spillfile) => self.spill_state.spills.push(spillfile), + Some((spillfile, max_record_batch_memory)) => { + self.spill_state.spills.push(SortedSpillFile { + file: spillfile, + max_record_batch_memory, + }) + } None => { return internal_err!( "Calling spill with no intermediate batch to spill" @@ -1067,14 +1074,13 @@ impl GroupedHashAggregateStream { sort_batch(&batch, &expr, None) })), ))); - for spill in self.spill_state.spills.drain(..) { - let stream = self.spill_state.spill_manager.read_spill_as_stream(spill)?; - streams.push(stream); - } + self.spill_state.is_stream_merging = true; self.input = StreamingMergeBuilder::new() .with_streams(streams) .with_schema(schema) + .with_spill_manager(self.spill_state.spill_manager.clone()) + .with_sorted_spill_files(std::mem::take(&mut self.spill_state.spills)) .with_expressions(&self.spill_state.spill_expr) .with_metrics(self.baseline_metrics.clone()) .with_batch_size(self.batch_size) diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index c7ffae4061..9c72e34fe3 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -20,6 +20,7 @@ mod builder; mod cursor; mod merge; +mod multi_level_merge; pub mod partial_sort; pub mod sort; pub mod sort_preserving_merge; diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs new file mode 100644 index 0000000000..bb6fc751b8 --- /dev/null +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -0,0 +1,449 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Create a stream that do a multi level merge stream + +use crate::metrics::BaselineMetrics; +use crate::{EmptyRecordBatchStream, SpillManager}; +use arrow::array::RecordBatch; +use std::fmt::{Debug, Formatter}; +use std::mem; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::Result; +use datafusion_execution::memory_pool::MemoryReservation; + +use crate::sorts::sort::get_reserved_byte_for_record_batch_size; +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; +use crate::stream::RecordBatchStreamAdapter; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use futures::TryStreamExt; +use futures::{Stream, StreamExt}; + +/// Merges a stream of sorted cursors and record batches into a single sorted stream +/// +/// This is a wrapper around [`SortPreservingMergeStream`](crate::sorts::merge::SortPreservingMergeStream) +/// that provide it the sorted streams/files to merge while making sure we can merge them in memory. +/// In case we can't merge all of them in a single pass we will spill the intermediate results to disk +/// and repeat the process. +/// +/// ## High level Algorithm +/// 1. Get the maximum amount of sorted in-memory streams and spill files we can merge with the available memory +/// 2. Sort them to a sorted stream +/// 3. Do we have more spill files to merge? +/// - Yes: write that sorted stream to a spill file, +/// add that spill file back to the spill files to merge and +/// repeat the process +/// +/// - No: return that sorted stream as the final output stream +/// +/// ```text +/// Initial State: Multiple sorted streams + spill files +/// ┌───────────┐ +/// │ Phase 1 │ +/// └───────────┘ +/// ┌──Can hold in memory─┐ +/// │ ┌──────────────┐ │ +/// │ │ In-memory │ +/// │ │sorted stream │──┼────────┐ +/// │ │ 1 │ │ │ +/// └──────────────┘ │ │ +/// │ ┌──────────────┐ │ │ +/// │ │ In-memory │ │ +/// │ │sorted stream │──┼────────┤ +/// │ │ 2 │ │ │ +/// └──────────────┘ │ │ +/// │ ┌──────────────┐ │ │ +/// │ │ In-memory │ │ +/// │ │sorted stream │──┼────────┤ +/// │ │ 3 │ │ │ +/// └──────────────┘ │ │ +/// │ ┌──────────────┐ │ │ ┌───────────┐ +/// │ │ Sorted Spill │ │ │ Phase 2 │ +/// │ │ file 1 │──┼────────┤ └───────────┘ +/// │ └──────────────┘ │ │ +/// ──── ──── ──── ──── ─┘ │ ┌──Can hold in memory─┐ +/// │ │ │ +/// ┌──────────────┐ │ │ ┌──────────────┐ +/// │ Sorted Spill │ │ │ │ Sorted Spill │ │ +/// │ file 2 │──────────────────────▶│ file 2 │──┼─────┐ +/// └──────────────┘ │ └──────────────┘ │ │ +/// ┌──────────────┐ │ │ ┌──────────────┐ │ │ +/// │ Sorted Spill │ │ │ │ Sorted Spill │ │ +/// │ file 3 │──────────────────────▶│ file 3 │──┼─────┤ +/// └──────────────┘ │ │ └──────────────┘ │ │ +/// ┌──────────────┐ │ ┌──────────────┐ │ │ +/// │ Sorted Spill │ │ │ │ Sorted Spill │ │ │ +/// │ file 4 │──────────────────────▶│ file 4 │────────┤ ┌───────────┐ +/// └──────────────┘ │ │ └──────────────┘ │ │ │ Phase 3 │ +/// │ │ │ │ └───────────┘ +/// │ ──── ──── ──── ──── ─┘ │ ┌──Can hold in memory─┐ +/// │ │ │ │ +/// ┌──────────────┐ │ ┌──────────────┐ │ │ ┌──────────────┐ +/// │ Sorted Spill │ │ │ Sorted Spill │ │ │ │ Sorted Spill │ │ +/// │ file 5 │──────────────────────▶│ file 5 │────────────────▶│ file 5 │───┼───┐ +/// └──────────────┘ │ └──────────────┘ │ │ └──────────────┘ │ │ +/// │ │ │ │ │ +/// │ ┌──────────────┐ │ │ ┌──────────────┐ │ +/// │ │ Sorted Spill │ │ │ │ Sorted Spill │ │ │ ┌── ─── ─── ─── ─── ─── ─── ──┐ +/// └──────────▶│ file 6 │────────────────▶│ file 6 │───┼───┼──────▶ Output Stream +/// └──────────────┘ │ │ └──────────────┘ │ │ └── ─── ─── ─── ─── ─── ─── ──┘ +/// │ │ │ │ +/// │ │ ┌──────────────┐ │ +/// │ │ │ Sorted Spill │ │ │ +/// └───────▶│ file 7 │───┼───┘ +/// │ └──────────────┘ │ +/// │ │ +/// └─ ──── ──── ──── ──── +/// ``` +/// +/// ## Memory Management Strategy +/// +/// This multi-level merge make sure that we can handle any amount of data to sort as long as +/// we have enough memory to merge at least 2 streams at a time. +/// +/// 1. **Worst-Case Memory Reservation**: Reserves memory based on the largest +/// batch size encountered in each spill file to merge, ensuring sufficient memory is always +/// available during merge operations. +/// 2. **Adaptive Buffer Sizing**: Reduces buffer sizes when memory is constrained +/// 3. **Spill-to-Disk**: Spill to disk when we cannot merge all files in memory +/// +pub(crate) struct MultiLevelMergeBuilder { + spill_manager: SpillManager, + schema: SchemaRef, + sorted_spill_files: Vec<SortedSpillFile>, + sorted_streams: Vec<SendableRecordBatchStream>, + expr: LexOrdering, + metrics: BaselineMetrics, + batch_size: usize, + reservation: MemoryReservation, + fetch: Option<usize>, + enable_round_robin_tie_breaker: bool, +} + +impl Debug for MultiLevelMergeBuilder { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "MultiLevelMergeBuilder") + } +} + +impl MultiLevelMergeBuilder { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + spill_manager: SpillManager, + schema: SchemaRef, + sorted_spill_files: Vec<SortedSpillFile>, + sorted_streams: Vec<SendableRecordBatchStream>, + expr: LexOrdering, + metrics: BaselineMetrics, + batch_size: usize, + reservation: MemoryReservation, + fetch: Option<usize>, + enable_round_robin_tie_breaker: bool, + ) -> Self { + Self { + spill_manager, + schema, + sorted_spill_files, + sorted_streams, + expr, + metrics, + batch_size, + reservation, + enable_round_robin_tie_breaker, + fetch, + } + } + + pub(crate) fn create_spillable_merge_stream(self) -> SendableRecordBatchStream { + Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + futures::stream::once(self.create_stream()).try_flatten(), + )) + } + + async fn create_stream(mut self) -> Result<SendableRecordBatchStream> { + loop { + let mut stream = self.merge_sorted_runs_within_mem_limit()?; + + // TODO - add a threshold for number of files to disk even if empty and reading from disk so + // we can avoid the memory reservation + + // If no spill files are left, we can return the stream as this is the last sorted run + // TODO - We can write to disk before reading it back to avoid having multiple streams in memory + if self.sorted_spill_files.is_empty() { + assert!( + self.sorted_streams.is_empty(), + "We should not have any sorted streams left" + ); + + return Ok(stream); + } + + // Need to sort to a spill file + let Some((spill_file, max_record_batch_memory)) = self + .spill_manager + .spill_record_batch_stream_and_return_max_batch_memory( + &mut stream, + "MultiLevelMergeBuilder intermediate spill", + ) + .await? + else { + continue; + }; + + // Add the spill file + self.sorted_spill_files.push(SortedSpillFile { + file: spill_file, + max_record_batch_memory, + }); + } + } + + /// This tries to create a stream that merges the most sorted streams and sorted spill files + /// as possible within the memory limit. + fn merge_sorted_runs_within_mem_limit( + &mut self, + ) -> Result<SendableRecordBatchStream> { + match (self.sorted_spill_files.len(), self.sorted_streams.len()) { + // No data so empty batch + (0, 0) => Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone( + &self.schema, + )))), + + // Only in-memory stream, return that + (0, 1) => Ok(self.sorted_streams.remove(0)), + + // Only single sorted spill file so return it + (1, 0) => { + let spill_file = self.sorted_spill_files.remove(0); + + // Not reserving any memory for this disk as we are not holding it in memory + self.spill_manager.read_spill_as_stream(spill_file.file) + } + + // Only in memory streams, so merge them all in a single pass + (0, _) => { + let sorted_stream = mem::take(&mut self.sorted_streams); + self.create_new_merge_sort( + sorted_stream, + // If we have no sorted spill files left, this is the last run + true, + true, + ) + } + + // Need to merge multiple streams + (_, _) => { + let mut memory_reservation = self.reservation.new_empty(); + + // Don't account for existing streams memory + // as we are not holding the memory for them + let mut sorted_streams = mem::take(&mut self.sorted_streams); + + let (sorted_spill_files, buffer_size) = self + .get_sorted_spill_files_to_merge( + 2, + // we must have at least 2 streams to merge + 2_usize.saturating_sub(sorted_streams.len()), + &mut memory_reservation, + )?; + + let is_only_merging_memory_streams = sorted_spill_files.is_empty(); + + for spill in sorted_spill_files { + let stream = self + .spill_manager + .clone() + .with_batch_read_buffer_capacity(buffer_size) + .read_spill_as_stream(spill.file)?; + sorted_streams.push(stream); + } + + let merge_sort_stream = self.create_new_merge_sort( + sorted_streams, + // If we have no sorted spill files left, this is the last run + self.sorted_spill_files.is_empty(), + is_only_merging_memory_streams, + )?; + + // If we're only merging memory streams, we don't need to attach the memory reservation + // as it's empty + if is_only_merging_memory_streams { + assert_eq!(memory_reservation.size(), 0, "when only merging memory streams, we should not have any memory reservation and let the merge sort handle the memory"); + + Ok(merge_sort_stream) + } else { + // Attach the memory reservation to the stream to make sure we have enough memory + // throughout the merge process as we bypassed the memory pool for the merge sort stream + Ok(Box::pin(StreamAttachedReservation::new( + merge_sort_stream, + memory_reservation, + ))) + } + } + } + } + + fn create_new_merge_sort( + &mut self, + streams: Vec<SendableRecordBatchStream>, + is_output: bool, + all_in_memory: bool, + ) -> Result<SendableRecordBatchStream> { + let mut builder = StreamingMergeBuilder::new() + .with_schema(Arc::clone(&self.schema)) + .with_expressions(&self.expr) + .with_batch_size(self.batch_size) + .with_fetch(self.fetch) + .with_metrics(if is_output { + // Only add the metrics to the last run + self.metrics.clone() + } else { + self.metrics.intermediate() + }) + .with_round_robin_tie_breaker(self.enable_round_robin_tie_breaker) + .with_streams(streams); + + if !all_in_memory { + // Don't track memory used by this stream as we reserve that memory by worst case sceneries + // (reserving memory for the biggest batch in each stream) + // TODO - avoid this hack as this can be broken easily when `SortPreservingMergeStream` + // changes the implementation to use more/less memory + builder = builder.with_bypass_mempool(); + } else { + // If we are only merging in-memory streams, we need to use the memory reservation + // because we don't know the maximum size of the batches in the streams + builder = builder.with_reservation(self.reservation.new_empty()); + } + + builder.build() + } + + /// Return the sorted spill files to use for the next phase, and the buffer size + /// This will try to get as many spill files as possible to merge, and if we don't have enough streams + /// it will try to reduce the buffer size until we have enough streams to merge + /// otherwise it will return an error + fn get_sorted_spill_files_to_merge( + &mut self, + buffer_len: usize, + minimum_number_of_required_streams: usize, + reservation: &mut MemoryReservation, + ) -> Result<(Vec<SortedSpillFile>, usize)> { + assert_ne!(buffer_len, 0, "Buffer length must be greater than 0"); + let mut number_of_spills_to_read_for_current_phase = 0; + + for spill in &self.sorted_spill_files { + // For memory pools that are not shared this is good, for other this is not + // and there should be some upper limit to memory reservation so we won't starve the system + match reservation.try_grow(get_reserved_byte_for_record_batch_size( + spill.max_record_batch_memory * buffer_len, + )) { + Ok(_) => { + number_of_spills_to_read_for_current_phase += 1; + } + // If we can't grow the reservation, we need to stop + Err(err) => { + // We must have at least 2 streams to merge, so if we don't have enough memory + // fail + if minimum_number_of_required_streams + > number_of_spills_to_read_for_current_phase + { + // Free the memory we reserved for this merge as we either try again or fail + reservation.free(); + if buffer_len > 1 { + // Try again with smaller buffer size, it will be slower but at least we can merge + return self.get_sorted_spill_files_to_merge( + buffer_len - 1, + minimum_number_of_required_streams, + reservation, + ); + } + + return Err(err); + } + + // We reached the maximum amount of memory we can use + // for this merge + break; + } + } + } + + let spills = self + .sorted_spill_files + .drain(..number_of_spills_to_read_for_current_phase) + .collect::<Vec<_>>(); + + Ok((spills, buffer_len)) + } +} + +struct StreamAttachedReservation { + stream: SendableRecordBatchStream, + reservation: MemoryReservation, +} + +impl StreamAttachedReservation { + fn new(stream: SendableRecordBatchStream, reservation: MemoryReservation) -> Self { + Self { + stream, + reservation, + } + } +} + +impl Stream for StreamAttachedReservation { + type Item = Result<RecordBatch>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Self::Item>> { + let res = self.stream.poll_next_unpin(cx); + + match res { + Poll::Ready(res) => { + match res { + Some(Ok(batch)) => Poll::Ready(Some(Ok(batch))), + Some(Err(err)) => { + // Had an error so drop the data + self.reservation.free(); + Poll::Ready(Some(Err(err))) + } + None => { + // Stream is done so free the memory + self.reservation.free(); + + Poll::Ready(None) + } + } + } + Poll::Pending => Poll::Pending, + } + } +} + +impl RecordBatchStream for StreamAttachedReservation { + fn schema(&self) -> SchemaRef { + self.stream.schema() + } +} diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index bb572c4315..0b7d3977d2 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -35,10 +35,10 @@ use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, }; use crate::projection::{make_with_child, update_ordering, ProjectionExec}; -use crate::sorts::streaming_merge::StreamingMergeBuilder; +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::get_record_batch_memory_size; use crate::spill::in_progress_spill_file::InProgressSpillFile; -use crate::spill::spill_manager::SpillManager; +use crate::spill::spill_manager::{GetSlicedSize, SpillManager}; use crate::stream::RecordBatchStreamAdapter; use crate::topk::TopK; use crate::{ @@ -52,7 +52,6 @@ use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays}; use arrow::datatypes::SchemaRef; use datafusion_common::config::SpillCompression; use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError, Result}; -use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; @@ -222,12 +221,16 @@ struct ExternalSorter { /// During external sorting, in-memory intermediate data will be appended to /// this file incrementally. Once finished, this file will be moved to [`Self::finished_spill_files`]. - in_progress_spill_file: Option<InProgressSpillFile>, + /// + /// this is a tuple of: + /// 1. `InProgressSpillFile` - the file that is being written to + /// 2. `max_record_batch_memory` - the maximum memory usage of a single batch in this spill file. + in_progress_spill_file: Option<(InProgressSpillFile, usize)>, /// If data has previously been spilled, the locations of the spill files (in /// Arrow IPC format) /// Within the same spill file, the data might be chunked into multiple batches, /// and ordered by sort keys. - finished_spill_files: Vec<RefCountedTempFile>, + finished_spill_files: Vec<SortedSpillFile>, // ======================================================================== // EXECUTION RESOURCES: @@ -335,8 +338,6 @@ impl ExternalSorter { self.merge_reservation.free(); if self.spilled_before() { - let mut streams = vec![]; - // Sort `in_mem_batches` and spill it first. If there are many // `in_mem_batches` and the memory limit is almost reached, merging // them with the spilled files at the same time might cause OOM. @@ -344,16 +345,9 @@ impl ExternalSorter { self.sort_and_spill_in_mem_batches().await?; } - for spill in self.finished_spill_files.drain(..) { - if !spill.path().exists() { - return internal_err!("Spill file {:?} does not exist", spill.path()); - } - let stream = self.spill_manager.read_spill_as_stream(spill)?; - streams.push(stream); - } - StreamingMergeBuilder::new() - .with_streams(streams) + .with_sorted_spill_files(std::mem::take(&mut self.finished_spill_files)) + .with_spill_manager(self.spill_manager.clone()) .with_schema(Arc::clone(&self.schema)) .with_expressions(&self.expr.clone()) .with_metrics(self.metrics.baseline.clone()) @@ -399,7 +393,7 @@ impl ExternalSorter { // Lazily initialize the in-progress spill file if self.in_progress_spill_file.is_none() { self.in_progress_spill_file = - Some(self.spill_manager.create_in_progress_file("Sorting")?); + Some((self.spill_manager.create_in_progress_file("Sorting")?, 0)); } Self::organize_stringview_arrays(globally_sorted_batches)?; @@ -409,12 +403,16 @@ impl ExternalSorter { let batches_to_spill = std::mem::take(globally_sorted_batches); self.reservation.free(); - let in_progress_file = self.in_progress_spill_file.as_mut().ok_or_else(|| { - internal_datafusion_err!("In-progress spill file should be initialized") - })?; + let (in_progress_file, max_record_batch_size) = + self.in_progress_spill_file.as_mut().ok_or_else(|| { + internal_datafusion_err!("In-progress spill file should be initialized") + })?; for batch in batches_to_spill { in_progress_file.append_batch(&batch)?; + + *max_record_batch_size = + (*max_record_batch_size).max(batch.get_sliced_size()?); } if !globally_sorted_batches.is_empty() { @@ -426,14 +424,17 @@ impl ExternalSorter { /// Finishes the in-progress spill file and moves it to the finished spill files. async fn spill_finish(&mut self) -> Result<()> { - let mut in_progress_file = + let (mut in_progress_file, max_record_batch_memory) = self.in_progress_spill_file.take().ok_or_else(|| { internal_datafusion_err!("Should be called after `spill_append`") })?; let spill_file = in_progress_file.finish()?; if let Some(spill_file) = spill_file { - self.finished_spill_files.push(spill_file); + self.finished_spill_files.push(SortedSpillFile { + file: spill_file, + max_record_batch_memory, + }); } Ok(()) @@ -784,11 +785,16 @@ impl ExternalSorter { /// in sorting and merging. The sorted copies are in either row format or array format. /// Please refer to cursor.rs and stream.rs for more details. No matter what format the /// sorted copies are, they will use more memory than the original record batch. -fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize { +pub(crate) fn get_reserved_byte_for_record_batch_size(record_batch_size: usize) -> usize { // 2x may not be enough for some cases, but it's a good start. // If 2x is not enough, user can set a larger value for `sort_spill_reservation_bytes` // to compensate for the extra memory needed. - get_record_batch_memory_size(batch) * 2 + record_batch_size * 2 +} + +/// Estimate how much memory is needed to sort a `RecordBatch`. +fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize { + get_reserved_byte_for_record_batch_size(get_record_batch_memory_size(batch)) } impl Debug for ExternalSorter { diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index b74954eb96..191b135753 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -19,16 +19,22 @@ //! This is an order-preserving merge. use crate::metrics::BaselineMetrics; +use crate::sorts::multi_level_merge::MultiLevelMergeBuilder; use crate::sorts::{ merge::SortPreservingMergeStream, stream::{FieldCursorStream, RowCursorStream}, }; -use crate::SendableRecordBatchStream; +use crate::{SendableRecordBatchStream, SpillManager}; use arrow::array::*; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::{internal_err, Result}; -use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::{ + human_readable_size, MemoryConsumer, MemoryPool, MemoryReservation, + UnboundedMemoryPool, +}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use std::sync::Arc; macro_rules! primitive_merge_helper { ($t:ty, $($v:ident),+) => { @@ -52,9 +58,29 @@ macro_rules! merge_helper { }}; } +pub struct SortedSpillFile { + pub file: RefCountedTempFile, + + /// how much memory the largest memory batch is taking + pub max_record_batch_memory: usize, +} + +impl std::fmt::Debug for SortedSpillFile { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "SortedSpillFile({:?}) takes {}", + self.file.path(), + human_readable_size(self.max_record_batch_memory) + ) + } +} + #[derive(Default)] pub struct StreamingMergeBuilder<'a> { streams: Vec<SendableRecordBatchStream>, + sorted_spill_files: Vec<SortedSpillFile>, + spill_manager: Option<SpillManager>, schema: Option<SchemaRef>, expressions: Option<&'a LexOrdering>, metrics: Option<BaselineMetrics>, @@ -77,6 +103,19 @@ impl<'a> StreamingMergeBuilder<'a> { self } + pub fn with_sorted_spill_files( + mut self, + sorted_spill_files: Vec<SortedSpillFile>, + ) -> Self { + self.sorted_spill_files = sorted_spill_files; + self + } + + pub fn with_spill_manager(mut self, spill_manager: SpillManager) -> Self { + self.spill_manager = Some(spill_manager); + self + } + pub fn with_schema(mut self, schema: SchemaRef) -> Self { self.schema = Some(schema); self @@ -119,9 +158,22 @@ impl<'a> StreamingMergeBuilder<'a> { self } + /// Bypass the mempool and avoid using the memory reservation. + /// + /// This is not marked as `pub` because it is not recommended to use this method + pub(super) fn with_bypass_mempool(self) -> Self { + let mem_pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default()); + + self.with_reservation( + MemoryConsumer::new("merge stream mock memory").register(&mem_pool), + ) + } + pub fn build(self) -> Result<SendableRecordBatchStream> { let Self { streams, + sorted_spill_files, + spill_manager, schema, metrics, batch_size, @@ -131,14 +183,42 @@ impl<'a> StreamingMergeBuilder<'a> { enable_round_robin_tie_breaker, } = self; - // Early return if streams or expressions are empty: - if streams.is_empty() { - return internal_err!("Streams cannot be empty for streaming merge"); - } + // Early return if expressions are empty: let Some(expressions) = expressions else { return internal_err!("Sort expressions cannot be empty for streaming merge"); }; + if !sorted_spill_files.is_empty() { + // Unwrapping mandatory fields + let schema = schema.expect("Schema cannot be empty for streaming merge"); + let metrics = metrics.expect("Metrics cannot be empty for streaming merge"); + let batch_size = + batch_size.expect("Batch size cannot be empty for streaming merge"); + let reservation = + reservation.expect("Reservation cannot be empty for streaming merge"); + + return Ok(MultiLevelMergeBuilder::new( + spill_manager.expect("spill_manager should exist"), + schema, + sorted_spill_files, + streams, + expressions.clone(), + metrics, + batch_size, + reservation, + fetch, + enable_round_robin_tie_breaker, + ) + .create_spillable_merge_stream()); + } + + // Early return if streams are empty: + if streams.is_empty() { + return internal_err!( + "Streams/sorted spill files cannot be empty for streaming merge" + ); + } + // Unwrapping mandatory fields let schema = schema.expect("Schema cannot be empty for streaming merge"); let metrics = metrics.expect("Metrics cannot be empty for streaming merge"); diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 4f3afc5d12..6c47af129f 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -17,11 +17,10 @@ //! Define the `SpillManager` struct, which is responsible for reading and writing `RecordBatch`es to raw files based on the provided configurations. -use std::sync::Arc; - use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_execution::runtime_env::RuntimeEnv; +use std::sync::Arc; use datafusion_common::{config::SpillCompression, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; @@ -59,6 +58,14 @@ impl SpillManager { } } + pub fn with_batch_read_buffer_capacity( + mut self, + batch_read_buffer_capacity: usize, + ) -> Self { + self.batch_read_buffer_capacity = batch_read_buffer_capacity; + self + } + pub fn with_compression_type(mut self, spill_compression: SpillCompression) -> Self { self.compression = spill_compression; self @@ -125,6 +132,69 @@ impl SpillManager { self.spill_record_batch_and_finish(&batches, request_description) } + /// Refer to the documentation for [`Self::spill_record_batch_and_finish`]. This method + /// additionally spills the `RecordBatch` into smaller batches, divided by `row_limit`. + /// + /// # Errors + /// - Returns an error if spilling would exceed the disk usage limit configured + /// by `max_temp_directory_size` in `DiskManager` + pub(crate) fn spill_record_batch_by_size_and_return_max_batch_memory( + &self, + batch: &RecordBatch, + request_description: &str, + row_limit: usize, + ) -> Result<Option<(RefCountedTempFile, usize)>> { + let total_rows = batch.num_rows(); + let mut batches = Vec::new(); + let mut offset = 0; + + // It's ok to calculate all slices first, because slicing is zero-copy. + while offset < total_rows { + let length = std::cmp::min(total_rows - offset, row_limit); + let sliced_batch = batch.slice(offset, length); + batches.push(sliced_batch); + offset += length; + } + + let mut in_progress_file = self.create_in_progress_file(request_description)?; + + let mut max_record_batch_size = 0; + + for batch in batches { + in_progress_file.append_batch(&batch)?; + + max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()?); + } + + let file = in_progress_file.finish()?; + + Ok(file.map(|f| (f, max_record_batch_size))) + } + + /// Spill a stream of `RecordBatch`es to disk and return the spill file and the size of the largest batch in memory + pub(crate) async fn spill_record_batch_stream_and_return_max_batch_memory( + &self, + stream: &mut SendableRecordBatchStream, + request_description: &str, + ) -> Result<Option<(RefCountedTempFile, usize)>> { + use futures::StreamExt; + + let mut in_progress_file = self.create_in_progress_file(request_description)?; + + let mut max_record_batch_size = 0; + + while let Some(batch) = stream.next().await { + let batch = batch?; + in_progress_file.append_batch(&batch)?; + + max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()?); + } + + let file = in_progress_file.finish()?; + + Ok(file.map(|f| (f, max_record_batch_size))) + } + /// Reads a spill file as a stream. The file must be created by the current `SpillManager`. /// This method will generate output in FIFO order: the batch appended first /// will be read first. @@ -140,3 +210,19 @@ impl SpillManager { Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } } + +pub(crate) trait GetSlicedSize { + /// Returns the size of the `RecordBatch` when sliced. + fn get_sliced_size(&self) -> Result<usize>; +} + +impl GetSlicedSize for RecordBatch { + fn get_sliced_size(&self) -> Result<usize> { + let mut total = 0; + for array in self.columns() { + let data = array.to_data(); + total += data.get_slice_memory_size()?; + } + Ok(total) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org