This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new c42dc8212 Fix bug in subquery join filters referencing outer query
(#2416)
c42dc8212 is described below
commit c42dc82126908f06155e9df9a7158db29f86b801
Author: Andy Grove <[email protected]>
AuthorDate: Tue May 3 11:59:47 2022 -0600
Fix bug in subquery join filters referencing outer query (#2416)
---
datafusion/core/src/logical_plan/expr_rewriter.rs | 2 +-
datafusion/core/src/logical_plan/mod.rs | 5 +-
datafusion/core/src/sql/planner.rs | 80 +++++++++++++++++++++--
3 files changed, 78 insertions(+), 9 deletions(-)
diff --git a/datafusion/core/src/logical_plan/expr_rewriter.rs
b/datafusion/core/src/logical_plan/expr_rewriter.rs
index 2c09d378e..4e9476899 100644
--- a/datafusion/core/src/logical_plan/expr_rewriter.rs
+++ b/datafusion/core/src/logical_plan/expr_rewriter.rs
@@ -361,7 +361,7 @@ pub fn normalize_col(expr: Expr, plan: &LogicalPlan) ->
Result<Expr> {
/// Recursively call [`Column::normalize_with_schemas`] on all Column
expressions
/// in the `expr` expression tree.
-fn normalize_col_with_schemas(
+pub fn normalize_col_with_schemas(
expr: Expr,
schemas: &[&Arc<DFSchema>],
using_columns: &[HashSet<Column>],
diff --git a/datafusion/core/src/logical_plan/mod.rs
b/datafusion/core/src/logical_plan/mod.rs
index cb30acf1a..55295e22e 100644
--- a/datafusion/core/src/logical_plan/mod.rs
+++ b/datafusion/core/src/logical_plan/mod.rs
@@ -51,8 +51,9 @@ pub use expr::{
when, Column, Expr, ExprSchema, Literal,
};
pub use expr_rewriter::{
- normalize_col, normalize_cols, replace_col, rewrite_sort_cols_by_aggs,
- unnormalize_col, unnormalize_cols, ExprRewritable, ExprRewriter,
RewriteRecursion,
+ normalize_col, normalize_col_with_schemas, normalize_cols, replace_col,
+ rewrite_sort_cols_by_aggs, unnormalize_col, unnormalize_cols,
ExprRewritable,
+ ExprRewriter, RewriteRecursion,
};
pub use expr_simplier::{ExprSimplifiable, SimplifyInfo};
pub use expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
diff --git a/datafusion/core/src/sql/planner.rs
b/datafusion/core/src/sql/planner.rs
index ba1a551c9..cab6c774d 100644
--- a/datafusion/core/src/sql/planner.rs
+++ b/datafusion/core/src/sql/planner.rs
@@ -29,10 +29,10 @@ use crate::logical_plan::window_frames::{WindowFrame,
WindowFrameUnits};
use crate::logical_plan::Expr::Alias;
use crate::logical_plan::{
and, builder::expand_qualified_wildcard, builder::expand_wildcard, col,
lit,
- normalize_col, union_with_alias, Column, CreateCatalog,
CreateCatalogSchema,
- CreateExternalTable as PlanCreateExternalTable, CreateMemoryTable,
DFSchema,
- DFSchemaRef, DropTable, Expr, FileType, LogicalPlan, LogicalPlanBuilder,
Operator,
- PlanType, ToDFSchema, ToStringifiedPlan,
+ normalize_col, normalize_col_with_schemas, union_with_alias, Column,
CreateCatalog,
+ CreateCatalogSchema, CreateExternalTable as PlanCreateExternalTable,
+ CreateMemoryTable, DFSchema, DFSchemaRef, DropTable, Expr, FileType,
LogicalPlan,
+ LogicalPlanBuilder, Operator, PlanType, ToDFSchema, ToStringifiedPlan,
};
use crate::optimizer::utils::exprlist_to_columns;
use crate::prelude::JoinType;
@@ -50,7 +50,7 @@ use datafusion_expr::{window_function::WindowFunction,
BuiltinScalarFunction};
use hashbrown::HashMap;
use datafusion_common::field_not_found;
-use datafusion_expr::logical_plan::Subquery;
+use datafusion_expr::logical_plan::{Filter, Subquery};
use sqlparser::ast::{
BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr,
FunctionArg,
FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, ObjectName,
Query,
@@ -803,6 +803,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let mut all_join_keys = HashSet::new();
+ let orig_plans = plans.clone();
let mut plans = plans.into_iter();
let mut left = plans.next().unwrap(); // have at least one plan
@@ -885,7 +886,33 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// remove join expressions from filter
match remove_join_expressions(&filter_expr, &all_join_keys)? {
Some(filter_expr) => {
-
LogicalPlanBuilder::from(left).filter(filter_expr)?.build()
+ // this logic is adapted from
[`LogicalPlanBuilder::filter`] to take
+ // the query outer schema into account so that joins
in subqueries
+ // can reference outer query fields.
+ let mut all_schemas: Vec<DFSchemaRef> = vec![];
+ for plan in orig_plans {
+ for schema in plan.all_schemas() {
+ all_schemas.push(schema.clone());
+ }
+ }
+ if let Some(outer_query_schema) = outer_query_schema {
+
all_schemas.push(Arc::new(outer_query_schema.clone()));
+ }
+ let mut join_columns = HashSet::new();
+ for (l, r) in &all_join_keys {
+ join_columns.insert(l.clone());
+ join_columns.insert(r.clone());
+ }
+ let x: Vec<&DFSchemaRef> =
all_schemas.iter().collect();
+ let filter_expr = normalize_col_with_schemas(
+ filter_expr,
+ x.as_slice(),
+ &[join_columns],
+ )?;
+ Ok(LogicalPlan::Filter(Filter {
+ predicate: filter_expr,
+ input: Arc::new(left),
+ }))
}
_ => Ok(left),
}
@@ -4244,6 +4271,18 @@ mod tests {
Field::new("t_date32", DataType::Date32, false),
Field::new("t_date64", DataType::Date64, false),
])),
+ "j1" => Some(Schema::new(vec![
+ Field::new("j1_id", DataType::Int32, false),
+ Field::new("j1_string", DataType::Utf8, false),
+ ])),
+ "j2" => Some(Schema::new(vec![
+ Field::new("j2_id", DataType::Int32, false),
+ Field::new("j2_string", DataType::Utf8, false),
+ ])),
+ "j3" => Some(Schema::new(vec![
+ Field::new("j3_id", DataType::Int32, false),
+ Field::new("j3_string", DataType::Utf8, false),
+ ])),
"person" => Some(Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("first_name", DataType::Utf8, false),
@@ -4518,6 +4557,35 @@ mod tests {
quick_test(sql, &expected);
}
+ #[test]
+ fn scalar_subquery_reference_outer_field() {
+ let sql = "SELECT j1_string, j2_string \
+ FROM j1, j2 \
+ WHERE j1_id = j2_id - 1 \
+ AND j2_id < (SELECT count(*) \
+ FROM j1, j3 \
+ WHERE j2_id = j1_id \
+ AND j1_id = j3_id)";
+
+ let subquery = "Subquery: Projection: #COUNT(UInt8(1))\
+ \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
+ \n Filter: #j2.j2_id = #j1.j1_id\
+ \n Inner Join: #j1.j1_id = #j3.j3_id\
+ \n TableScan: j1 projection=None\
+ \n TableScan: j3 projection=None";
+
+ let expected = format!(
+ "Projection: #j1.j1_string, #j2.j2_string\
+ \n Filter: #j1.j1_id = #j2.j2_id - Int64(1) AND #j2.j2_id < ({})\
+ \n CrossJoin:\
+ \n TableScan: j1 projection=None\
+ \n TableScan: j2 projection=None",
+ subquery
+ );
+
+ quick_test(sql, &expected);
+ }
+
#[tokio::test]
async fn subquery_references_cte() {
let sql = "WITH \