This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new f29bcf3618 Support no distinct aggregate sum/min/max in
`single_distinct_to_group_by` rule (#8266)
f29bcf3618 is described below
commit f29bcf36184691dc0417b6be2eb3e33fa8a6f1cc
Author: Huaijin <[email protected]>
AuthorDate: Sun Nov 26 19:53:46 2023 +0800
Support no distinct aggregate sum/min/max in `single_distinct_to_group_by`
rule (#8266)
* init impl
* add some tests
* add filter tests
* minor
* add more tests
* update test
---
.../optimizer/src/single_distinct_to_groupby.rs | 280 ++++++++++++++++++---
datafusion/sqllogictest/test_files/groupby.slt | 82 ++++++
2 files changed, 330 insertions(+), 32 deletions(-)
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index ac18e596b7..fa142438c4 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::{
+ aggregate_function::AggregateFunction::{Max, Min, Sum},
col,
expr::AggregateFunction,
logical_plan::{Aggregate, LogicalPlan, Projection},
@@ -35,17 +36,19 @@ use hashbrown::HashSet;
/// single distinct to group by optimizer rule
/// ```text
-/// SELECT F1(DISTINCT s),F2(DISTINCT s)
-/// ...
-/// GROUP BY k
+/// Before:
+/// SELECT a, COUNT(DINSTINCT b), SUM(c)
+/// FROM t
+/// GROUP BY a
///
-/// Into
-///
-/// SELECT F1(alias1),F2(alias1)
+/// After:
+/// SELECT a, COUNT(alias1), SUM(alias2)
/// FROM (
-/// SELECT s as alias1, k ... GROUP BY s, k
+/// SELECT a, b as alias1, SUM(c) as alias2
+/// FROM t
+/// GROUP BY a, b
/// )
-/// GROUP BY k
+/// GROUP BY a
/// ```
#[derive(Default)]
pub struct SingleDistinctToGroupBy {}
@@ -64,22 +67,30 @@ fn is_single_distinct_agg(plan: &LogicalPlan) ->
Result<bool> {
match plan {
LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => {
let mut fields_set = HashSet::new();
- let mut distinct_count = 0;
+ let mut aggregate_count = 0;
for expr in aggr_expr {
if let Expr::AggregateFunction(AggregateFunction {
- distinct, args, ..
+ fun,
+ distinct,
+ args,
+ filter,
+ order_by,
}) = expr
{
- if *distinct {
- distinct_count += 1;
+ if filter.is_some() || order_by.is_some() {
+ return Ok(false);
}
- for e in args {
- fields_set.insert(e.canonical_name());
+ aggregate_count += 1;
+ if *distinct {
+ for e in args {
+ fields_set.insert(e.canonical_name());
+ }
+ } else if !matches!(fun, Sum | Min | Max) {
+ return Ok(false);
}
}
}
- let res = distinct_count == aggr_expr.len() && fields_set.len() ==
1;
- Ok(res)
+ Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
}
_ => Ok(false),
}
@@ -152,30 +163,57 @@ impl OptimizerRule for SingleDistinctToGroupBy {
.collect::<Vec<_>>();
// replace the distinct arg with alias
+ let mut index = 1;
let mut group_fields_set = HashSet::new();
- let new_aggr_exprs = aggr_expr
+ let mut inner_aggr_exprs = vec![];
+ let outer_aggr_exprs = aggr_expr
.iter()
.map(|aggr_expr| match aggr_expr {
Expr::AggregateFunction(AggregateFunction {
fun,
args,
- filter,
- order_by,
+ distinct,
..
}) => {
// is_single_distinct_agg ensure args.len=1
- if
group_fields_set.insert(args[0].display_name()?) {
+ if *distinct
+ &&
group_fields_set.insert(args[0].display_name()?)
+ {
inner_group_exprs.push(
args[0].clone().alias(SINGLE_DISTINCT_ALIAS),
);
}
-
Ok(Expr::AggregateFunction(AggregateFunction::new(
- fun.clone(),
- vec![col(SINGLE_DISTINCT_ALIAS)],
- false, // intentional to remove distinct
here
- filter.clone(),
- order_by.clone(),
- )))
+
+ // if the aggregate function is not distinct,
we need to rewrite it like two phase aggregation
+ if !(*distinct) {
+ index += 1;
+ let alias_str = format!("alias{}", index);
+ inner_aggr_exprs.push(
+
Expr::AggregateFunction(AggregateFunction::new(
+ fun.clone(),
+ args.clone(),
+ false,
+ None,
+ None,
+ ))
+ .alias(&alias_str),
+ );
+
Ok(Expr::AggregateFunction(AggregateFunction::new(
+ fun.clone(),
+ vec![col(&alias_str)],
+ false,
+ None,
+ None,
+ )))
+ } else {
+
Ok(Expr::AggregateFunction(AggregateFunction::new(
+ fun.clone(),
+ vec![col(SINGLE_DISTINCT_ALIAS)],
+ false, // intentional to remove
distinct here
+ None,
+ None,
+ )))
+ }
}
_ => Ok(aggr_expr.clone()),
})
@@ -184,6 +222,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
// construct the inner AggrPlan
let inner_fields = inner_group_exprs
.iter()
+ .chain(inner_aggr_exprs.iter())
.map(|expr| expr.to_field(input.schema()))
.collect::<Result<Vec<_>>>()?;
let inner_schema = DFSchema::new_with_metadata(
@@ -193,12 +232,12 @@ impl OptimizerRule for SingleDistinctToGroupBy {
let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
input.clone(),
inner_group_exprs,
- Vec::new(),
+ inner_aggr_exprs,
)?);
let outer_fields = outer_group_exprs
.iter()
- .chain(new_aggr_exprs.iter())
+ .chain(outer_aggr_exprs.iter())
.map(|expr| expr.to_field(&inner_schema))
.collect::<Result<Vec<_>>>()?;
let outer_aggr_schema =
Arc::new(DFSchema::new_with_metadata(
@@ -220,7 +259,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
group_expr
}
})
- .chain(new_aggr_exprs.iter().enumerate().map(|(idx,
expr)| {
+ .chain(outer_aggr_exprs.iter().enumerate().map(|(idx,
expr)| {
let idx = idx + group_size;
let name = fields[idx].qualified_name();
columnize_expr(expr.clone().alias(name),
&outer_aggr_schema)
@@ -230,7 +269,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new(
Arc::new(inner_agg),
outer_group_exprs,
- new_aggr_exprs,
+ outer_aggr_exprs,
)?);
Ok(Some(LogicalPlan::Projection(Projection::try_new(
@@ -262,7 +301,7 @@ mod tests {
use datafusion_expr::expr::GroupingSet;
use datafusion_expr::{
col, count, count_distinct, lit,
logical_plan::builder::LogicalPlanBuilder, max,
- AggregateFunction,
+ min, sum, AggregateFunction,
};
fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
@@ -478,4 +517,181 @@ mod tests {
assert_optimized_plan_equal(&plan, expected)
}
+
+ #[test]
+ fn two_distinct_and_one_common() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(
+ vec![col("a")],
+ vec![
+ sum(col("c")),
+ count_distinct(col("b")),
+ Expr::AggregateFunction(expr::AggregateFunction::new(
+ AggregateFunction::Max,
+ vec![col("b")],
+ true,
+ None,
+ None,
+ )),
+ ],
+ )?
+ .build()?;
+ // Should work
+ let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c),
COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b)
[a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT
test.b):UInt32;N]\
+ \n Aggregate: groupBy=[[test.a]],
aggr=[[SUM(alias2), COUNT(alias1), MAX(alias1)]] [a:UInt32,
SUM(alias2):UInt64;N, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\
+ \n Aggregate: groupBy=[[test.a, test.b AS
alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32,
alias2:UInt64;N]\
+ \n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn one_distinctand_and_two_common() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(
+ vec![col("a")],
+ vec![sum(col("c")), max(col("c")), count_distinct(col("b"))],
+ )?
+ .build()?;
+ // Should work
+ let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c),
MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32,
SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\
+ \n Aggregate: groupBy=[[test.a]],
aggr=[[SUM(alias2), MAX(alias3), COUNT(alias1)]] [a:UInt32,
SUM(alias2):UInt64;N, MAX(alias3):UInt32;N, COUNT(alias1):Int64;N]\
+ \n Aggregate: groupBy=[[test.a, test.b AS
alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32,
alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\
+ \n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn one_distinct_and_one_common() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(
+ vec![col("c")],
+ vec![min(col("a")), count_distinct(col("b"))],
+ )?
+ .build()?;
+ // Should work
+ let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a),
COUNT(alias1) AS COUNT(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N,
COUNT(DISTINCT test.b):Int64;N]\
+ \n Aggregate: groupBy=[[test.c]],
aggr=[[MIN(alias2), COUNT(alias1)]] [c:UInt32, MIN(alias2):UInt32;N,
COUNT(alias1):Int64;N]\
+ \n Aggregate: groupBy=[[test.c, test.b AS
alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32,
alias2:UInt32;N]\
+ \n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn common_with_filter() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ // SUM(a) FILTER (WHERE a > 5)
+ let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+ AggregateFunction::Sum,
+ vec![col("a")],
+ false,
+ Some(Box::new(col("a").gt(lit(5)))),
+ None,
+ ));
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
+ .build()?;
+ // Do nothing
+ let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a)
FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32,
SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT
test.b):Int64;N]\
+ \n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn distinct_with_filter() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ // COUNT(DISTINCT a) FILTER (WHERE a > 5)
+ let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+ AggregateFunction::Count,
+ vec![col("a")],
+ true,
+ Some(Box::new(col("a").gt(lit(5)))),
+ None,
+ ));
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
+ .build()?;
+ // Do nothing
+ let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a),
COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32,
SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a >
Int32(5)):Int64;N]\
+ \n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn common_with_order_by() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ // SUM(a ORDER BY a)
+ let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+ AggregateFunction::Sum,
+ vec![col("a")],
+ false,
+ None,
+ Some(vec![col("a")]),
+ ));
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
+ .build()?;
+ // Do nothing
+ let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a)
ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY
[test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\
+ \n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn distinct_with_order_by() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ // COUNT(DISTINCT a ORDER BY a)
+ let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+ AggregateFunction::Count,
+ vec![col("a")],
+ true,
+ None,
+ Some(vec![col("a")]),
+ ));
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
+ .build()?;
+ // Do nothing
+ let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a),
COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N,
COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\
+ \n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn aggregate_with_filter_and_order_by() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5)
+ let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+ AggregateFunction::Count,
+ vec![col("a")],
+ true,
+ Some(Box::new(col("a").gt(lit(5)))),
+ Some(vec![col("a")]),
+ ));
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
+ .build()?;
+ // Do nothing
+ let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a),
COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]]
[c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a >
Int32(5)) ORDER BY [test.a]:Int64;N]\
+ \n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
}
diff --git a/datafusion/sqllogictest/test_files/groupby.slt
b/datafusion/sqllogictest/test_files/groupby.slt
index 756d3f7374..d6f9adb023 100644
--- a/datafusion/sqllogictest/test_files/groupby.slt
+++ b/datafusion/sqllogictest/test_files/groupby.slt
@@ -3965,3 +3965,85 @@ select date_bin(interval '1 year', time) as bla,
count(distinct state) as count
statement ok
drop table t1
+
+statement ok
+CREATE EXTERNAL TABLE aggregate_test_100 (
+ c1 VARCHAR NOT NULL,
+ c2 TINYINT NOT NULL,
+ c3 SMALLINT NOT NULL,
+ c4 SMALLINT,
+ c5 INT,
+ c6 BIGINT NOT NULL,
+ c7 SMALLINT NOT NULL,
+ c8 INT NOT NULL,
+ c9 INT UNSIGNED NOT NULL,
+ c10 BIGINT UNSIGNED NOT NULL,
+ c11 FLOAT NOT NULL,
+ c12 DOUBLE NOT NULL,
+ c13 VARCHAR NOT NULL
+)
+STORED AS CSV
+WITH HEADER ROW
+LOCATION '../../testing/data/csv/aggregate_test_100.csv'
+
+query TIIII
+SELECT c1, count(distinct c2), min(distinct c2), min(c3), max(c4) FROM
aggregate_test_100 GROUP BY c1 ORDER BY c1;
+----
+a 5 1 -101 32064
+b 5 1 -117 25286
+c 5 1 -117 29106
+d 5 1 -99 31106
+e 5 1 -95 32514
+
+query TT
+EXPLAIN SELECT c1, count(distinct c2), min(distinct c2), sum(c3), max(c4) FROM
aggregate_test_100 GROUP BY c1 ORDER BY c1;
+----
+logical_plan
+Sort: aggregate_test_100.c1 ASC NULLS LAST
+--Projection: aggregate_test_100.c1, COUNT(alias1) AS COUNT(DISTINCT
aggregate_test_100.c2), MIN(alias1) AS MIN(DISTINCT aggregate_test_100.c2),
SUM(alias2) AS SUM(aggregate_test_100.c3), MAX(alias3) AS
MAX(aggregate_test_100.c4)
+----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[COUNT(alias1),
MIN(alias1), SUM(alias2), MAX(alias3)]]
+------Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c2 AS
alias1]], aggr=[[SUM(CAST(aggregate_test_100.c3 AS Int64)) AS alias2,
MAX(aggregate_test_100.c4) AS alias3]]
+--------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4]
+physical_plan
+SortPreservingMergeExec: [c1@0 ASC NULLS LAST]
+--SortExec: expr=[c1@0 ASC NULLS LAST]
+----ProjectionExec: expr=[c1@0 as c1, COUNT(alias1)@1 as COUNT(DISTINCT
aggregate_test_100.c2), MIN(alias1)@2 as MIN(DISTINCT aggregate_test_100.c2),
SUM(alias2)@3 as SUM(aggregate_test_100.c3), MAX(alias3)@4 as
MAX(aggregate_test_100.c4)]
+------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1],
aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)]
+--------CoalesceBatchesExec: target_batch_size=2
+----------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8
+------------AggregateExec: mode=Partial, gby=[c1@0 as c1],
aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)]
+--------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1, alias1@1
as alias1], aggr=[alias2, alias3]
+----------------CoalesceBatchesExec: target_batch_size=2
+------------------RepartitionExec: partitioning=Hash([c1@0, alias1@1], 8),
input_partitions=8
+--------------------AggregateExec: mode=Partial, gby=[c1@0 as c1, c2@1 as
alias1], aggr=[alias2, alias3]
+----------------------RepartitionExec: partitioning=RoundRobinBatch(8),
input_partitions=1
+------------------------CsvExec: file_groups={1 group:
[[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1,
c2, c3, c4], has_header=true
+
+# Use PostgreSQL dialect
+statement ok
+set datafusion.sql_parser.dialect = 'Postgres';
+
+query II
+SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a') FROM aggregate_test_100
GROUP BY c2 ORDER BY c2;
+----
+1 17
+2 17
+3 13
+4 19
+5 11
+
+query III
+SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a'), count(c5) FILTER
(WHERE c1 != 'b') FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
+----
+1 17 19
+2 17 18
+3 13 17
+4 19 18
+5 11 9
+
+# Restore the default dialect
+statement ok
+set datafusion.sql_parser.dialect = 'Generic';
+
+statement ok
+drop table aggregate_test_100;