This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 0df9b99 Avoid changing expression names during constant folding
(#1319)
0df9b99 is described below
commit 0df9b99f3babdb5ba92c00c302d371fac9a743fd
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Nov 22 13:59:23 2021 -0800
Avoid changing expression names during constant folding (#1319)
* Avoid changing output column name before pushdown projection.
* Avoid early optimization which is duplicate.
* Modify test.
* Revert "Modify test."
This reverts commit ebf67d347de60d2ce550a35f1a6c561632ca6fba.
* Revert "Avoid early optimization which is duplicate."
This reverts commit fe8445de9c62b9d9d0253d128f22866b46e0b0e4.
* Add aliases during constant folding.
* Some expressions don't support name.
* Don't create redundant alias.
* Only add alias for certain plans.
* Fix clippy.
* Fix.
* Revert "Fix."
This reverts commit d767aebc1a4413485dd6d5117c481f5e37070cad.
* Apply to all nodes and update tests.
* Unalias when push donw to TableScan.
* Update more tests.
* Remove previous change.
---
datafusion/src/logical_plan/expr.rs | 9 +++++
datafusion/src/logical_plan/mod.rs | 2 +-
datafusion/src/optimizer/constant_folding.rs | 46 +++++++++++++++-------
datafusion/src/physical_plan/planner.rs | 7 ++--
datafusion/tests/sql.rs | 58 ++++++++++++++++++----------
5 files changed, 82 insertions(+), 40 deletions(-)
diff --git a/datafusion/src/logical_plan/expr.rs
b/datafusion/src/logical_plan/expr.rs
index 04e95e7..e7801e3 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -1349,6 +1349,15 @@ pub fn unnormalize_cols(exprs: impl IntoIterator<Item =
Expr>) -> Vec<Expr> {
exprs.into_iter().map(unnormalize_col).collect()
}
+/// Recursively un-alias an expressions
+#[inline]
+pub fn unalias(expr: Expr) -> Expr {
+ match expr {
+ Expr::Alias(sub_expr, _) => unalias(*sub_expr),
+ _ => expr,
+ }
+}
+
/// Create an expression to represent the min() aggregate function
pub fn min(expr: Expr) -> Expr {
Expr::AggregateFunction {
diff --git a/datafusion/src/logical_plan/mod.rs
b/datafusion/src/logical_plan/mod.rs
index 73fdcb9..494501d 100644
--- a/datafusion/src/logical_plan/mod.rs
+++ b/datafusion/src/logical_plan/mod.rs
@@ -44,7 +44,7 @@ pub use expr::{
max, md5, min, normalize_col, normalize_cols, now, octet_length, or,
random,
regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
right, round,
rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
- starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc,
+ starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc,
unalias,
unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter,
ExpressionVisitor, Literal, Recursion, RewriteRecursion,
};
diff --git a/datafusion/src/optimizer/constant_folding.rs
b/datafusion/src/optimizer/constant_folding.rs
index 3243f37..cc23cf0 100644
--- a/datafusion/src/optimizer/constant_folding.rs
+++ b/datafusion/src/optimizer/constant_folding.rs
@@ -92,6 +92,10 @@ impl OptimizerRule for ConstantFolding {
.expressions()
.into_iter()
.map(|e| {
+ // We need to keep original expression name, if any.
+ // Constant folding should not change expression name.
+ let name = &e.name(plan.schema());
+
// TODO iterate until no changes are made
// during rewrite (evaluating constants can
// enable new simplifications and
@@ -101,7 +105,18 @@ impl OptimizerRule for ConstantFolding {
// fold constants and then simplify
.rewrite(&mut const_evaluator)?
.rewrite(&mut simplifier)?;
- Ok(new_e)
+
+ let new_name = &new_e.name(plan.schema());
+
+ if let (Ok(expr_name), Ok(new_expr_name)) = (name,
new_name) {
+ if expr_name != new_expr_name {
+ Ok(new_e.alias(expr_name))
+ } else {
+ Ok(new_e)
+ }
+ } else {
+ Ok(new_e)
+ }
})
.collect::<Result<Vec<_>>>()?;
@@ -626,8 +641,8 @@ mod tests {
let expected = "\
Projection: #test.a\
- \n Filter: NOT #test.c\
- \n Filter: #test.b\
+ \n Filter: NOT #test.c AS test.c = Boolean(false)\
+ \n Filter: #test.b AS test.b = Boolean(true)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -647,8 +662,8 @@ mod tests {
let expected = "\
Projection: #test.a\
\n Limit: 1\
- \n Filter: #test.c\
- \n Filter: NOT #test.b\
+ \n Filter: #test.c AS test.c != Boolean(false)\
+ \n Filter: NOT #test.b AS test.b != Boolean(true)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -665,7 +680,7 @@ mod tests {
let expected = "\
Projection: #test.a\
- \n Filter: NOT #test.b AND #test.c\
+ \n Filter: NOT #test.b AND #test.c AS test.b != Boolean(true) AND
test.c = Boolean(true)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -682,7 +697,7 @@ mod tests {
let expected = "\
Projection: #test.a\
- \n Filter: NOT #test.b OR NOT #test.c\
+ \n Filter: NOT #test.b OR NOT #test.c AS test.b != Boolean(true) OR
test.c = Boolean(false)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -699,7 +714,7 @@ mod tests {
let expected = "\
Projection: #test.a\
- \n Filter: #test.b\
+ \n Filter: #test.b AS NOT test.b = Boolean(false)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -714,7 +729,7 @@ mod tests {
.build()?;
let expected = "\
- Projection: #test.a, #test.d, NOT #test.b\
+ Projection: #test.a, #test.d, NOT #test.b AS test.b = Boolean(false)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -733,7 +748,7 @@ mod tests {
.build()?;
let expected = "\
- Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b),
MIN(#test.b)]]\
+ Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b) AS
MAX(test.b = Boolean(true)), MIN(#test.b)]]\
\n Projection: #test.a, #test.c, #test.b\
\n TableScan: test projection=None";
@@ -789,7 +804,7 @@ mod tests {
.build()
.unwrap();
- let expected = "Projection: TimestampNanosecond(1599566400000000000)\
+ let expected = "Projection: TimestampNanosecond(1599566400000000000)
AS totimestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\
\n TableScan: test projection=None"
.to_string();
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
@@ -824,7 +839,7 @@ mod tests {
.build()
.unwrap();
- let expected = "Projection: Int32(0)\
+ let expected = "Projection: Int32(0) AS CAST(Utf8(\"0\") AS Int32)\
\n TableScan: test projection=None";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
@@ -873,7 +888,7 @@ mod tests {
// expect the same timestamp appears in both exprs
let actual = get_optimized_plan_formatted(&plan, &time);
let expected = format!(
- "Projection: TimestampNanosecond({}), TimestampNanosecond({}) AS
t2\
+ "Projection: TimestampNanosecond({}) AS now(),
TimestampNanosecond({}) AS t2\
\n TableScan: test projection=None",
time.timestamp_nanos(),
time.timestamp_nanos()
@@ -897,7 +912,8 @@ mod tests {
.unwrap();
let actual = get_optimized_plan_formatted(&plan, &time);
- let expected = "Projection: NOT #test.a\
+ let expected =
+ "Projection: NOT #test.a AS Boolean(true) OR Boolean(false) !=
test.a\
\n TableScan: test projection=None";
assert_eq!(actual, expected);
@@ -929,7 +945,7 @@ mod tests {
// Note that constant folder runs and folds the entire
// expression down to a single constant (true)
- let expected = "Filter: Boolean(true)\
+ let expected = "Filter: Boolean(true) AS CAST(now() AS Int64) <
CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\
\n TableScan: test projection=None";
let actual = get_optimized_plan_formatted(&plan, &time);
diff --git a/datafusion/src/physical_plan/planner.rs
b/datafusion/src/physical_plan/planner.rs
index 7fcacd2..a7ba64a 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -25,7 +25,7 @@ use super::{
use crate::execution::context::ExecutionContextState;
use crate::logical_plan::plan::{EmptyRelation, Filter, Projection, Window};
use crate::logical_plan::{
- unnormalize_cols, CrossJoin, DFSchema, Expr, LogicalPlan, Operator,
+ unalias, unnormalize_cols, CrossJoin, DFSchema, Expr, LogicalPlan,
Operator,
Partitioning as LogicalPartitioning, PlanType, Repartition,
ToStringifiedPlan, Union,
UserDefinedLogicalNode,
};
@@ -339,7 +339,8 @@ impl DefaultPhysicalPlanner {
// doesn't know (nor should care) how the relation was
// referred to in the query
let filters = unnormalize_cols(filters.iter().cloned());
- source.scan(projection, batch_size, &filters, *limit).await
+ let unaliased: Vec<Expr> =
filters.into_iter().map(unalias).collect();
+ source.scan(projection, batch_size, &unaliased,
*limit).await
}
LogicalPlan::Values(Values {
values,
@@ -1340,7 +1341,7 @@ impl DefaultPhysicalPlanner {
physical_input_schema: &Schema,
ctx_state: &ExecutionContextState,
) -> Result<Arc<dyn AggregateExpr>> {
- // unpack aliased logical expressions, e.g. "sum(col) as total"
+ // unpack (nested) aliased logical expressions, e.g. "sum(col) as
total"
let (name, e) = match e {
Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()),
_ => (physical_name(e)?, e),
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 24b2a49..5ea008c 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -1392,6 +1392,22 @@ async fn csv_query_approx_count() -> Result<()> {
}
#[tokio::test]
+async fn query_count_without_from() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ let sql = "SELECT count(1 + 1)";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+----------------------------+",
+ "| COUNT(Int64(1) + Int64(1)) |",
+ "+----------------------------+",
+ "| 1 |",
+ "+----------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
async fn csv_query_array_agg() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx).await?;
@@ -1663,12 +1679,12 @@ async fn csv_query_cast_literal() -> Result<()> {
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
- "+--------------------+------------+",
- "| c12 | Float64(1) |",
- "+--------------------+------------+",
- "| 0.9294097332465232 | 1 |",
- "| 0.3114712539863804 | 1 |",
- "+--------------------+------------+",
+ "+--------------------+---------------------------+",
+ "| c12 | CAST(Int64(1) AS Float64) |",
+ "+--------------------+---------------------------+",
+ "| 0.9294097332465232 | 1 |",
+ "| 0.3114712539863804 | 1 |",
+ "+--------------------+---------------------------+",
];
assert_batches_eq!(expected, &actual);
@@ -4410,11 +4426,11 @@ async fn query_without_from() -> Result<()> {
let sql = "SELECT 1+2, 3/4, cos(0)";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
- "+----------+----------+------------+",
- "| Int64(3) | Int64(0) | Float64(1) |",
- "+----------+----------+------------+",
- "| 3 | 0 | 1 |",
- "+----------+----------+------------+",
+ "+---------------------+---------------------+---------------+",
+ "| Int64(1) + Int64(2) | Int64(3) / Int64(4) | cos(Int64(0)) |",
+ "+---------------------+---------------------+---------------+",
+ "| 3 | 0 | 1 |",
+ "+---------------------+---------------------+---------------+",
];
assert_batches_eq!(expected, &actual);
@@ -5865,11 +5881,11 @@ async fn case_with_bool_type_result() -> Result<()> {
let sql = "select case when 'cpu' != 'cpu' then true else false end";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
- "+----------------+",
- "| Boolean(false) |",
- "+----------------+",
- "| false |",
- "+----------------+",
+
"+---------------------------------------------------------------------------------+",
+ "| CASE WHEN Utf8(\"cpu\") != Utf8(\"cpu\") THEN Boolean(true) ELSE
Boolean(false) END |",
+
"+---------------------------------------------------------------------------------+",
+ "| false
|",
+
"+---------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
@@ -5882,11 +5898,11 @@ async fn use_between_expression_in_select_query() ->
Result<()> {
let sql = "SELECT 1 NOT BETWEEN 3 AND 5";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
- "+---------------+",
- "| Boolean(true) |",
- "+---------------+",
- "| true |",
- "+---------------+",
+ "+--------------------------------------------+",
+ "| Int64(1) NOT BETWEEN Int64(3) AND Int64(5) |",
+ "+--------------------------------------------+",
+ "| true |",
+ "+--------------------------------------------+",
];
assert_batches_eq!(expected, &actual);