This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 73ea6e1 [Ballista] Support Union in ballista. (#2098)
73ea6e1 is described below
commit 73ea6e16f5c8f34526c01490a5ec277a68f33791
Author: Yang Jiang <[email protected]>
AuthorDate: Sun Mar 27 18:43:08 2022 +0800
[Ballista] Support Union in ballista. (#2098)
* add union in ballista.proto
* add ballista plan to proto
* fix clippy
* add ut
* fix clippy
* fix fmt
* fix comment
---
ballista/rust/client/src/context.rs | 60 +++++++++++++++++++++++
ballista/rust/core/proto/ballista.proto | 10 ++++
ballista/rust/core/src/serde/logical_plan/mod.rs | 37 +++++++++++++-
ballista/rust/core/src/serde/physical_plan/mod.rs | 21 ++++++++
datafusion/src/physical_plan/union.rs | 5 ++
5 files changed, 132 insertions(+), 1 deletion(-)
diff --git a/ballista/rust/client/src/context.rs
b/ballista/rust/client/src/context.rs
index 4a5fe6d..8db9a0c 100644
--- a/ballista/rust/client/src/context.rs
+++ b/ballista/rust/client/src/context.rs
@@ -561,4 +561,64 @@ mod tests {
let df = context.sql(sql).await.unwrap();
assert!(!df.collect().await.unwrap().is_empty());
}
+
+ #[tokio::test]
+ #[cfg(feature = "standalone")]
+ async fn test_union_and_union_all() {
+ use super::*;
+ use ballista_core::config::{
+ BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA,
+ };
+ use datafusion::arrow::util::pretty::pretty_format_batches;
+ use datafusion::assert_batches_eq;
+ let config = BallistaConfigBuilder::default()
+ .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true")
+ .build()
+ .unwrap();
+ let context = BallistaContext::standalone(&config, 1).await.unwrap();
+
+ let df = context
+ .sql("SELECT 1 as NUMBER union SELECT 1 as NUMBER;")
+ .await
+ .unwrap();
+ let res1 = df.collect().await.unwrap();
+ let expected1 = vec![
+ "+--------+",
+ "| number |",
+ "+--------+",
+ "| 1 |",
+ "+--------+",
+ ];
+ assert_eq!(
+ expected1,
+ pretty_format_batches(&*res1)
+ .unwrap()
+ .to_string()
+ .trim()
+ .lines()
+ .collect::<Vec<&str>>()
+ );
+ let expected2 = vec![
+ "+--------+",
+ "| number |",
+ "+--------+",
+ "| 1 |",
+ "| 1 |",
+ "+--------+",
+ ];
+ let df = context
+ .sql("SELECT 1 as NUMBER union all SELECT 1 as NUMBER;")
+ .await
+ .unwrap();
+ let res2 = df.collect().await.unwrap();
+ assert_eq!(
+ expected2,
+ pretty_format_batches(&*res2)
+ .unwrap()
+ .to_string()
+ .trim()
+ .lines()
+ .collect::<Vec<&str>>()
+ );
+ }
}
diff --git a/ballista/rust/core/proto/ballista.proto
b/ballista/rust/core/proto/ballista.proto
index 5bb1289..4b49310 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -50,6 +50,7 @@ message LogicalPlanNode {
ValuesNode values = 16;
LogicalExtensionNode extension = 17;
CreateCatalogSchemaNode create_catalog_schema = 18;
+ UnionNode union = 19;
}
}
@@ -212,6 +213,10 @@ message JoinNode {
bool null_equals_null = 7;
}
+message UnionNode {
+ repeated LogicalPlanNode inputs = 1;
+}
+
message CrossJoinNode {
LogicalPlanNode left = 1;
LogicalPlanNode right = 2;
@@ -253,6 +258,7 @@ message PhysicalPlanNode {
CrossJoinExecNode cross_join = 19;
AvroScanExecNode avro_scan = 20;
PhysicalExtensionNode extension = 21;
+ UnionExecNode union = 22;
}
}
@@ -433,6 +439,10 @@ message HashJoinExecNode {
bool null_equals_null = 7;
}
+message UnionExecNode {
+ repeated PhysicalPlanNode inputs = 1;
+}
+
message CrossJoinExecNode {
PhysicalPlanNode left = 1;
PhysicalPlanNode right = 2;
diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs
b/ballista/rust/core/src/serde/logical_plan/mod.rs
index d64fac3..3f622e8 100644
--- a/ballista/rust/core/src/serde/logical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/logical_plan/mod.rs
@@ -406,6 +406,25 @@ impl AsLogicalPlan for LogicalPlanNode {
builder.build().map_err(|e| e.into())
}
+ LogicalPlanType::Union(union) => {
+ let mut input_plans: Vec<LogicalPlan> = union
+ .inputs
+ .iter()
+ .map(|i| i.try_into_logical_plan(ctx, extension_codec))
+ .collect::<Result<_, BallistaError>>()?;
+
+ if input_plans.len() < 2 {
+ return Err( BallistaError::General(String::from(
+ "Protobuf deserialization error, Union was require at
least two input.",
+ )));
+ }
+
+ let mut builder =
LogicalPlanBuilder::from(input_plans.pop().unwrap());
+ for plan in input_plans {
+ builder = builder.union(plan)?;
+ }
+ builder.build().map_err(|e| e.into())
+ }
LogicalPlanType::CrossJoin(crossjoin) => {
let left = into_logical_plan!(crossjoin.left, &ctx,
extension_codec)?;
let right = into_logical_plan!(crossjoin.right, &ctx,
extension_codec)?;
@@ -815,7 +834,23 @@ impl AsLogicalPlan for LogicalPlanNode {
))),
})
}
- LogicalPlan::Union(_) => unimplemented!(),
+ LogicalPlan::Union(union) => {
+ let inputs: Vec<LogicalPlanNode> = union
+ .inputs
+ .iter()
+ .map(|i| {
+ protobuf::LogicalPlanNode::try_from_logical_plan(
+ i,
+ extension_codec,
+ )
+ })
+ .collect::<Result<_, BallistaError>>()?;
+ Ok(protobuf::LogicalPlanNode {
+ logical_plan_type: Some(LogicalPlanType::Union(
+ protobuf::UnionNode { inputs },
+ )),
+ })
+ }
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
let left = protobuf::LogicalPlanNode::try_from_logical_plan(
left.as_ref(),
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs
b/ballista/rust/core/src/serde/physical_plan/mod.rs
index 4b91a45..d7a8495 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -52,6 +52,7 @@ use datafusion::physical_plan::limit::{GlobalLimitExec,
LocalLimitExec};
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::sorts::sort::SortExec;
+use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec};
use datafusion::physical_plan::{
AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr,
@@ -382,6 +383,13 @@ impl AsExecutionPlan for PhysicalPlanNode {
&hashjoin.null_equals_null,
)?))
}
+ PhysicalPlanType::Union(union) => {
+ let mut inputs: Vec<Arc<dyn ExecutionPlan>> = vec![];
+ for input in &union.inputs {
+ inputs.push(input.try_into_physical_plan(ctx,
extension_codec)?);
+ }
+ Ok(Arc::new(UnionExec::new(inputs)))
+ }
PhysicalPlanType::CrossJoin(crossjoin) => {
let left: Arc<dyn ExecutionPlan> =
into_physical_plan!(crossjoin.left, ctx, extension_codec)?;
@@ -866,6 +874,19 @@ impl AsExecutionPlan for PhysicalPlanNode {
},
)),
})
+ } else if let Some(union) = plan.downcast_ref::<UnionExec>() {
+ let mut inputs: Vec<PhysicalPlanNode> = vec![];
+ for input in union.inputs() {
+ inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan(
+ input.to_owned(),
+ extension_codec,
+ )?);
+ }
+ Ok(protobuf::PhysicalPlanNode {
+ physical_plan_type: Some(PhysicalPlanType::Union(
+ protobuf::UnionExecNode { inputs },
+ )),
+ })
} else {
let mut buf: Vec<u8> = vec![];
extension_codec.try_encode(plan_clone.clone(), &mut buf)?;
diff --git a/datafusion/src/physical_plan/union.rs
b/datafusion/src/physical_plan/union.rs
index fb25cf3..bf6fb7c 100644
--- a/datafusion/src/physical_plan/union.rs
+++ b/datafusion/src/physical_plan/union.rs
@@ -56,6 +56,11 @@ impl UnionExec {
metrics: ExecutionPlanMetricsSet::new(),
}
}
+
+ /// Get inputs of the execution plan
+ pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
+ &self.inputs
+ }
}
#[async_trait]