This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new c42dc8212 Fix bug in subquery join filters referencing outer query 
(#2416)
c42dc8212 is described below

commit c42dc82126908f06155e9df9a7158db29f86b801
Author: Andy Grove <[email protected]>
AuthorDate: Tue May 3 11:59:47 2022 -0600

    Fix bug in subquery join filters referencing outer query (#2416)
---
 datafusion/core/src/logical_plan/expr_rewriter.rs |  2 +-
 datafusion/core/src/logical_plan/mod.rs           |  5 +-
 datafusion/core/src/sql/planner.rs                | 80 +++++++++++++++++++++--
 3 files changed, 78 insertions(+), 9 deletions(-)

diff --git a/datafusion/core/src/logical_plan/expr_rewriter.rs 
b/datafusion/core/src/logical_plan/expr_rewriter.rs
index 2c09d378e..4e9476899 100644
--- a/datafusion/core/src/logical_plan/expr_rewriter.rs
+++ b/datafusion/core/src/logical_plan/expr_rewriter.rs
@@ -361,7 +361,7 @@ pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> 
Result<Expr> {
 
 /// Recursively call [`Column::normalize_with_schemas`] on all Column 
expressions
 /// in the `expr` expression tree.
-fn normalize_col_with_schemas(
+pub fn normalize_col_with_schemas(
     expr: Expr,
     schemas: &[&Arc<DFSchema>],
     using_columns: &[HashSet<Column>],
diff --git a/datafusion/core/src/logical_plan/mod.rs 
b/datafusion/core/src/logical_plan/mod.rs
index cb30acf1a..55295e22e 100644
--- a/datafusion/core/src/logical_plan/mod.rs
+++ b/datafusion/core/src/logical_plan/mod.rs
@@ -51,8 +51,9 @@ pub use expr::{
     when, Column, Expr, ExprSchema, Literal,
 };
 pub use expr_rewriter::{
-    normalize_col, normalize_cols, replace_col, rewrite_sort_cols_by_aggs,
-    unnormalize_col, unnormalize_cols, ExprRewritable, ExprRewriter, 
RewriteRecursion,
+    normalize_col, normalize_col_with_schemas, normalize_cols, replace_col,
+    rewrite_sort_cols_by_aggs, unnormalize_col, unnormalize_cols, 
ExprRewritable,
+    ExprRewriter, RewriteRecursion,
 };
 pub use expr_simplier::{ExprSimplifiable, SimplifyInfo};
 pub use expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
diff --git a/datafusion/core/src/sql/planner.rs 
b/datafusion/core/src/sql/planner.rs
index ba1a551c9..cab6c774d 100644
--- a/datafusion/core/src/sql/planner.rs
+++ b/datafusion/core/src/sql/planner.rs
@@ -29,10 +29,10 @@ use crate::logical_plan::window_frames::{WindowFrame, 
WindowFrameUnits};
 use crate::logical_plan::Expr::Alias;
 use crate::logical_plan::{
     and, builder::expand_qualified_wildcard, builder::expand_wildcard, col, 
lit,
-    normalize_col, union_with_alias, Column, CreateCatalog, 
CreateCatalogSchema,
-    CreateExternalTable as PlanCreateExternalTable, CreateMemoryTable, 
DFSchema,
-    DFSchemaRef, DropTable, Expr, FileType, LogicalPlan, LogicalPlanBuilder, 
Operator,
-    PlanType, ToDFSchema, ToStringifiedPlan,
+    normalize_col, normalize_col_with_schemas, union_with_alias, Column, 
CreateCatalog,
+    CreateCatalogSchema, CreateExternalTable as PlanCreateExternalTable,
+    CreateMemoryTable, DFSchema, DFSchemaRef, DropTable, Expr, FileType, 
LogicalPlan,
+    LogicalPlanBuilder, Operator, PlanType, ToDFSchema, ToStringifiedPlan,
 };
 use crate::optimizer::utils::exprlist_to_columns;
 use crate::prelude::JoinType;
@@ -50,7 +50,7 @@ use datafusion_expr::{window_function::WindowFunction, 
BuiltinScalarFunction};
 use hashbrown::HashMap;
 
 use datafusion_common::field_not_found;
-use datafusion_expr::logical_plan::Subquery;
+use datafusion_expr::logical_plan::{Filter, Subquery};
 use sqlparser::ast::{
     BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, 
FunctionArg,
     FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, 
Query,
@@ -803,6 +803,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
 
                 let mut all_join_keys = HashSet::new();
 
+                let orig_plans = plans.clone();
                 let mut plans = plans.into_iter();
                 let mut left = plans.next().unwrap(); // have at least one plan
 
@@ -885,7 +886,33 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 // remove join expressions from filter
                 match remove_join_expressions(&filter_expr, &all_join_keys)? {
                     Some(filter_expr) => {
-                        
LogicalPlanBuilder::from(left).filter(filter_expr)?.build()
+                        // this logic is adapted from 
[`LogicalPlanBuilder::filter`] to take
+                        // the query outer schema into account so that joins 
in subqueries
+                        // can reference outer query fields.
+                        let mut all_schemas: Vec<DFSchemaRef> = vec![];
+                        for plan in orig_plans {
+                            for schema in plan.all_schemas() {
+                                all_schemas.push(schema.clone());
+                            }
+                        }
+                        if let Some(outer_query_schema) = outer_query_schema {
+                            
all_schemas.push(Arc::new(outer_query_schema.clone()));
+                        }
+                        let mut join_columns = HashSet::new();
+                        for (l, r) in &all_join_keys {
+                            join_columns.insert(l.clone());
+                            join_columns.insert(r.clone());
+                        }
+                        let x: Vec<&DFSchemaRef> = 
all_schemas.iter().collect();
+                        let filter_expr = normalize_col_with_schemas(
+                            filter_expr,
+                            x.as_slice(),
+                            &[join_columns],
+                        )?;
+                        Ok(LogicalPlan::Filter(Filter {
+                            predicate: filter_expr,
+                            input: Arc::new(left),
+                        }))
                     }
                     _ => Ok(left),
                 }
@@ -4244,6 +4271,18 @@ mod tests {
                     Field::new("t_date32", DataType::Date32, false),
                     Field::new("t_date64", DataType::Date64, false),
                 ])),
+                "j1" => Some(Schema::new(vec![
+                    Field::new("j1_id", DataType::Int32, false),
+                    Field::new("j1_string", DataType::Utf8, false),
+                ])),
+                "j2" => Some(Schema::new(vec![
+                    Field::new("j2_id", DataType::Int32, false),
+                    Field::new("j2_string", DataType::Utf8, false),
+                ])),
+                "j3" => Some(Schema::new(vec![
+                    Field::new("j3_id", DataType::Int32, false),
+                    Field::new("j3_string", DataType::Utf8, false),
+                ])),
                 "person" => Some(Schema::new(vec![
                     Field::new("id", DataType::UInt32, false),
                     Field::new("first_name", DataType::Utf8, false),
@@ -4518,6 +4557,35 @@ mod tests {
         quick_test(sql, &expected);
     }
 
+    #[test]
+    fn scalar_subquery_reference_outer_field() {
+        let sql = "SELECT j1_string, j2_string \
+        FROM j1, j2 \
+        WHERE j1_id = j2_id - 1 \
+        AND j2_id < (SELECT count(*) \
+            FROM j1, j3 \
+            WHERE j2_id = j1_id \
+            AND j1_id = j3_id)";
+
+        let subquery = "Subquery: Projection: #COUNT(UInt8(1))\
+            \n  Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
+            \n    Filter: #j2.j2_id = #j1.j1_id\
+            \n      Inner Join: #j1.j1_id = #j3.j3_id\
+            \n        TableScan: j1 projection=None\
+            \n        TableScan: j3 projection=None";
+
+        let expected = format!(
+            "Projection: #j1.j1_string, #j2.j2_string\
+            \n  Filter: #j1.j1_id = #j2.j2_id - Int64(1) AND #j2.j2_id < ({})\
+            \n    CrossJoin:\
+            \n      TableScan: j1 projection=None\
+            \n      TableScan: j2 projection=None",
+            subquery
+        );
+
+        quick_test(sql, &expected);
+    }
+
     #[tokio::test]
     async fn subquery_references_cte() {
         let sql = "WITH \

Reply via email to