This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new cd36ee3b30 fix: make `columnize_expr` resistant to display_name
collisions (#10459)
cd36ee3b30 is described below
commit cd36ee3b305dc4e50b9a5feb8ea88199ae17527a
Author: Jonah Gao <[email protected]>
AuthorDate: Tue May 14 18:23:05 2024 +0800
fix: make `columnize_expr` resistant to display_name collisions (#10459)
* fix: make `columnize_expr` resistant to display_name collisions
* fix simple_window_function test
* remove Projection
* add tests
* retry ci
* fix DataFrame tests
* Update datafusion/expr/src/logical_plan/plan.rs
Co-authored-by: Andrew Lamb <[email protected]>
* Remove copies
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/src/dataframe/mod.rs | 61 +++++++++++++++++++
datafusion/core/tests/dataframe/mod.rs | 2 +-
.../user_defined/user_defined_table_functions.rs | 4 +-
datafusion/expr/src/expr.rs | 13 +++--
datafusion/expr/src/logical_plan/builder.rs | 3 +-
datafusion/expr/src/logical_plan/plan.rs | 51 +++++++++++++++-
datafusion/expr/src/utils.rs | 68 +++++++++-------------
.../optimizer/src/common_subexpr_eliminate.rs | 4 +-
.../optimizer/src/single_distinct_to_groupby.rs | 37 +++---------
datafusion/sqllogictest/test_files/select.slt | 11 +++-
10 files changed, 168 insertions(+), 86 deletions(-)
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index 787698c009..04aaf5a890 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -2045,6 +2045,67 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn test_aggregate_subexpr() -> Result<()> {
+ let df = test_table().await?;
+
+ let group_expr = col("c2") + lit(1);
+ let aggr_expr = sum(col("c3") + lit(2));
+
+ let df = df
+ // GROUP BY `c2 + 1`
+ .aggregate(vec![group_expr.clone()], vec![aggr_expr.clone()])?
+ // SELECT `c2 + 1` as c2 + 10, sum(c3 + 2) + 20
+ // SELECT expressions contain aggr_expr and group_expr as
subexpressions
+ .select(vec![
+ group_expr.alias("c2") + lit(10),
+ (aggr_expr + lit(20)).alias("sum"),
+ ])?;
+
+ let df_results = df.collect().await?;
+
+ #[rustfmt::skip]
+ assert_batches_sorted_eq!([
+ "+----------------+------+",
+ "| c2 + Int32(10) | sum |",
+ "+----------------+------+",
+ "| 12 | 431 |",
+ "| 13 | 248 |",
+ "| 14 | 453 |",
+ "| 15 | 95 |",
+ "| 16 | -146 |",
+ "+----------------+------+",
+ ],
+ &df_results
+ );
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_aggregate_name_collision() -> Result<()> {
+ let df = test_table().await?;
+
+ let collided_alias = "aggregate_test_100.c2 + aggregate_test_100.c3";
+ let group_expr = lit(1).alias(collided_alias);
+
+ let df = df
+ // GROUP BY 1
+ .aggregate(vec![group_expr], vec![])?
+ // SELECT `aggregate_test_100.c2 + aggregate_test_100.c3`
+ .select(vec![
+ (col("aggregate_test_100.c2") + col("aggregate_test_100.c3")),
+ ])
+ // The select expr has the same display_name as the group_expr,
+ // but since they are different expressions, it should fail.
+ .expect_err("Expected error");
+ let expected = "Schema error: No field named aggregate_test_100.c2. \
+ Valid fields are \"aggregate_test_100.c2 +
aggregate_test_100.c3\".";
+ assert_eq!(df.strip_backtrace(), expected);
+
+ Ok(())
+ }
+
// Test issue: https://github.com/apache/datafusion/issues/10346
#[tokio::test]
async fn test_select_over_aggregate_schema() -> Result<()> {
diff --git a/datafusion/core/tests/dataframe/mod.rs
b/datafusion/core/tests/dataframe/mod.rs
index f565fba1db..009f45b280 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -210,7 +210,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> {
let sql_results = ctx
.sql("select count(*) from t1")
.await?
- .select(vec![count(wildcard())])?
+ .select(vec![col("COUNT(*)")])?
.explain(false, false)?
.collect()
.await?;
diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs
b/datafusion/core/tests/user_defined/user_defined_table_functions.rs
index 7342851569..d3ddbed20d 100644
--- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs
@@ -156,8 +156,8 @@ impl SimpleCsvTable {
let logical_plan = Projection::try_new(
vec![columnize_expr(
normalize_col(self.exprs[0].clone(), &plan)?,
- plan.schema(),
- )],
+ &plan,
+ )?],
Arc::new(plan),
)
.map(LogicalPlan::Projection)?;
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 660a45c27a..36953742c1 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -833,15 +833,16 @@ impl GroupingSet {
/// Return all distinct exprs in the grouping set. For `CUBE` and `ROLLUP`
this
/// is just the underlying list of exprs. For `GROUPING SET` we need to
deduplicate
/// the exprs in the underlying sets.
- pub fn distinct_expr(&self) -> Vec<Expr> {
+ pub fn distinct_expr(&self) -> Vec<&Expr> {
match self {
- GroupingSet::Rollup(exprs) => exprs.clone(),
- GroupingSet::Cube(exprs) => exprs.clone(),
+ GroupingSet::Rollup(exprs) | GroupingSet::Cube(exprs) => {
+ exprs.iter().collect()
+ }
GroupingSet::GroupingSets(groups) => {
- let mut exprs: Vec<Expr> = vec![];
+ let mut exprs: Vec<&Expr> = vec![];
for exp in groups.iter().flatten() {
- if !exprs.contains(exp) {
- exprs.push(exp.clone());
+ if !exprs.contains(&exp) {
+ exprs.push(exp);
}
}
exprs
diff --git a/datafusion/expr/src/logical_plan/builder.rs
b/datafusion/expr/src/logical_plan/builder.rs
index 6055537ac5..677d5a5da9 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -1435,8 +1435,7 @@ pub fn project(
input_schema,
None,
)?),
- _ => projected_expr
- .push(columnize_expr(normalize_col(e, &plan)?, input_schema)),
+ _ => projected_expr.push(columnize_expr(normalize_col(e, &plan)?,
&plan)?),
}
}
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index 266e7abc34..ddf075c2c2 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -1214,6 +1214,45 @@ impl LogicalPlan {
.unwrap();
contains
}
+
+ /// Get the output expressions and their corresponding columns.
+ ///
+ /// The parent node may reference the output columns of the plan by
expressions, such as
+ /// projection over aggregate or window functions. This method helps to
convert the
+ /// referenced expressions into columns.
+ ///
+ /// See also: [`crate::utils::columnize_expr`]
+ pub(crate) fn columnized_output_exprs(&self) -> Result<Vec<(&Expr,
Column)>> {
+ match self {
+ LogicalPlan::Aggregate(aggregate) => Ok(aggregate
+ .output_expressions()?
+ .into_iter()
+ .zip(self.schema().columns())
+ .collect()),
+ LogicalPlan::Window(Window {
+ window_expr,
+ input,
+ schema,
+ }) => {
+ // The input could be another Window, so the result should
also include the input's. For Example:
+ // `EXPLAIN SELECT RANK() OVER (PARTITION BY a ORDER BY b),
SUM(b) OVER (PARTITION BY a) FROM t`
+ // Its plan is:
+ // Projection: RANK() PARTITION BY [t.a] ORDER BY [t.b ASC
NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(t.b)
PARTITION BY [t.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ // WindowAggr: windowExpr=[[SUM(CAST(t.b AS Int64))
PARTITION BY [t.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
+ // WindowAggr: windowExpr=[[RANK() PARTITION BY [t.a]
ORDER BY [t.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT
ROW]]/
+ // TableScan: t projection=[a, b]
+ let mut output_exprs = input.columnized_output_exprs()?;
+ let input_len = input.schema().fields().len();
+ output_exprs.extend(
+ window_expr
+ .iter()
+ .zip(schema.columns().into_iter().skip(input_len)),
+ );
+ Ok(output_exprs)
+ }
+ _ => Ok(vec![]),
+ }
+ }
}
impl LogicalPlan {
@@ -2480,9 +2519,9 @@ impl Aggregate {
let is_grouping_set = matches!(group_expr.as_slice(),
[Expr::GroupingSet(_)]);
- let grouping_expr: Vec<Expr> =
grouping_set_to_exprlist(group_expr.as_slice())?;
+ let grouping_expr: Vec<&Expr> =
grouping_set_to_exprlist(group_expr.as_slice())?;
- let mut qualified_fields =
exprlist_to_fields(grouping_expr.as_slice(), &input)?;
+ let mut qualified_fields = exprlist_to_fields(grouping_expr, &input)?;
// Even columns that cannot be null will become nullable when used in
a grouping set.
if is_grouping_set {
@@ -2538,6 +2577,14 @@ impl Aggregate {
})
}
+ /// Get the output expressions.
+ fn output_expressions(&self) -> Result<Vec<&Expr>> {
+ let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
+ exprs.extend(self.aggr_expr.iter());
+ debug_assert!(exprs.len() == self.schema.fields().len());
+ Ok(exprs)
+ }
+
/// Get the length of the group by expression in the output schema
/// This is not simply group by expression length. Expression may be
/// GroupingSet, etc. In these case we need to get inner expression
lengths.
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 43e8ff7b23..581e299cf9 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -18,19 +18,20 @@
//! Expression utilities
use std::cmp::Ordering;
-use std::collections::HashSet;
+use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::expr::{Alias, Sort, WindowFunction};
use crate::expr_rewriter::strip_outer_reference;
use crate::signature::{Signature, TypeSignature};
use crate::{
- and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet,
LogicalPlan,
- Operator, TryCast,
+ and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan,
Operator,
};
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
-use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
+use datafusion_common::tree_node::{
+ Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
+};
use datafusion_common::utils::get_at_indices;
use datafusion_common::{
internal_err, plan_datafusion_err, plan_err, Column, DFSchema,
DFSchemaRef, Result,
@@ -247,7 +248,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) ->
Result<Vec<Expr>> {
/// Find all distinct exprs in a list of group by expressions. If the
/// first element is a `GroupingSet` expression then it must be the only expr.
-pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
+pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return plan_err!(
@@ -256,7 +257,7 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) ->
Result<Vec<Expr>> {
}
Ok(grouping_set.distinct_expr())
} else {
- Ok(group_expr.to_vec())
+ Ok(group_expr.iter().collect())
}
}
@@ -725,13 +726,16 @@ pub fn from_plan(
}
/// Create field meta-data from an expression, for use in a result set schema
-pub fn exprlist_to_fields(
- exprs: &[Expr],
+pub fn exprlist_to_fields<'a>(
+ exprs: impl IntoIterator<Item = &'a Expr>,
plan: &LogicalPlan,
) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
// look for exact match in plan's output schema
let input_schema = &plan.schema();
- exprs.iter().map(|e| e.to_field(input_schema)).collect()
+ exprs
+ .into_iter()
+ .map(|e| e.to_field(input_schema))
+ .collect()
}
/// Convert an expression into Column expression if it's already provided as
input plan.
@@ -749,37 +753,21 @@ pub fn exprlist_to_fields(
/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
/// .project(vec![col("c1"), col("SUM(c2)")?
/// ```
-pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr {
- match e {
- Expr::Column(_) => e,
- Expr::OuterReferenceColumn(_, _) => e,
- Expr::Alias(Alias {
- expr,
- relation,
- name,
- }) => columnize_expr(*expr, input_schema).alias_qualified(relation,
name),
- Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast {
- expr: Box::new(columnize_expr(*expr, input_schema)),
- data_type,
- }),
- Expr::TryCast(TryCast { expr, data_type }) =>
Expr::TryCast(TryCast::new(
- Box::new(columnize_expr(*expr, input_schema)),
- data_type,
+pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
+ let output_exprs = match input.columnized_output_exprs() {
+ Ok(exprs) if !exprs.is_empty() => exprs,
+ _ => return Ok(e),
+ };
+ let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
+ e.transform_down(|node: Expr| match exprs_map.get(&node) {
+ Some(column) => Ok(Transformed::new(
+ Expr::Column(column.clone()),
+ true,
+ TreeNodeRecursion::Jump,
)),
- Expr::ScalarSubquery(_) => e.clone(),
- _ => match e.display_name() {
- Ok(name) => {
- match
input_schema.qualified_field_with_unqualified_name(&name) {
- Ok((qualifier, field)) => {
- Expr::Column(Column::from((qualifier, field)))
- }
- // expression not provided as input, do not convert to a
column reference
- Err(_) => e,
- }
- }
- Err(_) => e,
- },
- }
+ None => Ok(Transformed::no(node)),
+ })
+ .data()
}
/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
@@ -1235,7 +1223,7 @@ mod tests {
use super::*;
use crate::{
col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup,
AggregateFunction,
- WindowFrame, WindowFunctionDefinition,
+ Cast, WindowFrame, WindowFunctionDefinition,
};
#[test]
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 0704fabea2..3532a57f62 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -1154,12 +1154,12 @@ mod test {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
- .project(vec![lit(1) + col("a")])?
+ .project(vec![lit(1) + col("a"), col("a")])?
.project(vec![lit(1) + col("a")])?
.build()?;
let expected = "Projection: Int32(1) + test.a\
- \n Projection: Int32(1) + test.a\
+ \n Projection: Int32(1) + test.a, test.a\
\n TableScan: test";
assert_optimized_plan_eq(expected, &plan);
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index aaf4667fb0..5c82cf93cb 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -22,15 +22,15 @@ use std::sync::Arc;
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{qualified_name, DFSchema, Result};
+use datafusion_common::{qualified_name, Result};
+use datafusion_expr::builder::project;
use datafusion_expr::expr::AggregateFunctionDefinition;
use datafusion_expr::{
aggregate_function::AggregateFunction::{Max, Min, Sum},
col,
expr::AggregateFunction,
- logical_plan::{Aggregate, LogicalPlan, Projection},
- utils::columnize_expr,
- Expr, ExprSchemable,
+ logical_plan::{Aggregate, LogicalPlan},
+ Expr,
};
use hashbrown::HashSet;
@@ -228,37 +228,18 @@ impl OptimizerRule for SingleDistinctToGroupBy {
.collect::<Result<Vec<_>>>()?;
// construct the inner AggrPlan
- let inner_fields = inner_group_exprs
- .iter()
- .chain(inner_aggr_exprs.iter())
- .map(|expr| expr.to_field(input.schema()))
- .collect::<Result<Vec<_>>>()?;
- let inner_schema = DFSchema::new_with_metadata(
- inner_fields,
- input.schema().metadata().clone(),
- )?;
let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
input.clone(),
inner_group_exprs,
inner_aggr_exprs,
)?);
- let outer_fields = outer_group_exprs
- .iter()
- .chain(outer_aggr_exprs.iter())
- .map(|expr| expr.to_field(&inner_schema))
- .collect::<Result<Vec<_>>>()?;
- let outer_aggr_schema =
Arc::new(DFSchema::new_with_metadata(
- outer_fields,
- input.schema().metadata().clone(),
- )?);
-
// so the aggregates are displayed in the same way even
after the rewrite
// this optimizer has two kinds of alias:
// - group_by aggr
// - aggr expr
let group_size = group_expr.len();
- let alias_expr = out_group_expr_with_alias
+ let alias_expr: Vec<_> = out_group_expr_with_alias
.into_iter()
.map(|(group_expr, original_field)| {
if let Some(name) = original_field {
@@ -271,7 +252,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
let idx = idx + group_size;
let (qualifier, field) =
schema.qualified_field(idx);
let name = qualified_name(qualifier, field.name());
- columnize_expr(expr.clone().alias(name),
&outer_aggr_schema)
+ expr.clone().alias(name)
}))
.collect();
@@ -280,11 +261,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
outer_group_exprs,
outer_aggr_exprs,
)?);
-
- Ok(Some(LogicalPlan::Projection(Projection::try_new(
- alias_expr,
- Arc::new(outer_aggr),
- )?)))
+ Ok(Some(project(outer_aggr, alias_expr)?))
} else {
Ok(None)
}
diff --git a/datafusion/sqllogictest/test_files/select.slt
b/datafusion/sqllogictest/test_files/select.slt
index d73157570d..6b74156b52 100644
--- a/datafusion/sqllogictest/test_files/select.slt
+++ b/datafusion/sqllogictest/test_files/select.slt
@@ -1619,7 +1619,16 @@ select count(1) from v;
query I
select a + b from (select 1 as a, 2 as b, 1 as "a + b");
----
-1
+3
+
+# Can't reference an output column by expression over projection.
+query error DataFusion error: Schema error: No field named a\. Valid fields
are "a \+ Int64\(1\)"\.
+select a + 1 from (select a+1 from (select 1 as a));
+
+query I
+select "a + Int64(1)" + 10 from (select a+1 from (select 1 as a));
+----
+12
# run below query without logical optimizations
statement ok
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]