This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 81c915c58e Add distinct union optimization (#7788)
81c915c58e is described below
commit 81c915c58e90967b1586fc0b3145194a6fddf6f5
Author: Eugene Marushchenko <[email protected]>
AuthorDate: Sat Oct 14 03:24:37 2023 +1000
Add distinct union optimization (#7788)
* Add eliminate_distinct_nested_union spagetti implementation
* update test
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/optimizer/src/eliminate_nested_union.rs | 208 +++++++++++++++++++--
datafusion/sqllogictest/test_files/union.slt | 24 +--
2 files changed, 201 insertions(+), 31 deletions(-)
diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs
b/datafusion/optimizer/src/eliminate_nested_union.rs
index e22c73e579..89bcc90bc0 100644
--- a/datafusion/optimizer/src/eliminate_nested_union.rs
+++ b/datafusion/optimizer/src/eliminate_nested_union.rs
@@ -16,12 +16,11 @@
// under the License.
//! Optimizer rule to replace nested unions to single union.
+use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
-use datafusion_expr::logical_plan::{LogicalPlan, Union};
-
-use crate::optimizer::ApplyOrder;
use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
+use datafusion_expr::{Distinct, LogicalPlan, Union};
use std::sync::Arc;
#[derive(Default)]
@@ -41,22 +40,11 @@ impl OptimizerRule for EliminateNestedUnion {
plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
- // TODO: Add optimization for nested distinct unions.
match plan {
LogicalPlan::Union(Union { inputs, schema }) => {
let inputs = inputs
.iter()
- .flat_map(|plan| match plan.as_ref() {
- LogicalPlan::Union(Union { inputs, schema }) => inputs
- .iter()
- .map(|plan| {
- Arc::new(
- coerce_plan_expr_for_schema(plan,
schema).unwrap(),
- )
- })
- .collect::<Vec<_>>(),
- _ => vec![plan.clone()],
- })
+ .flat_map(extract_plans_from_union)
.collect::<Vec<_>>();
Ok(Some(LogicalPlan::Union(Union {
@@ -64,6 +52,23 @@ impl OptimizerRule for EliminateNestedUnion {
schema: schema.clone(),
})))
}
+ LogicalPlan::Distinct(Distinct { input: plan }) => match
plan.as_ref() {
+ LogicalPlan::Union(Union { inputs, schema }) => {
+ let inputs = inputs
+ .iter()
+ .map(extract_plan_from_distinct)
+ .flat_map(extract_plans_from_union)
+ .collect::<Vec<_>>();
+
+ Ok(Some(LogicalPlan::Distinct(Distinct {
+ input: Arc::new(LogicalPlan::Union(Union {
+ inputs,
+ schema: schema.clone(),
+ })),
+ })))
+ }
+ _ => Ok(None),
+ },
_ => Ok(None),
}
}
@@ -77,6 +82,23 @@ impl OptimizerRule for EliminateNestedUnion {
}
}
+fn extract_plans_from_union(plan: &Arc<LogicalPlan>) -> Vec<Arc<LogicalPlan>> {
+ match plan.as_ref() {
+ LogicalPlan::Union(Union { inputs, schema }) => inputs
+ .iter()
+ .map(|plan| Arc::new(coerce_plan_expr_for_schema(plan,
schema).unwrap()))
+ .collect::<Vec<_>>(),
+ _ => vec![plan.clone()],
+ }
+}
+
+fn extract_plan_from_distinct(plan: &Arc<LogicalPlan>) -> &Arc<LogicalPlan> {
+ match plan.as_ref() {
+ LogicalPlan::Distinct(Distinct { input: plan }) => plan,
+ _ => plan,
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -112,6 +134,22 @@ mod tests {
assert_optimized_plan_equal(&plan, expected)
}
+ #[test]
+ fn eliminate_distinct_nothing() -> Result<()> {
+ let plan_builder = table_scan(Some("table"), &schema(), None)?;
+
+ let plan = plan_builder
+ .clone()
+ .union_distinct(plan_builder.clone().build()?)?
+ .build()?;
+
+ let expected = "Distinct:\
+ \n Union\
+ \n TableScan: table\
+ \n TableScan: table";
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
#[test]
fn eliminate_nested_union() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;
@@ -132,6 +170,69 @@ mod tests {
assert_optimized_plan_equal(&plan, expected)
}
+ #[test]
+ fn eliminate_nested_union_with_distinct_union() -> Result<()> {
+ let plan_builder = table_scan(Some("table"), &schema(), None)?;
+
+ let plan = plan_builder
+ .clone()
+ .union_distinct(plan_builder.clone().build()?)?
+ .union(plan_builder.clone().build()?)?
+ .union(plan_builder.clone().build()?)?
+ .build()?;
+
+ let expected = "Union\
+ \n Distinct:\
+ \n Union\
+ \n TableScan: table\
+ \n TableScan: table\
+ \n TableScan: table\
+ \n TableScan: table";
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn eliminate_nested_distinct_union() -> Result<()> {
+ let plan_builder = table_scan(Some("table"), &schema(), None)?;
+
+ let plan = plan_builder
+ .clone()
+ .union(plan_builder.clone().build()?)?
+ .union_distinct(plan_builder.clone().build()?)?
+ .union(plan_builder.clone().build()?)?
+ .union_distinct(plan_builder.clone().build()?)?
+ .build()?;
+
+ let expected = "Distinct:\
+ \n Union\
+ \n TableScan: table\
+ \n TableScan: table\
+ \n TableScan: table\
+ \n TableScan: table\
+ \n TableScan: table";
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> {
+ let plan_builder = table_scan(Some("table"), &schema(), None)?;
+
+ let plan = plan_builder
+ .clone()
+ .union_distinct(plan_builder.clone().distinct()?.build()?)?
+ .union(plan_builder.clone().distinct()?.build()?)?
+ .union_distinct(plan_builder.clone().build()?)?
+ .build()?;
+
+ let expected = "Distinct:\
+ \n Union\
+ \n TableScan: table\
+ \n TableScan: table\
+ \n TableScan: table\
+ \n TableScan: table";
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
// We don't need to use project_with_column_index in logical optimizer,
// after LogicalPlanBuilder::union, we already have all equal expression
aliases
#[test]
@@ -163,6 +264,36 @@ mod tests {
assert_optimized_plan_equal(&plan, expected)
}
+ #[test]
+ fn eliminate_nested_distinct_union_with_projection() -> Result<()> {
+ let plan_builder = table_scan(Some("table"), &schema(), None)?;
+
+ let plan = plan_builder
+ .clone()
+ .union_distinct(
+ plan_builder
+ .clone()
+ .project(vec![col("id").alias("table_id"), col("key"),
col("value")])?
+ .build()?,
+ )?
+ .union_distinct(
+ plan_builder
+ .clone()
+ .project(vec![col("id").alias("_id"), col("key"),
col("value")])?
+ .build()?,
+ )?
+ .build()?;
+
+ let expected = "Distinct:\
+ \n Union\
+ \n TableScan: table\
+ \n Projection: table.id AS id, table.key, table.value\
+ \n TableScan: table\
+ \n Projection: table.id AS id, table.key, table.value\
+ \n TableScan: table";
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
#[test]
fn eliminate_nested_union_with_type_cast_projection() -> Result<()> {
let table_1 = table_scan(
@@ -208,4 +339,51 @@ mod tests {
\n TableScan: table_1";
assert_optimized_plan_equal(&plan, expected)
}
+
+ #[test]
+ fn eliminate_nested_distinct_union_with_type_cast_projection() ->
Result<()> {
+ let table_1 = table_scan(
+ Some("table_1"),
+ &Schema::new(vec![
+ Field::new("id", DataType::Int64, false),
+ Field::new("key", DataType::Utf8, false),
+ Field::new("value", DataType::Float64, false),
+ ]),
+ None,
+ )?;
+
+ let table_2 = table_scan(
+ Some("table_1"),
+ &Schema::new(vec![
+ Field::new("id", DataType::Int32, false),
+ Field::new("key", DataType::Utf8, false),
+ Field::new("value", DataType::Float32, false),
+ ]),
+ None,
+ )?;
+
+ let table_3 = table_scan(
+ Some("table_1"),
+ &Schema::new(vec![
+ Field::new("id", DataType::Int16, false),
+ Field::new("key", DataType::Utf8, false),
+ Field::new("value", DataType::Float32, false),
+ ]),
+ None,
+ )?;
+
+ let plan = table_1
+ .union_distinct(table_2.build()?)?
+ .union_distinct(table_3.build()?)?
+ .build()?;
+
+ let expected = "Distinct:\
+ \n Union\
+ \n TableScan: table_1\
+ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value\
+ \n TableScan: table_1\
+ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value\
+ \n TableScan: table_1";
+ assert_optimized_plan_equal(&plan, expected)
+ }
}
diff --git a/datafusion/sqllogictest/test_files/union.slt
b/datafusion/sqllogictest/test_files/union.slt
index b11a687d8b..cbb1896efb 100644
--- a/datafusion/sqllogictest/test_files/union.slt
+++ b/datafusion/sqllogictest/test_files/union.slt
@@ -186,8 +186,7 @@ Bob_new
John
John_new
-# should be un-nested
-# https://github.com/apache/arrow-datafusion/issues/7786
+# should be un-nested, with a single (logical) aggregate
query TT
EXPLAIN SELECT name FROM t1 UNION (SELECT name from t2 UNION SELECT name ||
'_new' from t2)
----
@@ -195,26 +194,19 @@ logical_plan
Aggregate: groupBy=[[t1.name]], aggr=[[]]
--Union
----TableScan: t1 projection=[name]
-----Aggregate: groupBy=[[t2.name]], aggr=[[]]
-------Union
---------TableScan: t2 projection=[name]
---------Projection: t2.name || Utf8("_new") AS name
-----------TableScan: t2 projection=[name]
+----TableScan: t2 projection=[name]
+----Projection: t2.name || Utf8("_new") AS name
+------TableScan: t2 projection=[name]
physical_plan
AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[]
--CoalesceBatchesExec: target_batch_size=8192
-----RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=8
+----RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=12
------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[]
--------UnionExec
----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0]
-----------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[]
-------------CoalesceBatchesExec: target_batch_size=8192
---------------RepartitionExec: partitioning=Hash([name@0], 4),
input_partitions=8
-----------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[]
-------------------UnionExec
---------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0]
---------------------ProjectionExec: expr=[name@0 || _new as name]
-----------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0]
+----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0]
+----------ProjectionExec: expr=[name@0 || _new as name]
+------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0]
# nested_union_all
query T rowsort