This is an automated email from the ASF dual-hosted git repository.
jorgecarleitao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 84e4b15 ARROW-10320 [Rust] [DataFusion] Migrated from batch iterators
to batch streams.
84e4b15 is described below
commit 84e4b15889b7efda27115c51ad88cf20519269f4
Author: Jorge C. Leitao <[email protected]>
AuthorDate: Tue Oct 20 20:23:11 2020 +0200
ARROW-10320 [Rust] [DataFusion] Migrated from batch iterators to batch
streams.
Recently, we introduced `async` to `execute`. This allowed us to
parallelize multiple partitions as we denote an execution of a part (of a
partition) as the unit of work. However, a part is often a large task composing
multiple batches and steps.
This PR makes all our execution nodes return a dynamically-typed
[`Stream<Item =
ArrowResult<RecordBatch>>`](https://docs.rs/futures/0.3.6/futures/stream/trait.Stream.html)
instead of an Iterator. For reference, a Stream is an iterator of futures,
which in this case is a future of a `Result<RecordBatch>`.
This effectively breaks the execution in smaller units of work (on which an
individual unit is an operation returns a `Result<RecordBatch>`) allowing each
task to chew smaller bits.
This adds `futures` as a direct dependency of DataFusion (it was only a
dev-dependency).
This leads to a +2% degradation in aggregates in micro benchmarking, which
IMO is expected given that there is more context switching to handle. However,
I expect (hope?) this to be independent of the number of batches and
partitions, and be offset by any async work we perform to our sources (readers)
and sinks (writers).
I did not take the time to optimize - the primary goal was to implement the
idea, have it compile and pass tests, and have some discussion about it. I
expect that we should be able to replace some of our operations by `join_all`,
thereby scheduling multiple tasks at once (instead of waiting one by one).
<details>
<summary>Benchmarks</summary>
Aggregates:
```
aggregate_query_no_group_by 15 12
time: [782.11 us 784.12 us 786.19 us]
change: [+1.1421% +2.5252% +3.8963%] (p = 0.00 <
0.05)
Performance has regressed.
Found 10 outliers among 100 measurements (10.00%)
4 (4.00%) high mild
6 (6.00%) high severe
aggregate_query_group_by 15 12
time: [5.8751 ms 5.9206 ms 5.9679 ms]
change: [+1.0645% +2.0027% +3.0333%] (p = 0.00 <
0.05)
Performance has regressed.
aggregate_query_group_by_with_filter 15 12
time: [2.7652 ms 2.7983 ms 2.8340 ms]
change: [+0.3278% +1.8981% +3.3819%] (p = 0.02 <
0.05)
Change within noise threshold.
```
Math:
```
sqrt_20_9 time: [6.9844 ms 7.0582 ms 7.1363 ms]
change: [+0.0557% +1.5625% +3.0408%] (p = 0.05 <
0.05)
Change within noise threshold.
Found 3 outliers among 100 measurements (3.00%)
2 (2.00%) high mild
1 (1.00%) high severe
sqrt_20_12 time: [2.8350 ms 2.9504 ms 3.1204 ms]
change: [+3.8751% +8.2857% +14.671%] (p = 0.00 <
0.05)
Performance has regressed.
Found 5 outliers among 100 measurements (5.00%)
2 (2.00%) high mild
3 (3.00%) high severe
sqrt_22_12 time: [14.888 ms 15.242 ms 15.620 ms]
change: [+7.6388% +10.709% +14.098%] (p = 0.00 <
0.05)
Performance has regressed.
Found 5 outliers among 100 measurements (5.00%)
3 (3.00%) high mild
2 (2.00%) high severe
sqrt_22_14 time: [23.710 ms 23.817 ms 23.953 ms]
change: [-4.3401% -3.1824% -2.0952%] (p = 0.00 <
0.05)
Performance has improved.
Found 11 outliers among 100 measurements (11.00%)
5 (5.00%) high mild
6 (6.00%) high severe
```
</details>
I admit that this is a bit outside my comfort zone, and someone with more
experience in `async/await` could be of help.
IMO this would integrate very nicely with ARROW-10307, ARROW-9275, I
_think_ it would also help ARROW-9707, and I _think_ that it also opens the
possibility consuming / producing batches from/to sources and sinks from
arrow-flight / IPC
Closes #8473 from jorgecarleitao/streams
Authored-by: Jorge C. Leitao <[email protected]>
Signed-off-by: Jorge C. Leitao <[email protected]>
---
rust/datafusion/Cargo.toml | 2 +-
rust/datafusion/src/datasource/memory.rs | 34 ++--
rust/datafusion/src/datasource/parquet.rs | 10 +-
rust/datafusion/src/execution/context.rs | 17 +-
rust/datafusion/src/physical_plan/common.rs | 36 ++--
rust/datafusion/src/physical_plan/csv.rs | 33 ++--
rust/datafusion/src/physical_plan/empty.rs | 10 +-
rust/datafusion/src/physical_plan/explain.rs | 8 +-
rust/datafusion/src/physical_plan/filter.rs | 102 +++++-----
.../datafusion/src/physical_plan/hash_aggregate.rs | 210 ++++++++++++---------
rust/datafusion/src/physical_plan/limit.rs | 25 +--
rust/datafusion/src/physical_plan/memory.rs | 32 ++--
rust/datafusion/src/physical_plan/merge.rs | 53 +++---
rust/datafusion/src/physical_plan/mod.rs | 21 ++-
rust/datafusion/src/physical_plan/parquet.rs | 38 ++--
rust/datafusion/src/physical_plan/planner.rs | 4 +-
rust/datafusion/src/physical_plan/projection.rs | 70 ++++---
rust/datafusion/src/physical_plan/sort.rs | 10 +-
rust/datafusion/tests/user_defined_plan.rs | 99 ++++++----
19 files changed, 469 insertions(+), 345 deletions(-)
diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml
index f99756d..380e04c 100644
--- a/rust/datafusion/Cargo.toml
+++ b/rust/datafusion/Cargo.toml
@@ -55,13 +55,13 @@ paste = "0.1"
num_cpus = "1.13.0"
chrono = "0.4"
async-trait = "0.1.41"
+futures = "0.3"
tokio = { version = "0.2", features = ["macros", "rt-core", "rt-threaded"] }
[dev-dependencies]
rand = "0.7"
criterion = "0.3"
tempfile = "3"
-futures = "0.3"
prost = "0.6"
arrow-flight = { path = "../arrow-flight", version = "3.0.0-SNAPSHOT" }
tonic = "0.3"
diff --git a/rust/datafusion/src/datasource/memory.rs
b/rust/datafusion/src/datasource/memory.rs
index b454315..28dc2a3 100644
--- a/rust/datafusion/src/datasource/memory.rs
+++ b/rust/datafusion/src/datasource/memory.rs
@@ -22,16 +22,14 @@
use std::sync::Arc;
use arrow::datatypes::{Field, Schema, SchemaRef};
-use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use crate::datasource::TableProvider;
use crate::error::{ExecutionError, Result};
+use crate::physical_plan::common;
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::ExecutionPlan;
-use tokio::task::{self, JoinHandle};
-
/// In-memory table
pub struct MemTable {
schema: SchemaRef,
@@ -63,19 +61,20 @@ impl MemTable {
let exec = t.scan(&None, batch_size)?;
let partition_count = exec.output_partitioning().partition_count();
- let mut tasks = Vec::with_capacity(partition_count);
- for partition in 0..partition_count {
- let exec = exec.clone();
- let task: JoinHandle<Result<Vec<RecordBatch>>> = task::spawn(async
move {
- let it = exec.execute(partition).await?;
- it.into_iter()
- .collect::<ArrowResult<Vec<RecordBatch>>>()
- .map_err(ExecutionError::from)
- });
- tasks.push(task)
- }
+ let tasks = (0..partition_count)
+ .map(|part_i| {
+ let exec = exec.clone();
+ tokio::spawn(async move {
+ let stream = exec.execute(part_i).await?;
+ common::collect(stream).await
+ })
+ })
+ // this collect *is needed* so that the join below can
+ // switch between tasks
+ .collect::<Vec<_>>();
- let mut data: Vec<Vec<RecordBatch>> =
Vec::with_capacity(partition_count);
+ let mut data: Vec<Vec<RecordBatch>> =
+ Vec::with_capacity(exec.output_partitioning().partition_count());
for task in tasks {
let result = task.await.expect("MemTable::load could not join
task")?;
data.push(result);
@@ -135,6 +134,7 @@ mod tests {
use super::*;
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
+ use futures::StreamExt;
#[tokio::test]
async fn test_with_projection() -> Result<()> {
@@ -158,7 +158,7 @@ mod tests {
// scan with projection
let exec = provider.scan(&Some(vec![2, 1]), 1024)?;
let mut it = exec.execute(0).await?;
- let batch2 = it.next().unwrap()?;
+ let batch2 = it.next().await.unwrap()?;
assert_eq!(2, batch2.schema().fields().len());
assert_eq!("c", batch2.schema().field(0).name());
assert_eq!("b", batch2.schema().field(1).name());
@@ -188,7 +188,7 @@ mod tests {
let exec = provider.scan(&None, 1024)?;
let mut it = exec.execute(0).await?;
- let batch1 = it.next().unwrap()?;
+ let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
diff --git a/rust/datafusion/src/datasource/parquet.rs
b/rust/datafusion/src/datasource/parquet.rs
index 5308246..ca5ac6b 100644
--- a/rust/datafusion/src/datasource/parquet.rs
+++ b/rust/datafusion/src/datasource/parquet.rs
@@ -75,6 +75,7 @@ mod tests {
TimestampNanosecondArray,
};
use arrow::record_batch::RecordBatch;
+ use futures::StreamExt;
use std::env;
#[tokio::test]
@@ -82,16 +83,16 @@ mod tests {
let table = load_table("alltypes_plain.parquet")?;
let projection = None;
let exec = table.scan(&projection, 2)?;
- let it = exec.execute(0).await?;
+ let stream = exec.execute(0).await?;
- let count = it
- .into_iter()
+ let count = stream
.map(|batch| {
let batch = batch.unwrap();
assert_eq!(11, batch.num_columns());
assert_eq!(2, batch.num_rows());
})
- .count();
+ .fold(0, |acc, _| async move { acc + 1i32 })
+ .await;
// we should have seen 4 batches of 2 rows
assert_eq!(4, count);
@@ -305,6 +306,7 @@ mod tests {
let exec = table.scan(projection, 1024)?;
let mut it = exec.execute(0).await?;
it.next()
+ .await
.expect("should have received at least one batch")
.map_err(|e| e.into())
}
diff --git a/rust/datafusion/src/execution/context.rs
b/rust/datafusion/src/execution/context.rs
index a2dd6c9..ce7c159 100644
--- a/rust/datafusion/src/execution/context.rs
+++ b/rust/datafusion/src/execution/context.rs
@@ -23,9 +23,10 @@ use std::path::Path;
use std::string::String;
use std::sync::Arc;
+use futures::{StreamExt, TryStreamExt};
+
use arrow::csv;
use arrow::datatypes::*;
-use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use crate::datasource::csv::CsvFile;
@@ -328,14 +329,14 @@ impl ExecutionContext {
0 => Ok(vec![]),
1 => {
let it = plan.execute(0).await?;
- common::collect(it)
+ common::collect(it).await
}
_ => {
// merge into a single partition
let plan = MergeExec::new(plan.clone());
// MergeExec must produce a single partition
assert_eq!(1, plan.output_partitioning().partition_count());
- common::collect(plan.execute(0).await?)
+ common::collect(plan.execute(0).await?).await
}
}
}
@@ -357,13 +358,13 @@ impl ExecutionContext {
let path = Path::new(&path).join(&filename);
let file = fs::File::create(path)?;
let mut writer = csv::Writer::new(file);
- let reader = plan.execute(i).await?;
+ let stream = plan.execute(i).await?;
- reader
- .into_iter()
+ stream
.map(|batch| writer.write(&batch?))
- .collect::<ArrowResult<_>>()
- .map_err(|e| ExecutionError::from(e))?
+ .try_collect()
+ .await
+ .map_err(|e| ExecutionError::from(e))?;
}
Ok(())
}
diff --git a/rust/datafusion/src/physical_plan/common.rs
b/rust/datafusion/src/physical_plan/common.rs
index 2b1e76a..8bf9baf 100644
--- a/rust/datafusion/src/physical_plan/common.rs
+++ b/rust/datafusion/src/physical_plan/common.rs
@@ -20,8 +20,9 @@
use std::fs;
use std::fs::metadata;
use std::sync::Arc;
+use std::task::{Context, Poll};
-use super::SendableRecordBatchReader;
+use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{ExecutionError, Result};
use array::{
@@ -31,23 +32,24 @@ use array::{
};
use arrow::datatypes::{DataType, SchemaRef};
use arrow::error::Result as ArrowResult;
-use arrow::record_batch::{RecordBatch, RecordBatchReader};
+use arrow::record_batch::RecordBatch;
use arrow::{
array::{self, ArrayRef},
datatypes::Schema,
};
+use futures::{Stream, TryStreamExt};
-/// Iterator over a vector of record batches
-pub struct RecordBatchIterator {
+/// Stream of record batches
+pub struct SizedRecordBatchStream {
schema: SchemaRef,
batches: Vec<Arc<RecordBatch>>,
index: usize,
}
-impl RecordBatchIterator {
+impl SizedRecordBatchStream {
/// Create a new RecordBatchIterator
pub fn new(schema: SchemaRef, batches: Vec<Arc<RecordBatch>>) -> Self {
- RecordBatchIterator {
+ SizedRecordBatchStream {
schema,
index: 0,
batches,
@@ -55,29 +57,33 @@ impl RecordBatchIterator {
}
}
-impl Iterator for RecordBatchIterator {
+impl Stream for SizedRecordBatchStream {
type Item = ArrowResult<RecordBatch>;
- fn next(&mut self) -> Option<Self::Item> {
- if self.index < self.batches.len() {
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ _: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ Poll::Ready(if self.index < self.batches.len() {
self.index += 1;
Some(Ok(self.batches[self.index - 1].as_ref().clone()))
} else {
None
- }
+ })
}
}
-impl RecordBatchReader for RecordBatchIterator {
+impl RecordBatchStream for SizedRecordBatchStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
-/// Create a vector of record batches from an iterator
-pub fn collect(it: SendableRecordBatchReader) -> Result<Vec<RecordBatch>> {
- it.into_iter()
- .collect::<ArrowResult<Vec<_>>>()
+/// Create a vector of record batches from a stream
+pub async fn collect(stream: SendableRecordBatchStream) ->
Result<Vec<RecordBatch>> {
+ stream
+ .try_collect::<Vec<_>>()
+ .await
.map_err(|e| ExecutionError::from(e))
}
diff --git a/rust/datafusion/src/physical_plan/csv.rs
b/rust/datafusion/src/physical_plan/csv.rs
index 32b7c26..c94f040 100644
--- a/rust/datafusion/src/physical_plan/csv.rs
+++ b/rust/datafusion/src/physical_plan/csv.rs
@@ -19,7 +19,9 @@
use std::any::Any;
use std::fs::File;
+use std::pin::Pin;
use std::sync::Arc;
+use std::task::{Context, Poll};
use crate::error::{ExecutionError, Result};
use crate::physical_plan::ExecutionPlan;
@@ -27,9 +29,10 @@ use crate::physical_plan::{common, Partitioning};
use arrow::csv;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
-use arrow::record_batch::{RecordBatch, RecordBatchReader};
+use arrow::record_batch::RecordBatch;
+use futures::Stream;
-use super::SendableRecordBatchReader;
+use super::{RecordBatchStream, SendableRecordBatchStream};
use async_trait::async_trait;
/// CSV file read option
@@ -218,8 +221,8 @@ impl ExecutionPlan for CsvExec {
}
}
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader> {
- Ok(Box::new(CsvIterator::try_new(
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
+ Ok(Box::pin(CsvStream::try_new(
&self.filenames[partition],
self.schema.clone(),
self.has_header,
@@ -231,12 +234,12 @@ impl ExecutionPlan for CsvExec {
}
/// Iterator over batches
-struct CsvIterator {
+struct CsvStream {
/// Arrow CSV reader
reader: csv::Reader<File>,
}
-impl CsvIterator {
+impl CsvStream {
/// Create an iterator for a CSV file
pub fn try_new(
filename: &str,
@@ -260,15 +263,18 @@ impl CsvIterator {
}
}
-impl Iterator for CsvIterator {
+impl Stream for CsvStream {
type Item = ArrowResult<RecordBatch>;
- fn next(&mut self) -> Option<Self::Item> {
- self.reader.next()
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ _: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ Poll::Ready(self.reader.next())
}
}
-impl RecordBatchReader for CsvIterator {
+impl RecordBatchStream for CsvStream {
/// Get the schema
fn schema(&self) -> SchemaRef {
self.reader.schema()
@@ -279,6 +285,7 @@ impl RecordBatchReader for CsvIterator {
mod tests {
use super::*;
use crate::test::{aggr_test_schema, arrow_testdata_path};
+ use futures::StreamExt;
#[tokio::test]
async fn csv_exec_with_projection() -> Result<()> {
@@ -295,8 +302,8 @@ mod tests {
assert_eq!(13, csv.schema.fields().len());
assert_eq!(3, csv.projected_schema.fields().len());
assert_eq!(3, csv.schema().fields().len());
- let mut it = csv.execute(0).await?;
- let batch = it.next().unwrap()?;
+ let mut stream = csv.execute(0).await?;
+ let batch = stream.next().await.unwrap()?;
assert_eq!(3, batch.num_columns());
let batch_schema = batch.schema();
assert_eq!(3, batch_schema.fields().len());
@@ -318,7 +325,7 @@ mod tests {
assert_eq!(13, csv.projected_schema.fields().len());
assert_eq!(13, csv.schema().fields().len());
let mut it = csv.execute(0).await?;
- let batch = it.next().unwrap()?;
+ let batch = it.next().await.unwrap()?;
assert_eq!(13, batch.num_columns());
let batch_schema = batch.schema();
assert_eq!(13, batch_schema.fields().len());
diff --git a/rust/datafusion/src/physical_plan/empty.rs
b/rust/datafusion/src/physical_plan/empty.rs
index 7e7019b..0d96479 100644
--- a/rust/datafusion/src/physical_plan/empty.rs
+++ b/rust/datafusion/src/physical_plan/empty.rs
@@ -21,11 +21,11 @@ use std::any::Any;
use std::sync::Arc;
use crate::error::{ExecutionError, Result};
-use crate::physical_plan::memory::MemoryIterator;
+use crate::physical_plan::memory::MemoryStream;
use crate::physical_plan::{Distribution, ExecutionPlan, Partitioning};
use arrow::datatypes::SchemaRef;
-use super::SendableRecordBatchReader;
+use super::SendableRecordBatchStream;
use async_trait::async_trait;
@@ -78,7 +78,7 @@ impl ExecutionPlan for EmptyExec {
}
}
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader> {
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
// GlobalLimitExec has a single output partition
if 0 != partition {
return Err(ExecutionError::General(format!(
@@ -88,7 +88,7 @@ impl ExecutionPlan for EmptyExec {
}
let data = vec![];
- Ok(Box::new(MemoryIterator::try_new(
+ Ok(Box::pin(MemoryStream::try_new(
data,
self.schema.clone(),
None,
@@ -111,7 +111,7 @@ mod tests {
// we should have no results
let iter = empty.execute(0).await?;
- let batches = common::collect(iter)?;
+ let batches = common::collect(iter).await?;
assert!(batches.is_empty());
Ok(())
diff --git a/rust/datafusion/src/physical_plan/explain.rs
b/rust/datafusion/src/physical_plan/explain.rs
index 34b824d..4a46ad6 100644
--- a/rust/datafusion/src/physical_plan/explain.rs
+++ b/rust/datafusion/src/physical_plan/explain.rs
@@ -23,13 +23,13 @@ use std::sync::Arc;
use crate::error::{ExecutionError, Result};
use crate::{
logical_plan::StringifiedPlan,
- physical_plan::{common::RecordBatchIterator, ExecutionPlan},
+ physical_plan::{common::SizedRecordBatchStream, ExecutionPlan},
};
use arrow::{array::StringBuilder, datatypes::SchemaRef,
record_batch::RecordBatch};
use crate::physical_plan::Partitioning;
-use super::SendableRecordBatchReader;
+use super::SendableRecordBatchStream;
use async_trait::async_trait;
/// Explain execution plan operator. This operator contains the string
@@ -89,7 +89,7 @@ impl ExecutionPlan for ExplainExec {
}
}
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader> {
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
if 0 != partition {
return Err(ExecutionError::General(format!(
"ExplainExec invalid partition {}",
@@ -113,7 +113,7 @@ impl ExecutionPlan for ExplainExec {
],
)?;
- Ok(Box::new(RecordBatchIterator::new(
+ Ok(Box::pin(SizedRecordBatchStream::new(
self.schema.clone(),
vec![Arc::new(record_batch)],
)))
diff --git a/rust/datafusion/src/physical_plan/filter.rs
b/rust/datafusion/src/physical_plan/filter.rs
index b1e443a..e4ea3df 100644
--- a/rust/datafusion/src/physical_plan/filter.rs
+++ b/rust/datafusion/src/physical_plan/filter.rs
@@ -19,19 +19,23 @@
//! include in its output batches.
use std::any::Any;
+use std::pin::Pin;
use std::sync::Arc;
+use std::task::{Context, Poll};
-use super::SendableRecordBatchReader;
+use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{ExecutionError, Result};
use crate::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr};
use arrow::array::BooleanArray;
use arrow::compute::filter;
use arrow::datatypes::{DataType, SchemaRef};
use arrow::error::Result as ArrowResult;
-use arrow::record_batch::{RecordBatch, RecordBatchReader};
+use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
+use futures::stream::{Stream, StreamExt};
+
/// FilterExec evaluates a boolean predicate against all input batches to
determine which rows to
/// include in its output batches.
#[derive(Debug)]
@@ -98,8 +102,8 @@ impl ExecutionPlan for FilterExec {
}
}
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader> {
- Ok(Box::new(FilterExecIter {
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
+ Ok(Box::pin(FilterExecStream {
schema: self.input.schema().clone(),
predicate: self.predicate.clone(),
input: self.input.execute(partition).await?,
@@ -107,63 +111,67 @@ impl ExecutionPlan for FilterExec {
}
}
-/// The FilterExec iterator wraps the input iterator and applies the predicate
expression to
+/// The FilterExec streams wraps the input iterator and applies the predicate
expression to
/// determine which rows to include in its output batches
-struct FilterExecIter {
+struct FilterExecStream {
/// Output schema, which is the same as the input schema for this operator
schema: SchemaRef,
/// The expression to filter on. This expression must evaluate to a
boolean value.
predicate: Arc<dyn PhysicalExpr>,
/// The input partition to filter.
- input: SendableRecordBatchReader,
+ input: SendableRecordBatchStream,
+}
+
+fn batch_filter(
+ batch: &RecordBatch,
+ predicate: &Arc<dyn PhysicalExpr>,
+) -> ArrowResult<RecordBatch> {
+ predicate
+ .evaluate(&batch)
+ .map_err(ExecutionError::into_arrow_external_error)
+ .and_then(|array| {
+ array
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .ok_or(
+ ExecutionError::InternalError(
+ "Filter predicate evaluated to non-boolean
value".to_string(),
+ )
+ .into_arrow_external_error(),
+ )
+ // apply predicate to each column
+ .and_then(|predicate| {
+ batch
+ .columns()
+ .iter()
+ .map(|column| filter(column.as_ref(), predicate))
+ .collect::<ArrowResult<Vec<_>>>()
+ })
+ })
+ // build RecordBatch
+ .and_then(|columns| RecordBatch::try_new(batch.schema().clone(),
columns))
}
-impl Iterator for FilterExecIter {
+impl Stream for FilterExecStream {
type Item = ArrowResult<RecordBatch>;
- /// Get the next batch
- fn next(&mut self) -> Option<ArrowResult<RecordBatch>> {
- match self.input.next() {
- Some(Ok(batch)) => {
- // evaluate the filter predicate to get a boolean array
indicating which rows
- // to include in the output
- Some(
- self.predicate
- .evaluate(&batch)
- .map_err(ExecutionError::into_arrow_external_error)
- .and_then(|array| {
- array
- .as_any()
- .downcast_ref::<BooleanArray>()
- .ok_or(
- ExecutionError::InternalError(
- "Filter predicate evaluated to
non-boolean value"
- .to_string(),
- )
- .into_arrow_external_error(),
- )
- // apply predicate to each column
- .and_then(|predicate| {
- batch
- .columns()
- .iter()
- .map(|column| filter(column.as_ref(),
predicate))
- .collect::<ArrowResult<Vec<_>>>()
- })
- })
- // build RecordBatch
- .and_then(|columns| {
- RecordBatch::try_new(batch.schema().clone(),
columns)
- }),
- )
- }
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ self.input.poll_next_unpin(cx).map(|x| match x {
+ Some(Ok(batch)) => Some(batch_filter(&batch, &self.predicate)),
other => other,
- }
+ })
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ // same number of record batches
+ self.input.size_hint()
}
}
-impl RecordBatchReader for FilterExecIter {
- /// Get the schema
+impl RecordBatchStream for FilterExecStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs
b/rust/datafusion/src/physical_plan/hash_aggregate.rs
index 2860c3b..9329fcc 100644
--- a/rust/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs
@@ -19,6 +19,10 @@
use std::any::Any;
use std::sync::Arc;
+use std::task::{Context, Poll};
+
+use futures::stream::{Stream, StreamExt, TryStreamExt};
+use futures::FutureExt;
use crate::error::{ExecutionError, Result};
use crate::physical_plan::{Accumulator, AggregateExpr};
@@ -27,7 +31,7 @@ use crate::physical_plan::{Distribution, ExecutionPlan,
Partitioning, PhysicalEx
use crate::arrow::array::PrimitiveArrayOps;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
-use arrow::record_batch::{RecordBatch, RecordBatchReader};
+use arrow::record_batch::RecordBatch;
use arrow::{
array::{
ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
@@ -39,7 +43,8 @@ use arrow::{
use fnv::FnvHashMap;
use super::{
- common, expressions::Column, group_scalar::GroupByScalar,
SendableRecordBatchReader,
+ common, expressions::Column, group_scalar::GroupByScalar,
RecordBatchStream,
+ SendableRecordBatchStream,
};
use async_trait::async_trait;
@@ -145,19 +150,19 @@ impl ExecutionPlan for HashAggregateExec {
self.input.output_partitioning()
}
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader> {
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
let input = self.input.execute(partition).await?;
let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect();
if self.group_expr.is_empty() {
- Ok(Box::new(HashAggregateIterator::new(
+ Ok(Box::pin(HashAggregateStream::new(
self.mode,
self.schema.clone(),
self.aggr_expr.clone(),
input,
)))
} else {
- Ok(Box::new(GroupedHashAggregateIterator::new(
+ Ok(Box::pin(GroupedHashAggregateStream::new(
self.mode.clone(),
self.schema.clone(),
group_expr,
@@ -210,12 +215,12 @@ Example: average
* Once all N record batches arrive, `merge` is performed, which builds a
RecordBatch with N rows and 2 columns.
* Finally, `get_value` returns an array with one entry computed from the state
*/
-struct GroupedHashAggregateIterator {
+struct GroupedHashAggregateStream {
mode: AggregateMode,
schema: SchemaRef,
group_expr: Vec<Arc<dyn PhysicalExpr>>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
- input: SendableRecordBatchReader,
+ input: SendableRecordBatchStream,
finished: bool,
}
@@ -223,12 +228,12 @@ fn group_aggregate_batch(
mode: &AggregateMode,
group_expr: &Vec<Arc<dyn PhysicalExpr>>,
aggr_expr: &Vec<Arc<dyn AggregateExpr>>,
- batch: &RecordBatch,
- accumulators: &mut FnvHashMap<Vec<GroupByScalar>, (AccumulatorSet,
Box<Vec<u32>>)>,
+ batch: RecordBatch,
+ mut accumulators: Accumulators,
aggregate_expressions: &Vec<Vec<Arc<dyn PhysicalExpr>>>,
-) -> Result<()> {
+) -> Result<Accumulators> {
// evaluate the grouping expressions
- let group_values = evaluate(group_expr, batch)?;
+ let group_values = evaluate(group_expr, &batch)?;
// evaluate the aggregation expressions.
// We could evaluate them after the `take`, but since we need to evaluate
all
@@ -307,19 +312,20 @@ fn group_aggregate_batch(
// 2.5
.and(Ok(indices.clear()))
})
- .collect::<Result<()>>()
+ .collect::<Result<()>>()?;
+ Ok(accumulators)
}
-impl GroupedHashAggregateIterator {
- /// Create a new HashAggregateIterator
+impl GroupedHashAggregateStream {
+ /// Create a new HashAggregateStream
pub fn new(
mode: AggregateMode,
schema: SchemaRef,
group_expr: Vec<Arc<dyn PhysicalExpr>>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
- input: SendableRecordBatchReader,
+ input: SendableRecordBatchStream,
) -> Self {
- GroupedHashAggregateIterator {
+ GroupedHashAggregateStream {
mode,
schema,
group_expr,
@@ -331,72 +337,73 @@ impl GroupedHashAggregateIterator {
}
type AccumulatorSet = Vec<Box<dyn Accumulator>>;
+type Accumulators = FnvHashMap<Vec<GroupByScalar>, (AccumulatorSet,
Box<Vec<u32>>)>;
-impl Iterator for GroupedHashAggregateIterator {
+impl Stream for GroupedHashAggregateStream {
type Item = ArrowResult<RecordBatch>;
- fn next(&mut self) -> Option<Self::Item> {
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
if self.finished {
- return None;
+ return Poll::Ready(None);
}
// return single batch
self.finished = true;
- let mode = &self.mode;
- let group_expr = &self.group_expr;
- let aggr_expr = &self.aggr_expr;
+ let mode = self.mode.clone();
+ let group_expr = self.group_expr.clone();
+ let aggr_expr = self.aggr_expr.clone();
+ let schema = self.schema.clone();
// the expressions to evaluate the batch, one vec of expressions per
aggregation
let aggregate_expressions = match aggregate_expressions(&aggr_expr,
&mode) {
Ok(e) => e,
- Err(e) => return
Some(Err(ExecutionError::into_arrow_external_error(e))),
+ Err(e) => {
+ return
Poll::Ready(Some(Err(ExecutionError::into_arrow_external_error(
+ e,
+ ))))
+ }
};
// mapping key -> (set of accumulators, indices of the key in the
batch)
// * the indexes are updated at each row
// * the accumulators are updated at the end of each batch
// * the indexes are `clear`ed at the end of each batch
- let mut accumulators: FnvHashMap<
- Vec<GroupByScalar>,
- (AccumulatorSet, Box<Vec<u32>>),
- > = FnvHashMap::default();
+ //let mut accumulators: Accumulators = FnvHashMap::default();
// iterate over all input batches and update the accumulators
- match self
- .input
- .as_mut()
- .into_iter()
- .map(|batch| {
+ let future = self.input.as_mut().try_fold(
+ Accumulators::default(),
+ |accumulators, batch| async {
group_aggregate_batch(
&mode,
&group_expr,
&aggr_expr,
- &batch?,
- &mut accumulators,
+ batch,
+ accumulators,
&aggregate_expressions,
)
.map_err(ExecutionError::into_arrow_external_error)
- })
- .collect::<ArrowResult<()>>()
- {
- Err(e) => return Some(Err(e)),
- Ok(_) => {}
- }
+ },
+ );
- Some(
- create_batch_from_map(
- &self.mode,
- &accumulators,
- self.group_expr.len(),
- &self.schema,
- )
- .map_err(ExecutionError::into_arrow_external_error),
- )
+ let future = future.map(|maybe_accumulators| {
+ maybe_accumulators.map(|accumulators| {
+ create_batch_from_map(&mode, &accumulators, group_expr.len(),
&schema)
+ })?
+ });
+
+ // send the stream to the heap, so that it outlives this function.
+ let mut combined = Box::pin(future.into_stream());
+
+ combined.poll_next_unpin(cx)
}
}
-impl RecordBatchReader for GroupedHashAggregateIterator {
+impl RecordBatchStream for GroupedHashAggregateStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
@@ -456,23 +463,23 @@ fn aggregate_expressions(
}
}
-struct HashAggregateIterator {
+struct HashAggregateStream {
mode: AggregateMode,
schema: SchemaRef,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
- input: SendableRecordBatchReader,
+ input: SendableRecordBatchStream,
finished: bool,
}
-impl HashAggregateIterator {
- /// Create a new HashAggregateIterator
+impl HashAggregateStream {
+ /// Create a new HashAggregateStream
pub fn new(
mode: AggregateMode,
schema: SchemaRef,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
- input: SendableRecordBatchReader,
+ input: SendableRecordBatchStream,
) -> Self {
- HashAggregateIterator {
+ HashAggregateStream {
mode,
schema,
aggr_expr,
@@ -485,9 +492,9 @@ impl HashAggregateIterator {
fn aggregate_batch(
mode: &AggregateMode,
batch: &RecordBatch,
- accumulators: &mut AccumulatorSet,
+ accumulators: AccumulatorSet,
expressions: &Vec<Vec<Arc<dyn PhysicalExpr>>>,
-) -> Result<()> {
+) -> Result<AccumulatorSet> {
// 1.1 iterate accumulators and respective expressions together
// 1.2 evaluate expressions
// 1.3 update / merge accumulators with the expressions' values
@@ -496,7 +503,7 @@ fn aggregate_batch(
accumulators
.into_iter()
.zip(expressions)
- .map(|(accum, expr)| {
+ .map(|(mut accum, expr)| {
// 1.2
let values = &expr
.iter()
@@ -505,62 +512,85 @@ fn aggregate_batch(
// 1.3
match mode {
- AggregateMode::Partial => accum.update_batch(values),
- AggregateMode::Final => accum.merge_batch(values),
+ AggregateMode::Partial => {
+ accum.update_batch(values)?;
+ }
+ AggregateMode::Final => {
+ accum.merge_batch(values)?;
+ }
}
+ Ok(accum)
})
- .collect::<Result<()>>()
+ .collect::<Result<Vec<_>>>()
}
-impl Iterator for HashAggregateIterator {
+impl Stream for HashAggregateStream {
type Item = ArrowResult<RecordBatch>;
- fn next(&mut self) -> Option<Self::Item> {
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
if self.finished {
- return None;
+ return Poll::Ready(None);
}
// return single batch
self.finished = true;
- let mut accumulators = match create_accumulators(&self.aggr_expr) {
+ let accumulators = match create_accumulators(&self.aggr_expr) {
Ok(e) => e,
- Err(e) => return
Some(Err(ExecutionError::into_arrow_external_error(e))),
+ Err(e) => {
+ return
Poll::Ready(Some(Err(ExecutionError::into_arrow_external_error(
+ e,
+ ))))
+ }
};
let expressions = match aggregate_expressions(&self.aggr_expr,
&self.mode) {
Ok(e) => e,
- Err(e) => return
Some(Err(ExecutionError::into_arrow_external_error(e))),
+ Err(e) => {
+ return
Poll::Ready(Some(Err(ExecutionError::into_arrow_external_error(
+ e,
+ ))))
+ }
};
+ let expressions = Arc::new(expressions);
let mode = self.mode;
let schema = self.schema();
// 1 for each batch, update / merge accumulators with the expressions'
values
- match self
+ // future is ready when all batches are computed
+ let future = self
.input
.as_mut()
- .into_iter()
- .map(|batch| {
- aggregate_batch(&mode, &batch?, &mut accumulators,
&expressions)
+ .try_fold(
+ // pass the expressions on every fold to handle closures'
mutability
+ (accumulators, expressions),
+ |(acc, expr), batch| async move {
+ aggregate_batch(&mode, &batch, acc, &expr)
+ .map_err(ExecutionError::into_arrow_external_error)
+ .map(|agg| (agg, expr))
+ },
+ )
+ // pick the accumulators (disregard the expressions)
+ .map(|e| e.map(|e| e.0));
+
+ let future = future.map(|maybe_accumulators| {
+ maybe_accumulators.map(|accumulators| {
+ // 2. convert values to a record batch
+ finalize_aggregation(&accumulators, &mode)
.map_err(ExecutionError::into_arrow_external_error)
- })
- .collect::<ArrowResult<()>>()
- {
- Err(e) => return Some(Err(e)),
- Ok(_) => {}
- }
+ .and_then(|columns| RecordBatch::try_new(schema.clone(),
columns))
+ })?
+ });
- // 2 convert values to a record batch
- Some(
- finalize_aggregation(&accumulators, &mode)
- .map_err(ExecutionError::into_arrow_external_error)
- .and_then(|columns| RecordBatch::try_new(schema.clone(),
columns)),
- )
+ Box::pin(future.into_stream()).poll_next_unpin(cx)
}
}
-impl RecordBatchReader for HashAggregateIterator {
+impl RecordBatchStream for HashAggregateStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
@@ -580,10 +610,10 @@ fn concatenate(arrays: Vec<Vec<ArrayRef>>) ->
ArrowResult<Vec<ArrayRef>> {
/// Create a RecordBatch with all group keys and accumulator' states or values.
fn create_batch_from_map(
mode: &AggregateMode,
- accumulators: &FnvHashMap<Vec<GroupByScalar>, (AccumulatorSet,
Box<Vec<u32>>)>,
+ accumulators: &Accumulators,
num_group_expr: usize,
output_schema: &Schema,
-) -> Result<RecordBatch> {
+) -> ArrowResult<RecordBatch> {
// 1. for each key
// 2. create single-row ArrayRef with all group expressions
// 3. create single-row ArrayRef with all aggregate states or values
@@ -618,7 +648,7 @@ fn create_batch_from_map(
Ok(groups)
})
// 4.
- .collect::<Result<Vec<Vec<ArrayRef>>>>()?;
+ .collect::<ArrowResult<Vec<Vec<ArrayRef>>>>()?;
let batch = if arrays.len() != 0 {
// 5.
@@ -787,7 +817,7 @@ mod tests {
input,
)?);
- let result = common::collect(partial_aggregate.execute(0).await?)?;
+ let result =
common::collect(partial_aggregate.execute(0).await?).await?;
let keys = result[0]
.column(0)
@@ -826,7 +856,7 @@ mod tests {
merge,
)?);
- let result = common::collect(merged_aggregate.execute(0).await?)?;
+ let result =
common::collect(merged_aggregate.execute(0).await?).await?;
assert_eq!(result.len(), 1);
let batch = &result[0];
diff --git a/rust/datafusion/src/physical_plan/limit.rs
b/rust/datafusion/src/physical_plan/limit.rs
index 753cbf7..570e4d4 100644
--- a/rust/datafusion/src/physical_plan/limit.rs
+++ b/rust/datafusion/src/physical_plan/limit.rs
@@ -21,14 +21,15 @@ use std::any::Any;
use std::sync::Arc;
use crate::error::{ExecutionError, Result};
-use crate::physical_plan::memory::MemoryIterator;
+use crate::physical_plan::memory::MemoryStream;
use crate::physical_plan::{Distribution, ExecutionPlan, Partitioning};
use arrow::array::ArrayRef;
use arrow::compute::limit;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
+use futures::StreamExt;
-use super::SendableRecordBatchReader;
+use super::SendableRecordBatchStream;
use async_trait::async_trait;
@@ -94,7 +95,7 @@ impl ExecutionPlan for GlobalLimitExec {
}
}
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader> {
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
// GlobalLimitExec has a single output partition
if 0 != partition {
return Err(ExecutionError::General(format!(
@@ -111,8 +112,8 @@ impl ExecutionPlan for GlobalLimitExec {
}
let mut it = self.input.execute(0).await?;
- Ok(Box::new(MemoryIterator::try_new(
- collect_with_limit(&mut it, self.limit)?,
+ Ok(Box::pin(MemoryStream::try_new(
+ collect_with_limit(&mut it, self.limit).await?,
self.input.schema(),
None,
)?))
@@ -167,10 +168,10 @@ impl ExecutionPlan for LocalLimitExec {
}
}
- async fn execute(&self, _: usize) -> Result<SendableRecordBatchReader> {
+ async fn execute(&self, _: usize) -> Result<SendableRecordBatchStream> {
let mut it = self.input.execute(0).await?;
- Ok(Box::new(MemoryIterator::try_new(
- collect_with_limit(&mut it, self.limit)?,
+ Ok(Box::pin(MemoryStream::try_new(
+ collect_with_limit(&mut it, self.limit).await?,
self.input.schema(),
None,
)?))
@@ -190,14 +191,14 @@ pub fn truncate_batch(batch: &RecordBatch, n: usize) ->
Result<RecordBatch> {
}
/// Create a vector of record batches from an iterator
-fn collect_with_limit(
- reader: &mut SendableRecordBatchReader,
+async fn collect_with_limit(
+ reader: &mut SendableRecordBatchStream,
limit: usize,
) -> Result<Vec<RecordBatch>> {
let mut count = 0;
let mut results: Vec<RecordBatch> = vec![];
loop {
- match reader.as_mut().next() {
+ match reader.as_mut().next().await {
Some(Ok(batch)) => {
let capacity = limit - count;
if batch.num_rows() <= capacity {
@@ -247,7 +248,7 @@ mod tests {
// the result should contain 4 batches (one per input partition)
let iter = limit.execute(0).await?;
- let batches = common::collect(iter)?;
+ let batches = common::collect(iter).await?;
// there should be a total of 100 rows
let row_count: usize = batches.iter().map(|batch|
batch.num_rows()).sum();
diff --git a/rust/datafusion/src/physical_plan/memory.rs
b/rust/datafusion/src/physical_plan/memory.rs
index 6219eb3..a3fc733 100644
--- a/rust/datafusion/src/physical_plan/memory.rs
+++ b/rust/datafusion/src/physical_plan/memory.rs
@@ -19,15 +19,16 @@
use std::any::Any;
use std::sync::Arc;
+use std::task::{Context, Poll};
+use super::{ExecutionPlan, Partitioning, RecordBatchStream,
SendableRecordBatchStream};
use crate::error::{ExecutionError, Result};
-use crate::physical_plan::{ExecutionPlan, Partitioning};
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
-use arrow::record_batch::{RecordBatch, RecordBatchReader};
+use arrow::record_batch::RecordBatch;
-use super::SendableRecordBatchReader;
use async_trait::async_trait;
+use futures::Stream;
/// Execution plan for reading in-memory batches of data
#[derive(Debug)]
@@ -72,8 +73,8 @@ impl ExecutionPlan for MemoryExec {
)))
}
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader> {
- Ok(Box::new(MemoryIterator::try_new(
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
+ Ok(Box::pin(MemoryStream::try_new(
self.partitions[partition].clone(),
self.schema.clone(),
self.projection.clone(),
@@ -97,7 +98,7 @@ impl MemoryExec {
}
/// Iterator over batches
-pub(crate) struct MemoryIterator {
+pub(crate) struct MemoryStream {
/// Vector of record batches
data: Vec<RecordBatch>,
/// Schema representing the data
@@ -108,7 +109,7 @@ pub(crate) struct MemoryIterator {
index: usize,
}
-impl MemoryIterator {
+impl MemoryStream {
/// Create an iterator for a vector of record batches
pub fn try_new(
data: Vec<RecordBatch>,
@@ -124,11 +125,14 @@ impl MemoryIterator {
}
}
-impl Iterator for MemoryIterator {
+impl Stream for MemoryStream {
type Item = ArrowResult<RecordBatch>;
- fn next(&mut self) -> Option<Self::Item> {
- if self.index < self.data.len() {
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ _: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ Poll::Ready(if self.index < self.data.len() {
self.index += 1;
let batch = &self.data[self.index - 1];
// apply projection
@@ -141,11 +145,15 @@ impl Iterator for MemoryIterator {
}
} else {
None
- }
+ })
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ (self.data.len(), Some(self.data.len()))
}
}
-impl RecordBatchReader for MemoryIterator {
+impl RecordBatchStream for MemoryStream {
/// Get the schema
fn schema(&self) -> SchemaRef {
self.schema.clone()
diff --git a/rust/datafusion/src/physical_plan/merge.rs
b/rust/datafusion/src/physical_plan/merge.rs
index 7ce737c..8628551 100644
--- a/rust/datafusion/src/physical_plan/merge.rs
+++ b/rust/datafusion/src/physical_plan/merge.rs
@@ -19,17 +19,20 @@
//! into a single partition
use std::any::Any;
+use std::iter::Iterator;
use std::sync::Arc;
+use futures::future;
+
+use super::common;
use crate::error::{ExecutionError, Result};
-use crate::physical_plan::common::RecordBatchIterator;
+use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::Partitioning;
-use crate::physical_plan::{common, ExecutionPlan};
-use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
+use arrow::{datatypes::SchemaRef, error::ArrowError};
-use super::SendableRecordBatchReader;
+use super::SendableRecordBatchStream;
use async_trait::async_trait;
use tokio;
@@ -81,7 +84,7 @@ impl ExecutionPlan for MergeExec {
}
}
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader> {
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
// MergeExec produces a single partition
if 0 != partition {
return Err(ExecutionError::General(format!(
@@ -100,27 +103,29 @@ impl ExecutionPlan for MergeExec {
self.input.execute(0).await
}
_ => {
- let tasks = (0..input_partitions)
- .map(|part_i| {
- let input = self.input.clone();
- tokio::spawn(async move {
- let it = input.execute(part_i).await?;
- common::collect(it)
- })
+ let tasks = (0..input_partitions).map(|part_i| {
+ let input = self.input.clone();
+ tokio::spawn(async move {
+ let stream = input.execute(part_i).await?;
+ common::collect(stream).await
})
- // this collect *is needed* so that the join below can
- // switch between tasks
+ });
+
+ let results = future::try_join_all(tasks)
+ .await
+ .map_err(|e|
ArrowError::from_external_error(Box::new(e)))?;
+
+ let combined_results = results
+ .into_iter()
+ .try_fold(Vec::<RecordBatch>::new(), |mut acc,
maybe_batches| {
+ acc.append(&mut maybe_batches?);
+ Result::Ok(acc)
+ })?
+ .into_iter()
+ .map(|x| Arc::new(x))
.collect::<Vec<_>>();
- let mut combined_results: Vec<Arc<RecordBatch>> = vec![];
- for task in tasks {
- let result = task.await.unwrap()?;
- for batch in &result {
- combined_results.push(Arc::new(batch.clone()));
- }
- }
-
- Ok(Box::new(RecordBatchIterator::new(
+ Ok(Box::pin(common::SizedRecordBatchStream::new(
self.input.schema(),
combined_results,
)))
@@ -158,7 +163,7 @@ mod tests {
// the result should contain 4 batches (one per input partition)
let iter = merge.execute(0).await?;
- let batches = common::collect(iter)?;
+ let batches = common::collect(iter).await?;
assert_eq!(batches.len(), num_partitions);
// there should be a total of 100 rows
diff --git a/rust/datafusion/src/physical_plan/mod.rs
b/rust/datafusion/src/physical_plan/mod.rs
index 1d6c46a..a2bc3be 100644
--- a/rust/datafusion/src/physical_plan/mod.rs
+++ b/rust/datafusion/src/physical_plan/mod.rs
@@ -17,19 +17,32 @@
//! Traits for physical query plan, supporting parallel execution for
partitioned relations.
-use std::any::Any;
use std::fmt::{Debug, Display};
use std::sync::Arc;
+use std::{any::Any, pin::Pin};
use crate::execution::context::ExecutionContextState;
use crate::logical_plan::LogicalPlan;
use crate::{error::Result, scalar::ScalarValue};
use arrow::datatypes::{DataType, Schema, SchemaRef};
-use arrow::record_batch::{RecordBatch, RecordBatchReader};
+use arrow::error::Result as ArrowResult;
+use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
use async_trait::async_trait;
-type SendableRecordBatchReader = Box<dyn RecordBatchReader + Send>;
+use futures::stream::Stream;
+
+/// Trait for types that stream [arrow::record_batch::RecordBatch]
+pub trait RecordBatchStream: Stream<Item = ArrowResult<RecordBatch>> {
+ /// Returns the schema of this `RecordBatchStream`.
+ ///
+ /// Implementation of this trait should guarantee that all `RecordBatch`'s
returned by this
+ /// stream should have the same schema as returned from this method.
+ fn schema(&self) -> SchemaRef;
+}
+
+/// Trait for a stream of record batches.
+pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
/// Physical query planner that converts a `LogicalPlan` to an
/// `ExecutionPlan` suitable for execution.
@@ -68,7 +81,7 @@ pub trait ExecutionPlan: Debug + Send + Sync {
) -> Result<Arc<dyn ExecutionPlan>>;
/// creates an iterator
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader>;
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream>;
}
/// Partitioning schemes supported by operators.
diff --git a/rust/datafusion/src/physical_plan/parquet.rs
b/rust/datafusion/src/physical_plan/parquet.rs
index b15c5a3..946f015 100644
--- a/rust/datafusion/src/physical_plan/parquet.rs
+++ b/rust/datafusion/src/physical_plan/parquet.rs
@@ -21,22 +21,24 @@ use std::any::Any;
use std::fs::File;
use std::rc::Rc;
use std::sync::Arc;
+use std::task::{Context, Poll};
use std::{fmt, thread};
+use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{ExecutionError, Result};
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::{common, Partitioning};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::{ArrowError, Result as ArrowResult};
-use arrow::record_batch::{RecordBatch, RecordBatchReader};
+use arrow::record_batch::RecordBatch;
use parquet::file::reader::SerializedFileReader;
use crossbeam::channel::{bounded, Receiver, RecvError, Sender};
use fmt::Debug;
use parquet::arrow::{ArrowReader, ParquetFileArrowReader};
-use super::SendableRecordBatchReader;
use async_trait::async_trait;
+use futures::stream::Stream;
/// Execution plan for scanning a Parquet file
#[derive(Debug, Clone)]
@@ -125,7 +127,7 @@ impl ExecutionPlan for ParquetExec {
}
}
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader> {
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
// because the parquet implementation is not thread-safe, it is
necessary to execute
// on a thread and communicate with channels
let (response_tx, response_rx): (
@@ -143,12 +145,10 @@ impl ExecutionPlan for ParquetExec {
}
});
- let iterator = Box::new(ParquetIterator {
+ Ok(Box::pin(ParquetStream {
schema: self.schema.clone(),
response_rx,
- });
-
- Ok(iterator)
+ }))
}
}
@@ -197,24 +197,27 @@ fn read_file(
Ok(())
}
-struct ParquetIterator {
+struct ParquetStream {
schema: SchemaRef,
response_rx: Receiver<Option<ArrowResult<RecordBatch>>>,
}
-impl Iterator for ParquetIterator {
+impl Stream for ParquetStream {
type Item = ArrowResult<RecordBatch>;
- fn next(&mut self) -> Option<Self::Item> {
+ fn poll_next(
+ self: std::pin::Pin<&mut Self>,
+ _: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
match self.response_rx.recv() {
- Ok(batch) => batch,
+ Ok(batch) => Poll::Ready(batch),
// RecvError means receiver has exited and closed the channel
- Err(RecvError) => None,
+ Err(RecvError) => Poll::Ready(None),
}
}
}
-impl RecordBatchReader for ParquetIterator {
+impl RecordBatchStream for ParquetStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
@@ -223,6 +226,7 @@ impl RecordBatchReader for ParquetIterator {
#[cfg(test)]
mod tests {
use super::*;
+ use futures::StreamExt;
use std::env;
#[tokio::test]
@@ -234,7 +238,7 @@ mod tests {
assert_eq!(parquet_exec.output_partitioning().partition_count(), 1);
let mut results = parquet_exec.execute(0).await?;
- let batch = results.next().unwrap()?;
+ let batch = results.next().await.unwrap()?;
assert_eq!(8, batch.num_rows());
assert_eq!(3, batch.num_columns());
@@ -244,13 +248,13 @@ mod tests {
schema.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(vec!["id", "bool_col", "tinyint_col"], field_names);
- let batch = results.next();
+ let batch = results.next().await;
assert!(batch.is_none());
- let batch = results.next();
+ let batch = results.next().await;
assert!(batch.is_none());
- let batch = results.next();
+ let batch = results.next().await;
assert!(batch.is_none());
Ok(())
diff --git a/rust/datafusion/src/physical_plan/planner.rs
b/rust/datafusion/src/physical_plan/planner.rs
index c4ae2dc..cfc8858 100644
--- a/rust/datafusion/src/physical_plan/planner.rs
+++ b/rust/datafusion/src/physical_plan/planner.rs
@@ -553,7 +553,7 @@ mod tests {
use crate::physical_plan::{csv::CsvReadOptions, expressions, Partitioning};
use crate::{
logical_plan::{col, lit, sum, LogicalPlanBuilder},
- physical_plan::SendableRecordBatchReader,
+ physical_plan::SendableRecordBatchStream,
};
use crate::{prelude::ExecutionConfig, test::arrow_testdata_path};
use arrow::datatypes::{DataType, Field, SchemaRef};
@@ -804,7 +804,7 @@ mod tests {
unimplemented!("NoOpExecutionPlan::with_new_children");
}
- async fn execute(&self, _partition: usize) ->
Result<SendableRecordBatchReader> {
+ async fn execute(&self, _partition: usize) ->
Result<SendableRecordBatchStream> {
unimplemented!("NoOpExecutionPlan::execute");
}
}
diff --git a/rust/datafusion/src/physical_plan/projection.rs
b/rust/datafusion/src/physical_plan/projection.rs
index bb8f853..895148d 100644
--- a/rust/datafusion/src/physical_plan/projection.rs
+++ b/rust/datafusion/src/physical_plan/projection.rs
@@ -21,17 +21,22 @@
//! projection expressions.
use std::any::Any;
+use std::pin::Pin;
use std::sync::Arc;
+use std::task::{Context, Poll};
use crate::error::{ExecutionError, Result};
use crate::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr};
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
-use arrow::record_batch::{RecordBatch, RecordBatchReader};
+use arrow::record_batch::RecordBatch;
-use super::SendableRecordBatchReader;
+use super::{RecordBatchStream, SendableRecordBatchStream};
use async_trait::async_trait;
+use futures::stream::Stream;
+use futures::stream::StreamExt;
+
/// Execution plan for a projection
#[derive(Debug)]
pub struct ProjectionExec {
@@ -108,8 +113,8 @@ impl ExecutionPlan for ProjectionExec {
}
}
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader> {
- Ok(Box::new(ProjectionIterator {
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
+ Ok(Box::pin(ProjectionStream {
schema: self.schema.clone(),
expr: self.expr.iter().map(|x| x.0.clone()).collect(),
input: self.input.execute(partition).await?,
@@ -117,34 +122,48 @@ impl ExecutionPlan for ProjectionExec {
}
}
+fn batch_project(
+ batch: &RecordBatch,
+ expressions: &Vec<Arc<dyn PhysicalExpr>>,
+ schema: &SchemaRef,
+) -> ArrowResult<RecordBatch> {
+ expressions
+ .iter()
+ .map(|expr| expr.evaluate(&batch))
+ .collect::<Result<Vec<_>>>()
+ .map_or_else(
+ |e| Err(ExecutionError::into_arrow_external_error(e)),
+ |arrays| RecordBatch::try_new(schema.clone(), arrays),
+ )
+}
+
/// Projection iterator
-struct ProjectionIterator {
+struct ProjectionStream {
schema: SchemaRef,
expr: Vec<Arc<dyn PhysicalExpr>>,
- input: SendableRecordBatchReader,
+ input: SendableRecordBatchStream,
}
-impl Iterator for ProjectionIterator {
+impl Stream for ProjectionStream {
type Item = ArrowResult<RecordBatch>;
- fn next(&mut self) -> Option<Self::Item> {
- match self.input.next() {
- Some(Ok(batch)) => Some(
- self.expr
- .iter()
- .map(|expr| expr.evaluate(&batch))
- .collect::<Result<Vec<_>>>()
- .map_or_else(
- |e| Err(ExecutionError::into_arrow_external_error(e)),
- |arrays| RecordBatch::try_new(self.schema.clone(),
arrays),
- ),
- ),
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ self.input.poll_next_unpin(cx).map(|x| match x {
+ Some(Ok(batch)) => Some(batch_project(&batch, &self.expr,
&self.schema)),
other => other,
- }
+ })
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ // same number of record batches
+ self.input.size_hint()
}
}
-impl RecordBatchReader for ProjectionIterator {
+impl RecordBatchStream for ProjectionStream {
/// Get the schema
fn schema(&self) -> SchemaRef {
self.schema.clone()
@@ -158,6 +177,7 @@ mod tests {
use crate::physical_plan::csv::{CsvExec, CsvReadOptions};
use crate::physical_plan::expressions::col;
use crate::test;
+ use futures::future;
#[tokio::test]
async fn project_first_column() -> Result<()> {
@@ -177,16 +197,16 @@ mod tests {
let mut row_count = 0;
for partition in 0..projection.output_partitioning().partition_count()
{
partition_count += 1;
- let iterator = projection.execute(partition).await?;
+ let stream = projection.execute(partition).await?;
- row_count += iterator
- .into_iter()
+ row_count += stream
.map(|batch| {
let batch = batch.unwrap();
assert_eq!(1, batch.num_columns());
batch.num_rows()
})
- .sum::<usize>();
+ .fold(0, |acc, x| future::ready(acc + x))
+ .await;
}
assert_eq!(partitions, partition_count);
assert_eq!(100, row_count);
diff --git a/rust/datafusion/src/physical_plan/sort.rs
b/rust/datafusion/src/physical_plan/sort.rs
index 7c00cc5..61d8bd8 100644
--- a/rust/datafusion/src/physical_plan/sort.rs
+++ b/rust/datafusion/src/physical_plan/sort.rs
@@ -26,12 +26,12 @@ use arrow::compute::{concat, lexsort_to_indices, take,
SortColumn, TakeOptions};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
+use super::SendableRecordBatchStream;
use crate::error::{ExecutionError, Result};
-use crate::physical_plan::common::RecordBatchIterator;
+use crate::physical_plan::common::SizedRecordBatchStream;
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::{common, Distribution, ExecutionPlan, Partitioning};
-use super::SendableRecordBatchReader;
use async_trait::async_trait;
/// Sort execution plan
@@ -100,7 +100,7 @@ impl ExecutionPlan for SortExec {
}
}
- async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchReader> {
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
if 0 != partition {
return Err(ExecutionError::General(format!(
"SortExec invalid partition {}",
@@ -115,7 +115,7 @@ impl ExecutionPlan for SortExec {
));
}
let it = self.input.execute(0).await?;
- let batches = common::collect(it)?;
+ let batches = common::collect(it).await?;
// combine all record batches into one for each column
let combined_batch = RecordBatch::try_new(
@@ -164,7 +164,7 @@ impl ExecutionPlan for SortExec {
.collect::<Result<Vec<ArrayRef>>>()?,
)?;
- Ok(Box::new(RecordBatchIterator::new(
+ Ok(Box::pin(SizedRecordBatchStream::new(
self.schema(),
vec![Arc::new(sorted_batch)],
)))
diff --git a/rust/datafusion/tests/user_defined_plan.rs
b/rust/datafusion/tests/user_defined_plan.rs
index 71957f2..6c50459 100644
--- a/rust/datafusion/tests/user_defined_plan.rs
+++ b/rust/datafusion/tests/user_defined_plan.rs
@@ -58,11 +58,13 @@
//! N elements, reducing the total amount of required buffer memory.
//!
+use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
+
use arrow::{
array::{Int64Array, PrimitiveArrayOps, StringArray},
datatypes::SchemaRef,
error::ArrowError,
- record_batch::{RecordBatch, RecordBatchReader},
+ record_batch::RecordBatch,
util::pretty::pretty_format_batches,
};
use datafusion::{
@@ -73,11 +75,13 @@ use datafusion::{
optimizer::{optimizer::OptimizerRule, utils::optimize_explain},
physical_plan::{
planner::{DefaultPhysicalPlanner, ExtensionPlanner},
- Distribution, ExecutionPlan, Partitioning, PhysicalPlanner,
+ Distribution, ExecutionPlan, Partitioning, PhysicalPlanner,
RecordBatchStream,
+ SendableRecordBatchStream,
},
prelude::{ExecutionConfig, ExecutionContext},
};
use fmt::Debug;
+use std::task::{Context, Poll};
use std::{any::Any, collections::BTreeMap, fmt, sync::Arc};
use async_trait::async_trait;
@@ -392,10 +396,7 @@ impl ExecutionPlan for TopKExec {
}
/// Execute one partition and return an iterator over RecordBatch
- async fn execute(
- &self,
- partition: usize,
- ) -> Result<Box<dyn RecordBatchReader + Send>> {
+ async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
if 0 != partition {
return Err(ExecutionError::General(format!(
"TopKExec invalid partition {}",
@@ -403,7 +404,7 @@ impl ExecutionPlan for TopKExec {
)));
}
- Ok(Box::new(TopKReader {
+ Ok(Box::pin(TopKReader {
input: self.input.execute(partition).await?,
k: self.k,
done: false,
@@ -414,7 +415,7 @@ impl ExecutionPlan for TopKExec {
// A very specialized TopK implementation
struct TopKReader {
/// The input to read data from
- input: Box<dyn RecordBatchReader + Send>,
+ input: SendableRecordBatchStream,
/// Maximum number of output values
k: usize,
/// Have we produced the output yet?
@@ -448,9 +449,9 @@ fn remove_lowest_value(top_values: &mut BTreeMap<i64,
String>) {
fn accumulate_batch(
input_batch: &RecordBatch,
- top_values: &mut BTreeMap<i64, String>,
+ mut top_values: BTreeMap<i64, String>,
k: &usize,
-) -> Result<()> {
+) -> Result<BTreeMap<i64, String>> {
let num_rows = input_batch.num_rows();
// Assuming the input columns are
// column[0]: customer_id / UTF8
@@ -468,51 +469,69 @@ fn accumulate_batch(
.expect("Column 1 is not revenue");
for row in 0..num_rows {
- add_row(top_values, customer_id.value(row), revenue.value(row), k);
+ add_row(
+ &mut top_values,
+ customer_id.value(row),
+ revenue.value(row),
+ k,
+ );
}
- Ok(())
+ Ok(top_values)
}
-impl Iterator for TopKReader {
+impl Stream for TopKReader {
type Item = std::result::Result<RecordBatch, ArrowError>;
- /// Reads the next `RecordBatch`.
- fn next(&mut self) -> Option<Self::Item> {
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
if self.done {
- return None;
+ return Poll::Ready(None);
}
-
- // Hard coded implementation for sales / customer_id example
- let mut top_values: BTreeMap<i64, String> = BTreeMap::new();
+ // this aggregates and thus returns a single RecordBatch.
+ self.done = true;
// take this as immutable
- let k = &self.k;
+ let k = self.k;
+ let schema = self.schema().clone();
- self.input
+ let top_values = self
+ .input
.as_mut()
- .into_iter()
- .map(|batch| accumulate_batch(&batch?, &mut top_values, k))
- .collect::<Result<()>>()
- .unwrap();
-
- // make output by walking over the map backwards (so values are
descending)
- let (revenue, customer): (Vec<i64>, Vec<&String>) =
- top_values.iter().rev().unzip();
-
- let customer: Vec<&str> = customer.iter().map(|&s| &**s).collect();
+ // Hard coded implementation for sales / customer_id example as
BTree
+ .try_fold(
+ BTreeMap::<i64, String>::new(),
+ move |top_values, batch| async move {
+ accumulate_batch(&batch, top_values, &k)
+ .map_err(ExecutionError::into_arrow_external_error)
+ },
+ );
+
+ let top_values = top_values.map(|top_values| match top_values {
+ Ok(top_values) => {
+ // make output by walking over the map backwards (so values
are descending)
+ let (revenue, customer): (Vec<i64>, Vec<&String>) =
+ top_values.iter().rev().unzip();
+
+ let customer: Vec<&str> = customer.iter().map(|&s|
&**s).collect();
+ Ok(RecordBatch::try_new(
+ schema,
+ vec![
+ Arc::new(StringArray::from(customer)),
+ Arc::new(Int64Array::from(revenue)),
+ ],
+ )?)
+ }
+ Err(e) => Err(e),
+ });
+ let mut top_values = Box::pin(top_values.into_stream());
- self.done = true;
- Some(RecordBatch::try_new(
- self.schema().clone(),
- vec![
- Arc::new(StringArray::from(customer)),
- Arc::new(Int64Array::from(revenue)),
- ],
- ))
+ top_values.poll_next_unpin(cx)
}
}
-impl RecordBatchReader for TopKReader {
+impl RecordBatchStream for TopKReader {
fn schema(&self) -> SchemaRef {
self.input.schema()
}