This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 0ec292f454 Make Logical Plans more readable by removing extra aliases
(#10832)
0ec292f454 is described below
commit 0ec292f45404359356ab9125b1b0f5b21a135ab8
Author: Mohamed Abdeen <[email protected]>
AuthorDate: Tue Jun 11 21:37:10 2024 +0300
Make Logical Plans more readable by removing extra aliases (#10832)
* logical plan: remove unnecessary aliases
* revert EnterMark
* fix docs and benchmarks
* revert id_array change
* add alias counter
* fix alias counter bug
* fix slt test
* fix benchmark results
* revert alias/unalias changes
* remove TODO
* minor fix
* fix benchmark
---
.../optimizer/src/common_subexpr_eliminate.rs | 73 ++++++++++++++++------
datafusion/sqllogictest/test_files/group_by.slt | 18 +++---
.../sqllogictest/test_files/tpch/q1.slt.part | 4 +-
3 files changed, 65 insertions(+), 30 deletions(-)
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 6820ba04f0..3ed1309f15 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -128,7 +128,7 @@ impl CommonSubexprEliminate {
fn rewrite_exprs_list(
&self,
exprs_list: &[&[Expr]],
- arrays_list: &[&[Vec<(usize, String)>]],
+ arrays_list: &[&[IdArray]],
expr_stats: &ExprStats,
common_exprs: &mut CommonExprs,
) -> Result<Vec<Vec<Expr>>> {
@@ -159,7 +159,7 @@ impl CommonSubexprEliminate {
fn rewrite_expr(
&self,
exprs_list: &[&[Expr]],
- arrays_list: &[&[Vec<(usize, String)>]],
+ arrays_list: &[&[IdArray]],
input: &LogicalPlan,
expr_stats: &ExprStats,
config: &dyn OptimizerConfig,
@@ -480,7 +480,7 @@ fn to_arrays(
input_schema: DFSchemaRef,
expr_stats: &mut ExprStats,
expr_mask: ExprMask,
-) -> Result<Vec<Vec<(usize, String)>>> {
+) -> Result<Vec<IdArray>> {
expr.iter()
.map(|e| {
let mut id_array = vec![];
@@ -739,7 +739,7 @@ fn expr_identifier(expr: &Expr, sub_expr_identifier:
Identifier) -> Identifier {
fn expr_to_identifier(
expr: &Expr,
expr_stats: &mut ExprStats,
- id_array: &mut Vec<(usize, Identifier)>,
+ id_array: &mut IdArray,
input_schema: DFSchemaRef,
expr_mask: ExprMask,
) -> Result<()> {
@@ -769,15 +769,28 @@ struct CommonSubexprRewriter<'a> {
common_exprs: &'a mut CommonExprs,
// preorder index, starts from 0.
down_index: usize,
+ // how many aliases have we seen so far
+ alias_counter: usize,
}
impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
type Node = Expr;
+ fn f_up(&mut self, expr: Expr) -> Result<Transformed<Self::Node>> {
+ if matches!(expr, Expr::Alias(_)) {
+ self.alias_counter -= 1
+ }
+ Ok(Transformed::no(expr))
+ }
+
fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to
generate
// the `id_array`, which records the expr's identifier used to rewrite
expr. So if we
// skip an expr in `ExprIdentifierVisitor`, we should skip it here,
too.
+ if matches!(expr, Expr::Alias(_)) {
+ self.alias_counter += 1;
+ }
+
if expr.short_circuits() || expr.is_volatile()? {
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}
@@ -801,15 +814,16 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
let expr_name = expr.display_name()?;
self.common_exprs.insert(expr_id.clone(), expr);
- // Alias this `Column` expr to it original "expr name",
- // `projection_push_down` optimizer use "expr name" to eliminate
useless
- // projections.
- // TODO: do we really need to alias here?
- Ok(Transformed::new(
- col(expr_id).alias(expr_name),
- true,
- TreeNodeRecursion::Jump,
- ))
+
+ // alias the expressions without an `Alias` ancestor node
+ let rewritten = if self.alias_counter > 0 {
+ col(expr_id)
+ } else {
+ self.alias_counter += 1;
+ col(expr_id).alias(expr_name)
+ };
+
+ Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump))
} else {
Ok(Transformed::no(expr))
}
@@ -829,6 +843,7 @@ fn replace_common_expr(
id_array,
common_exprs,
down_index: 0,
+ alias_counter: 0,
})
.data()
}
@@ -962,6 +977,26 @@ mod test {
Ok(())
}
+ #[test]
+ fn nested_aliases() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan.clone())
+ .project(vec![
+ (col("a") + col("b") - col("c")).alias("alias1") * (col("a") +
col("b")),
+ col("a") + col("b"),
+ ])?
+ .build()?;
+
+ let expected = "Projection: {test.a + test.b|{test.b}|{test.a}} -
test.c AS alias1 * {test.a + test.b|{test.b}|{test.a}} AS test.a + test.b,
{test.a + test.b|{test.b}|{test.a}} AS test.a + test.b\
+ \n Projection: test.a + test.b AS {test.a +
test.b|{test.b}|{test.a}}, test.a, test.b, test.c\
+ \n TableScan: test";
+
+ assert_optimized_plan_eq(expected, &plan);
+
+ Ok(())
+ }
+
#[test]
fn aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
@@ -1006,7 +1041,7 @@ mod test {
)?
.build()?;
- let expected = "Projection: {AVG(test.a)|{test.a}} AS AVG(test.a) AS
col1, {AVG(test.a)|{test.a}} AS AVG(test.a) AS col2, col3, {AVG(test.c)} AS
AVG(test.c), {my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col4,
{my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col5, col6, {my_agg(test.c)} AS
my_agg(test.c)\
+ let expected = "Projection: {AVG(test.a)|{test.a}} AS col1,
{AVG(test.a)|{test.a}} AS col2, col3, {AVG(test.c)} AS AVG(test.c),
{my_agg(test.a)|{test.a}} AS col4, {my_agg(test.a)|{test.a}} AS col5, col6,
{my_agg(test.c)} AS my_agg(test.c)\
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS
{AVG(test.a)|{test.a}}, my_agg(test.a) AS {my_agg(test.a)|{test.a}},
AVG(test.b) AS col3, AVG(test.c) AS {AVG(test.c)}, my_agg(test.b) AS col6,
my_agg(test.c) AS {my_agg(test.c)}]]\
\n TableScan: test";
@@ -1042,7 +1077,7 @@ mod test {
)?
.build()?;
- let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) +
test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1)
+ test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\n Projection:
UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a,
test.b, test.c\n TableScan: test";
+ let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) +
test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) +
test.a|{test.a}|{UInt32(1)}}) AS col2]]\n Projection: UInt32(1) + test.a AS
{UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n
TableScan: test";
assert_optimized_plan_eq(expected, &plan);
@@ -1057,7 +1092,7 @@ mod test {
)?
.build()?;
- let expected = "Aggregate: groupBy=[[{UInt32(1) +
test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) +
test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1)
+ test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\
+ let expected = "Aggregate: groupBy=[[{UInt32(1) +
test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) +
test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) +
test.a|{test.a}|{UInt32(1)}}) AS col2]]\
\n Projection: UInt32(1) + test.a AS {UInt32(1) +
test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\
\n TableScan: test";
@@ -1078,8 +1113,8 @@ mod test {
)?
.build()?;
- let expected = "Projection: UInt32(1) + test.a, UInt32(1) +
{AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) +
test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) +
test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a)
AS col1, UInt32(1) - {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS
UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) +
test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32( [...]
- \n Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS
UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS
UInt32(1) + test.a) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS
UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) +
test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}}, my_agg({UInt32(1) +
test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {my_agg({UInt32(1) +
test.a|{test.a}|{UInt32(1)}} AS UIn [...]
+ let expected = "Projection: UInt32(1) + test.a, UInt32(1) +
{AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) +
test.a|{test.a}|{UInt32(1)}}}} AS col1, UInt32(1) - {AVG({UInt32(1) +
test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS
col2, {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)} AS
AVG(UInt32(1) + test.a), UInt32(1) + {my_agg({UInt32(1) +
test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} A
[...]
+ \n Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS
UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS
{AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) +
test.a|{test.a}|{UInt32(1)}}}}, my_agg({UInt32(1) +
test.a|{test.a}|{UInt32(1)}}) AS {my_agg({UInt32(1) +
test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}},
AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS
{AVG({UInt32(1) + t [...]
\n Projection: UInt32(1) + test.a AS {UInt32(1) +
test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\
\n TableScan: test";
@@ -1126,7 +1161,7 @@ mod test {
])?
.build()?;
- let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS
Int32(1) + test.a AS first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS Int32(1)
+ test.a AS second\
+ let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS
first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS second\
\n Projection: Int32(1) + test.a AS {Int32(1) +
test.a|{test.a}|{Int32(1)}}, test.a, test.b, test.c\
\n TableScan: test";
diff --git a/datafusion/sqllogictest/test_files/group_by.slt
b/datafusion/sqllogictest/test_files/group_by.slt
index 24a301d4a7..9e8a2450e0 100644
--- a/datafusion/sqllogictest/test_files/group_by.slt
+++ b/datafusion/sqllogictest/test_files/group_by.slt
@@ -4538,7 +4538,7 @@ CREATE EXTERNAL TABLE timestamp_table (
c2 INT,
)
STORED AS CSV
-LOCATION 'test_files/scratch/group_by/timestamp_table'
+LOCATION 'test_files/scratch/group_by/timestamp_table'
OPTIONS ('format.has_header' 'true');
# Group By using date_trunc
@@ -4611,7 +4611,7 @@ DROP TABLE timestamp_table;
# Table with an int column and Dict<Int8> column:
statement ok
-CREATE TABLE int8_dict AS VALUES
+CREATE TABLE int8_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int8, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int8, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int8, Utf8)')),
@@ -4649,7 +4649,7 @@ DROP TABLE int8_dict;
# Table with an int column and Dict<Int16> column:
statement ok
-CREATE TABLE int16_dict AS VALUES
+CREATE TABLE int16_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int16, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int16, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int16, Utf8)')),
@@ -4687,7 +4687,7 @@ DROP TABLE int16_dict;
# Table with an int column and Dict<Int32> column:
statement ok
-CREATE TABLE int32_dict AS VALUES
+CREATE TABLE int32_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int32, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int32, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int32, Utf8)')),
@@ -4725,7 +4725,7 @@ DROP TABLE int32_dict;
# Table with an int column and Dict<Int64> column:
statement ok
-CREATE TABLE int64_dict AS VALUES
+CREATE TABLE int64_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int64, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int64, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int64, Utf8)')),
@@ -4763,7 +4763,7 @@ DROP TABLE int64_dict;
# Table with an int column and Dict<UInt8> column:
statement ok
-CREATE TABLE uint8_dict AS VALUES
+CREATE TABLE uint8_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt8, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt8, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt8, Utf8)')),
@@ -4801,7 +4801,7 @@ DROP TABLE uint8_dict;
# Table with an int column and Dict<UInt16> column:
statement ok
-CREATE TABLE uint16_dict AS VALUES
+CREATE TABLE uint16_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt16, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt16, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt16, Utf8)')),
@@ -4839,7 +4839,7 @@ DROP TABLE uint16_dict;
# Table with an int column and Dict<UInt32> column:
statement ok
-CREATE TABLE uint32_dict AS VALUES
+CREATE TABLE uint32_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt32, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt32, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt32, Utf8)')),
@@ -4877,7 +4877,7 @@ DROP TABLE uint32_dict;
# Table with an int column and Dict<UInt64> column:
statement ok
-CREATE TABLE uint64_dict AS VALUES
+CREATE TABLE uint64_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt64, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt64, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt64, Utf8)')),
diff --git a/datafusion/sqllogictest/test_files/tpch/q1.slt.part
b/datafusion/sqllogictest/test_files/tpch/q1.slt.part
index 0583c6ef07..5e0930b992 100644
--- a/datafusion/sqllogictest/test_files/tpch/q1.slt.part
+++ b/datafusion/sqllogictest/test_files/tpch/q1.slt.part
@@ -42,7 +42,7 @@ explain select
logical_plan
01)Sort: lineitem.l_returnflag ASC NULLS LAST, lineitem.l_linestatus ASC NULLS
LAST
02)--Projection: lineitem.l_returnflag, lineitem.l_linestatus,
sum(lineitem.l_quantity) AS sum_qty, sum(lineitem.l_extendedprice) AS
sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)
AS sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) -
lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge,
AVG(lineitem.l_quantity) AS avg_qty, AVG(lineitem.l_extendedprice) AS
avg_price, AVG(lineitem.l_discount) AS avg_disc, COUNT(*) AS count_order
-03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]],
aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice),
sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) -
lineitem.l_discount)|{Decimal128(Some(1),20,0) -
lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}}
AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount)
AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), [...]
+03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]],
aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice),
sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) -
lineitem.l_discount)|{Decimal128(Some(1),20,0) -
lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}})
AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount),
sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discou
[...]
04)------Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) -
lineitem.l_discount) AS {lineitem.l_extendedprice * (Decimal128(Some(1),20,0) -
lineitem.l_discount)|{Decimal128(Some(1),20,0) -
lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}},
lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount,
lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus
05)--------Filter: lineitem.l_shipdate <= Date32("1998-09-02")
06)----------TableScan: lineitem projection=[l_quantity, l_extendedprice,
l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate],
partial_filters=[lineitem.l_shipdate <= Date32("1998-09-02")]
@@ -80,7 +80,7 @@ group by
l_linestatus
order by
l_returnflag,
- l_linestatus;
+ l_linestatus;
----
A F 3774200 5320753880.69 5054096266.6828 5256751331.449234 25.537587
36002.123829 0.050144 147790
N F 95257 133737795.84 127132372.6512 132286291.229445 25.300664 35521.326916
0.049394 3765
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]