This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new cd36ee3b30 fix: make `columnize_expr` resistant to display_name 
collisions (#10459)
cd36ee3b30 is described below

commit cd36ee3b305dc4e50b9a5feb8ea88199ae17527a
Author: Jonah Gao <[email protected]>
AuthorDate: Tue May 14 18:23:05 2024 +0800

    fix: make `columnize_expr` resistant to display_name collisions (#10459)
    
    * fix: make `columnize_expr` resistant to display_name collisions
    
    * fix simple_window_function test
    
    * remove Projection
    
    * add tests
    
    * retry ci
    
    * fix DataFrame tests
    
    * Update datafusion/expr/src/logical_plan/plan.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Remove copies
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/dataframe/mod.rs               | 61 +++++++++++++++++++
 datafusion/core/tests/dataframe/mod.rs             |  2 +-
 .../user_defined/user_defined_table_functions.rs   |  4 +-
 datafusion/expr/src/expr.rs                        | 13 +++--
 datafusion/expr/src/logical_plan/builder.rs        |  3 +-
 datafusion/expr/src/logical_plan/plan.rs           | 51 +++++++++++++++-
 datafusion/expr/src/utils.rs                       | 68 +++++++++-------------
 .../optimizer/src/common_subexpr_eliminate.rs      |  4 +-
 .../optimizer/src/single_distinct_to_groupby.rs    | 37 +++---------
 datafusion/sqllogictest/test_files/select.slt      | 11 +++-
 10 files changed, 168 insertions(+), 86 deletions(-)

diff --git a/datafusion/core/src/dataframe/mod.rs 
b/datafusion/core/src/dataframe/mod.rs
index 787698c009..04aaf5a890 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -2045,6 +2045,67 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn test_aggregate_subexpr() -> Result<()> {
+        let df = test_table().await?;
+
+        let group_expr = col("c2") + lit(1);
+        let aggr_expr = sum(col("c3") + lit(2));
+
+        let df = df
+            // GROUP BY `c2 + 1`
+            .aggregate(vec![group_expr.clone()], vec![aggr_expr.clone()])?
+            // SELECT `c2 + 1` as c2 + 10, sum(c3 + 2) + 20
+            // SELECT expressions contain aggr_expr and group_expr as 
subexpressions
+            .select(vec![
+                group_expr.alias("c2") + lit(10),
+                (aggr_expr + lit(20)).alias("sum"),
+            ])?;
+
+        let df_results = df.collect().await?;
+
+        #[rustfmt::skip]
+        assert_batches_sorted_eq!([
+                "+----------------+------+",
+                "| c2 + Int32(10) | sum  |",
+                "+----------------+------+",
+                "| 12             | 431  |",
+                "| 13             | 248  |",
+                "| 14             | 453  |",
+                "| 15             | 95   |",
+                "| 16             | -146 |",
+                "+----------------+------+",
+            ],
+            &df_results
+        );
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_aggregate_name_collision() -> Result<()> {
+        let df = test_table().await?;
+
+        let collided_alias = "aggregate_test_100.c2 + aggregate_test_100.c3";
+        let group_expr = lit(1).alias(collided_alias);
+
+        let df = df
+            // GROUP BY 1
+            .aggregate(vec![group_expr], vec![])?
+            // SELECT `aggregate_test_100.c2 + aggregate_test_100.c3`
+            .select(vec![
+                (col("aggregate_test_100.c2") + col("aggregate_test_100.c3")),
+            ])
+            // The select expr has the same display_name as the group_expr,
+            // but since they are different expressions, it should fail.
+            .expect_err("Expected error");
+        let expected = "Schema error: No field named aggregate_test_100.c2. \
+            Valid fields are \"aggregate_test_100.c2 + 
aggregate_test_100.c3\".";
+        assert_eq!(df.strip_backtrace(), expected);
+
+        Ok(())
+    }
+
     // Test issue: https://github.com/apache/datafusion/issues/10346
     #[tokio::test]
     async fn test_select_over_aggregate_schema() -> Result<()> {
diff --git a/datafusion/core/tests/dataframe/mod.rs 
b/datafusion/core/tests/dataframe/mod.rs
index f565fba1db..009f45b280 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -210,7 +210,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> {
     let sql_results = ctx
         .sql("select count(*) from t1")
         .await?
-        .select(vec![count(wildcard())])?
+        .select(vec![col("COUNT(*)")])?
         .explain(false, false)?
         .collect()
         .await?;
diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_table_functions.rs
index 7342851569..d3ddbed20d 100644
--- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs
@@ -156,8 +156,8 @@ impl SimpleCsvTable {
         let logical_plan = Projection::try_new(
             vec![columnize_expr(
                 normalize_col(self.exprs[0].clone(), &plan)?,
-                plan.schema(),
-            )],
+                &plan,
+            )?],
             Arc::new(plan),
         )
         .map(LogicalPlan::Projection)?;
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 660a45c27a..36953742c1 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -833,15 +833,16 @@ impl GroupingSet {
     /// Return all distinct exprs in the grouping set. For `CUBE` and `ROLLUP` 
this
     /// is just the underlying list of exprs. For `GROUPING SET` we need to 
deduplicate
     /// the exprs in the underlying sets.
-    pub fn distinct_expr(&self) -> Vec<Expr> {
+    pub fn distinct_expr(&self) -> Vec<&Expr> {
         match self {
-            GroupingSet::Rollup(exprs) => exprs.clone(),
-            GroupingSet::Cube(exprs) => exprs.clone(),
+            GroupingSet::Rollup(exprs) | GroupingSet::Cube(exprs) => {
+                exprs.iter().collect()
+            }
             GroupingSet::GroupingSets(groups) => {
-                let mut exprs: Vec<Expr> = vec![];
+                let mut exprs: Vec<&Expr> = vec![];
                 for exp in groups.iter().flatten() {
-                    if !exprs.contains(exp) {
-                        exprs.push(exp.clone());
+                    if !exprs.contains(&exp) {
+                        exprs.push(exp);
                     }
                 }
                 exprs
diff --git a/datafusion/expr/src/logical_plan/builder.rs 
b/datafusion/expr/src/logical_plan/builder.rs
index 6055537ac5..677d5a5da9 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -1435,8 +1435,7 @@ pub fn project(
                 input_schema,
                 None,
             )?),
-            _ => projected_expr
-                .push(columnize_expr(normalize_col(e, &plan)?, input_schema)),
+            _ => projected_expr.push(columnize_expr(normalize_col(e, &plan)?, 
&plan)?),
         }
     }
 
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 266e7abc34..ddf075c2c2 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -1214,6 +1214,45 @@ impl LogicalPlan {
         .unwrap();
         contains
     }
+
+    /// Get the output expressions and their corresponding columns.
+    ///
+    /// The parent node may reference the output columns of the plan by 
expressions, such as
+    /// projection over aggregate or window functions. This method helps to 
convert the
+    /// referenced expressions into columns.
+    ///
+    /// See also: [`crate::utils::columnize_expr`]
+    pub(crate) fn columnized_output_exprs(&self) -> Result<Vec<(&Expr, 
Column)>> {
+        match self {
+            LogicalPlan::Aggregate(aggregate) => Ok(aggregate
+                .output_expressions()?
+                .into_iter()
+                .zip(self.schema().columns())
+                .collect()),
+            LogicalPlan::Window(Window {
+                window_expr,
+                input,
+                schema,
+            }) => {
+                // The input could be another Window, so the result should 
also include the input's. For Example:
+                // `EXPLAIN SELECT RANK() OVER (PARTITION BY a ORDER BY b), 
SUM(b) OVER (PARTITION BY a) FROM t`
+                // Its plan is:
+                // Projection: RANK() PARTITION BY [t.a] ORDER BY [t.b ASC 
NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(t.b) 
PARTITION BY [t.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+                //   WindowAggr: windowExpr=[[SUM(CAST(t.b AS Int64)) 
PARTITION BY [t.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
+                //     WindowAggr: windowExpr=[[RANK() PARTITION BY [t.a] 
ORDER BY [t.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT 
ROW]]/
+                //       TableScan: t projection=[a, b]
+                let mut output_exprs = input.columnized_output_exprs()?;
+                let input_len = input.schema().fields().len();
+                output_exprs.extend(
+                    window_expr
+                        .iter()
+                        .zip(schema.columns().into_iter().skip(input_len)),
+                );
+                Ok(output_exprs)
+            }
+            _ => Ok(vec![]),
+        }
+    }
 }
 
 impl LogicalPlan {
@@ -2480,9 +2519,9 @@ impl Aggregate {
 
         let is_grouping_set = matches!(group_expr.as_slice(), 
[Expr::GroupingSet(_)]);
 
-        let grouping_expr: Vec<Expr> = 
grouping_set_to_exprlist(group_expr.as_slice())?;
+        let grouping_expr: Vec<&Expr> = 
grouping_set_to_exprlist(group_expr.as_slice())?;
 
-        let mut qualified_fields = 
exprlist_to_fields(grouping_expr.as_slice(), &input)?;
+        let mut qualified_fields = exprlist_to_fields(grouping_expr, &input)?;
 
         // Even columns that cannot be null will become nullable when used in 
a grouping set.
         if is_grouping_set {
@@ -2538,6 +2577,14 @@ impl Aggregate {
         })
     }
 
+    /// Get the output expressions.
+    fn output_expressions(&self) -> Result<Vec<&Expr>> {
+        let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
+        exprs.extend(self.aggr_expr.iter());
+        debug_assert!(exprs.len() == self.schema.fields().len());
+        Ok(exprs)
+    }
+
     /// Get the length of the group by expression in the output schema
     /// This is not simply group by expression length. Expression may be
     /// GroupingSet, etc. In these case we need to get inner expression 
lengths.
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 43e8ff7b23..581e299cf9 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -18,19 +18,20 @@
 //! Expression utilities
 
 use std::cmp::Ordering;
-use std::collections::HashSet;
+use std::collections::{HashMap, HashSet};
 use std::sync::Arc;
 
 use crate::expr::{Alias, Sort, WindowFunction};
 use crate::expr_rewriter::strip_outer_reference;
 use crate::signature::{Signature, TypeSignature};
 use crate::{
-    and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet, 
LogicalPlan,
-    Operator, TryCast,
+    and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, 
Operator,
 };
 
 use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
-use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
+use datafusion_common::tree_node::{
+    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
+};
 use datafusion_common::utils::get_at_indices;
 use datafusion_common::{
     internal_err, plan_datafusion_err, plan_err, Column, DFSchema, 
DFSchemaRef, Result,
@@ -247,7 +248,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> 
Result<Vec<Expr>> {
 
 /// Find all distinct exprs in a list of group by expressions. If the
 /// first element is a `GroupingSet` expression then it must be the only expr.
-pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
+pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
     if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
         if group_expr.len() > 1 {
             return plan_err!(
@@ -256,7 +257,7 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> 
Result<Vec<Expr>> {
         }
         Ok(grouping_set.distinct_expr())
     } else {
-        Ok(group_expr.to_vec())
+        Ok(group_expr.iter().collect())
     }
 }
 
@@ -725,13 +726,16 @@ pub fn from_plan(
 }
 
 /// Create field meta-data from an expression, for use in a result set schema
-pub fn exprlist_to_fields(
-    exprs: &[Expr],
+pub fn exprlist_to_fields<'a>(
+    exprs: impl IntoIterator<Item = &'a Expr>,
     plan: &LogicalPlan,
 ) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
     // look for exact match in plan's output schema
     let input_schema = &plan.schema();
-    exprs.iter().map(|e| e.to_field(input_schema)).collect()
+    exprs
+        .into_iter()
+        .map(|e| e.to_field(input_schema))
+        .collect()
 }
 
 /// Convert an expression into Column expression if it's already provided as 
input plan.
@@ -749,37 +753,21 @@ pub fn exprlist_to_fields(
 /// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
 /// .project(vec![col("c1"), col("SUM(c2)")?
 /// ```
-pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr {
-    match e {
-        Expr::Column(_) => e,
-        Expr::OuterReferenceColumn(_, _) => e,
-        Expr::Alias(Alias {
-            expr,
-            relation,
-            name,
-        }) => columnize_expr(*expr, input_schema).alias_qualified(relation, 
name),
-        Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast {
-            expr: Box::new(columnize_expr(*expr, input_schema)),
-            data_type,
-        }),
-        Expr::TryCast(TryCast { expr, data_type }) => 
Expr::TryCast(TryCast::new(
-            Box::new(columnize_expr(*expr, input_schema)),
-            data_type,
+pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
+    let output_exprs = match input.columnized_output_exprs() {
+        Ok(exprs) if !exprs.is_empty() => exprs,
+        _ => return Ok(e),
+    };
+    let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
+    e.transform_down(|node: Expr| match exprs_map.get(&node) {
+        Some(column) => Ok(Transformed::new(
+            Expr::Column(column.clone()),
+            true,
+            TreeNodeRecursion::Jump,
         )),
-        Expr::ScalarSubquery(_) => e.clone(),
-        _ => match e.display_name() {
-            Ok(name) => {
-                match 
input_schema.qualified_field_with_unqualified_name(&name) {
-                    Ok((qualifier, field)) => {
-                        Expr::Column(Column::from((qualifier, field)))
-                    }
-                    // expression not provided as input, do not convert to a 
column reference
-                    Err(_) => e,
-                }
-            }
-            Err(_) => e,
-        },
-    }
+        None => Ok(Transformed::no(node)),
+    })
+    .data()
 }
 
 /// Collect all deeply nested `Expr::Column`'s. They are returned in order of
@@ -1235,7 +1223,7 @@ mod tests {
     use super::*;
     use crate::{
         col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, 
AggregateFunction,
-        WindowFrame, WindowFunctionDefinition,
+        Cast, WindowFrame, WindowFunctionDefinition,
     };
 
     #[test]
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs 
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 0704fabea2..3532a57f62 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -1154,12 +1154,12 @@ mod test {
         let table_scan = test_table_scan()?;
 
         let plan = LogicalPlanBuilder::from(table_scan)
-            .project(vec![lit(1) + col("a")])?
+            .project(vec![lit(1) + col("a"), col("a")])?
             .project(vec![lit(1) + col("a")])?
             .build()?;
 
         let expected = "Projection: Int32(1) + test.a\
-        \n  Projection: Int32(1) + test.a\
+        \n  Projection: Int32(1) + test.a, test.a\
         \n    TableScan: test";
 
         assert_optimized_plan_eq(expected, &plan);
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs 
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index aaf4667fb0..5c82cf93cb 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -22,15 +22,15 @@ use std::sync::Arc;
 use crate::optimizer::ApplyOrder;
 use crate::{OptimizerConfig, OptimizerRule};
 
-use datafusion_common::{qualified_name, DFSchema, Result};
+use datafusion_common::{qualified_name, Result};
+use datafusion_expr::builder::project;
 use datafusion_expr::expr::AggregateFunctionDefinition;
 use datafusion_expr::{
     aggregate_function::AggregateFunction::{Max, Min, Sum},
     col,
     expr::AggregateFunction,
-    logical_plan::{Aggregate, LogicalPlan, Projection},
-    utils::columnize_expr,
-    Expr, ExprSchemable,
+    logical_plan::{Aggregate, LogicalPlan},
+    Expr,
 };
 
 use hashbrown::HashSet;
@@ -228,37 +228,18 @@ impl OptimizerRule for SingleDistinctToGroupBy {
                         .collect::<Result<Vec<_>>>()?;
 
                     // construct the inner AggrPlan
-                    let inner_fields = inner_group_exprs
-                        .iter()
-                        .chain(inner_aggr_exprs.iter())
-                        .map(|expr| expr.to_field(input.schema()))
-                        .collect::<Result<Vec<_>>>()?;
-                    let inner_schema = DFSchema::new_with_metadata(
-                        inner_fields,
-                        input.schema().metadata().clone(),
-                    )?;
                     let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
                         input.clone(),
                         inner_group_exprs,
                         inner_aggr_exprs,
                     )?);
 
-                    let outer_fields = outer_group_exprs
-                        .iter()
-                        .chain(outer_aggr_exprs.iter())
-                        .map(|expr| expr.to_field(&inner_schema))
-                        .collect::<Result<Vec<_>>>()?;
-                    let outer_aggr_schema = 
Arc::new(DFSchema::new_with_metadata(
-                        outer_fields,
-                        input.schema().metadata().clone(),
-                    )?);
-
                     // so the aggregates are displayed in the same way even 
after the rewrite
                     // this optimizer has two kinds of alias:
                     // - group_by aggr
                     // - aggr expr
                     let group_size = group_expr.len();
-                    let alias_expr = out_group_expr_with_alias
+                    let alias_expr: Vec<_> = out_group_expr_with_alias
                         .into_iter()
                         .map(|(group_expr, original_field)| {
                             if let Some(name) = original_field {
@@ -271,7 +252,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
                             let idx = idx + group_size;
                             let (qualifier, field) = 
schema.qualified_field(idx);
                             let name = qualified_name(qualifier, field.name());
-                            columnize_expr(expr.clone().alias(name), 
&outer_aggr_schema)
+                            expr.clone().alias(name)
                         }))
                         .collect();
 
@@ -280,11 +261,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
                         outer_group_exprs,
                         outer_aggr_exprs,
                     )?);
-
-                    Ok(Some(LogicalPlan::Projection(Projection::try_new(
-                        alias_expr,
-                        Arc::new(outer_aggr),
-                    )?)))
+                    Ok(Some(project(outer_aggr, alias_expr)?))
                 } else {
                     Ok(None)
                 }
diff --git a/datafusion/sqllogictest/test_files/select.slt 
b/datafusion/sqllogictest/test_files/select.slt
index d73157570d..6b74156b52 100644
--- a/datafusion/sqllogictest/test_files/select.slt
+++ b/datafusion/sqllogictest/test_files/select.slt
@@ -1619,7 +1619,16 @@ select count(1) from v;
 query I
 select a + b from (select 1 as a, 2 as b, 1 as "a + b");
 ----
-1
+3
+
+# Can't reference an output column by expression over projection.
+query error DataFusion error: Schema error: No field named a\. Valid fields 
are "a \+ Int64\(1\)"\.
+select a + 1 from (select a+1 from (select 1 as a));
+
+query I
+select "a + Int64(1)" + 10 from (select a+1 from (select 1 as a));
+----
+12
 
 # run below query without logical optimizations
 statement ok


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to