This is an automated email from the ASF dual-hosted git repository.

akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f5230ffc7 Bugfix: Add functional dependency check and aggregate 
try_new schema (#8584)
6f5230ffc7 is described below

commit 6f5230ffc77ec0151a7aa870808d2fb31e6146c7
Author: Mustafa Akur <[email protected]>
AuthorDate: Wed Dec 20 10:58:49 2023 +0300

    Bugfix: Add functional dependency check and aggregate try_new schema (#8584)
    
    * Add functional dependency check and aggregate try_new schema
    
    * Update comments, make implementation idiomatic
    
    * Use constraint during stream table initialization
---
 datafusion/common/src/dfschema.rs              | 16 +++++
 datafusion/core/src/datasource/stream.rs       |  3 +-
 datafusion/expr/src/utils.rs                   | 13 ++--
 datafusion/physical-plan/src/aggregates/mod.rs | 92 ++++++++++++++++++++++++--
 datafusion/sqllogictest/test_files/groupby.slt | 12 ++++
 5 files changed, 125 insertions(+), 11 deletions(-)

diff --git a/datafusion/common/src/dfschema.rs 
b/datafusion/common/src/dfschema.rs
index e06f947ad5..d6e4490cec 100644
--- a/datafusion/common/src/dfschema.rs
+++ b/datafusion/common/src/dfschema.rs
@@ -347,6 +347,22 @@ impl DFSchema {
             .collect()
     }
 
+    /// Find all fields indices having the given qualifier
+    pub fn fields_indices_with_qualified(
+        &self,
+        qualifier: &TableReference,
+    ) -> Vec<usize> {
+        self.fields
+            .iter()
+            .enumerate()
+            .filter_map(|(idx, field)| {
+                field
+                    .qualifier()
+                    .and_then(|q| q.eq(qualifier).then_some(idx))
+            })
+            .collect()
+    }
+
     /// Find all fields match the given name
     pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> {
         self.fields
diff --git a/datafusion/core/src/datasource/stream.rs 
b/datafusion/core/src/datasource/stream.rs
index b9b45a6c74..830cd7a07e 100644
--- a/datafusion/core/src/datasource/stream.rs
+++ b/datafusion/core/src/datasource/stream.rs
@@ -64,7 +64,8 @@ impl TableProviderFactory for StreamTableFactory {
             .with_encoding(encoding)
             .with_order(cmd.order_exprs.clone())
             .with_header(cmd.has_header)
-            .with_batch_size(state.config().batch_size());
+            .with_batch_size(state.config().batch_size())
+            .with_constraints(cmd.constraints.clone());
 
         Ok(Arc::new(StreamTable(Arc::new(config))))
     }
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index abdd7f5f57..09f4842c9e 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -32,6 +32,7 @@ use crate::{
 
 use arrow::datatypes::{DataType, TimeUnit};
 use datafusion_common::tree_node::{TreeNode, VisitRecursion};
+use datafusion_common::utils::get_at_indices;
 use datafusion_common::{
     internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, 
DFSchemaRef,
     DataFusionError, Result, ScalarValue, TableReference,
@@ -425,18 +426,18 @@ pub fn expand_qualified_wildcard(
     wildcard_options: Option<&WildcardAdditionalOptions>,
 ) -> Result<Vec<Expr>> {
     let qualifier = TableReference::from(qualifier);
-    let qualified_fields: Vec<DFField> = schema
-        .fields_with_qualified(&qualifier)
-        .into_iter()
-        .cloned()
-        .collect();
+    let qualified_indices = schema.fields_indices_with_qualified(&qualifier);
+    let projected_func_dependencies = schema
+        .functional_dependencies()
+        .project_functional_dependencies(&qualified_indices, 
qualified_indices.len());
+    let qualified_fields = get_at_indices(schema.fields(), 
&qualified_indices)?;
     if qualified_fields.is_empty() {
         return plan_err!("Invalid qualifier {qualifier}");
     }
     let qualified_schema =
         DFSchema::new_with_metadata(qualified_fields, 
schema.metadata().clone())?
             // We can use the functional dependencies as is, since it only 
stores indices:
-            
.with_functional_dependencies(schema.functional_dependencies().clone())?;
+            .with_functional_dependencies(projected_func_dependencies)?;
     let excluded_columns = if let Some(WildcardAdditionalOptions {
         opt_exclude,
         opt_except,
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs 
b/datafusion/physical-plan/src/aggregates/mod.rs
index c74c4ac0f8..921de96252 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -43,7 +43,7 @@ use datafusion_execution::TaskContext;
 use datafusion_expr::Accumulator;
 use datafusion_physical_expr::{
     aggregate::is_order_sensitive,
-    equivalence::collapse_lex_req,
+    equivalence::{collapse_lex_req, ProjectionMapping},
     expressions::{Column, Max, Min, UnKnownColumn},
     physical_exprs_contains, reverse_order_bys, AggregateExpr, 
EquivalenceProperties,
     LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, 
PhysicalSortRequirement,
@@ -59,7 +59,6 @@ mod topk;
 mod topk_stream;
 
 pub use datafusion_expr::AggregateFunction;
-use datafusion_physical_expr::equivalence::ProjectionMapping;
 pub use datafusion_physical_expr::expressions::create_aggregate_expr;
 
 /// Hash aggregate modes
@@ -464,7 +463,7 @@ impl AggregateExec {
     pub fn try_new(
         mode: AggregateMode,
         group_by: PhysicalGroupBy,
-        mut aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+        aggr_expr: Vec<Arc<dyn AggregateExpr>>,
         filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
         input: Arc<dyn ExecutionPlan>,
         input_schema: SchemaRef,
@@ -482,6 +481,37 @@ impl AggregateExec {
             group_by.expr.len(),
         ));
         let original_schema = Arc::new(original_schema);
+        AggregateExec::try_new_with_schema(
+            mode,
+            group_by,
+            aggr_expr,
+            filter_expr,
+            input,
+            input_schema,
+            schema,
+            original_schema,
+        )
+    }
+
+    /// Create a new hash aggregate execution plan with the given schema.
+    /// This constructor isn't part of the public API, it is used internally
+    /// by Datafusion to enforce schema consistency during when re-creating
+    /// `AggregateExec`s inside optimization rules. Schema field names of an
+    /// `AggregateExec` depends on the names of aggregate expressions. Since
+    /// a rule may re-write aggregate expressions (e.g. reverse them) during
+    /// initialization, field names may change inadvertently if one re-creates
+    /// the schema in such cases.
+    #[allow(clippy::too_many_arguments)]
+    fn try_new_with_schema(
+        mode: AggregateMode,
+        group_by: PhysicalGroupBy,
+        mut aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
+        input: Arc<dyn ExecutionPlan>,
+        input_schema: SchemaRef,
+        schema: SchemaRef,
+        original_schema: SchemaRef,
+    ) -> Result<Self> {
         // Reset ordering requirement to `None` if aggregator is not 
order-sensitive
         let mut order_by_expr = aggr_expr
             .iter()
@@ -858,13 +888,15 @@ impl ExecutionPlan for AggregateExec {
         self: Arc<Self>,
         children: Vec<Arc<dyn ExecutionPlan>>,
     ) -> Result<Arc<dyn ExecutionPlan>> {
-        let mut me = AggregateExec::try_new(
+        let mut me = AggregateExec::try_new_with_schema(
             self.mode,
             self.group_by.clone(),
             self.aggr_expr.clone(),
             self.filter_expr.clone(),
             children[0].clone(),
             self.input_schema.clone(),
+            self.schema.clone(),
+            self.original_schema.clone(),
         )?;
         me.limit = self.limit;
         Ok(Arc::new(me))
@@ -2162,4 +2194,56 @@ mod tests {
         assert_eq!(res, common_requirement);
         Ok(())
     }
+
+    #[test]
+    fn test_agg_exec_same_schema() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Float32, true),
+            Field::new("b", DataType::Float32, true),
+        ]));
+
+        let col_a = col("a", &schema)?;
+        let col_b = col("b", &schema)?;
+        let option_desc = SortOptions {
+            descending: true,
+            nulls_first: true,
+        };
+        let sort_expr = vec![PhysicalSortExpr {
+            expr: col_b.clone(),
+            options: option_desc,
+        }];
+        let sort_expr_reverse = reverse_order_bys(&sort_expr);
+        let groups = PhysicalGroupBy::new_single(vec![(col_a, 
"a".to_string())]);
+
+        let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![
+            Arc::new(FirstValue::new(
+                col_b.clone(),
+                "FIRST_VALUE(b)".to_string(),
+                DataType::Float64,
+                sort_expr_reverse.clone(),
+                vec![DataType::Float64],
+            )),
+            Arc::new(LastValue::new(
+                col_b.clone(),
+                "LAST_VALUE(b)".to_string(),
+                DataType::Float64,
+                sort_expr.clone(),
+                vec![DataType::Float64],
+            )),
+        ];
+        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 
1));
+        let aggregate_exec = Arc::new(AggregateExec::try_new(
+            AggregateMode::Partial,
+            groups,
+            aggregates.clone(),
+            vec![None, None],
+            blocking_exec.clone(),
+            schema,
+        )?);
+        let new_agg = aggregate_exec
+            .clone()
+            .with_new_children(vec![blocking_exec])?;
+        assert_eq!(new_agg.schema(), aggregate_exec.schema());
+        Ok(())
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/groupby.slt 
b/datafusion/sqllogictest/test_files/groupby.slt
index 44d30ba0b3..f1b6a57287 100644
--- a/datafusion/sqllogictest/test_files/groupby.slt
+++ b/datafusion/sqllogictest/test_files/groupby.slt
@@ -4280,3 +4280,15 @@ LIMIT 5
 2 0 0
 3 0 0
 4 0 1
+
+
+query ITIPTR rowsort
+SELECT r.*
+FROM sales_global_with_pk as l, sales_global_with_pk as r
+LIMIT 5
+----
+0 GRC 0 2022-01-01T06:00:00 EUR 30
+1 FRA 1 2022-01-01T08:00:00 EUR 50
+1 FRA 3 2022-01-02T12:00:00 EUR 200
+1 TUR 2 2022-01-01T11:30:00 TRY 75
+1 TUR 4 2022-01-03T10:00:00 TRY 100

Reply via email to