This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new c8e1c84e6b Detect when filters make subqueries scalar (#8312)
c8e1c84e6b is described below

commit c8e1c84e6b4f1292afa6f5517bc6978b55758723
Author: Jesse <[email protected]>
AuthorDate: Wed Dec 6 23:33:53 2023 +0100

    Detect when filters make subqueries scalar (#8312)
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/common/src/functional_dependencies.rs |   8 ++
 datafusion/expr/src/logical_plan/plan.rs         | 141 ++++++++++++++++++++++-
 datafusion/sqllogictest/test_files/subquery.slt  |  18 +++
 3 files changed, 164 insertions(+), 3 deletions(-)

diff --git a/datafusion/common/src/functional_dependencies.rs 
b/datafusion/common/src/functional_dependencies.rs
index fbddcddab4..4587677e77 100644
--- a/datafusion/common/src/functional_dependencies.rs
+++ b/datafusion/common/src/functional_dependencies.rs
@@ -413,6 +413,14 @@ impl FunctionalDependencies {
     }
 }
 
+impl Deref for FunctionalDependencies {
+    type Target = [FunctionalDependence];
+
+    fn deref(&self) -> &Self::Target {
+        self.deps.as_slice()
+    }
+}
+
 /// Calculates functional dependencies for aggregate output, when there is a 
GROUP BY expression.
 pub fn aggregate_functional_dependencies(
     aggr_input_schema: &DFSchema,
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 2988e7536b..d85e0b5b0a 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -33,6 +33,7 @@ use crate::logical_plan::{DmlStatement, Statement};
 use crate::utils::{
     enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs,
     grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre,
+    split_conjunction,
 };
 use crate::{
     build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, 
CreateView, Expr,
@@ -47,7 +48,7 @@ use datafusion_common::tree_node::{
 };
 use datafusion_common::{
     aggregate_functional_dependencies, internal_err, plan_err, Column, 
Constraints,
-    DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies,
+    DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, 
FunctionalDependencies,
     OwnedTableReference, ParamValues, Result, UnnestOptions,
 };
 // backwards compatibility
@@ -1032,7 +1033,13 @@ impl LogicalPlan {
     pub fn max_rows(self: &LogicalPlan) -> Option<usize> {
         match self {
             LogicalPlan::Projection(Projection { input, .. }) => 
input.max_rows(),
-            LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(),
+            LogicalPlan::Filter(filter) => {
+                if filter.is_scalar() {
+                    Some(1)
+                } else {
+                    filter.input.max_rows()
+                }
+            }
             LogicalPlan::Window(Window { input, .. }) => input.max_rows(),
             LogicalPlan::Aggregate(Aggregate {
                 input, group_expr, ..
@@ -1913,6 +1920,73 @@ impl Filter {
 
         Ok(Self { predicate, input })
     }
+
+    /// Is this filter guaranteed to return 0 or 1 row in a given 
instantiation?
+    ///
+    /// This function will return `true` if its predicate contains a 
conjunction of
+    /// `col(a) = <expr>`, where its schema has a unique filter that is covered
+    /// by this conjunction.
+    ///
+    /// For example, for the table:
+    /// ```sql
+    /// CREATE TABLE t (a INTEGER PRIMARY KEY, b INTEGER);
+    /// ```
+    /// `Filter(a = 2).is_scalar() == true`
+    /// , whereas
+    /// `Filter(b = 2).is_scalar() == false`
+    /// and
+    /// `Filter(a = 2 OR b = 2).is_scalar() == false`
+    fn is_scalar(&self) -> bool {
+        let schema = self.input.schema();
+
+        let functional_dependencies = 
self.input.schema().functional_dependencies();
+        let unique_keys = functional_dependencies.iter().filter(|dep| {
+            let nullable = dep.nullable
+                && dep
+                    .source_indices
+                    .iter()
+                    .any(|&source| schema.field(source).is_nullable());
+            !nullable
+                && dep.mode == Dependency::Single
+                && dep.target_indices.len() == schema.fields().len()
+        });
+
+        let exprs = split_conjunction(&self.predicate);
+        let eq_pred_cols: HashSet<_> = exprs
+            .iter()
+            .filter_map(|expr| {
+                let Expr::BinaryExpr(BinaryExpr {
+                    left,
+                    op: Operator::Eq,
+                    right,
+                }) = expr
+                else {
+                    return None;
+                };
+                // This is a no-op filter expression
+                if left == right {
+                    return None;
+                }
+
+                match (left.as_ref(), right.as_ref()) {
+                    (Expr::Column(_), Expr::Column(_)) => None,
+                    (Expr::Column(c), _) | (_, Expr::Column(c)) => {
+                        Some(schema.index_of_column(c).unwrap())
+                    }
+                    _ => None,
+                }
+            })
+            .collect();
+
+        // If we have a functional dependence that is a subset of our 
predicate,
+        // this filter is scalar
+        for key in unique_keys {
+            if key.source_indices.iter().all(|c| eq_pred_cols.contains(c)) {
+                return true;
+            }
+        }
+        false
+    }
 }
 
 /// Window its input based on a set of window spec and window function (e.g. 
SUM or RANK)
@@ -2554,12 +2628,16 @@ pub struct Unnest {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::builder::LogicalTableSource;
     use crate::logical_plan::table_scan;
     use crate::{col, count, exists, in_subquery, lit, placeholder, 
GroupingSet};
     use arrow::datatypes::{DataType, Field, Schema};
     use datafusion_common::tree_node::TreeNodeVisitor;
-    use datafusion_common::{not_impl_err, DFSchema, ScalarValue, 
TableReference};
+    use datafusion_common::{
+        not_impl_err, Constraint, DFSchema, ScalarValue, TableReference,
+    };
     use std::collections::HashMap;
+    use std::sync::Arc;
 
     fn employee_schema() -> Schema {
         Schema::new(vec![
@@ -3056,6 +3134,63 @@ digraph {
             .is_nullable());
     }
 
+    #[test]
+    fn test_filter_is_scalar() {
+        // test empty placeholder
+        let schema =
+            Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, 
false)]));
+
+        let source = Arc::new(LogicalTableSource::new(schema));
+        let schema = Arc::new(
+            DFSchema::try_from_qualified_schema(
+                TableReference::bare("tab"),
+                &source.schema(),
+            )
+            .unwrap(),
+        );
+        let scan = Arc::new(LogicalPlan::TableScan(TableScan {
+            table_name: TableReference::bare("tab"),
+            source: source.clone(),
+            projection: None,
+            projected_schema: schema.clone(),
+            filters: vec![],
+            fetch: None,
+        }));
+        let col = schema.field(0).qualified_column();
+
+        let filter = Filter::try_new(
+            Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))),
+            scan,
+        )
+        .unwrap();
+        assert!(!filter.is_scalar());
+        let unique_schema =
+            Arc::new(schema.as_ref().clone().with_functional_dependencies(
+                FunctionalDependencies::new_from_constraints(
+                    Some(&Constraints::new_unverified(vec![Constraint::Unique(
+                        vec![0],
+                    )])),
+                    1,
+                ),
+            ));
+        let scan = Arc::new(LogicalPlan::TableScan(TableScan {
+            table_name: TableReference::bare("tab"),
+            source,
+            projection: None,
+            projected_schema: unique_schema.clone(),
+            filters: vec![],
+            fetch: None,
+        }));
+        let col = schema.field(0).qualified_column();
+
+        let filter = Filter::try_new(
+            Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))),
+            scan,
+        )
+        .unwrap();
+        assert!(filter.is_scalar());
+    }
+
     #[test]
     fn test_transform_explain() {
         let schema = Schema::new(vec![
diff --git a/datafusion/sqllogictest/test_files/subquery.slt 
b/datafusion/sqllogictest/test_files/subquery.slt
index 430e676fa4..3e0fcb7aa9 100644
--- a/datafusion/sqllogictest/test_files/subquery.slt
+++ b/datafusion/sqllogictest/test_files/subquery.slt
@@ -49,6 +49,13 @@ CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS 
VALUES
 (44, 'x', 3),
 (55, 'w', 3);
 
+statement ok
+CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES
+(11, 'e', 3),
+(22, 'f', 1),
+(44, 'g', 3),
+(55, 'h', 3);
+
 statement ok
 CREATE EXTERNAL TABLE IF NOT EXISTS customer (
         c_custkey BIGINT,
@@ -419,6 +426,17 @@ SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in 
(SELECT t2_int FROM t2
 statement error DataFusion error: check_analyzed_plan\ncaused by\nError during 
planning: Correlated scalar subquery must be aggregated to return at most one 
row
 SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int 
from t1
 
+#non_aggregated_correlated_scalar_subquery_unique
+query II rowsort
+SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from 
t1
+----
+11 3
+22 1
+33 NULL
+44 3
+
+
+#non_aggregated_correlated_scalar_subquery
 statement error DataFusion error: check_analyzed_plan\ncaused by\nError during 
planning: Correlated scalar subquery must be aggregated to return at most one 
row
 SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) 
as t2_int from t1
 

Reply via email to