This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new c1cdf46b fix: Input batch to ShuffleRepartitioner.insert_batch should
not be larger than configured batch size (#523)
c1cdf46b is described below
commit c1cdf46baaec3d75e1f853662588a4ab4ce8cb5a
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Thu Jun 6 00:04:14 2024 -0700
fix: Input batch to ShuffleRepartitioner.insert_batch should not be larger
than configured batch size (#523)
* fix: Input batch to ShuffleRepartitioner.insert_batch should not be
larger than configured batch size
* Add test
* For review
---
core/src/execution/datafusion/shuffle_writer.rs | 64 +++++++++++++++++++++++--
1 file changed, 61 insertions(+), 3 deletions(-)
diff --git a/core/src/execution/datafusion/shuffle_writer.rs
b/core/src/execution/datafusion/shuffle_writer.rs
index 96734097..99ac885b 100644
--- a/core/src/execution/datafusion/shuffle_writer.rs
+++ b/core/src/execution/datafusion/shuffle_writer.rs
@@ -575,6 +575,8 @@ struct ShuffleRepartitioner {
hashes_buf: Vec<u32>,
/// Partition ids for each row in the current batch
partition_ids: Vec<u64>,
+ /// The configured batch size
+ batch_size: usize,
}
struct ShuffleRepartitionerMetrics {
@@ -642,17 +644,41 @@ impl ShuffleRepartitioner {
reservation,
hashes_buf,
partition_ids,
+ batch_size,
}
}
+ /// Shuffles rows in input batch into corresponding partition buffer.
+ /// This function will slice input batch according to configured batch
size and then
+ /// shuffle rows into corresponding partition buffer.
+ async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
+ let mut start = 0;
+ while start < batch.num_rows() {
+ let end = (start + self.batch_size).min(batch.num_rows());
+ let batch = batch.slice(start, end - start);
+ self.partitioning_batch(batch).await?;
+ start = end;
+ }
+ Ok(())
+ }
+
/// Shuffles rows in input batch into corresponding partition buffer.
/// This function first calculates hashes for rows and then takes rows in
same
/// partition as a record batch which is appended into partition buffer.
- async fn insert_batch(&mut self, input: RecordBatch) -> Result<()> {
+ /// This should not be called directly. Use `insert_batch` instead.
+ async fn partitioning_batch(&mut self, input: RecordBatch) -> Result<()> {
if input.num_rows() == 0 {
// skip empty batch
return Ok(());
}
+
+ if input.num_rows() > self.batch_size {
+ return Err(DataFusionError::Internal(
+ "Input batch size exceeds configured batch size. Call
`insert_batch` instead."
+ .to_string(),
+ ));
+ }
+
let _timer = self.metrics.baseline.elapsed_compute().timer();
// NOTE: in shuffle writer exec, the output_rows metrics represents the
@@ -951,8 +977,7 @@ async fn external_shuffle(
);
while let Some(batch) = input.next().await {
- let batch = batch?;
- repartitioner.insert_batch(batch).await?;
+ repartitioner.insert_batch(batch?).await?;
}
repartitioner.shuffle_write().await
}
@@ -1387,6 +1412,11 @@ impl RecordBatchStream for EmptyStream {
#[cfg(test)]
mod test {
use super::*;
+ use datafusion::physical_plan::common::collect;
+ use datafusion::physical_plan::memory::MemoryExec;
+ use datafusion::prelude::SessionContext;
+ use datafusion_physical_expr::expressions::Column;
+ use tokio::runtime::Runtime;
#[test]
fn test_slot_size() {
@@ -1415,4 +1445,32 @@ mod test {
assert_eq!(slot_size, *expected);
})
}
+
+ #[test]
+ fn test_insert_larger_batch() {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8,
true)]));
+ let mut b = StringBuilder::new();
+ for i in 0..10000 {
+ b.append_value(format!("{i}"));
+ }
+ let array = b.finish();
+ let batch = RecordBatch::try_new(schema.clone(),
vec![Arc::new(array)]).unwrap();
+
+ let mut batches = Vec::new();
+ batches.push(batch.clone());
+
+ let partitions = &[batches];
+ let exec = ShuffleWriterExec::try_new(
+ Arc::new(MemoryExec::try_new(partitions, batch.schema(),
None).unwrap()),
+ Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 16),
+ "/tmp/data.out".to_string(),
+ "/tmp/index.out".to_string(),
+ )
+ .unwrap();
+ let ctx = SessionContext::new();
+ let task_ctx = ctx.task_ctx();
+ let stream = exec.execute(0, task_ctx).unwrap();
+ let rt = Runtime::new().unwrap();
+ rt.block_on(collect(stream)).unwrap();
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]