This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new b7bb2cfba Fix Ballista executing during plan (#2428)
b7bb2cfba is described below

commit b7bb2cfba13cc04a08c2f687102dd14a8dedc7b6
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Wed May 4 15:12:53 2022 +0100

    Fix Ballista executing during plan (#2428)
---
 .../core/src/execution_plans/distributed_query.rs  | 195 ++++++------
 .../core/src/execution_plans/shuffle_reader.rs     |  37 ++-
 .../core/src/execution_plans/shuffle_writer.rs     | 350 +++++++++++----------
 ballista/rust/core/src/utils.rs                    |  42 +--
 datafusion/core/src/physical_plan/stream.rs        |   8 +-
 5 files changed, 314 insertions(+), 318 deletions(-)

diff --git a/ballista/rust/core/src/execution_plans/distributed_query.rs 
b/ballista/rust/core/src/execution_plans/distributed_query.rs
index b0d3bef1f..ed3c9fceb 100644
--- a/ballista/rust/core/src/execution_plans/distributed_query.rs
+++ b/ballista/rust/core/src/execution_plans/distributed_query.rs
@@ -30,9 +30,8 @@ use crate::serde::protobuf::{
     ExecuteQueryParams, GetJobStatusParams, GetJobStatusResult, KeyValuePair,
     PartitionLocation,
 };
-use crate::utils::WrappedStream;
 
-use datafusion::arrow::datatypes::{Schema, SchemaRef};
+use datafusion::arrow::datatypes::SchemaRef;
 use datafusion::error::{DataFusionError, Result};
 use datafusion::logical_plan::LogicalPlan;
 use datafusion::physical_plan::expressions::PhysicalSortExpr;
@@ -43,12 +42,14 @@ use datafusion::physical_plan::{
 use crate::serde::protobuf::execute_query_params::OptionalSessionId;
 use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, 
LogicalExtensionCodec};
 use async_trait::async_trait;
+use datafusion::arrow::error::{ArrowError, Result as ArrowResult};
+use datafusion::arrow::record_batch::RecordBatch;
 use datafusion::execution::context::TaskContext;
-use futures::future;
-use futures::StreamExt;
+use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
+use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
 use log::{error, info};
 
-/// This operator sends a logial plan to a Ballista scheduler for execution and
+/// This operator sends a logical plan to a Ballista scheduler for execution 
and
 /// polls the scheduler until the query is complete and then fetches the 
resulting
 /// batches directly from the executors that hold the results from the final
 /// query stage.
@@ -168,15 +169,6 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for 
DistributedQueryExec<T> {
     ) -> Result<SendableRecordBatchStream> {
         assert_eq!(0, partition);
 
-        info!("Connecting to Ballista scheduler at {}", self.scheduler_url);
-        // TODO reuse the scheduler to avoid connecting to the Ballista 
scheduler again and again
-
-        let mut scheduler = 
SchedulerGrpcClient::connect(self.scheduler_url.clone())
-            .await
-            .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
-
-        let schema: Schema = self.plan.schema().as_ref().clone().into();
-
         let mut buf: Vec<u8> = vec![];
         let plan_message =
             T::try_from_logical_plan(&self.plan, 
self.extension_codec.as_ref()).map_err(
@@ -191,88 +183,30 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for 
DistributedQueryExec<T> {
             DataFusionError::Execution(format!("failed to encode logical plan: 
{:?}", e))
         })?;
 
-        let query_result = scheduler
-            .execute_query(ExecuteQueryParams {
-                query: Some(Query::LogicalPlan(buf)),
-                settings: self
-                    .config
-                    .settings()
-                    .iter()
-                    .map(|(k, v)| KeyValuePair {
-                        key: k.to_owned(),
-                        value: v.to_owned(),
-                    })
-                    .collect::<Vec<_>>(),
-                optional_session_id: Some(OptionalSessionId::SessionId(
-                    self.session_id.clone(),
-                )),
-            })
-            .await
-            .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?
-            .into_inner();
-
-        let response_session_id = query_result.session_id;
-        assert_eq!(
-            self.session_id.clone(),
-            response_session_id,
-            "Session id inconsistent between Client and Server side in 
DistributedQueryExec."
-        );
+        let query = ExecuteQueryParams {
+            query: Some(Query::LogicalPlan(buf)),
+            settings: self
+                .config
+                .settings()
+                .iter()
+                .map(|(k, v)| KeyValuePair {
+                    key: k.to_owned(),
+                    value: v.to_owned(),
+                })
+                .collect::<Vec<_>>(),
+            optional_session_id: Some(OptionalSessionId::SessionId(
+                self.session_id.clone(),
+            )),
+        };
 
-        let job_id = query_result.job_id;
-        let mut prev_status: Option<job_status::Status> = None;
+        let stream = futures::stream::once(
+            execute_query(self.scheduler_url.clone(), self.session_id.clone(), 
query)
+                .map_err(|e| ArrowError::ExternalError(Box::new(e))),
+        )
+        .try_flatten();
 
-        loop {
-            let GetJobStatusResult { status } = scheduler
-                .get_job_status(GetJobStatusParams {
-                    job_id: job_id.clone(),
-                })
-                .await
-                .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?
-                .into_inner();
-            let status = status.and_then(|s| s.status).ok_or_else(|| {
-                DataFusionError::Internal("Received empty status 
message".to_owned())
-            })?;
-            let wait_future = tokio::time::sleep(Duration::from_millis(100));
-            let has_status_change = prev_status.map(|x| x != 
status).unwrap_or(true);
-            match status {
-                job_status::Status::Queued(_) => {
-                    if has_status_change {
-                        info!("Job {} still queued...", job_id);
-                    }
-                    wait_future.await;
-                    prev_status = Some(status);
-                }
-                job_status::Status::Running(_) => {
-                    if has_status_change {
-                        info!("Job {} is running...", job_id);
-                    }
-                    wait_future.await;
-                    prev_status = Some(status);
-                }
-                job_status::Status::Failed(err) => {
-                    let msg = format!("Job {} failed: {}", job_id, err.error);
-                    error!("{}", msg);
-                    break Err(DataFusionError::Execution(msg));
-                }
-                job_status::Status::Completed(completed) => {
-                    let result = future::join_all(
-                        completed
-                            .partition_location
-                            .into_iter()
-                            .map(fetch_partition),
-                    )
-                    .await
-                    .into_iter()
-                    .collect::<Result<Vec<_>>>()?;
-
-                    let result = WrappedStream::new(
-                        Box::pin(futures::stream::iter(result).flatten()),
-                        Arc::new(schema),
-                    );
-                    break Ok(Box::pin(result));
-                }
-            };
-        }
+        let schema = self.schema();
+        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
     }
 
     fn fmt_as(
@@ -299,6 +233,79 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for 
DistributedQueryExec<T> {
     }
 }
 
+async fn execute_query(
+    scheduler_url: String,
+    session_id: String,
+    query: ExecuteQueryParams,
+) -> Result<impl Stream<Item = ArrowResult<RecordBatch>> + Send> {
+    info!("Connecting to Ballista scheduler at {}", scheduler_url);
+    // TODO reuse the scheduler to avoid connecting to the Ballista scheduler 
again and again
+
+    let mut scheduler = SchedulerGrpcClient::connect(scheduler_url.clone())
+        .await
+        .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
+
+    let query_result = scheduler
+        .execute_query(query)
+        .await
+        .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?
+        .into_inner();
+
+    assert_eq!(
+        session_id, query_result.session_id,
+        "Session id inconsistent between Client and Server side in 
DistributedQueryExec."
+    );
+
+    let job_id = query_result.job_id;
+    let mut prev_status: Option<job_status::Status> = None;
+
+    loop {
+        let GetJobStatusResult { status } = scheduler
+            .get_job_status(GetJobStatusParams {
+                job_id: job_id.clone(),
+            })
+            .await
+            .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?
+            .into_inner();
+        let status = status.and_then(|s| s.status).ok_or_else(|| {
+            DataFusionError::Internal("Received empty status 
message".to_owned())
+        })?;
+        let wait_future = tokio::time::sleep(Duration::from_millis(100));
+        let has_status_change = prev_status.map(|x| x != 
status).unwrap_or(true);
+        match status {
+            job_status::Status::Queued(_) => {
+                if has_status_change {
+                    info!("Job {} still queued...", job_id);
+                }
+                wait_future.await;
+                prev_status = Some(status);
+            }
+            job_status::Status::Running(_) => {
+                if has_status_change {
+                    info!("Job {} is running...", job_id);
+                }
+                wait_future.await;
+                prev_status = Some(status);
+            }
+            job_status::Status::Failed(err) => {
+                let msg = format!("Job {} failed: {}", job_id, err.error);
+                error!("{}", msg);
+                break Err(DataFusionError::Execution(msg));
+            }
+            job_status::Status::Completed(completed) => {
+                let streams = completed.partition_location.into_iter().map(|p| 
{
+                    let f = fetch_partition(p)
+                        .map_err(|e| ArrowError::ExternalError(Box::new(e)));
+
+                    futures::stream::once(f).try_flatten()
+                });
+
+                break Ok(futures::stream::iter(streams).flatten());
+            }
+        };
+    }
+}
+
 async fn fetch_partition(
     location: PartitionLocation,
 ) -> Result<SendableRecordBatchStream> {
diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs 
b/ballista/rust/core/src/execution_plans/shuffle_reader.rs
index b0aa6af11..27252b980 100644
--- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs
@@ -15,16 +15,16 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::any::Any;
 use std::sync::Arc;
-use std::{any::Any, pin::Pin};
 
 use crate::client::BallistaClient;
 use crate::serde::scheduler::{PartitionLocation, PartitionStats};
 
-use crate::utils::WrappedStream;
 use async_trait::async_trait;
 use datafusion::arrow::datatypes::SchemaRef;
 
+use datafusion::error::{DataFusionError, Result};
 use datafusion::physical_plan::expressions::PhysicalSortExpr;
 use datafusion::physical_plan::metrics::{
     ExecutionPlanMetricsSet, MetricBuilder, MetricsSet,
@@ -32,13 +32,11 @@ use datafusion::physical_plan::metrics::{
 use datafusion::physical_plan::{
     DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, 
Statistics,
 };
-use datafusion::{
-    error::{DataFusionError, Result},
-    physical_plan::RecordBatchStream,
-};
-use futures::{future, StreamExt};
+use futures::{StreamExt, TryStreamExt};
 
+use datafusion::arrow::error::ArrowError;
 use datafusion::execution::context::TaskContext;
+use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
 use log::info;
 
 /// ShuffleReaderExec reads partitions that have already been materialized by 
a ShuffleWriterExec
@@ -112,18 +110,23 @@ impl ExecutionPlan for ShuffleReaderExec {
 
         let fetch_time =
             MetricBuilder::new(&self.metrics).subset_time("fetch_time", 
partition);
-        let timer = fetch_time.timer();
 
-        let partition_locations = &self.partition[partition];
-        let result = 
future::join_all(partition_locations.iter().map(fetch_partition))
-            .await
-            .into_iter()
-            .collect::<Result<Vec<_>>>()?;
-        timer.done();
+        let locations = self.partition[partition].clone();
+        let stream = locations.into_iter().map(move |p| {
+            let fetch_time = fetch_time.clone();
+            futures::stream::once(async move {
+                let timer = fetch_time.timer();
+                let r = fetch_partition(&p).await;
+                timer.done();
+
+                r.map_err(|e| ArrowError::ExternalError(Box::new(e)))
+            })
+            .try_flatten()
+        });
 
-        let result = WrappedStream::new(
-            Box::pin(futures::stream::iter(result).flatten()),
+        let result = RecordBatchStreamAdapter::new(
             Arc::new(self.schema.as_ref().clone()),
+            futures::stream::iter(stream).flatten(),
         );
         Ok(Box::pin(result))
     }
@@ -201,7 +204,7 @@ fn stats_for_partitions(
 
 async fn fetch_partition(
     location: &PartitionLocation,
-) -> Result<Pin<Box<dyn RecordBatchStream + Send>>> {
+) -> Result<SendableRecordBatchStream> {
     let metadata = &location.executor_meta;
     let partition_id = &location.partition_id;
     let mut ballista_client =
diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs 
b/ballista/rust/core/src/execution_plans/shuffle_writer.rs
index 7a87406af..f5c98b200 100644
--- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs
+++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs
@@ -23,6 +23,7 @@
 use datafusion::physical_plan::expressions::PhysicalSortExpr;
 
 use std::any::Any;
+use std::future::Future;
 use std::iter::Iterator;
 use std::path::PathBuf;
 use std::sync::Arc;
@@ -49,10 +50,12 @@ use datafusion::physical_plan::metrics::{
 use datafusion::physical_plan::{
     DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, 
Statistics,
 };
-use futures::StreamExt;
+use futures::{StreamExt, TryFutureExt, TryStreamExt};
 
+use datafusion::arrow::error::ArrowError;
 use datafusion::execution::context::TaskContext;
 use datafusion::physical_plan::repartition::BatchPartitioner;
+use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
 use log::{debug, info};
 
 /// ShuffleWriterExec represents a section of a query plan that has consistent 
partitioning and
@@ -137,149 +140,155 @@ impl ShuffleWriterExec {
         self.shuffle_output_partitioning.as_ref()
     }
 
-    pub async fn execute_shuffle_write(
+    pub fn execute_shuffle_write(
         &self,
         input_partition: usize,
         context: Arc<TaskContext>,
-    ) -> Result<Vec<ShuffleWritePartition>> {
-        let now = Instant::now();
-
-        let mut stream = self.plan.execute(input_partition, context).await?;
-
+    ) -> impl Future<Output = Result<Vec<ShuffleWritePartition>>> {
         let mut path = PathBuf::from(&self.work_dir);
         path.push(&self.job_id);
         path.push(&format!("{}", self.stage_id));
 
         let write_metrics = ShuffleWriteMetrics::new(input_partition, 
&self.metrics);
-
-        match &self.shuffle_output_partitioning {
-            None => {
-                let timer = write_metrics.write_time.timer();
-                path.push(&format!("{}", input_partition));
-                std::fs::create_dir_all(&path)?;
-                path.push("data.arrow");
-                let path = path.to_str().unwrap();
-                info!("Writing results to {}", path);
-
-                // stream results to disk
-                let stats = utils::write_stream_to_disk(
-                    &mut stream,
-                    path,
-                    &write_metrics.write_time,
-                )
-                .await
-                .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
-
-                write_metrics
-                    .input_rows
-                    .add(stats.num_rows.unwrap_or(0) as usize);
-                write_metrics
-                    .output_rows
-                    .add(stats.num_rows.unwrap_or(0) as usize);
-                timer.done();
-
-                info!(
-                    "Executed partition {} in {} seconds. Statistics: {}",
-                    input_partition,
-                    now.elapsed().as_secs(),
-                    stats
-                );
-
-                Ok(vec![ShuffleWritePartition {
-                    partition_id: input_partition as u64,
-                    path: path.to_owned(),
-                    num_batches: stats.num_batches.unwrap_or(0),
-                    num_rows: stats.num_rows.unwrap_or(0),
-                    num_bytes: stats.num_bytes.unwrap_or(0),
-                }])
-            }
-
-            Some(Partitioning::Hash(exprs, n)) => {
-                let num_output_partitions = *n;
-
-                // we won't necessary produce output for every possible 
partition, so we
-                // create writers on demand
-                let mut writers: Vec<Option<IPCWriter>> = vec![];
-                for _ in 0..num_output_partitions {
-                    writers.push(None);
+        let output_partitioning = self.shuffle_output_partitioning.clone();
+        let plan = self.plan.clone();
+
+        async move {
+            let now = Instant::now();
+            let mut stream = plan.execute(input_partition, context).await?;
+
+            match output_partitioning {
+                None => {
+                    let timer = write_metrics.write_time.timer();
+                    path.push(&format!("{}", input_partition));
+                    std::fs::create_dir_all(&path)?;
+                    path.push("data.arrow");
+                    let path = path.to_str().unwrap();
+                    info!("Writing results to {}", path);
+
+                    // stream results to disk
+                    let stats = utils::write_stream_to_disk(
+                        &mut stream,
+                        path,
+                        &write_metrics.write_time,
+                    )
+                    .await
+                    .map_err(|e| DataFusionError::Execution(format!("{:?}", 
e)))?;
+
+                    write_metrics
+                        .input_rows
+                        .add(stats.num_rows.unwrap_or(0) as usize);
+                    write_metrics
+                        .output_rows
+                        .add(stats.num_rows.unwrap_or(0) as usize);
+                    timer.done();
+
+                    info!(
+                        "Executed partition {} in {} seconds. Statistics: {}",
+                        input_partition,
+                        now.elapsed().as_secs(),
+                        stats
+                    );
+
+                    Ok(vec![ShuffleWritePartition {
+                        partition_id: input_partition as u64,
+                        path: path.to_owned(),
+                        num_batches: stats.num_batches.unwrap_or(0),
+                        num_rows: stats.num_rows.unwrap_or(0),
+                        num_bytes: stats.num_bytes.unwrap_or(0),
+                    }])
                 }
 
-                let mut partitioner = BatchPartitioner::try_new(
-                    Partitioning::Hash(exprs.clone(), *n),
-                    write_metrics.repart_time.clone(),
-                )?;
-
-                while let Some(result) = stream.next().await {
-                    let input_batch = result?;
-
-                    write_metrics.input_rows.add(input_batch.num_rows());
+                Some(Partitioning::Hash(exprs, num_output_partitions)) => {
+                    // we won't necessary produce output for every possible 
partition, so we
+                    // create writers on demand
+                    let mut writers: Vec<Option<IPCWriter>> = vec![];
+                    for _ in 0..num_output_partitions {
+                        writers.push(None);
+                    }
 
-                    partitioner.partition(
-                        input_batch,
-                        |output_partition, output_batch| {
-                            // write non-empty batch out
+                    let mut partitioner = BatchPartitioner::try_new(
+                        Partitioning::Hash(exprs, num_output_partitions),
+                        write_metrics.repart_time.clone(),
+                    )?;
 
-                            // TODO optimize so we don't write or fetch empty 
partitions
-                            // if output_batch.num_rows() > 0 {
-                            let timer = write_metrics.write_time.timer();
-                            match &mut writers[output_partition] {
-                                Some(w) => {
-                                    w.write(&output_batch)?;
+                    while let Some(result) = stream.next().await {
+                        let input_batch = result?;
+
+                        write_metrics.input_rows.add(input_batch.num_rows());
+
+                        partitioner.partition(
+                            input_batch,
+                            |output_partition, output_batch| {
+                                // write non-empty batch out
+
+                                // TODO optimize so we don't write or fetch 
empty partitions
+                                // if output_batch.num_rows() > 0 {
+                                let timer = write_metrics.write_time.timer();
+                                match &mut writers[output_partition] {
+                                    Some(w) => {
+                                        w.write(&output_batch)?;
+                                    }
+                                    None => {
+                                        let mut path = path.clone();
+                                        path.push(&format!("{}", 
output_partition));
+                                        std::fs::create_dir_all(&path)?;
+
+                                        path.push(format!(
+                                            "data-{}.arrow",
+                                            input_partition
+                                        ));
+                                        info!("Writing results to {:?}", path);
+
+                                        let mut writer = IPCWriter::new(
+                                            &path,
+                                            stream.schema().as_ref(),
+                                        )?;
+
+                                        writer.write(&output_batch)?;
+                                        writers[output_partition] = 
Some(writer);
+                                    }
                                 }
-                                None => {
-                                    let mut path = path.clone();
-                                    path.push(&format!("{}", 
output_partition));
-                                    std::fs::create_dir_all(&path)?;
-
-                                    path.push(format!("data-{}.arrow", 
input_partition));
-                                    info!("Writing results to {:?}", path);
-
-                                    let mut writer =
-                                        IPCWriter::new(&path, 
stream.schema().as_ref())?;
+                                
write_metrics.output_rows.add(output_batch.num_rows());
+                                timer.done();
+                                Ok(())
+                            },
+                        )?;
+                    }
 
-                                    writer.write(&output_batch)?;
-                                    writers[output_partition] = Some(writer);
-                                }
+                    let mut part_locs = vec![];
+
+                    for (i, w) in writers.iter_mut().enumerate() {
+                        match w {
+                            Some(w) => {
+                                w.finish()?;
+                                info!(
+                                    "Finished writing shuffle partition {} at 
{:?}. Batches: {}. Rows: {}. Bytes: {}.",
+                                    i,
+                                    w.path(),
+                                    w.num_batches,
+                                    w.num_rows,
+                                    w.num_bytes
+                                );
+
+                                part_locs.push(ShuffleWritePartition {
+                                    partition_id: i as u64,
+                                    path: 
w.path().to_string_lossy().to_string(),
+                                    num_batches: w.num_batches,
+                                    num_rows: w.num_rows,
+                                    num_bytes: w.num_bytes,
+                                });
                             }
-                            
write_metrics.output_rows.add(output_batch.num_rows());
-                            timer.done();
-                            Ok(())
-                        },
-                    )?;
-                }
-
-                let mut part_locs = vec![];
-
-                for (i, w) in writers.iter_mut().enumerate() {
-                    match w {
-                        Some(w) => {
-                            w.finish()?;
-                            info!(
-                                "Finished writing shuffle partition {} at 
{:?}. Batches: {}. Rows: {}. Bytes: {}.",
-                                i,
-                                w.path(),
-                                w.num_batches,
-                                w.num_rows,
-                                w.num_bytes
-                            );
-
-                            part_locs.push(ShuffleWritePartition {
-                                partition_id: i as u64,
-                                path: w.path().to_string_lossy().to_string(),
-                                num_batches: w.num_batches,
-                                num_rows: w.num_rows,
-                                num_bytes: w.num_bytes,
-                            });
+                            None => {}
                         }
-                        None => {}
                     }
+                    Ok(part_locs)
                 }
-                Ok(part_locs)
-            }
 
-            _ => Err(DataFusionError::Execution(
-                "Invalid shuffle partitioning scheme".to_owned(),
-            )),
+                _ => Err(DataFusionError::Execution(
+                    "Invalid shuffle partitioning scheme".to_owned(),
+                )),
+            }
         }
     }
 }
@@ -332,50 +341,61 @@ impl ExecutionPlan for ShuffleWriterExec {
         partition: usize,
         context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
-        let part_loc = self.execute_shuffle_write(partition, context).await?;
-
-        // build metadata result batch
-        let num_writers = part_loc.len();
-        let mut partition_builder = UInt32Builder::new(num_writers);
-        let mut path_builder = StringBuilder::new(num_writers);
-        let mut num_rows_builder = UInt64Builder::new(num_writers);
-        let mut num_batches_builder = UInt64Builder::new(num_writers);
-        let mut num_bytes_builder = UInt64Builder::new(num_writers);
-
-        for loc in &part_loc {
-            path_builder.append_value(loc.path.clone())?;
-            partition_builder.append_value(loc.partition_id as u32)?;
-            num_rows_builder.append_value(loc.num_rows)?;
-            num_batches_builder.append_value(loc.num_batches)?;
-            num_bytes_builder.append_value(loc.num_bytes)?;
-        }
+        let schema = result_schema();
 
-        // build arrays
-        let partition_num: ArrayRef = Arc::new(partition_builder.finish());
-        let path: ArrayRef = Arc::new(path_builder.finish());
-        let field_builders: Vec<Box<dyn ArrayBuilder>> = vec![
-            Box::new(num_rows_builder),
-            Box::new(num_batches_builder),
-            Box::new(num_bytes_builder),
-        ];
-        let mut stats_builder = StructBuilder::new(
-            PartitionStats::default().arrow_struct_fields(),
-            field_builders,
-        );
-        for _ in 0..num_writers {
-            stats_builder.append(true)?;
-        }
-        let stats = Arc::new(stats_builder.finish());
+        let schema_captured = schema.clone();
+        let fut_stream = self
+            .execute_shuffle_write(partition, context)
+            .and_then(|part_loc| async move {
+                // build metadata result batch
+                let num_writers = part_loc.len();
+                let mut partition_builder = UInt32Builder::new(num_writers);
+                let mut path_builder = StringBuilder::new(num_writers);
+                let mut num_rows_builder = UInt64Builder::new(num_writers);
+                let mut num_batches_builder = UInt64Builder::new(num_writers);
+                let mut num_bytes_builder = UInt64Builder::new(num_writers);
+
+                for loc in &part_loc {
+                    path_builder.append_value(loc.path.clone())?;
+                    partition_builder.append_value(loc.partition_id as u32)?;
+                    num_rows_builder.append_value(loc.num_rows)?;
+                    num_batches_builder.append_value(loc.num_batches)?;
+                    num_bytes_builder.append_value(loc.num_bytes)?;
+                }
 
-        // build result batch containing metadata
-        let schema = result_schema();
-        let batch =
-            RecordBatch::try_new(schema.clone(), vec![partition_num, path, 
stats])
-                .map_err(DataFusionError::ArrowError)?;
+                // build arrays
+                let partition_num: ArrayRef = 
Arc::new(partition_builder.finish());
+                let path: ArrayRef = Arc::new(path_builder.finish());
+                let field_builders: Vec<Box<dyn ArrayBuilder>> = vec![
+                    Box::new(num_rows_builder),
+                    Box::new(num_batches_builder),
+                    Box::new(num_bytes_builder),
+                ];
+                let mut stats_builder = StructBuilder::new(
+                    PartitionStats::default().arrow_struct_fields(),
+                    field_builders,
+                );
+                for _ in 0..num_writers {
+                    stats_builder.append(true)?;
+                }
+                let stats = Arc::new(stats_builder.finish());
+
+                // build result batch containing metadata
+                let batch = RecordBatch::try_new(
+                    schema_captured.clone(),
+                    vec![partition_num, path, stats],
+                )?;
+
+                debug!("RESULTS METADATA:\n{:?}", batch);
 
-        debug!("RESULTS METADATA:\n{:?}", batch);
+                MemoryStream::try_new(vec![batch], schema_captured, None)
+            })
+            .map_err(|e| ArrowError::ExternalError(Box::new(e)));
 
-        Ok(Box::pin(MemoryStream::try_new(vec![batch], schema, None)?))
+        Ok(Box::pin(RecordBatchStreamAdapter::new(
+            schema,
+            futures::stream::once(fut_stream).try_flatten(),
+        )))
     }
 
     fn metrics(&self) -> Option<MetricsSet> {
diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs
index 85a557e43..1418aecb3 100644
--- a/ballista/rust/core/src/utils.rs
+++ b/ballista/rust/core/src/utils.rs
@@ -32,10 +32,7 @@ use crate::config::BallistaConfig;
 use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, 
LogicalExtensionCodec};
 use async_trait::async_trait;
 use datafusion::arrow::datatypes::Schema;
-use datafusion::arrow::error::Result as ArrowResult;
-use datafusion::arrow::{
-    datatypes::SchemaRef, ipc::writer::FileWriter, record_batch::RecordBatch,
-};
+use datafusion::arrow::{ipc::writer::FileWriter, record_batch::RecordBatch};
 use datafusion::error::DataFusionError;
 use datafusion::execution::context::{
     QueryPlanner, SessionConfig, SessionContext, SessionState,
@@ -55,7 +52,7 @@ use datafusion::physical_plan::hash_join::HashJoinExec;
 use datafusion::physical_plan::projection::ProjectionExec;
 use datafusion::physical_plan::sorts::sort::SortExec;
 use datafusion::physical_plan::{metrics, ExecutionPlan, RecordBatchStream};
-use futures::{Stream, StreamExt};
+use futures::StreamExt;
 
 /// Stream data to disk in Arrow IPC format
 
@@ -316,38 +313,3 @@ impl<T: 'static + AsLogicalPlan> QueryPlanner for 
BallistaQueryPlanner<T> {
         }
     }
 }
-
-pub struct WrappedStream {
-    stream: Pin<Box<dyn Stream<Item = ArrowResult<RecordBatch>> + Send>>,
-    schema: SchemaRef,
-}
-
-impl WrappedStream {
-    pub fn new(
-        stream: Pin<Box<dyn Stream<Item = ArrowResult<RecordBatch>> + Send>>,
-        schema: SchemaRef,
-    ) -> Self {
-        Self { stream, schema }
-    }
-}
-
-impl RecordBatchStream for WrappedStream {
-    fn schema(&self) -> SchemaRef {
-        self.schema.clone()
-    }
-}
-
-impl Stream for WrappedStream {
-    type Item = ArrowResult<RecordBatch>;
-
-    fn poll_next(
-        mut self: Pin<&mut Self>,
-        cx: &mut std::task::Context<'_>,
-    ) -> std::task::Poll<Option<Self::Item>> {
-        self.stream.poll_next_unpin(cx)
-    }
-
-    fn size_hint(&self) -> (usize, Option<usize>) {
-        self.stream.size_hint()
-    }
-}
diff --git a/datafusion/core/src/physical_plan/stream.rs 
b/datafusion/core/src/physical_plan/stream.rs
index 99209121f..06d670ff4 100644
--- a/datafusion/core/src/physical_plan/stream.rs
+++ b/datafusion/core/src/physical_plan/stream.rs
@@ -78,7 +78,7 @@ impl RecordBatchStream for RecordBatchReceiverStream {
 pin_project! {
     /// Combines a [`Stream`] with a [`SchemaRef`] implementing
     /// [`RecordBatchStream`] for the combination
-    pub(crate) struct RecordBatchStreamAdapter<S> {
+    pub struct RecordBatchStreamAdapter<S> {
         schema: SchemaRef,
 
         #[pin]
@@ -88,7 +88,7 @@ pin_project! {
 
 impl<S> RecordBatchStreamAdapter<S> {
     /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema 
and stream
-    pub(crate) fn new(schema: SchemaRef, stream: S) -> Self {
+    pub fn new(schema: SchemaRef, stream: S) -> Self {
         Self { schema, stream }
     }
 }
@@ -113,6 +113,10 @@ where
     ) -> std::task::Poll<Option<Self::Item>> {
         self.project().stream.poll_next(cx)
     }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        self.stream.size_hint()
+    }
 }
 
 impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>

Reply via email to