This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 73ea6e1  [Ballista] Support Union in ballista. (#2098)
73ea6e1 is described below

commit 73ea6e16f5c8f34526c01490a5ec277a68f33791
Author: Yang Jiang <[email protected]>
AuthorDate: Sun Mar 27 18:43:08 2022 +0800

    [Ballista] Support Union in ballista. (#2098)
    
    * add union in ballista.proto
    
    * add ballista plan to proto
    
    * fix clippy
    
    * add ut
    
    * fix clippy
    
    * fix fmt
    
    * fix comment
---
 ballista/rust/client/src/context.rs               | 60 +++++++++++++++++++++++
 ballista/rust/core/proto/ballista.proto           | 10 ++++
 ballista/rust/core/src/serde/logical_plan/mod.rs  | 37 +++++++++++++-
 ballista/rust/core/src/serde/physical_plan/mod.rs | 21 ++++++++
 datafusion/src/physical_plan/union.rs             |  5 ++
 5 files changed, 132 insertions(+), 1 deletion(-)

diff --git a/ballista/rust/client/src/context.rs 
b/ballista/rust/client/src/context.rs
index 4a5fe6d..8db9a0c 100644
--- a/ballista/rust/client/src/context.rs
+++ b/ballista/rust/client/src/context.rs
@@ -561,4 +561,64 @@ mod tests {
         let df = context.sql(sql).await.unwrap();
         assert!(!df.collect().await.unwrap().is_empty());
     }
+
+    #[tokio::test]
+    #[cfg(feature = "standalone")]
+    async fn test_union_and_union_all() {
+        use super::*;
+        use ballista_core::config::{
+            BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA,
+        };
+        use datafusion::arrow::util::pretty::pretty_format_batches;
+        use datafusion::assert_batches_eq;
+        let config = BallistaConfigBuilder::default()
+            .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true")
+            .build()
+            .unwrap();
+        let context = BallistaContext::standalone(&config, 1).await.unwrap();
+
+        let df = context
+            .sql("SELECT 1 as NUMBER union SELECT 1 as NUMBER;")
+            .await
+            .unwrap();
+        let res1 = df.collect().await.unwrap();
+        let expected1 = vec![
+            "+--------+",
+            "| number |",
+            "+--------+",
+            "| 1      |",
+            "+--------+",
+        ];
+        assert_eq!(
+            expected1,
+            pretty_format_batches(&*res1)
+                .unwrap()
+                .to_string()
+                .trim()
+                .lines()
+                .collect::<Vec<&str>>()
+        );
+        let expected2 = vec![
+            "+--------+",
+            "| number |",
+            "+--------+",
+            "| 1      |",
+            "| 1      |",
+            "+--------+",
+        ];
+        let df = context
+            .sql("SELECT 1 as NUMBER union all SELECT 1 as NUMBER;")
+            .await
+            .unwrap();
+        let res2 = df.collect().await.unwrap();
+        assert_eq!(
+            expected2,
+            pretty_format_batches(&*res2)
+                .unwrap()
+                .to_string()
+                .trim()
+                .lines()
+                .collect::<Vec<&str>>()
+        );
+    }
 }
diff --git a/ballista/rust/core/proto/ballista.proto 
b/ballista/rust/core/proto/ballista.proto
index 5bb1289..4b49310 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -50,6 +50,7 @@ message LogicalPlanNode {
     ValuesNode values = 16;
     LogicalExtensionNode extension = 17;
     CreateCatalogSchemaNode create_catalog_schema = 18;
+    UnionNode union = 19;
   }
 }
 
@@ -212,6 +213,10 @@ message JoinNode {
   bool null_equals_null = 7;
 }
 
+message UnionNode {
+  repeated LogicalPlanNode inputs = 1;
+}
+
 message CrossJoinNode {
   LogicalPlanNode left = 1;
   LogicalPlanNode right = 2;
@@ -253,6 +258,7 @@ message PhysicalPlanNode {
     CrossJoinExecNode cross_join = 19;
     AvroScanExecNode avro_scan = 20;
     PhysicalExtensionNode extension = 21;
+    UnionExecNode union = 22;
   }
 }
 
@@ -433,6 +439,10 @@ message HashJoinExecNode {
   bool null_equals_null = 7;
 }
 
+message UnionExecNode {
+  repeated PhysicalPlanNode inputs = 1;
+}
+
 message CrossJoinExecNode {
   PhysicalPlanNode left = 1;
   PhysicalPlanNode right = 2;
diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs 
b/ballista/rust/core/src/serde/logical_plan/mod.rs
index d64fac3..3f622e8 100644
--- a/ballista/rust/core/src/serde/logical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/logical_plan/mod.rs
@@ -406,6 +406,25 @@ impl AsLogicalPlan for LogicalPlanNode {
 
                 builder.build().map_err(|e| e.into())
             }
+            LogicalPlanType::Union(union) => {
+                let mut input_plans: Vec<LogicalPlan> = union
+                    .inputs
+                    .iter()
+                    .map(|i| i.try_into_logical_plan(ctx, extension_codec))
+                    .collect::<Result<_, BallistaError>>()?;
+
+                if input_plans.len() < 2 {
+                    return  Err( BallistaError::General(String::from(
+                       "Protobuf deserialization error, Union was require at 
least two input.",
+                   )));
+                }
+
+                let mut builder = 
LogicalPlanBuilder::from(input_plans.pop().unwrap());
+                for plan in input_plans {
+                    builder = builder.union(plan)?;
+                }
+                builder.build().map_err(|e| e.into())
+            }
             LogicalPlanType::CrossJoin(crossjoin) => {
                 let left = into_logical_plan!(crossjoin.left, &ctx, 
extension_codec)?;
                 let right = into_logical_plan!(crossjoin.right, &ctx, 
extension_codec)?;
@@ -815,7 +834,23 @@ impl AsLogicalPlan for LogicalPlanNode {
                     ))),
                 })
             }
-            LogicalPlan::Union(_) => unimplemented!(),
+            LogicalPlan::Union(union) => {
+                let inputs: Vec<LogicalPlanNode> = union
+                    .inputs
+                    .iter()
+                    .map(|i| {
+                        protobuf::LogicalPlanNode::try_from_logical_plan(
+                            i,
+                            extension_codec,
+                        )
+                    })
+                    .collect::<Result<_, BallistaError>>()?;
+                Ok(protobuf::LogicalPlanNode {
+                    logical_plan_type: Some(LogicalPlanType::Union(
+                        protobuf::UnionNode { inputs },
+                    )),
+                })
+            }
             LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
                 let left = protobuf::LogicalPlanNode::try_from_logical_plan(
                     left.as_ref(),
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs 
b/ballista/rust/core/src/serde/physical_plan/mod.rs
index 4b91a45..d7a8495 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -52,6 +52,7 @@ use datafusion::physical_plan::limit::{GlobalLimitExec, 
LocalLimitExec};
 use datafusion::physical_plan::projection::ProjectionExec;
 use datafusion::physical_plan::repartition::RepartitionExec;
 use datafusion::physical_plan::sorts::sort::SortExec;
+use datafusion::physical_plan::union::UnionExec;
 use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec};
 use datafusion::physical_plan::{
     AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr,
@@ -382,6 +383,13 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     &hashjoin.null_equals_null,
                 )?))
             }
+            PhysicalPlanType::Union(union) => {
+                let mut inputs: Vec<Arc<dyn ExecutionPlan>> = vec![];
+                for input in &union.inputs {
+                    inputs.push(input.try_into_physical_plan(ctx, 
extension_codec)?);
+                }
+                Ok(Arc::new(UnionExec::new(inputs)))
+            }
             PhysicalPlanType::CrossJoin(crossjoin) => {
                 let left: Arc<dyn ExecutionPlan> =
                     into_physical_plan!(crossjoin.left, ctx, extension_codec)?;
@@ -866,6 +874,19 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     },
                 )),
             })
+        } else if let Some(union) = plan.downcast_ref::<UnionExec>() {
+            let mut inputs: Vec<PhysicalPlanNode> = vec![];
+            for input in union.inputs() {
+                inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan(
+                    input.to_owned(),
+                    extension_codec,
+                )?);
+            }
+            Ok(protobuf::PhysicalPlanNode {
+                physical_plan_type: Some(PhysicalPlanType::Union(
+                    protobuf::UnionExecNode { inputs },
+                )),
+            })
         } else {
             let mut buf: Vec<u8> = vec![];
             extension_codec.try_encode(plan_clone.clone(), &mut buf)?;
diff --git a/datafusion/src/physical_plan/union.rs 
b/datafusion/src/physical_plan/union.rs
index fb25cf3..bf6fb7c 100644
--- a/datafusion/src/physical_plan/union.rs
+++ b/datafusion/src/physical_plan/union.rs
@@ -56,6 +56,11 @@ impl UnionExec {
             metrics: ExecutionPlanMetricsSet::new(),
         }
     }
+
+    /// Get inputs of the execution plan
+    pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
+        &self.inputs
+    }
 }
 
 #[async_trait]

Reply via email to