This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new e4df37a  Ballista: Prep for fixing shuffle mechansim, part 1 (#738)
e4df37a is described below

commit e4df37a4001423909964348289360da66acdd0a3
Author: Andy Grove <[email protected]>
AuthorDate: Mon Jul 19 17:58:14 2021 -0600

    Ballista: Prep for fixing shuffle mechansim, part 1 (#738)
---
 ballista/rust/core/proto/ballista.proto            |  12 +
 .../core/src/execution_plans/shuffle_writer.rs     | 252 ++++++++++++---------
 .../rust/core/src/serde/physical_plan/to_proto.rs  |  38 ++--
 .../rust/core/src/serde/scheduler/from_proto.rs    |   1 +
 ballista/rust/core/src/serde/scheduler/mod.rs      |   7 +-
 ballista/rust/core/src/serde/scheduler/to_proto.rs |   1 +
 ballista/rust/executor/src/execution_loop.rs       |  12 +-
 ballista/rust/executor/src/executor.rs             |  21 +-
 ballista/rust/scheduler/src/planner.rs             |  44 +++-
 ballista/rust/scheduler/src/state/mod.rs           | 117 ++++++----
 10 files changed, 316 insertions(+), 189 deletions(-)

diff --git a/ballista/rust/core/proto/ballista.proto 
b/ballista/rust/core/proto/ballista.proto
index 0575460..50bd901 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -721,6 +721,7 @@ message PartitionLocation {
   PartitionId partition_id = 1;
   ExecutorMetadata executor_meta = 2;
   PartitionStats partition_stats = 3;
+  string path = 4;
 }
 
 // Unique identifier for a materialized partition of data
@@ -776,6 +777,17 @@ message FailedTask {
 
 message CompletedTask {
   string executor_id = 1;
+  // TODO tasks are currently always shuffle writes but this will not always 
be the case
+  // so we might want to think about some refactoring of the task definitions
+  repeated ShuffleWritePartition partitions = 2;
+}
+
+message ShuffleWritePartition {
+  uint64 partition_id = 1;
+  string path = 2;
+  uint64 num_batches = 3;
+  uint64 num_rows = 4;
+  uint64 num_bytes = 5;
 }
 
 message TaskStatus {
diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs 
b/ballista/rust/core/src/execution_plans/shuffle_writer.rs
index d5c7d8f..47bf2a2 100644
--- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs
+++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs
@@ -31,7 +31,8 @@ use crate::error::BallistaError;
 use crate::memory_stream::MemoryStream;
 use crate::utils;
 
-use crate::serde::scheduler::PartitionStats;
+use crate::serde::protobuf::ShuffleWritePartition;
+use crate::serde::scheduler::{PartitionLocation, PartitionStats};
 use async_trait::async_trait;
 use datafusion::arrow::array::{
     Array, ArrayBuilder, ArrayRef, StringBuilder, StructBuilder, UInt32Builder,
@@ -39,16 +40,19 @@ use datafusion::arrow::array::{
 };
 use datafusion::arrow::compute::take;
 use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use datafusion::arrow::ipc::reader::FileReader;
 use datafusion::arrow::ipc::writer::FileWriter;
 use datafusion::arrow::record_batch::RecordBatch;
 use datafusion::error::{DataFusionError, Result};
 use datafusion::physical_plan::hash_join::create_hashes;
+use datafusion::physical_plan::repartition::RepartitionExec;
+use datafusion::physical_plan::Partitioning::RoundRobinBatch;
 use datafusion::physical_plan::{
     DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, 
SQLMetric,
 };
 use futures::StreamExt;
 use hashbrown::HashMap;
-use log::info;
+use log::{debug, info};
 use uuid::Uuid;
 
 /// ShuffleWriterExec represents a section of a query plan that has consistent 
partitioning and
@@ -75,12 +79,16 @@ pub struct ShuffleWriterExec {
 struct ShuffleWriteMetrics {
     /// Time spend writing batches to shuffle files
     write_time: Arc<SQLMetric>,
+    input_rows: Arc<SQLMetric>,
+    output_rows: Arc<SQLMetric>,
 }
 
 impl ShuffleWriteMetrics {
     fn new() -> Self {
         Self {
             write_time: SQLMetric::time_nanos(),
+            input_rows: SQLMetric::counter(),
+            output_rows: SQLMetric::counter(),
         }
     }
 }
@@ -113,50 +121,19 @@ impl ShuffleWriterExec {
     pub fn stage_id(&self) -> usize {
         self.stage_id
     }
-}
-
-#[async_trait]
-impl ExecutionPlan for ShuffleWriterExec {
-    fn as_any(&self) -> &dyn Any {
-        self
-    }
-
-    fn schema(&self) -> SchemaRef {
-        self.plan.schema()
-    }
-
-    fn output_partitioning(&self) -> Partitioning {
-        match &self.shuffle_output_partitioning {
-            Some(p) => p.clone(),
-            _ => Partitioning::UnknownPartitioning(1),
-        }
-    }
 
-    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
-        vec![self.plan.clone()]
+    /// Get the true output partitioning
+    pub fn shuffle_output_partitioning(&self) -> Option<&Partitioning> {
+        self.shuffle_output_partitioning.as_ref()
     }
 
-    fn with_new_children(
+    pub async fn execute_shuffle_write(
         &self,
-        children: Vec<Arc<dyn ExecutionPlan>>,
-    ) -> Result<Arc<dyn ExecutionPlan>> {
-        assert!(children.len() == 1);
-        Ok(Arc::new(ShuffleWriterExec::try_new(
-            self.job_id.clone(),
-            self.stage_id,
-            children[0].clone(),
-            self.work_dir.clone(),
-            self.shuffle_output_partitioning.clone(),
-        )?))
-    }
-
-    async fn execute(
-        &self,
-        partition: usize,
-    ) -> Result<Pin<Box<dyn RecordBatchStream + Send + Sync>>> {
+        input_partition: usize,
+    ) -> Result<Vec<ShuffleWritePartition>> {
         let now = Instant::now();
 
-        let mut stream = self.plan.execute(partition).await?;
+        let mut stream = self.plan.execute(input_partition).await?;
 
         let mut path = PathBuf::from(&self.work_dir);
         path.push(&self.job_id);
@@ -164,7 +141,7 @@ impl ExecutionPlan for ShuffleWriterExec {
 
         match &self.shuffle_output_partitioning {
             None => {
-                path.push(&format!("{}", partition));
+                path.push(&format!("{}", input_partition));
                 std::fs::create_dir_all(&path)?;
                 path.push("data.arrow");
                 let path = path.to_str().unwrap();
@@ -181,29 +158,18 @@ impl ExecutionPlan for ShuffleWriterExec {
 
                 info!(
                     "Executed partition {} in {} seconds. Statistics: {}",
-                    partition,
+                    input_partition,
                     now.elapsed().as_secs(),
                     stats
                 );
 
-                let schema = result_schema();
-
-                // build result set with summary of the partition execution 
status
-                let mut part_builder = UInt32Builder::new(1);
-                part_builder.append_value(partition as u32)?;
-                let part: ArrayRef = Arc::new(part_builder.finish());
-
-                let mut path_builder = StringBuilder::new(1);
-                path_builder.append_value(&path)?;
-                let path: ArrayRef = Arc::new(path_builder.finish());
-
-                let stats: ArrayRef = stats
-                    .to_arrow_arrayref()
-                    .map_err(|e| DataFusionError::Execution(format!("{:?}", 
e)))?;
-                let batch = RecordBatch::try_new(schema.clone(), vec![part, 
path, stats])
-                    .map_err(DataFusionError::ArrowError)?;
-
-                Ok(Box::pin(MemoryStream::try_new(vec![batch], schema, None)?))
+                Ok(vec![ShuffleWritePartition {
+                    partition_id: input_partition as u64,
+                    path: path.to_owned(),
+                    num_batches: stats.num_batches.unwrap_or(0),
+                    num_rows: stats.num_rows.unwrap_or(0),
+                    num_bytes: stats.num_bytes.unwrap_or(0),
+                }])
             }
 
             Some(Partitioning::Hash(exprs, n)) => {
@@ -218,8 +184,12 @@ impl ExecutionPlan for ShuffleWriterExec {
 
                 let hashes_buf = &mut vec![];
                 let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0);
+
                 while let Some(result) = stream.next().await {
                     let input_batch = result?;
+
+                    self.metrics.input_rows.add(input_batch.num_rows());
+
                     let arrays = exprs
                         .iter()
                         .map(|expr| {
@@ -241,6 +211,7 @@ impl ExecutionPlan for ShuffleWriterExec {
                         indices.into_iter().enumerate()
                     {
                         let indices = partition_indices.into();
+
                         // Produce batches based on indices
                         let columns = input_batch
                             .columns()
@@ -255,7 +226,8 @@ impl ExecutionPlan for ShuffleWriterExec {
                         let output_batch =
                             RecordBatch::try_new(input_batch.schema(), 
columns)?;
 
-                        // write batch out
+                        // write non-empty batch out
+                        //if output_batch.num_rows() > 0 {
                         let start = Instant::now();
                         match &mut writers[output_partition] {
                             Some(w) => {
@@ -266,7 +238,7 @@ impl ExecutionPlan for ShuffleWriterExec {
                                 path.push(&format!("{}", output_partition));
                                 std::fs::create_dir_all(&path)?;
 
-                                path.push("data.arrow");
+                                path.push(format!("data-{}.arrow", 
input_partition));
                                 let path = path.to_str().unwrap();
                                 info!("Writing results to {}", path);
 
@@ -277,58 +249,39 @@ impl ExecutionPlan for ShuffleWriterExec {
                                 writers[output_partition] = Some(writer);
                             }
                         }
+                        self.metrics.output_rows.add(output_batch.num_rows());
                         self.metrics.write_time.add_elapsed(start);
+                        //}
                     }
                 }
 
-                // build metadata result batch
-                let num_writers = writers.iter().filter(|w| 
w.is_some()).count();
-                let mut partition_builder = UInt32Builder::new(num_writers);
-                let mut path_builder = StringBuilder::new(num_writers);
-                let mut num_rows_builder = UInt64Builder::new(num_writers);
-                let mut num_batches_builder = UInt64Builder::new(num_writers);
-                let mut num_bytes_builder = UInt64Builder::new(num_writers);
+                let mut part_locs = vec![];
 
                 for (i, w) in writers.iter_mut().enumerate() {
                     match w {
                         Some(w) => {
                             w.finish()?;
-                            path_builder.append_value(w.path())?;
-                            partition_builder.append_value(i as u32)?;
-                            num_rows_builder.append_value(w.num_rows)?;
-                            num_batches_builder.append_value(w.num_batches)?;
-                            num_bytes_builder.append_value(w.num_bytes)?;
+                            info!(
+                                "Finished writing shuffle partition {} at {}. 
Batches: {}. Rows: {}. Bytes: {}.",
+                                i,
+                                w.path(),
+                                w.num_batches,
+                                w.num_rows,
+                                w.num_bytes
+                            );
+
+                            part_locs.push(ShuffleWritePartition {
+                                partition_id: i as u64,
+                                path: w.path().to_owned(),
+                                num_batches: w.num_batches,
+                                num_rows: w.num_rows,
+                                num_bytes: w.num_bytes,
+                            });
                         }
                         None => {}
                     }
                 }
-
-                // build arrays
-                let partition_num: ArrayRef = 
Arc::new(partition_builder.finish());
-                let path: ArrayRef = Arc::new(path_builder.finish());
-                let field_builders: Vec<Box<dyn ArrayBuilder>> = vec![
-                    Box::new(num_rows_builder),
-                    Box::new(num_batches_builder),
-                    Box::new(num_bytes_builder),
-                ];
-                let mut stats_builder = StructBuilder::new(
-                    PartitionStats::default().arrow_struct_fields(),
-                    field_builders,
-                );
-                for _ in 0..num_writers {
-                    stats_builder.append(true)?;
-                }
-                let stats = Arc::new(stats_builder.finish());
-
-                // build result batch containing metadata
-                let schema = result_schema();
-                let batch = RecordBatch::try_new(
-                    schema.clone(),
-                    vec![partition_num, path, stats],
-                )
-                .map_err(DataFusionError::ArrowError)?;
-
-                Ok(Box::pin(MemoryStream::try_new(vec![batch], schema, None)?))
+                Ok(part_locs)
             }
 
             _ => Err(DataFusionError::Execution(
@@ -336,9 +289,98 @@ impl ExecutionPlan for ShuffleWriterExec {
             )),
         }
     }
+}
+
+#[async_trait]
+impl ExecutionPlan for ShuffleWriterExec {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        self.plan.schema()
+    }
+
+    fn output_partitioning(&self) -> Partitioning {
+        // This operator needs to be executed once for each *input* partition 
and there
+        // isn't really a mechanism yet in DataFusion to support this use case 
so we report
+        // the input partitioning as the output partitioning here. The 
executor reports
+        // output partition meta data back to the scheduler.
+        self.plan.output_partitioning()
+    }
+
+    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+        vec![self.plan.clone()]
+    }
+
+    fn with_new_children(
+        &self,
+        children: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        assert!(children.len() == 1);
+        Ok(Arc::new(ShuffleWriterExec::try_new(
+            self.job_id.clone(),
+            self.stage_id,
+            children[0].clone(),
+            self.work_dir.clone(),
+            self.shuffle_output_partitioning.clone(),
+        )?))
+    }
+
+    async fn execute(
+        &self,
+        input_partition: usize,
+    ) -> Result<Pin<Box<dyn RecordBatchStream + Send + Sync>>> {
+        let part_loc = self.execute_shuffle_write(input_partition).await?;
+
+        // build metadata result batch
+        let num_writers = part_loc.len();
+        let mut partition_builder = UInt32Builder::new(num_writers);
+        let mut path_builder = StringBuilder::new(num_writers);
+        let mut num_rows_builder = UInt64Builder::new(num_writers);
+        let mut num_batches_builder = UInt64Builder::new(num_writers);
+        let mut num_bytes_builder = UInt64Builder::new(num_writers);
+
+        for loc in &part_loc {
+            path_builder.append_value(loc.path.clone())?;
+            partition_builder.append_value(loc.partition_id as u32)?;
+            num_rows_builder.append_value(loc.num_rows)?;
+            num_batches_builder.append_value(loc.num_batches)?;
+            num_bytes_builder.append_value(loc.num_bytes)?;
+        }
+
+        // build arrays
+        let partition_num: ArrayRef = Arc::new(partition_builder.finish());
+        let path: ArrayRef = Arc::new(path_builder.finish());
+        let field_builders: Vec<Box<dyn ArrayBuilder>> = vec![
+            Box::new(num_rows_builder),
+            Box::new(num_batches_builder),
+            Box::new(num_bytes_builder),
+        ];
+        let mut stats_builder = StructBuilder::new(
+            PartitionStats::default().arrow_struct_fields(),
+            field_builders,
+        );
+        for _ in 0..num_writers {
+            stats_builder.append(true)?;
+        }
+        let stats = Arc::new(stats_builder.finish());
+
+        // build result batch containing metadata
+        let schema = result_schema();
+        let batch =
+            RecordBatch::try_new(schema.clone(), vec![partition_num, path, 
stats])
+                .map_err(DataFusionError::ArrowError)?;
+
+        debug!("RESULTS METADATA:\n{:?}", batch);
+
+        Ok(Box::pin(MemoryStream::try_new(vec![batch], schema, None)?))
+    }
 
     fn metrics(&self) -> HashMap<String, SQLMetric> {
         let mut metrics = HashMap::new();
+        metrics.insert("inputRows".to_owned(), 
(*self.metrics.input_rows).clone());
+        metrics.insert("outputRows".to_owned(), 
(*self.metrics.output_rows).clone());
         metrics.insert("writeTime".to_owned(), 
(*self.metrics.write_time).clone());
         metrics
     }
@@ -454,13 +496,13 @@ mod tests {
 
         let file0 = path.value(0);
         assert!(
-            file0.ends_with("/jobOne/1/0/data.arrow")
-                || file0.ends_with("\\jobOne\\1\\0\\data.arrow")
+            file0.ends_with("/jobOne/1/0/data-0.arrow")
+                || file0.ends_with("\\jobOne\\1\\0\\data-0.arrow")
         );
         let file1 = path.value(1);
         assert!(
-            file1.ends_with("/jobOne/1/1/data.arrow")
-                || file1.ends_with("\\jobOne\\1\\1\\data.arrow")
+            file1.ends_with("/jobOne/1/1/data-0.arrow")
+                || file1.ends_with("\\jobOne\\1\\1\\data-0.arrow")
         );
 
         let stats = batch.columns()[2]
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs 
b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index 0429efb..fa35eb4 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -361,29 +361,33 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn 
ExecutionPlan> {
         } else if let Some(exec) = plan.downcast_ref::<ShuffleWriterExec>() {
             let input: protobuf::PhysicalPlanNode =
                 exec.children()[0].to_owned().try_into()?;
+            // note that we use shuffle_output_partitioning() rather than 
output_partitioning()
+            // to get the true output partitioning
+            let output_partitioning = match exec.shuffle_output_partitioning() 
{
+                Some(Partitioning::Hash(exprs, partition_count)) => {
+                    Some(protobuf::PhysicalHashRepartition {
+                        hash_expr: exprs
+                            .iter()
+                            .map(|expr| expr.clone().try_into())
+                            .collect::<Result<Vec<_>, BallistaError>>()?,
+                        partition_count: *partition_count as u64,
+                    })
+                }
+                None => None,
+                other => {
+                    return Err(BallistaError::General(format!(
+                        "physical_plan::to_proto() invalid partitioning for 
ShuffleWriterExec: {:?}",
+                        other
+                    )))
+                }
+            };
             Ok(protobuf::PhysicalPlanNode {
                 physical_plan_type: 
Some(PhysicalPlanType::ShuffleWriter(Box::new(
                     protobuf::ShuffleWriterExecNode {
                         job_id: exec.job_id().to_string(),
                         stage_id: exec.stage_id() as u32,
                         input: Some(Box::new(input)),
-                        output_partitioning: match exec.output_partitioning() {
-                            Partitioning::Hash(exprs, partition_count) => {
-                                Some(protobuf::PhysicalHashRepartition {
-                                    hash_expr: exprs
-                                        .iter()
-                                        .map(|expr| expr.clone().try_into())
-                                        .collect::<Result<Vec<_>, 
BallistaError>>()?,
-                                    partition_count: partition_count as u64,
-                                })
-                            }
-                            other => {
-                                return Err(BallistaError::General(format!(
-                                    "physical_plan::to_proto() invalid 
partitioning for ShuffleWriterExec: {:?}",
-                                    other
-                                )))
-                            }
-                        },
+                        output_partitioning,
                     },
                 ))),
             })
diff --git a/ballista/rust/core/src/serde/scheduler/from_proto.rs 
b/ballista/rust/core/src/serde/scheduler/from_proto.rs
index 73f8f53..4f9c9bc 100644
--- a/ballista/rust/core/src/serde/scheduler/from_proto.rs
+++ b/ballista/rust/core/src/serde/scheduler/from_proto.rs
@@ -102,6 +102,7 @@ impl TryInto<PartitionLocation> for 
protobuf::PartitionLocation {
                     )
                 })?
                 .into(),
+            path: self.path,
         })
     }
 }
diff --git a/ballista/rust/core/src/serde/scheduler/mod.rs 
b/ballista/rust/core/src/serde/scheduler/mod.rs
index fa2c1b8..eeddfbb 100644
--- a/ballista/rust/core/src/serde/scheduler/mod.rs
+++ b/ballista/rust/core/src/serde/scheduler/mod.rs
@@ -62,6 +62,7 @@ pub struct PartitionLocation {
     pub partition_id: PartitionId,
     pub executor_meta: ExecutorMeta,
     pub partition_stats: PartitionStats,
+    pub path: String,
 }
 
 /// Meta-data for an executor, used when fetching shuffle partitions from 
other executors
@@ -96,9 +97,9 @@ impl From<protobuf::ExecutorMetadata> for ExecutorMeta {
 /// Summary of executed partition
 #[derive(Debug, Copy, Clone)]
 pub struct PartitionStats {
-    num_rows: Option<u64>,
-    num_batches: Option<u64>,
-    num_bytes: Option<u64>,
+    pub(crate) num_rows: Option<u64>,
+    pub(crate) num_batches: Option<u64>,
+    pub(crate) num_bytes: Option<u64>,
 }
 
 impl Default for PartitionStats {
diff --git a/ballista/rust/core/src/serde/scheduler/to_proto.rs 
b/ballista/rust/core/src/serde/scheduler/to_proto.rs
index c3f2046..57d4f61 100644
--- a/ballista/rust/core/src/serde/scheduler/to_proto.rs
+++ b/ballista/rust/core/src/serde/scheduler/to_proto.rs
@@ -70,6 +70,7 @@ impl TryInto<protobuf::PartitionLocation> for 
PartitionLocation {
             partition_id: Some(self.partition_id.into()),
             executor_meta: Some(self.executor_meta.into()),
             partition_stats: Some(self.partition_stats.into()),
+            path: self.path,
         })
     }
 }
diff --git a/ballista/rust/executor/src/execution_loop.rs 
b/ballista/rust/executor/src/execution_loop.rs
index 17f6e4d..b65b83b 100644
--- a/ballista/rust/executor/src/execution_loop.rs
+++ b/ballista/rust/executor/src/execution_loop.rs
@@ -27,7 +27,8 @@ use tonic::transport::Channel;
 use ballista_core::serde::protobuf::ExecutorRegistration;
 use ballista_core::serde::protobuf::{
     self, scheduler_grpc_client::SchedulerGrpcClient, task_status, FailedTask,
-    PartitionId, PollWorkParams, PollWorkResult, TaskDefinition, TaskStatus,
+    PartitionId, PollWorkParams, PollWorkResult, ShuffleWritePartition, 
TaskDefinition,
+    TaskStatus,
 };
 use protobuf::CompletedTask;
 
@@ -110,7 +111,7 @@ async fn run_received_tasks(
 
     tokio::spawn(async move {
         let execution_result = executor
-            .execute_partition(
+            .execute_shuffle_write(
                 task_id.job_id.clone(),
                 task_id.stage_id as usize,
                 task_id.partition_id as usize,
@@ -121,7 +122,7 @@ async fn run_received_tasks(
         debug!("Statistics: {:?}", execution_result);
         available_tasks_slots.fetch_add(1, Ordering::SeqCst);
         let _ = task_status_sender.send(as_task_status(
-            execution_result.map(|_| ()),
+            execution_result,
             executor_id,
             task_id,
         ));
@@ -129,18 +130,19 @@ async fn run_received_tasks(
 }
 
 fn as_task_status(
-    execution_result: ballista_core::error::Result<()>,
+    execution_result: ballista_core::error::Result<Vec<ShuffleWritePartition>>,
     executor_id: String,
     task_id: PartitionId,
 ) -> TaskStatus {
     match execution_result {
-        Ok(_) => {
+        Ok(partitions) => {
             info!("Task {:?} finished", task_id);
 
             TaskStatus {
                 partition_id: Some(task_id),
                 status: Some(task_status::Status::Completed(CompletedTask {
                     executor_id,
+                    partitions,
                 })),
             }
         }
diff --git a/ballista/rust/executor/src/executor.rs 
b/ballista/rust/executor/src/executor.rs
index 4a75448..cbf3eb0 100644
--- a/ballista/rust/executor/src/executor.rs
+++ b/ballista/rust/executor/src/executor.rs
@@ -21,8 +21,7 @@ use std::sync::Arc;
 
 use ballista_core::error::BallistaError;
 use ballista_core::execution_plans::ShuffleWriterExec;
-use ballista_core::utils;
-use datafusion::arrow::record_batch::RecordBatch;
+use ballista_core::serde::protobuf;
 use datafusion::physical_plan::display::DisplayableExecutionPlan;
 use datafusion::physical_plan::ExecutionPlan;
 
@@ -45,22 +44,26 @@ impl Executor {
     /// Execute one partition of a query stage and persist the result to disk 
in IPC format. On
     /// success, return a RecordBatch containing metadata about the results, 
including path
     /// and statistics.
-    pub async fn execute_partition(
+    pub async fn execute_shuffle_write(
         &self,
         job_id: String,
         stage_id: usize,
         part: usize,
         plan: Arc<dyn ExecutionPlan>,
-    ) -> Result<RecordBatch, BallistaError> {
+    ) -> Result<Vec<protobuf::ShuffleWritePartition>, BallistaError> {
+        // TODO to enable shuffling we need to specify the output partitioning 
here and
+        // until we do that there is always a single output partition
+        // see https://github.com/apache/arrow-datafusion/issues/707
+        let shuffle_output_partitioning = None;
+
         let exec = ShuffleWriterExec::try_new(
             job_id,
             stage_id,
             plan,
             self.work_dir.clone(),
-            None,
+            shuffle_output_partitioning,
         )?;
-        let mut stream = exec.execute(part).await?;
-        let batches = utils::collect_stream(&mut stream).await?;
+        let partitions = exec.execute_shuffle_write(part).await?;
 
         println!(
             "=== Physical plan with metrics ===\n{}\n",
@@ -69,9 +72,7 @@ impl Executor {
                 .to_string()
         );
 
-        // the output should be a single batch containing metadata (path and 
statistics)
-        assert!(batches.len() == 1);
-        Ok(batches[0].clone())
+        Ok(partitions)
     }
 
     pub fn work_dir(&self) -> &str {
diff --git a/ballista/rust/scheduler/src/planner.rs 
b/ballista/rust/scheduler/src/planner.rs
index 3f90da2..11f5c99 100644
--- a/ballista/rust/scheduler/src/planner.rs
+++ b/ballista/rust/scheduler/src/planner.rs
@@ -108,6 +108,10 @@ impl DistributedPlanner {
             let query_stage = create_shuffle_writer(
                 job_id,
                 self.next_stage_id(),
+                //TODO should be children[0].clone() so that we replace this
+                // with an UnresolvedShuffleExec instead of just executing this
+                // part of the plan again
+                // see https://github.com/apache/arrow-datafusion/issues/707
                 coalesce.children()[0].clone(),
                 None,
             )?;
@@ -127,6 +131,10 @@ impl DistributedPlanner {
             let query_stage = create_shuffle_writer(
                 job_id,
                 self.next_stage_id(),
+                //TODO should be children[0].clone() so that we replace this
+                // with an UnresolvedShuffleExec instead of just executing this
+                // part of the plan again
+                // see https://github.com/apache/arrow-datafusion/issues/707
                 repart.children()[0].clone(),
                 Some(repart.partitioning().to_owned()),
             )?;
@@ -158,7 +166,7 @@ impl DistributedPlanner {
 
 pub fn remove_unresolved_shuffles(
     stage: &dyn ExecutionPlan,
-    partition_locations: &HashMap<usize, Vec<Vec<PartitionLocation>>>,
+    partition_locations: &HashMap<usize, HashMap<usize, 
Vec<PartitionLocation>>>,
 ) -> Result<Arc<dyn ExecutionPlan>> {
     let mut new_children: Vec<Arc<dyn ExecutionPlan>> = vec![];
     for child in stage.children() {
@@ -166,16 +174,30 @@ pub fn remove_unresolved_shuffles(
             child.as_any().downcast_ref::<UnresolvedShuffleExec>()
         {
             let mut relevant_locations = vec![];
-            relevant_locations.append(
-                &mut partition_locations
-                    .get(&unresolved_shuffle.stage_id)
-                    .ok_or_else(|| {
-                        BallistaError::General(
-                            "Missing partition location. Could not remove 
unresolved shuffles"
-                                .to_owned(),
-                        )
-                    })?
-                    .clone(),
+            let p = partition_locations
+                .get(&unresolved_shuffle.stage_id)
+                .ok_or_else(|| {
+                    BallistaError::General(
+                        "Missing partition location. Could not remove 
unresolved shuffles"
+                            .to_owned(),
+                    )
+                })?
+                .clone();
+
+            for i in 0..unresolved_shuffle.partition_count {
+                if let Some(x) = p.get(&i) {
+                    relevant_locations.push(x.to_owned());
+                } else {
+                    relevant_locations.push(vec![]);
+                }
+            }
+            println!(
+                "create shuffle reader with {:?}",
+                relevant_locations
+                    .iter()
+                    .map(|c| format!("{:?}", c))
+                    .collect::<Vec<_>>()
+                    .join("\n")
             );
             new_children.push(Arc::new(ShuffleReaderExec::try_new(
                 relevant_locations,
diff --git a/ballista/rust/scheduler/src/state/mod.rs 
b/ballista/rust/scheduler/src/state/mod.rs
index 3ddbced..a4ae59e 100644
--- a/ballista/rust/scheduler/src/state/mod.rs
+++ b/ballista/rust/scheduler/src/state/mod.rs
@@ -27,16 +27,13 @@ use prost::Message;
 use tokio::sync::OwnedMutexGuard;
 
 use ballista_core::serde::protobuf::{
-    job_status, task_status, CompletedJob, CompletedTask, ExecutorHeartbeat,
+    self, job_status, task_status, CompletedJob, CompletedTask, 
ExecutorHeartbeat,
     ExecutorMetadata, FailedJob, FailedTask, JobStatus, PhysicalPlanNode, 
RunningJob,
     RunningTask, TaskStatus,
 };
 use ballista_core::serde::scheduler::PartitionStats;
 use ballista_core::{error::BallistaError, serde::scheduler::ExecutorMeta};
-use ballista_core::{
-    error::Result, execution_plans::UnresolvedShuffleExec,
-    serde::protobuf::PartitionLocation,
-};
+use ballista_core::{error::Result, execution_plans::UnresolvedShuffleExec};
 
 use super::planner::remove_unresolved_shuffles;
 
@@ -254,9 +251,9 @@ impl SchedulerState {
         executors: &[ExecutorMeta],
     ) -> Result<bool> {
         let executor_id: &str = match &task_status.status {
-            Some(task_status::Status::Completed(CompletedTask { executor_id 
})) => {
-                executor_id
-            }
+            Some(task_status::Status::Completed(CompletedTask {
+                executor_id, ..
+            })) => executor_id,
             Some(task_status::Status::Running(RunningTask { executor_id })) => 
{
                 executor_id
             }
@@ -298,8 +295,11 @@ impl SchedulerState {
                 // Let's try to resolve any unresolved shuffles we find
                 let unresolved_shuffles = find_unresolved_shuffles(&plan)?;
                 let mut partition_locations: HashMap<
-                    usize,
-                    
Vec<Vec<ballista_core::serde::scheduler::PartitionLocation>>,
+                    usize, // stage id
+                    HashMap<
+                        usize,                                                 
  // shuffle input partition id
+                        
Vec<ballista_core::serde::scheduler::PartitionLocation>, // shuffle output 
partitions
+                    >,
                 > = HashMap::new();
                 for unresolved_shuffle in unresolved_shuffles {
                     for partition_id in 0..unresolved_shuffle.partition_count {
@@ -317,30 +317,49 @@ impl SchedulerState {
                         if task_is_dead {
                             continue 'tasks;
                         } else if let Some(task_status::Status::Completed(
-                            CompletedTask { executor_id },
+                            CompletedTask {
+                                executor_id,
+                                partitions,
+                            },
                         )) = &referenced_task.status
                         {
-                            let empty = vec![];
                             let locations = partition_locations
                                 .entry(unresolved_shuffle.stage_id)
-                                .or_insert(empty);
+                                .or_insert_with(HashMap::new);
                             let executor_meta = executors
                                 .iter()
                                 .find(|exec| exec.id == *executor_id)
                                 .unwrap()
                                 .clone();
-                            locations.push(vec![
-                                
ballista_core::serde::scheduler::PartitionLocation {
-                                    partition_id:
-                                        
ballista_core::serde::scheduler::PartitionId {
-                                            job_id: partition.job_id.clone(),
-                                            stage_id: 
unresolved_shuffle.stage_id,
-                                            partition_id,
-                                        },
-                                    executor_meta,
-                                    partition_stats: PartitionStats::default(),
-                                },
-                            ]);
+
+                            let temp =
+                                
locations.entry(partition_id).or_insert_with(Vec::new);
+                            for p in partitions {
+                                let executor_meta = executor_meta.clone();
+                                let partition_location =
+                                    
ballista_core::serde::scheduler::PartitionLocation {
+                                        partition_id:
+                                            
ballista_core::serde::scheduler::PartitionId {
+                                                job_id: 
partition.job_id.clone(),
+                                                stage_id: 
unresolved_shuffle.stage_id,
+                                                partition_id,
+                                            },
+                                        executor_meta,
+                                        partition_stats: PartitionStats::new(
+                                            Some(p.num_rows),
+                                            Some(p.num_batches),
+                                            Some(p.num_bytes),
+                                        ),
+                                        path: p.path.clone(),
+                                    };
+                                info!(
+                                    "Scheduler storing stage {} partition {} 
path: {}",
+                                    unresolved_shuffle.stage_id,
+                                    partition_id,
+                                    partition_location.path
+                                );
+                                temp.push(partition_location);
+                            }
                         } else {
                             continue 'tasks;
                         }
@@ -452,24 +471,39 @@ impl SchedulerState {
         let mut job_status = statuses
             .iter()
             .map(|status| match &status.status {
-                Some(task_status::Status::Completed(CompletedTask { 
executor_id })) => {
-                    Ok((status, executor_id))
-                }
+                Some(task_status::Status::Completed(CompletedTask {
+                    executor_id,
+                    partitions,
+                })) => Ok((status, executor_id, partitions)),
                 _ => Err(BallistaError::General("Task not 
completed".to_string())),
             })
             .collect::<Result<Vec<_>>>()
             .ok()
             .map(|info| {
-                let partition_location = info
-                    .into_iter()
-                    .map(|(status, execution_id)| PartitionLocation {
-                        partition_id: status.partition_id.to_owned(),
-                        executor_meta: executors
-                            .get(execution_id)
-                            .map(|e| e.clone().into()),
-                        partition_stats: None,
-                    })
-                    .collect();
+                let mut partition_location = vec![];
+                for (status, executor_id, partitions) in info {
+                    let input_partition_id = 
status.partition_id.as_ref().unwrap(); //TODO unwrap
+                    let executor_meta =
+                        executors.get(executor_id).map(|e| e.clone().into());
+                    for shuffle_write_partition in partitions {
+                        let shuffle_input_partition_id = 
Some(protobuf::PartitionId {
+                            job_id: input_partition_id.job_id.clone(),
+                            stage_id: input_partition_id.stage_id,
+                            partition_id: input_partition_id.partition_id,
+                        });
+                        partition_location.push(protobuf::PartitionLocation {
+                            partition_id: shuffle_input_partition_id.clone(),
+                            executor_meta: executor_meta.clone(),
+                            partition_stats: Some(protobuf::PartitionStats {
+                                num_batches: 
shuffle_write_partition.num_batches as i64,
+                                num_rows: shuffle_write_partition.num_rows as 
i64,
+                                num_bytes: shuffle_write_partition.num_bytes 
as i64,
+                                column_stats: vec![],
+                            }),
+                            path: shuffle_write_partition.path.clone(),
+                        });
+                    }
+                }
                 job_status::Status::Completed(CompletedJob { 
partition_location })
             });
 
@@ -745,6 +779,7 @@ mod test {
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
+                partitions: vec![],
             })),
             partition_id: Some(PartitionId {
                 job_id: job_id.to_owned(),
@@ -784,6 +819,7 @@ mod test {
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
+                partitions: vec![],
             })),
             partition_id: Some(PartitionId {
                 job_id: job_id.to_owned(),
@@ -821,6 +857,7 @@ mod test {
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
+                partitions: vec![],
             })),
             partition_id: Some(PartitionId {
                 job_id: job_id.to_owned(),
@@ -832,6 +869,7 @@ mod test {
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
+                partitions: vec![],
             })),
             partition_id: Some(PartitionId {
                 job_id: job_id.to_owned(),
@@ -863,6 +901,7 @@ mod test {
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
+                partitions: vec![],
             })),
             partition_id: Some(PartitionId {
                 job_id: job_id.to_owned(),
@@ -874,6 +913,7 @@ mod test {
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
+                partitions: vec![],
             })),
             partition_id: Some(PartitionId {
                 job_id: job_id.to_owned(),
@@ -905,6 +945,7 @@ mod test {
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
+                partitions: vec![],
             })),
             partition_id: Some(PartitionId {
                 job_id: job_id.to_owned(),

Reply via email to