This is an automated email from the ASF dual-hosted git repository.
akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 6f5230ffc7 Bugfix: Add functional dependency check and aggregate
try_new schema (#8584)
6f5230ffc7 is described below
commit 6f5230ffc77ec0151a7aa870808d2fb31e6146c7
Author: Mustafa Akur <[email protected]>
AuthorDate: Wed Dec 20 10:58:49 2023 +0300
Bugfix: Add functional dependency check and aggregate try_new schema (#8584)
* Add functional dependency check and aggregate try_new schema
* Update comments, make implementation idiomatic
* Use constraint during stream table initialization
---
datafusion/common/src/dfschema.rs | 16 +++++
datafusion/core/src/datasource/stream.rs | 3 +-
datafusion/expr/src/utils.rs | 13 ++--
datafusion/physical-plan/src/aggregates/mod.rs | 92 ++++++++++++++++++++++++--
datafusion/sqllogictest/test_files/groupby.slt | 12 ++++
5 files changed, 125 insertions(+), 11 deletions(-)
diff --git a/datafusion/common/src/dfschema.rs
b/datafusion/common/src/dfschema.rs
index e06f947ad5..d6e4490cec 100644
--- a/datafusion/common/src/dfschema.rs
+++ b/datafusion/common/src/dfschema.rs
@@ -347,6 +347,22 @@ impl DFSchema {
.collect()
}
+ /// Find all fields indices having the given qualifier
+ pub fn fields_indices_with_qualified(
+ &self,
+ qualifier: &TableReference,
+ ) -> Vec<usize> {
+ self.fields
+ .iter()
+ .enumerate()
+ .filter_map(|(idx, field)| {
+ field
+ .qualifier()
+ .and_then(|q| q.eq(qualifier).then_some(idx))
+ })
+ .collect()
+ }
+
/// Find all fields match the given name
pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> {
self.fields
diff --git a/datafusion/core/src/datasource/stream.rs
b/datafusion/core/src/datasource/stream.rs
index b9b45a6c74..830cd7a07e 100644
--- a/datafusion/core/src/datasource/stream.rs
+++ b/datafusion/core/src/datasource/stream.rs
@@ -64,7 +64,8 @@ impl TableProviderFactory for StreamTableFactory {
.with_encoding(encoding)
.with_order(cmd.order_exprs.clone())
.with_header(cmd.has_header)
- .with_batch_size(state.config().batch_size());
+ .with_batch_size(state.config().batch_size())
+ .with_constraints(cmd.constraints.clone());
Ok(Arc::new(StreamTable(Arc::new(config))))
}
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index abdd7f5f57..09f4842c9e 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -32,6 +32,7 @@ use crate::{
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
+use datafusion_common::utils::get_at_indices;
use datafusion_common::{
internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema,
DFSchemaRef,
DataFusionError, Result, ScalarValue, TableReference,
@@ -425,18 +426,18 @@ pub fn expand_qualified_wildcard(
wildcard_options: Option<&WildcardAdditionalOptions>,
) -> Result<Vec<Expr>> {
let qualifier = TableReference::from(qualifier);
- let qualified_fields: Vec<DFField> = schema
- .fields_with_qualified(&qualifier)
- .into_iter()
- .cloned()
- .collect();
+ let qualified_indices = schema.fields_indices_with_qualified(&qualifier);
+ let projected_func_dependencies = schema
+ .functional_dependencies()
+ .project_functional_dependencies(&qualified_indices,
qualified_indices.len());
+ let qualified_fields = get_at_indices(schema.fields(),
&qualified_indices)?;
if qualified_fields.is_empty() {
return plan_err!("Invalid qualifier {qualifier}");
}
let qualified_schema =
DFSchema::new_with_metadata(qualified_fields,
schema.metadata().clone())?
// We can use the functional dependencies as is, since it only
stores indices:
-
.with_functional_dependencies(schema.functional_dependencies().clone())?;
+ .with_functional_dependencies(projected_func_dependencies)?;
let excluded_columns = if let Some(WildcardAdditionalOptions {
opt_exclude,
opt_except,
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs
b/datafusion/physical-plan/src/aggregates/mod.rs
index c74c4ac0f8..921de96252 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -43,7 +43,7 @@ use datafusion_execution::TaskContext;
use datafusion_expr::Accumulator;
use datafusion_physical_expr::{
aggregate::is_order_sensitive,
- equivalence::collapse_lex_req,
+ equivalence::{collapse_lex_req, ProjectionMapping},
expressions::{Column, Max, Min, UnKnownColumn},
physical_exprs_contains, reverse_order_bys, AggregateExpr,
EquivalenceProperties,
LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr,
PhysicalSortRequirement,
@@ -59,7 +59,6 @@ mod topk;
mod topk_stream;
pub use datafusion_expr::AggregateFunction;
-use datafusion_physical_expr::equivalence::ProjectionMapping;
pub use datafusion_physical_expr::expressions::create_aggregate_expr;
/// Hash aggregate modes
@@ -464,7 +463,7 @@ impl AggregateExec {
pub fn try_new(
mode: AggregateMode,
group_by: PhysicalGroupBy,
- mut aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+ aggr_expr: Vec<Arc<dyn AggregateExpr>>,
filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
@@ -482,6 +481,37 @@ impl AggregateExec {
group_by.expr.len(),
));
let original_schema = Arc::new(original_schema);
+ AggregateExec::try_new_with_schema(
+ mode,
+ group_by,
+ aggr_expr,
+ filter_expr,
+ input,
+ input_schema,
+ schema,
+ original_schema,
+ )
+ }
+
+ /// Create a new hash aggregate execution plan with the given schema.
+ /// This constructor isn't part of the public API, it is used internally
+ /// by Datafusion to enforce schema consistency during when re-creating
+ /// `AggregateExec`s inside optimization rules. Schema field names of an
+ /// `AggregateExec` depends on the names of aggregate expressions. Since
+ /// a rule may re-write aggregate expressions (e.g. reverse them) during
+ /// initialization, field names may change inadvertently if one re-creates
+ /// the schema in such cases.
+ #[allow(clippy::too_many_arguments)]
+ fn try_new_with_schema(
+ mode: AggregateMode,
+ group_by: PhysicalGroupBy,
+ mut aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+ filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
+ input: Arc<dyn ExecutionPlan>,
+ input_schema: SchemaRef,
+ schema: SchemaRef,
+ original_schema: SchemaRef,
+ ) -> Result<Self> {
// Reset ordering requirement to `None` if aggregator is not
order-sensitive
let mut order_by_expr = aggr_expr
.iter()
@@ -858,13 +888,15 @@ impl ExecutionPlan for AggregateExec {
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
- let mut me = AggregateExec::try_new(
+ let mut me = AggregateExec::try_new_with_schema(
self.mode,
self.group_by.clone(),
self.aggr_expr.clone(),
self.filter_expr.clone(),
children[0].clone(),
self.input_schema.clone(),
+ self.schema.clone(),
+ self.original_schema.clone(),
)?;
me.limit = self.limit;
Ok(Arc::new(me))
@@ -2162,4 +2194,56 @@ mod tests {
assert_eq!(res, common_requirement);
Ok(())
}
+
+ #[test]
+ fn test_agg_exec_same_schema() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Float32, true),
+ Field::new("b", DataType::Float32, true),
+ ]));
+
+ let col_a = col("a", &schema)?;
+ let col_b = col("b", &schema)?;
+ let option_desc = SortOptions {
+ descending: true,
+ nulls_first: true,
+ };
+ let sort_expr = vec![PhysicalSortExpr {
+ expr: col_b.clone(),
+ options: option_desc,
+ }];
+ let sort_expr_reverse = reverse_order_bys(&sort_expr);
+ let groups = PhysicalGroupBy::new_single(vec![(col_a,
"a".to_string())]);
+
+ let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![
+ Arc::new(FirstValue::new(
+ col_b.clone(),
+ "FIRST_VALUE(b)".to_string(),
+ DataType::Float64,
+ sort_expr_reverse.clone(),
+ vec![DataType::Float64],
+ )),
+ Arc::new(LastValue::new(
+ col_b.clone(),
+ "LAST_VALUE(b)".to_string(),
+ DataType::Float64,
+ sort_expr.clone(),
+ vec![DataType::Float64],
+ )),
+ ];
+ let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema),
1));
+ let aggregate_exec = Arc::new(AggregateExec::try_new(
+ AggregateMode::Partial,
+ groups,
+ aggregates.clone(),
+ vec![None, None],
+ blocking_exec.clone(),
+ schema,
+ )?);
+ let new_agg = aggregate_exec
+ .clone()
+ .with_new_children(vec![blocking_exec])?;
+ assert_eq!(new_agg.schema(), aggregate_exec.schema());
+ Ok(())
+ }
}
diff --git a/datafusion/sqllogictest/test_files/groupby.slt
b/datafusion/sqllogictest/test_files/groupby.slt
index 44d30ba0b3..f1b6a57287 100644
--- a/datafusion/sqllogictest/test_files/groupby.slt
+++ b/datafusion/sqllogictest/test_files/groupby.slt
@@ -4280,3 +4280,15 @@ LIMIT 5
2 0 0
3 0 0
4 0 1
+
+
+query ITIPTR rowsort
+SELECT r.*
+FROM sales_global_with_pk as l, sales_global_with_pk as r
+LIMIT 5
+----
+0 GRC 0 2022-01-01T06:00:00 EUR 30
+1 FRA 1 2022-01-01T08:00:00 EUR 50
+1 FRA 3 2022-01-02T12:00:00 EUR 200
+1 TUR 2 2022-01-01T11:30:00 TRY 75
+1 TUR 4 2022-01-03T10:00:00 TRY 100