alamb commented on code in PR #7801: URL: https://github.com/apache/arrow-datafusion/pull/7801#discussion_r1365531237
########## datafusion/core/src/datasource/file_format/write/demux.rs: ########## @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! dividing input stream into multiple output files at execution time + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::datasource::listing::ListingTableUrl; + +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::builder::UInt64Builder; +use arrow_array::cast::AsArray; +use arrow_array::{RecordBatch, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::cast::as_string_array; +use datafusion_common::DataFusionError; + +use datafusion_execution::TaskContext; + +use futures::StreamExt; +use object_store::path::Path; + +use rand::distributions::DistString; + +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; + +type RecordBatchReceiver = Receiver<RecordBatch>; +type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; + +/// Splits a single [SendableRecordBatchStream] into a dynamically determined Review Comment: thank you for this description @devinjdangelo -- it is very clear to read and extremely helpful for review (and long term understanding how this code works) ########## datafusion/core/src/datasource/file_format/write/demux.rs: ########## @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! dividing input stream into multiple output files at execution time + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::datasource::listing::ListingTableUrl; + +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::builder::UInt64Builder; +use arrow_array::cast::AsArray; +use arrow_array::{RecordBatch, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::cast::as_string_array; +use datafusion_common::DataFusionError; + +use datafusion_execution::TaskContext; + +use futures::StreamExt; +use object_store::path::Path; + +use rand::distributions::DistString; + +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; + +type RecordBatchReceiver = Receiver<RecordBatch>; +type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; + +/// Splits a single [SendableRecordBatchStream] into a dynamically determined +/// number of partitions at execution time. The partitions are determined by +/// factors known only at execution time, such as total number of rows and +/// partition column values. The demuxer task communicates to the caller +/// by sending channels over a channel. The inner channels send RecordBatches +/// which should be contained within the same output file. The outer channel +/// is used to send a dynamic number of inner channels, representing a dynamic +/// number of total output files. The caller is also responsible to monitor +/// the demux task for errors and abort accordingly. The single_file_ouput parameter +/// overrides all other settings to force only a single file to be written. +/// partition_by parameter will additionally split the input based on the unique +/// values of a specific column `<https://github.com/apache/arrow-datafusion/issues/7744>`` Review Comment: I think the formatting got messed up here somehow: ```suggestion /// values of a specific column `<https://github.com/apache/arrow-datafusion/issues/7744> /// /// ```text ``` ########## datafusion/sqllogictest/test_files/insert_to_external.slt: ########## @@ -87,6 +87,127 @@ SELECT * from ordered_insert_test; 7 8 7 7 +# test partitioned insert + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test(a string, b string, c bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned/' +PARTITIONED BY (a, b) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +#note that partitioned cols are moved to the end so value tuples are (c, a, b) Review Comment: Thank you for the comment, I recommend updating the test to use values like `(1, 10, 100), (1, 20, 100), ...` so it is more clear which values belong to which columns ########## datafusion/core/src/datasource/file_format/write/demux.rs: ########## @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! dividing input stream into multiple output files at execution time + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::datasource::listing::ListingTableUrl; + +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::builder::UInt64Builder; +use arrow_array::cast::AsArray; +use arrow_array::{RecordBatch, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::cast::as_string_array; +use datafusion_common::DataFusionError; + +use datafusion_execution::TaskContext; + +use futures::StreamExt; +use object_store::path::Path; + +use rand::distributions::DistString; + +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; + +type RecordBatchReceiver = Receiver<RecordBatch>; +type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; + +/// Splits a single [SendableRecordBatchStream] into a dynamically determined +/// number of partitions at execution time. The partitions are determined by +/// factors known only at execution time, such as total number of rows and +/// partition column values. The demuxer task communicates to the caller +/// by sending channels over a channel. The inner channels send RecordBatches +/// which should be contained within the same output file. The outer channel +/// is used to send a dynamic number of inner channels, representing a dynamic +/// number of total output files. The caller is also responsible to monitor +/// the demux task for errors and abort accordingly. The single_file_ouput parameter +/// overrides all other settings to force only a single file to be written. +/// partition_by parameter will additionally split the input based on the unique +/// values of a specific column `<https://github.com/apache/arrow-datafusion/issues/7744>`` +/// ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌──────▶ │ batch 1 ├────▶...──────▶│ Batch a │ │ Output File1│ +/// │ └───────────┘ └────────────┘ └─────────────┘ +/// │ +/// ┌──────────┐ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌───────────┐ ┌────────────┐ │ │ ├──────▶ │ batch a+1├────▶...──────▶│ Batch b │ │ Output File2│ +/// │ batch 1 ├────▶...──────▶│ Batch N ├─────▶│ Demux ├────────┤ ... └───────────┘ └────────────┘ └─────────────┘ +/// └───────────┘ └────────────┘ │ │ │ +/// └──────────┘ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// └──────▶ │ batch d ├────▶...──────▶│ Batch n │ │ Output FileN│ +/// └───────────┘ └────────────┘ └─────────────┘ +pub(crate) fn start_demuxer_task( + input: SendableRecordBatchStream, + context: &Arc<TaskContext>, + partition_by: Option<Vec<(String, DataType)>>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> (JoinHandle<Result<()>>, DemuxedStreamReceiver) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let context = context.clone(); + let task: JoinHandle<std::result::Result<(), DataFusionError>> = match partition_by { + Some(parts) => { + // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot + // bound this channel without risking a deadlock. + tokio::spawn(async move { + hive_style_partitions_demuxer( + tx, + input, + context, + parts, + base_output_path, + file_extension, + ) + .await + }) + } + None => tokio::spawn(async move { + row_count_demuxer( + tx, + input, + context, + base_output_path, + file_extension, + single_file_output, + ) + .await + }), + }; + + (task, rx) +} + +/// Dynamically partitions input stream to acheive desired maximum rows per file Review Comment: ```suggestion /// Dynamically partitions input stream to achieve desired maximum rows per file ``` ########## datafusion/sqllogictest/test_files/insert_to_external.slt: ########## @@ -87,6 +87,127 @@ SELECT * from ordered_insert_test; 7 8 7 7 +# test partitioned insert + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test(a string, b string, c bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned/' +PARTITIONED BY (a, b) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +#note that partitioned cols are moved to the end so value tuples are (c, a, b) +query ITT +INSERT INTO partitioned_insert_test values (1, 1, 1), (1, 1, 2), (1, 2, 1), (1, 2, 2), (2, 2, 1), (2, 2, 2); +---- +6 + +query ITT +select * from partitioned_insert_test order by a,b,c +---- +1 1 1 +1 1 2 +1 2 1 +2 2 1 +1 2 2 +2 2 2 + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_verify(c bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned/a=2/b=1/' Review Comment: that is very cool ########## datafusion/core/src/datasource/file_format/write/demux.rs: ########## @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! dividing input stream into multiple output files at execution time + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::datasource::listing::ListingTableUrl; + +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::builder::UInt64Builder; +use arrow_array::cast::AsArray; +use arrow_array::{RecordBatch, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::cast::as_string_array; +use datafusion_common::DataFusionError; + +use datafusion_execution::TaskContext; + +use futures::StreamExt; +use object_store::path::Path; + +use rand::distributions::DistString; + +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; + +type RecordBatchReceiver = Receiver<RecordBatch>; +type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; + +/// Splits a single [SendableRecordBatchStream] into a dynamically determined +/// number of partitions at execution time. The partitions are determined by +/// factors known only at execution time, such as total number of rows and +/// partition column values. The demuxer task communicates to the caller +/// by sending channels over a channel. The inner channels send RecordBatches +/// which should be contained within the same output file. The outer channel +/// is used to send a dynamic number of inner channels, representing a dynamic +/// number of total output files. The caller is also responsible to monitor +/// the demux task for errors and abort accordingly. The single_file_ouput parameter +/// overrides all other settings to force only a single file to be written. +/// partition_by parameter will additionally split the input based on the unique +/// values of a specific column `<https://github.com/apache/arrow-datafusion/issues/7744>`` +/// ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌──────▶ │ batch 1 ├────▶...──────▶│ Batch a │ │ Output File1│ +/// │ └───────────┘ └────────────┘ └─────────────┘ +/// │ +/// ┌──────────┐ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌───────────┐ ┌────────────┐ │ │ ├──────▶ │ batch a+1├────▶...──────▶│ Batch b │ │ Output File2│ +/// │ batch 1 ├────▶...──────▶│ Batch N ├─────▶│ Demux ├────────┤ ... └───────────┘ └────────────┘ └─────────────┘ +/// └───────────┘ └────────────┘ │ │ │ +/// └──────────┘ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// └──────▶ │ batch d ├────▶...──────▶│ Batch n │ │ Output FileN│ +/// └───────────┘ └────────────┘ └─────────────┘ +pub(crate) fn start_demuxer_task( + input: SendableRecordBatchStream, + context: &Arc<TaskContext>, + partition_by: Option<Vec<(String, DataType)>>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> (JoinHandle<Result<()>>, DemuxedStreamReceiver) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let context = context.clone(); + let task: JoinHandle<std::result::Result<(), DataFusionError>> = match partition_by { + Some(parts) => { + // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot + // bound this channel without risking a deadlock. + tokio::spawn(async move { + hive_style_partitions_demuxer( + tx, + input, + context, + parts, + base_output_path, + file_extension, + ) + .await + }) + } + None => tokio::spawn(async move { + row_count_demuxer( + tx, + input, + context, + base_output_path, + file_extension, + single_file_output, + ) + .await + }), + }; + + (task, rx) +} + +/// Dynamically partitions input stream to acheive desired maximum rows per file +async fn row_count_demuxer( + mut tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> Result<()> { + let exec_options = &context.session_config().options().execution; + let max_rows_per_file = exec_options.soft_max_rows_per_output_file; + let max_buffered_batches = exec_options.max_buffered_batches_per_output_file; + let mut total_rows_current_file = 0; + let mut part_idx = 0; + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let mut tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + + while let Some(rb) = input.next().await.transpose()? { + total_rows_current_file += rb.num_rows(); + tx_file.send(rb).await.map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + + if total_rows_current_file >= max_rows_per_file && !single_file_output { + total_rows_current_file = 0; + tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + } + } + Ok(()) +} + +/// Helper for row count demuxer +fn generate_file_path( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, +) -> Path { + if !single_file_output { + base_output_path + .prefix() + .child(format!("{}_{}.{}", write_id, part_idx, file_extension)) + } else { + base_output_path.prefix().to_owned() + } +} + +/// Helper for row count demuxer +fn create_new_file_stream( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, + max_buffered_batches: usize, + tx: &mut UnboundedSender<(Path, Receiver<RecordBatch>)>, +) -> Result<Sender<RecordBatch>> { + let file_path = generate_file_path( + base_output_path, + write_id, + part_idx, + file_extension, + single_file_output, + ); + let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2); + tx.send((file_path, rx_file)).map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + Ok(tx_file) +} + +/// Splits an input stream based on the distinct values of a set of columns +/// Assumes standard hive style partition paths such as +/// /col1=val1/col2=val2/outputfile.parquet +async fn hive_style_partitions_demuxer( + tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + partition_by: Vec<(String, DataType)>, + base_output_path: ListingTableUrl, + file_extension: String, +) -> Result<()> { + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let exec_options = &context.session_config().options().execution; + let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file; + + // To support non string partition col types, cast the type to &str first + let mut value_map: HashMap<Vec<String>, Sender<RecordBatch>> = HashMap::new(); + + while let Some(rb) = input.next().await.transpose()? { + // First compute partition key for each row of batch, e.g. (col1=val1, col2=val2, ...) + let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?; + + // Next compute how the batch should be split up to take each distinct key to its own batch + let take_map = compute_take_arrays(&rb, all_partition_values); + + // Divide up the batch into distinct partition key batches and send each batch + for (part_key, mut builder) in take_map.into_iter() { + // Take method adapted from https://github.com/lancedb/lance/pull/1337/files + // TODO: upstream RecordBatch::take to arrow-rs + let take_indices = builder.finish(); + let struct_array: StructArray = rb.clone().into(); + let parted_batch = RecordBatch::try_from( + arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(), + ) + .map_err(|_| { + DataFusionError::Internal("Unexpected error partitioning batch!".into()) + })?; + + // Get or create channel for this batch + let part_tx = match value_map.get_mut(&part_key) { + Some(part_tx) => part_tx, + None => { + // Create channel for previously unseen distinct partition key and notify consumer of new file + let (part_tx, part_rx) = tokio::sync::mpsc::channel::<RecordBatch>( + max_buffered_recordbatches, + ); + let file_path = compute_hive_style_file_path( + &part_key, + &partition_by, + &write_id, + &file_extension, + &base_output_path, + ); + + tx.send((file_path, part_rx)).map_err(|_| { + DataFusionError::Execution( + "Error sending new file stream!".into(), + ) + })?; + + value_map.insert(part_key.clone(), part_tx); + value_map + .get_mut(&part_key) + .ok_or(DataFusionError::Internal( + "Key must exist since it was just inserted!".into(), + ))? + } + }; + + // remove partitions columns + let final_batch_to_send = + remove_partition_by_columns(&parted_batch, &partition_by)?; + + // Finally send the partial batch partitioned by distinct value! + part_tx.send(final_batch_to_send).await.map_err(|_| { + DataFusionError::Internal("Unexpected error sending parted batch!".into()) + })?; + } + } + + Ok(()) +} + +fn compute_partition_keys_by_row<'a>( + rb: &'a RecordBatch, + partition_by: &'a [(String, DataType)], +) -> Result<Vec<Vec<&'a str>>> { + let mut all_partition_values = vec![]; + + for (col, dtype) in partition_by.iter() { + let mut partition_values = vec![]; + let col_array = + rb.column_by_name(col) + .ok_or(DataFusionError::Execution(format!( + "PartitionBy Column {} does not exist in source data!", + col + )))?; + + match dtype { + DataType::Utf8 => { + let array = as_string_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(array.value(i)); + } + } + _ => { + return Err(DataFusionError::NotImplemented(format!( + "it is not yet supported to write to hive partitions with datatype {}", + dtype + ))) + } + } + + all_partition_values.push(partition_values); + } + + Ok(all_partition_values) +} + +fn compute_take_arrays( + rb: &RecordBatch, + all_partition_values: Vec<Vec<&str>>, +) -> HashMap<Vec<String>, UInt64Builder> { + let mut take_map = HashMap::new(); + for i in 0..rb.num_rows() { + let mut part_key = vec![]; + for vals in all_partition_values.iter() { + part_key.push(vals[i].to_owned()); Review Comment: is there any reason to use `to_owned` here? given `take_map` isn't returned I think it would be fine to store `&str` values here too ```suggestion part_key.push(vals[i]); ``` ########## datafusion/core/src/datasource/file_format/write/demux.rs: ########## @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! dividing input stream into multiple output files at execution time + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::datasource::listing::ListingTableUrl; + +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::builder::UInt64Builder; +use arrow_array::cast::AsArray; +use arrow_array::{RecordBatch, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::cast::as_string_array; +use datafusion_common::DataFusionError; + +use datafusion_execution::TaskContext; + +use futures::StreamExt; +use object_store::path::Path; + +use rand::distributions::DistString; + +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; + +type RecordBatchReceiver = Receiver<RecordBatch>; +type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; + +/// Splits a single [SendableRecordBatchStream] into a dynamically determined +/// number of partitions at execution time. The partitions are determined by +/// factors known only at execution time, such as total number of rows and +/// partition column values. The demuxer task communicates to the caller +/// by sending channels over a channel. The inner channels send RecordBatches +/// which should be contained within the same output file. The outer channel +/// is used to send a dynamic number of inner channels, representing a dynamic +/// number of total output files. The caller is also responsible to monitor +/// the demux task for errors and abort accordingly. The single_file_ouput parameter +/// overrides all other settings to force only a single file to be written. +/// partition_by parameter will additionally split the input based on the unique +/// values of a specific column `<https://github.com/apache/arrow-datafusion/issues/7744>`` +/// ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌──────▶ │ batch 1 ├────▶...──────▶│ Batch a │ │ Output File1│ +/// │ └───────────┘ └────────────┘ └─────────────┘ +/// │ +/// ┌──────────┐ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌───────────┐ ┌────────────┐ │ │ ├──────▶ │ batch a+1├────▶...──────▶│ Batch b │ │ Output File2│ +/// │ batch 1 ├────▶...──────▶│ Batch N ├─────▶│ Demux ├────────┤ ... └───────────┘ └────────────┘ └─────────────┘ +/// └───────────┘ └────────────┘ │ │ │ +/// └──────────┘ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// └──────▶ │ batch d ├────▶...──────▶│ Batch n │ │ Output FileN│ +/// └───────────┘ └────────────┘ └─────────────┘ +pub(crate) fn start_demuxer_task( + input: SendableRecordBatchStream, + context: &Arc<TaskContext>, + partition_by: Option<Vec<(String, DataType)>>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> (JoinHandle<Result<()>>, DemuxedStreamReceiver) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let context = context.clone(); + let task: JoinHandle<std::result::Result<(), DataFusionError>> = match partition_by { + Some(parts) => { + // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot + // bound this channel without risking a deadlock. + tokio::spawn(async move { + hive_style_partitions_demuxer( + tx, + input, + context, + parts, + base_output_path, + file_extension, + ) + .await + }) + } + None => tokio::spawn(async move { + row_count_demuxer( + tx, + input, + context, + base_output_path, + file_extension, + single_file_output, + ) + .await + }), + }; + + (task, rx) +} + +/// Dynamically partitions input stream to acheive desired maximum rows per file +async fn row_count_demuxer( + mut tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> Result<()> { + let exec_options = &context.session_config().options().execution; + let max_rows_per_file = exec_options.soft_max_rows_per_output_file; + let max_buffered_batches = exec_options.max_buffered_batches_per_output_file; + let mut total_rows_current_file = 0; + let mut part_idx = 0; + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let mut tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + + while let Some(rb) = input.next().await.transpose()? { + total_rows_current_file += rb.num_rows(); + tx_file.send(rb).await.map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + + if total_rows_current_file >= max_rows_per_file && !single_file_output { + total_rows_current_file = 0; + tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + } + } + Ok(()) +} + +/// Helper for row count demuxer +fn generate_file_path( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, +) -> Path { + if !single_file_output { + base_output_path + .prefix() + .child(format!("{}_{}.{}", write_id, part_idx, file_extension)) + } else { + base_output_path.prefix().to_owned() + } +} + +/// Helper for row count demuxer +fn create_new_file_stream( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, + max_buffered_batches: usize, + tx: &mut UnboundedSender<(Path, Receiver<RecordBatch>)>, +) -> Result<Sender<RecordBatch>> { + let file_path = generate_file_path( + base_output_path, + write_id, + part_idx, + file_extension, + single_file_output, + ); + let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2); + tx.send((file_path, rx_file)).map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + Ok(tx_file) +} + +/// Splits an input stream based on the distinct values of a set of columns +/// Assumes standard hive style partition paths such as +/// /col1=val1/col2=val2/outputfile.parquet +async fn hive_style_partitions_demuxer( + tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + partition_by: Vec<(String, DataType)>, + base_output_path: ListingTableUrl, + file_extension: String, +) -> Result<()> { + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let exec_options = &context.session_config().options().execution; + let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file; + + // To support non string partition col types, cast the type to &str first + let mut value_map: HashMap<Vec<String>, Sender<RecordBatch>> = HashMap::new(); + + while let Some(rb) = input.next().await.transpose()? { + // First compute partition key for each row of batch, e.g. (col1=val1, col2=val2, ...) + let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?; + + // Next compute how the batch should be split up to take each distinct key to its own batch + let take_map = compute_take_arrays(&rb, all_partition_values); + + // Divide up the batch into distinct partition key batches and send each batch + for (part_key, mut builder) in take_map.into_iter() { + // Take method adapted from https://github.com/lancedb/lance/pull/1337/files + // TODO: upstream RecordBatch::take to arrow-rs + let take_indices = builder.finish(); + let struct_array: StructArray = rb.clone().into(); + let parted_batch = RecordBatch::try_from( + arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(), + ) + .map_err(|_| { + DataFusionError::Internal("Unexpected error partitioning batch!".into()) + })?; + + // Get or create channel for this batch + let part_tx = match value_map.get_mut(&part_key) { + Some(part_tx) => part_tx, + None => { + // Create channel for previously unseen distinct partition key and notify consumer of new file + let (part_tx, part_rx) = tokio::sync::mpsc::channel::<RecordBatch>( + max_buffered_recordbatches, + ); + let file_path = compute_hive_style_file_path( + &part_key, + &partition_by, + &write_id, + &file_extension, + &base_output_path, + ); + + tx.send((file_path, part_rx)).map_err(|_| { + DataFusionError::Execution( + "Error sending new file stream!".into(), + ) + })?; + + value_map.insert(part_key.clone(), part_tx); + value_map + .get_mut(&part_key) + .ok_or(DataFusionError::Internal( + "Key must exist since it was just inserted!".into(), + ))? + } + }; + + // remove partitions columns + let final_batch_to_send = + remove_partition_by_columns(&parted_batch, &partition_by)?; + + // Finally send the partial batch partitioned by distinct value! + part_tx.send(final_batch_to_send).await.map_err(|_| { + DataFusionError::Internal("Unexpected error sending parted batch!".into()) + })?; + } + } + + Ok(()) +} + +fn compute_partition_keys_by_row<'a>( + rb: &'a RecordBatch, + partition_by: &'a [(String, DataType)], +) -> Result<Vec<Vec<&'a str>>> { + let mut all_partition_values = vec![]; + + for (col, dtype) in partition_by.iter() { + let mut partition_values = vec![]; + let col_array = + rb.column_by_name(col) + .ok_or(DataFusionError::Execution(format!( + "PartitionBy Column {} does not exist in source data!", + col + )))?; + + match dtype { + DataType::Utf8 => { + let array = as_string_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(array.value(i)); + } + } + _ => { + return Err(DataFusionError::NotImplemented(format!( + "it is not yet supported to write to hive partitions with datatype {}", Review Comment: ```suggestion "it is not yet supported to write to hive partitions with datatype {} while writing column {col}", ``` ########## datafusion/core/src/datasource/file_format/write/demux.rs: ########## @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! dividing input stream into multiple output files at execution time + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::datasource::listing::ListingTableUrl; + +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::builder::UInt64Builder; +use arrow_array::cast::AsArray; +use arrow_array::{RecordBatch, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::cast::as_string_array; +use datafusion_common::DataFusionError; + +use datafusion_execution::TaskContext; + +use futures::StreamExt; +use object_store::path::Path; + +use rand::distributions::DistString; + +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; + +type RecordBatchReceiver = Receiver<RecordBatch>; +type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; + +/// Splits a single [SendableRecordBatchStream] into a dynamically determined +/// number of partitions at execution time. The partitions are determined by +/// factors known only at execution time, such as total number of rows and +/// partition column values. The demuxer task communicates to the caller +/// by sending channels over a channel. The inner channels send RecordBatches +/// which should be contained within the same output file. The outer channel +/// is used to send a dynamic number of inner channels, representing a dynamic +/// number of total output files. The caller is also responsible to monitor +/// the demux task for errors and abort accordingly. The single_file_ouput parameter +/// overrides all other settings to force only a single file to be written. +/// partition_by parameter will additionally split the input based on the unique +/// values of a specific column `<https://github.com/apache/arrow-datafusion/issues/7744>`` +/// ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌──────▶ │ batch 1 ├────▶...──────▶│ Batch a │ │ Output File1│ +/// │ └───────────┘ └────────────┘ └─────────────┘ +/// │ +/// ┌──────────┐ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌───────────┐ ┌────────────┐ │ │ ├──────▶ │ batch a+1├────▶...──────▶│ Batch b │ │ Output File2│ +/// │ batch 1 ├────▶...──────▶│ Batch N ├─────▶│ Demux ├────────┤ ... └───────────┘ └────────────┘ └─────────────┘ +/// └───────────┘ └────────────┘ │ │ │ +/// └──────────┘ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// └──────▶ │ batch d ├────▶...──────▶│ Batch n │ │ Output FileN│ +/// └───────────┘ └────────────┘ └─────────────┘ +pub(crate) fn start_demuxer_task( + input: SendableRecordBatchStream, + context: &Arc<TaskContext>, + partition_by: Option<Vec<(String, DataType)>>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> (JoinHandle<Result<()>>, DemuxedStreamReceiver) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let context = context.clone(); + let task: JoinHandle<std::result::Result<(), DataFusionError>> = match partition_by { + Some(parts) => { + // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot + // bound this channel without risking a deadlock. + tokio::spawn(async move { + hive_style_partitions_demuxer( + tx, + input, + context, + parts, + base_output_path, + file_extension, + ) + .await + }) + } + None => tokio::spawn(async move { + row_count_demuxer( + tx, + input, + context, + base_output_path, + file_extension, + single_file_output, + ) + .await + }), + }; + + (task, rx) +} + +/// Dynamically partitions input stream to acheive desired maximum rows per file +async fn row_count_demuxer( + mut tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> Result<()> { + let exec_options = &context.session_config().options().execution; + let max_rows_per_file = exec_options.soft_max_rows_per_output_file; + let max_buffered_batches = exec_options.max_buffered_batches_per_output_file; + let mut total_rows_current_file = 0; + let mut part_idx = 0; + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let mut tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + + while let Some(rb) = input.next().await.transpose()? { + total_rows_current_file += rb.num_rows(); + tx_file.send(rb).await.map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + + if total_rows_current_file >= max_rows_per_file && !single_file_output { + total_rows_current_file = 0; + tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + } + } + Ok(()) +} + +/// Helper for row count demuxer +fn generate_file_path( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, +) -> Path { + if !single_file_output { + base_output_path + .prefix() + .child(format!("{}_{}.{}", write_id, part_idx, file_extension)) + } else { + base_output_path.prefix().to_owned() + } +} + +/// Helper for row count demuxer +fn create_new_file_stream( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, + max_buffered_batches: usize, + tx: &mut UnboundedSender<(Path, Receiver<RecordBatch>)>, +) -> Result<Sender<RecordBatch>> { + let file_path = generate_file_path( + base_output_path, + write_id, + part_idx, + file_extension, + single_file_output, + ); + let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2); + tx.send((file_path, rx_file)).map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + Ok(tx_file) +} + +/// Splits an input stream based on the distinct values of a set of columns +/// Assumes standard hive style partition paths such as +/// /col1=val1/col2=val2/outputfile.parquet +async fn hive_style_partitions_demuxer( + tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + partition_by: Vec<(String, DataType)>, + base_output_path: ListingTableUrl, + file_extension: String, +) -> Result<()> { + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let exec_options = &context.session_config().options().execution; + let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file; + + // To support non string partition col types, cast the type to &str first + let mut value_map: HashMap<Vec<String>, Sender<RecordBatch>> = HashMap::new(); + + while let Some(rb) = input.next().await.transpose()? { + // First compute partition key for each row of batch, e.g. (col1=val1, col2=val2, ...) + let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?; + + // Next compute how the batch should be split up to take each distinct key to its own batch + let take_map = compute_take_arrays(&rb, all_partition_values); + + // Divide up the batch into distinct partition key batches and send each batch + for (part_key, mut builder) in take_map.into_iter() { + // Take method adapted from https://github.com/lancedb/lance/pull/1337/files + // TODO: upstream RecordBatch::take to arrow-rs Review Comment: ```suggestion // TODO: use RecordBatch::take in arrow-rs https://github.com/apache/arrow-rs/issues/4958 ``` ########## datafusion/core/src/datasource/file_format/write/demux.rs: ########## @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! dividing input stream into multiple output files at execution time + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::datasource::listing::ListingTableUrl; + +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::builder::UInt64Builder; +use arrow_array::cast::AsArray; +use arrow_array::{RecordBatch, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::cast::as_string_array; +use datafusion_common::DataFusionError; + +use datafusion_execution::TaskContext; + +use futures::StreamExt; +use object_store::path::Path; + +use rand::distributions::DistString; + +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; + +type RecordBatchReceiver = Receiver<RecordBatch>; +type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; + +/// Splits a single [SendableRecordBatchStream] into a dynamically determined +/// number of partitions at execution time. The partitions are determined by +/// factors known only at execution time, such as total number of rows and +/// partition column values. The demuxer task communicates to the caller +/// by sending channels over a channel. The inner channels send RecordBatches +/// which should be contained within the same output file. The outer channel +/// is used to send a dynamic number of inner channels, representing a dynamic +/// number of total output files. The caller is also responsible to monitor +/// the demux task for errors and abort accordingly. The single_file_ouput parameter +/// overrides all other settings to force only a single file to be written. +/// partition_by parameter will additionally split the input based on the unique +/// values of a specific column `<https://github.com/apache/arrow-datafusion/issues/7744>`` +/// ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌──────▶ │ batch 1 ├────▶...──────▶│ Batch a │ │ Output File1│ +/// │ └───────────┘ └────────────┘ └─────────────┘ +/// │ +/// ┌──────────┐ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌───────────┐ ┌────────────┐ │ │ ├──────▶ │ batch a+1├────▶...──────▶│ Batch b │ │ Output File2│ +/// │ batch 1 ├────▶...──────▶│ Batch N ├─────▶│ Demux ├────────┤ ... └───────────┘ └────────────┘ └─────────────┘ +/// └───────────┘ └────────────┘ │ │ │ +/// └──────────┘ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// └──────▶ │ batch d ├────▶...──────▶│ Batch n │ │ Output FileN│ +/// └───────────┘ └────────────┘ └─────────────┘ +pub(crate) fn start_demuxer_task( + input: SendableRecordBatchStream, + context: &Arc<TaskContext>, + partition_by: Option<Vec<(String, DataType)>>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> (JoinHandle<Result<()>>, DemuxedStreamReceiver) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let context = context.clone(); + let task: JoinHandle<std::result::Result<(), DataFusionError>> = match partition_by { + Some(parts) => { + // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot + // bound this channel without risking a deadlock. + tokio::spawn(async move { + hive_style_partitions_demuxer( + tx, + input, + context, + parts, + base_output_path, + file_extension, + ) + .await + }) + } + None => tokio::spawn(async move { + row_count_demuxer( + tx, + input, + context, + base_output_path, + file_extension, + single_file_output, + ) + .await + }), + }; + + (task, rx) +} + +/// Dynamically partitions input stream to acheive desired maximum rows per file +async fn row_count_demuxer( + mut tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> Result<()> { + let exec_options = &context.session_config().options().execution; + let max_rows_per_file = exec_options.soft_max_rows_per_output_file; + let max_buffered_batches = exec_options.max_buffered_batches_per_output_file; + let mut total_rows_current_file = 0; + let mut part_idx = 0; + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let mut tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + + while let Some(rb) = input.next().await.transpose()? { + total_rows_current_file += rb.num_rows(); + tx_file.send(rb).await.map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + + if total_rows_current_file >= max_rows_per_file && !single_file_output { + total_rows_current_file = 0; + tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + } + } + Ok(()) +} + +/// Helper for row count demuxer +fn generate_file_path( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, +) -> Path { + if !single_file_output { + base_output_path + .prefix() + .child(format!("{}_{}.{}", write_id, part_idx, file_extension)) + } else { + base_output_path.prefix().to_owned() + } +} + +/// Helper for row count demuxer +fn create_new_file_stream( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, + max_buffered_batches: usize, + tx: &mut UnboundedSender<(Path, Receiver<RecordBatch>)>, +) -> Result<Sender<RecordBatch>> { + let file_path = generate_file_path( + base_output_path, + write_id, + part_idx, + file_extension, + single_file_output, + ); + let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2); + tx.send((file_path, rx_file)).map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + Ok(tx_file) +} + +/// Splits an input stream based on the distinct values of a set of columns +/// Assumes standard hive style partition paths such as +/// /col1=val1/col2=val2/outputfile.parquet +async fn hive_style_partitions_demuxer( + tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + partition_by: Vec<(String, DataType)>, + base_output_path: ListingTableUrl, + file_extension: String, +) -> Result<()> { + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let exec_options = &context.session_config().options().execution; + let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file; + + // To support non string partition col types, cast the type to &str first + let mut value_map: HashMap<Vec<String>, Sender<RecordBatch>> = HashMap::new(); + + while let Some(rb) = input.next().await.transpose()? { + // First compute partition key for each row of batch, e.g. (col1=val1, col2=val2, ...) + let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?; + + // Next compute how the batch should be split up to take each distinct key to its own batch + let take_map = compute_take_arrays(&rb, all_partition_values); + + // Divide up the batch into distinct partition key batches and send each batch + for (part_key, mut builder) in take_map.into_iter() { + // Take method adapted from https://github.com/lancedb/lance/pull/1337/files + // TODO: upstream RecordBatch::take to arrow-rs + let take_indices = builder.finish(); + let struct_array: StructArray = rb.clone().into(); + let parted_batch = RecordBatch::try_from( + arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(), + ) + .map_err(|_| { + DataFusionError::Internal("Unexpected error partitioning batch!".into()) + })?; + + // Get or create channel for this batch + let part_tx = match value_map.get_mut(&part_key) { + Some(part_tx) => part_tx, + None => { + // Create channel for previously unseen distinct partition key and notify consumer of new file + let (part_tx, part_rx) = tokio::sync::mpsc::channel::<RecordBatch>( + max_buffered_recordbatches, + ); + let file_path = compute_hive_style_file_path( + &part_key, + &partition_by, + &write_id, + &file_extension, + &base_output_path, + ); + + tx.send((file_path, part_rx)).map_err(|_| { + DataFusionError::Execution( + "Error sending new file stream!".into(), + ) + })?; + + value_map.insert(part_key.clone(), part_tx); + value_map + .get_mut(&part_key) + .ok_or(DataFusionError::Internal( + "Key must exist since it was just inserted!".into(), + ))? + } + }; + + // remove partitions columns + let final_batch_to_send = + remove_partition_by_columns(&parted_batch, &partition_by)?; + + // Finally send the partial batch partitioned by distinct value! + part_tx.send(final_batch_to_send).await.map_err(|_| { + DataFusionError::Internal("Unexpected error sending parted batch!".into()) + })?; + } + } + + Ok(()) +} + +fn compute_partition_keys_by_row<'a>( + rb: &'a RecordBatch, + partition_by: &'a [(String, DataType)], +) -> Result<Vec<Vec<&'a str>>> { + let mut all_partition_values = vec![]; + + for (col, dtype) in partition_by.iter() { + let mut partition_values = vec![]; + let col_array = + rb.column_by_name(col) + .ok_or(DataFusionError::Execution(format!( + "PartitionBy Column {} does not exist in source data!", + col + )))?; + + match dtype { + DataType::Utf8 => { + let array = as_string_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(array.value(i)); + } + } + _ => { + return Err(DataFusionError::NotImplemented(format!( + "it is not yet supported to write to hive partitions with datatype {}", + dtype + ))) + } + } + + all_partition_values.push(partition_values); + } + + Ok(all_partition_values) +} + +fn compute_take_arrays( + rb: &RecordBatch, + all_partition_values: Vec<Vec<&str>>, +) -> HashMap<Vec<String>, UInt64Builder> { + let mut take_map = HashMap::new(); + for i in 0..rb.num_rows() { + let mut part_key = vec![]; + for vals in all_partition_values.iter() { + part_key.push(vals[i].to_owned()); + } + let builder = take_map.entry(part_key).or_insert(UInt64Builder::new()); + builder.append_value(i as u64); + } + take_map +} + +fn remove_partition_by_columns( + parted_batch: &RecordBatch, + partition_by: &Vec<(String, DataType)>, +) -> Result<RecordBatch> { + let end_idx = parted_batch.num_columns() - partition_by.len(); Review Comment: You could probably express this more concisely using `[filter()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.filter) to discard any partition columns ########## datafusion/core/src/datasource/file_format/write/demux.rs: ########## @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! dividing input stream into multiple output files at execution time + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::datasource::listing::ListingTableUrl; + +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::builder::UInt64Builder; +use arrow_array::cast::AsArray; +use arrow_array::{RecordBatch, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::cast::as_string_array; +use datafusion_common::DataFusionError; + +use datafusion_execution::TaskContext; + +use futures::StreamExt; +use object_store::path::Path; + +use rand::distributions::DistString; + +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; + +type RecordBatchReceiver = Receiver<RecordBatch>; +type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; + +/// Splits a single [SendableRecordBatchStream] into a dynamically determined +/// number of partitions at execution time. The partitions are determined by +/// factors known only at execution time, such as total number of rows and +/// partition column values. The demuxer task communicates to the caller +/// by sending channels over a channel. The inner channels send RecordBatches +/// which should be contained within the same output file. The outer channel +/// is used to send a dynamic number of inner channels, representing a dynamic +/// number of total output files. The caller is also responsible to monitor +/// the demux task for errors and abort accordingly. The single_file_ouput parameter +/// overrides all other settings to force only a single file to be written. +/// partition_by parameter will additionally split the input based on the unique +/// values of a specific column `<https://github.com/apache/arrow-datafusion/issues/7744>`` +/// ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌──────▶ │ batch 1 ├────▶...──────▶│ Batch a │ │ Output File1│ +/// │ └───────────┘ └────────────┘ └─────────────┘ +/// │ +/// ┌──────────┐ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌───────────┐ ┌────────────┐ │ │ ├──────▶ │ batch a+1├────▶...──────▶│ Batch b │ │ Output File2│ +/// │ batch 1 ├────▶...──────▶│ Batch N ├─────▶│ Demux ├────────┤ ... └───────────┘ └────────────┘ └─────────────┘ +/// └───────────┘ └────────────┘ │ │ │ +/// └──────────┘ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// └──────▶ │ batch d ├────▶...──────▶│ Batch n │ │ Output FileN│ +/// └───────────┘ └────────────┘ └─────────────┘ +pub(crate) fn start_demuxer_task( + input: SendableRecordBatchStream, + context: &Arc<TaskContext>, + partition_by: Option<Vec<(String, DataType)>>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> (JoinHandle<Result<()>>, DemuxedStreamReceiver) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let context = context.clone(); + let task: JoinHandle<std::result::Result<(), DataFusionError>> = match partition_by { + Some(parts) => { + // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot + // bound this channel without risking a deadlock. + tokio::spawn(async move { + hive_style_partitions_demuxer( + tx, + input, + context, + parts, + base_output_path, + file_extension, + ) + .await + }) + } + None => tokio::spawn(async move { + row_count_demuxer( + tx, + input, + context, + base_output_path, + file_extension, + single_file_output, + ) + .await + }), + }; + + (task, rx) +} + +/// Dynamically partitions input stream to acheive desired maximum rows per file +async fn row_count_demuxer( + mut tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> Result<()> { + let exec_options = &context.session_config().options().execution; + let max_rows_per_file = exec_options.soft_max_rows_per_output_file; + let max_buffered_batches = exec_options.max_buffered_batches_per_output_file; + let mut total_rows_current_file = 0; + let mut part_idx = 0; + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let mut tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + + while let Some(rb) = input.next().await.transpose()? { + total_rows_current_file += rb.num_rows(); + tx_file.send(rb).await.map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + + if total_rows_current_file >= max_rows_per_file && !single_file_output { + total_rows_current_file = 0; + tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + } + } + Ok(()) +} + +/// Helper for row count demuxer +fn generate_file_path( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, +) -> Path { + if !single_file_output { + base_output_path + .prefix() + .child(format!("{}_{}.{}", write_id, part_idx, file_extension)) + } else { + base_output_path.prefix().to_owned() + } +} + +/// Helper for row count demuxer +fn create_new_file_stream( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, + max_buffered_batches: usize, + tx: &mut UnboundedSender<(Path, Receiver<RecordBatch>)>, +) -> Result<Sender<RecordBatch>> { + let file_path = generate_file_path( + base_output_path, + write_id, + part_idx, + file_extension, + single_file_output, + ); + let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2); + tx.send((file_path, rx_file)).map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + Ok(tx_file) +} + +/// Splits an input stream based on the distinct values of a set of columns +/// Assumes standard hive style partition paths such as +/// /col1=val1/col2=val2/outputfile.parquet +async fn hive_style_partitions_demuxer( + tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + partition_by: Vec<(String, DataType)>, + base_output_path: ListingTableUrl, + file_extension: String, +) -> Result<()> { + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let exec_options = &context.session_config().options().execution; + let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file; + + // To support non string partition col types, cast the type to &str first + let mut value_map: HashMap<Vec<String>, Sender<RecordBatch>> = HashMap::new(); + + while let Some(rb) = input.next().await.transpose()? { + // First compute partition key for each row of batch, e.g. (col1=val1, col2=val2, ...) + let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?; + + // Next compute how the batch should be split up to take each distinct key to its own batch + let take_map = compute_take_arrays(&rb, all_partition_values); + + // Divide up the batch into distinct partition key batches and send each batch + for (part_key, mut builder) in take_map.into_iter() { + // Take method adapted from https://github.com/lancedb/lance/pull/1337/files + // TODO: upstream RecordBatch::take to arrow-rs + let take_indices = builder.finish(); + let struct_array: StructArray = rb.clone().into(); + let parted_batch = RecordBatch::try_from( + arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(), + ) + .map_err(|_| { + DataFusionError::Internal("Unexpected error partitioning batch!".into()) + })?; + + // Get or create channel for this batch + let part_tx = match value_map.get_mut(&part_key) { + Some(part_tx) => part_tx, + None => { + // Create channel for previously unseen distinct partition key and notify consumer of new file + let (part_tx, part_rx) = tokio::sync::mpsc::channel::<RecordBatch>( + max_buffered_recordbatches, + ); + let file_path = compute_hive_style_file_path( + &part_key, + &partition_by, + &write_id, + &file_extension, + &base_output_path, + ); + + tx.send((file_path, part_rx)).map_err(|_| { + DataFusionError::Execution( + "Error sending new file stream!".into(), + ) + })?; + + value_map.insert(part_key.clone(), part_tx); + value_map + .get_mut(&part_key) + .ok_or(DataFusionError::Internal( + "Key must exist since it was just inserted!".into(), + ))? + } + }; + + // remove partitions columns + let final_batch_to_send = + remove_partition_by_columns(&parted_batch, &partition_by)?; + + // Finally send the partial batch partitioned by distinct value! + part_tx.send(final_batch_to_send).await.map_err(|_| { + DataFusionError::Internal("Unexpected error sending parted batch!".into()) + })?; + } + } + + Ok(()) +} + +fn compute_partition_keys_by_row<'a>( + rb: &'a RecordBatch, + partition_by: &'a [(String, DataType)], +) -> Result<Vec<Vec<&'a str>>> { + let mut all_partition_values = vec![]; + + for (col, dtype) in partition_by.iter() { + let mut partition_values = vec![]; + let col_array = + rb.column_by_name(col) + .ok_or(DataFusionError::Execution(format!( + "PartitionBy Column {} does not exist in source data!", + col + )))?; + + match dtype { + DataType::Utf8 => { + let array = as_string_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(array.value(i)); + } + } + _ => { + return Err(DataFusionError::NotImplemented(format!( + "it is not yet supported to write to hive partitions with datatype {}", + dtype + ))) + } + } + + all_partition_values.push(partition_values); + } + + Ok(all_partition_values) +} + +fn compute_take_arrays( + rb: &RecordBatch, + all_partition_values: Vec<Vec<&str>>, +) -> HashMap<Vec<String>, UInt64Builder> { + let mut take_map = HashMap::new(); + for i in 0..rb.num_rows() { + let mut part_key = vec![]; + for vals in all_partition_values.iter() { + part_key.push(vals[i].to_owned()); + } + let builder = take_map.entry(part_key).or_insert(UInt64Builder::new()); + builder.append_value(i as u64); + } + take_map +} + +fn remove_partition_by_columns( + parted_batch: &RecordBatch, + partition_by: &Vec<(String, DataType)>, +) -> Result<RecordBatch> { + let end_idx = parted_batch.num_columns() - partition_by.len(); + let non_part_cols = &parted_batch.columns()[..end_idx]; + let mut non_part_fields = vec![]; + 'outer: for field in parted_batch.schema().all_fields() { + let name = field.name(); + for (part_name, _) in partition_by.iter() { + if name == part_name { + continue 'outer; + } + } + non_part_fields.push(field.to_owned()) + } + let schema = Schema::new(non_part_fields); + let final_batch_to_send = + RecordBatch::try_new(Arc::new(schema), non_part_cols.into())?; + + Ok(final_batch_to_send) +} + +fn compute_hive_style_file_path( + part_key: &Vec<String>, + partition_by: &[(String, DataType)], + write_id: &str, + file_extension: &str, + base_output_path: &ListingTableUrl, +) -> Path { + let mut file_path = base_output_path.prefix().clone(); + for j in 0..part_key.len() { + file_path = file_path.child(format!("{}={}", partition_by[j].0, part_key[j])); + } + + file_path.child(format!("{}.{}", write_id, file_extension)) +} Review Comment: What would you think about writing some unit tests of the partition extraction and schema pruning code? Perhaps those tests could be used to cover the issue I found while testing this end to end? ########## datafusion/core/src/datasource/file_format/write/mod.rs: ########## @@ -0,0 +1,305 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! write support for the various file formats + +use std::io::Error; +use std::mem; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::datasource::file_format::file_compression_type::FileCompressionType; + +use crate::datasource::physical_plan::FileMeta; +use crate::error::Result; + +use arrow_array::RecordBatch; + +use datafusion_common::{exec_err, DataFusionError}; + +use async_trait::async_trait; +use bytes::Bytes; + +use futures::future::BoxFuture; +use futures::ready; +use futures::FutureExt; +use object_store::path::Path; +use object_store::{MultipartId, ObjectMeta, ObjectStore}; + +use tokio::io::AsyncWrite; + +pub(crate) mod demux; +pub(crate) mod orchestration; + +/// `AsyncPutWriter` is an object that facilitates asynchronous writing to object stores. +/// It is specifically designed for the `object_store` crate's `put` method and sends +/// whole bytes at once when the buffer is flushed. +pub struct AsyncPutWriter { + /// Object metadata + object_meta: ObjectMeta, + /// A shared reference to the object store + store: Arc<dyn ObjectStore>, + /// A buffer that stores the bytes to be sent + current_buffer: Vec<u8>, + /// Used for async handling in flush method + inner_state: AsyncPutState, +} + +impl AsyncPutWriter { + /// Constructor for the `AsyncPutWriter` object + pub fn new(object_meta: ObjectMeta, store: Arc<dyn ObjectStore>) -> Self { + Self { + object_meta, + store, + current_buffer: vec![], + // The writer starts out in buffering mode + inner_state: AsyncPutState::Buffer, + } + } + + /// Separate implementation function that unpins the [`AsyncPutWriter`] so + /// that partial borrows work correctly + fn poll_shutdown_inner( + &mut self, + cx: &mut Context<'_>, + ) -> Poll<std::result::Result<(), Error>> { + loop { + match &mut self.inner_state { + AsyncPutState::Buffer => { + // Convert the current buffer to bytes and take ownership of it + let bytes = Bytes::from(mem::take(&mut self.current_buffer)); + // Set the inner state to Put variant with the bytes + self.inner_state = AsyncPutState::Put { bytes } + } + AsyncPutState::Put { bytes } => { + // Send the bytes to the object store's put method + return Poll::Ready( + ready!(self + .store + .put(&self.object_meta.location, bytes.clone()) + .poll_unpin(cx)) + .map_err(Error::from), + ); + } + } + } + } +} + +/// An enum that represents the inner state of AsyncPut +enum AsyncPutState { + /// Building Bytes struct in this state + Buffer, + /// Data in the buffer is being sent to the object store + Put { bytes: Bytes }, +} + +impl AsyncWrite for AsyncPutWriter { + // Define the implementation of the AsyncWrite trait for the `AsyncPutWriter` struct + fn poll_write( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll<std::result::Result<usize, Error>> { + // Extend the current buffer with the incoming buffer + self.current_buffer.extend_from_slice(buf); + // Return a ready poll with the length of the incoming buffer + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll<std::result::Result<(), Error>> { + // Return a ready poll with an empty result + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<std::result::Result<(), Error>> { + // Call the poll_shutdown_inner method to handle the actual sending of data to the object store + self.poll_shutdown_inner(cx) + } +} + +/// Stores data needed during abortion of MultiPart writers +pub(crate) struct MultiPart { + /// A shared reference to the object store + store: Arc<dyn ObjectStore>, + multipart_id: MultipartId, + location: Path, +} + +impl MultiPart { + /// Create a new `MultiPart` + pub fn new( + store: Arc<dyn ObjectStore>, + multipart_id: MultipartId, + location: Path, + ) -> Self { + Self { + store, + multipart_id, + location, + } + } +} + +pub(crate) enum AbortMode { + Put, + Append, + MultiPart(MultiPart), +} + +/// A wrapper struct with abort method and writer +pub(crate) struct AbortableWrite<W: AsyncWrite + Unpin + Send> { + writer: W, + mode: AbortMode, +} + +impl<W: AsyncWrite + Unpin + Send> AbortableWrite<W> { + /// Create a new `AbortableWrite` instance with the given writer, and write mode. + pub(crate) fn new(writer: W, mode: AbortMode) -> Self { + Self { writer, mode } + } + + /// handling of abort for different write modes + pub(crate) fn abort_writer(&self) -> Result<BoxFuture<'static, Result<()>>> { + match &self.mode { + AbortMode::Put => Ok(async { Ok(()) }.boxed()), + AbortMode::Append => exec_err!("Cannot abort in append mode"), + AbortMode::MultiPart(MultiPart { + store, + multipart_id, + location, + }) => { + let location = location.clone(); + let multipart_id = multipart_id.clone(); + let store = store.clone(); + Ok(Box::pin(async move { + store + .abort_multipart(&location, &multipart_id) + .await + .map_err(DataFusionError::ObjectStore) + })) + } + } + } +} + +impl<W: AsyncWrite + Unpin + Send> AsyncWrite for AbortableWrite<W> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<std::result::Result<usize, Error>> { + Pin::new(&mut self.get_mut().writer).poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<std::result::Result<(), Error>> { + Pin::new(&mut self.get_mut().writer).poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<std::result::Result<(), Error>> { + Pin::new(&mut self.get_mut().writer).poll_shutdown(cx) + } +} + +/// An enum that defines different file writer modes. +#[derive(Debug, Clone, Copy)] +pub enum FileWriterMode { + /// Data is appended to an existing file. + Append, + /// Data is written to a new file. + Put, + /// Data is written to a new file in multiple parts. + PutMultipart, +} +/// A trait that defines the methods required for a RecordBatch serializer. +#[async_trait] +pub trait BatchSerializer: Unpin + Send { Review Comment: once we get your outstanding PRs merged, it might be a good time to refactor the code into (even more) modules ########## datafusion/core/src/datasource/file_format/write/demux.rs: ########## @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! dividing input stream into multiple output files at execution time + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::datasource::listing::ListingTableUrl; + +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::builder::UInt64Builder; +use arrow_array::cast::AsArray; +use arrow_array::{RecordBatch, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::cast::as_string_array; +use datafusion_common::DataFusionError; + +use datafusion_execution::TaskContext; + +use futures::StreamExt; +use object_store::path::Path; + +use rand::distributions::DistString; + +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; + +type RecordBatchReceiver = Receiver<RecordBatch>; +type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; + +/// Splits a single [SendableRecordBatchStream] into a dynamically determined +/// number of partitions at execution time. The partitions are determined by +/// factors known only at execution time, such as total number of rows and +/// partition column values. The demuxer task communicates to the caller +/// by sending channels over a channel. The inner channels send RecordBatches +/// which should be contained within the same output file. The outer channel +/// is used to send a dynamic number of inner channels, representing a dynamic +/// number of total output files. The caller is also responsible to monitor +/// the demux task for errors and abort accordingly. The single_file_ouput parameter +/// overrides all other settings to force only a single file to be written. +/// partition_by parameter will additionally split the input based on the unique +/// values of a specific column `<https://github.com/apache/arrow-datafusion/issues/7744>`` +/// ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌──────▶ │ batch 1 ├────▶...──────▶│ Batch a │ │ Output File1│ +/// │ └───────────┘ └────────────┘ └─────────────┘ +/// │ +/// ┌──────────┐ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌───────────┐ ┌────────────┐ │ │ ├──────▶ │ batch a+1├────▶...──────▶│ Batch b │ │ Output File2│ +/// │ batch 1 ├────▶...──────▶│ Batch N ├─────▶│ Demux ├────────┤ ... └───────────┘ └────────────┘ └─────────────┘ +/// └───────────┘ └────────────┘ │ │ │ +/// └──────────┘ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// └──────▶ │ batch d ├────▶...──────▶│ Batch n │ │ Output FileN│ +/// └───────────┘ └────────────┘ └─────────────┘ +pub(crate) fn start_demuxer_task( + input: SendableRecordBatchStream, + context: &Arc<TaskContext>, + partition_by: Option<Vec<(String, DataType)>>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> (JoinHandle<Result<()>>, DemuxedStreamReceiver) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let context = context.clone(); + let task: JoinHandle<std::result::Result<(), DataFusionError>> = match partition_by { + Some(parts) => { + // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot + // bound this channel without risking a deadlock. + tokio::spawn(async move { + hive_style_partitions_demuxer( + tx, + input, + context, + parts, + base_output_path, + file_extension, + ) + .await + }) + } + None => tokio::spawn(async move { + row_count_demuxer( + tx, + input, + context, + base_output_path, + file_extension, + single_file_output, + ) + .await + }), + }; + + (task, rx) +} + +/// Dynamically partitions input stream to acheive desired maximum rows per file +async fn row_count_demuxer( + mut tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> Result<()> { + let exec_options = &context.session_config().options().execution; + let max_rows_per_file = exec_options.soft_max_rows_per_output_file; + let max_buffered_batches = exec_options.max_buffered_batches_per_output_file; + let mut total_rows_current_file = 0; + let mut part_idx = 0; + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let mut tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + + while let Some(rb) = input.next().await.transpose()? { + total_rows_current_file += rb.num_rows(); + tx_file.send(rb).await.map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + + if total_rows_current_file >= max_rows_per_file && !single_file_output { + total_rows_current_file = 0; + tx_file = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + } + } + Ok(()) +} + +/// Helper for row count demuxer +fn generate_file_path( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, +) -> Path { + if !single_file_output { + base_output_path + .prefix() + .child(format!("{}_{}.{}", write_id, part_idx, file_extension)) + } else { + base_output_path.prefix().to_owned() + } +} + +/// Helper for row count demuxer +fn create_new_file_stream( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, + max_buffered_batches: usize, + tx: &mut UnboundedSender<(Path, Receiver<RecordBatch>)>, +) -> Result<Sender<RecordBatch>> { + let file_path = generate_file_path( + base_output_path, + write_id, + part_idx, + file_extension, + single_file_output, + ); + let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2); + tx.send((file_path, rx_file)).map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + Ok(tx_file) +} + +/// Splits an input stream based on the distinct values of a set of columns +/// Assumes standard hive style partition paths such as +/// /col1=val1/col2=val2/outputfile.parquet +async fn hive_style_partitions_demuxer( + tx: UnboundedSender<(Path, Receiver<RecordBatch>)>, + mut input: SendableRecordBatchStream, + context: Arc<TaskContext>, + partition_by: Vec<(String, DataType)>, + base_output_path: ListingTableUrl, + file_extension: String, +) -> Result<()> { + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let exec_options = &context.session_config().options().execution; + let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file; + + // To support non string partition col types, cast the type to &str first + let mut value_map: HashMap<Vec<String>, Sender<RecordBatch>> = HashMap::new(); + + while let Some(rb) = input.next().await.transpose()? { + // First compute partition key for each row of batch, e.g. (col1=val1, col2=val2, ...) + let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?; + + // Next compute how the batch should be split up to take each distinct key to its own batch + let take_map = compute_take_arrays(&rb, all_partition_values); + + // Divide up the batch into distinct partition key batches and send each batch + for (part_key, mut builder) in take_map.into_iter() { + // Take method adapted from https://github.com/lancedb/lance/pull/1337/files + // TODO: upstream RecordBatch::take to arrow-rs Review Comment: I filed https://github.com/apache/arrow-rs/issues/4958 to upstream the issue to arrow-rs -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
