This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new fb2f0db6c support for non-correlated subqueries (#3287)
fb2f0db6c is described below
commit fb2f0db6cd3777c1de342ce2b0d4233f69d549fb
Author: kmitchener <[email protected]>
AuthorDate: Tue Aug 30 16:51:35 2022 -0400
support for non-correlated subqueries (#3287)
---
.../optimizer/src/decorrelate_scalar_subquery.rs | 90 +++++++++++++++++-----
datafusion/optimizer/src/utils.rs | 4 +-
2 files changed, 73 insertions(+), 21 deletions(-)
diff --git a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs
b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs
index 561757dc8..1d6e5d533 100644
--- a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs
+++ b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs
@@ -101,16 +101,17 @@ impl OptimizerRule for DecorrelateScalarSubquery {
let (subqueries, other_exprs) =
self.extract_subquery_exprs(predicate, optimizer_config)?;
- let optimized_plan = LogicalPlan::Filter(Filter {
- predicate: predicate.clone(),
- input: Arc::new(optimized_input),
- });
+
if subqueries.is_empty() {
// regular filter, no subquery exists clause here
+ let optimized_plan = LogicalPlan::Filter(Filter {
+ predicate: predicate.clone(),
+ input: Arc::new(optimized_input),
+ });
return Ok(optimized_plan);
}
- // iterate through all exists clauses in predicate, turning
each into a join
+ // iterate through all subqueries in predicate, turning each
into a join
let mut cur_input = (**input).clone();
for subquery in subqueries {
cur_input = optimize_scalar(
@@ -136,22 +137,39 @@ impl OptimizerRule for DecorrelateScalarSubquery {
/// Takes a query like:
///
-/// ```select id from customers where balance >
+/// ```text
+/// select id from customers where balance >
/// (select avg(total) from orders where orders.c_id = customers.id)
/// ```
///
/// and optimizes it into:
///
-/// ```select c.id from customers c
+/// ```text
+/// select c.id from customers c
/// inner join (select c_id, avg(total) as val from orders group by c_id) o on
o.c_id = c.c_id
-/// where c.balance > o.val```
+/// where c.balance > o.val
+/// ```
+///
+/// Or a query like:
+///
+/// ```text
+/// select id from customers where balance >
+/// (select avg(total) from orders)
+/// ```
+///
+/// and optimizes it into:
+///
+/// ```text
+/// select c.id from customers c
+/// cross join (select avg(total) as val from orders) a
+/// where c.balance > a.val
+/// ```
///
/// # Arguments
///
-/// * `subqry` - The subquery portion of the `where exists` (select * from
orders)
-/// * `negated` - True if the subquery is a `where not exists`
+/// * `query_info` - The subquery portion of the `where` (select avg(total)
from orders)
/// * `filter_input` - The non-subquery portion (from customers)
-/// * `other_filter_exprs` - Any additional parts to the `where` expression
(and c.x = y)
+/// * `outer_others` - Any additional parts to the `where` expression (and c.x
= y)
/// * `optimizer_config` - Used to generate unique subquery aliases
fn optimize_scalar(
query_info: &SubqueryInfo,
@@ -173,20 +191,27 @@ fn optimize_scalar(
.map_err(|e| context!("Exactly one input is expected. Is this a
join?", e))?;
let aggr = Aggregate::try_from_plan(sub_input)
.map_err(|e| context!("scalar subqueries must aggregate a value", e))?;
- let filter = Filter::try_from_plan(&aggr.input).map_err(|e| {
- context!("scalar subqueries must have a filter to be correlated", e)
- })?;
+ let filter = Filter::try_from_plan(&aggr.input).ok();
- // split into filters
+ // if there were filters, we use that logical plan, otherwise the plan
from the aggregate
+ let input = if let Some(filter) = filter {
+ &filter.input
+ } else {
+ &aggr.input
+ };
+
+ // if there were filters, split and capture them
let mut subqry_filter_exprs = vec![];
- split_conjunction(&filter.predicate, &mut subqry_filter_exprs);
+ if let Some(filter) = filter {
+ split_conjunction(&filter.predicate, &mut subqry_filter_exprs);
+ }
verify_not_disjunction(&subqry_filter_exprs)?;
// Grab column names to join on
let (col_exprs, other_subqry_exprs) =
- find_join_exprs(subqry_filter_exprs, filter.input.schema())?;
+ find_join_exprs(subqry_filter_exprs, input.schema())?;
let (outer_cols, subqry_cols, join_filters) =
- exprs_to_join_cols(&col_exprs, filter.input.schema(), false)?;
+ exprs_to_join_cols(&col_exprs, input.schema(), false)?;
if join_filters.is_some() {
plan_err!("only joins on column equality are presently supported")?;
}
@@ -199,7 +224,7 @@ fn optimize_scalar(
.collect();
// build subquery side of join - the thing the subquery was querying
- let mut subqry_plan = LogicalPlanBuilder::from((*filter.input).clone());
+ let mut subqry_plan = LogicalPlanBuilder::from((**input).clone());
if let Some(expr) = combine_filters(&other_subqry_exprs) {
subqry_plan = subqry_plan.filter(expr)? // if the subquery had
additional expressions, restore them
}
@@ -702,4 +727,31 @@ mod tests {
assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan,
expected);
Ok(())
}
+
+ /// Test for non-correlated scalar subquery with no filters
+ #[test]
+ fn scalar_subquery_non_correlated_no_filters() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .aggregate(Vec::<Expr>::new(),
vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8,
__value:Int64;N]
+ CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #MAX(orders.o_custkey) AS __value, alias=__sq_1
[__value:Int64;N]
+ Aggregate: groupBy=[[]], aggr=[[MAX(#orders.o_custkey)]]
[MAX(orders.o_custkey):Int64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan,
expected);
+ Ok(())
+ }
}
diff --git a/datafusion/optimizer/src/utils.rs
b/datafusion/optimizer/src/utils.rs
index 41c75d689..d962dd7b4 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -125,7 +125,7 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr])
-> LogicalPlan {
/// # Arguments
///
/// * `exprs` - List of expressions that may or may not be joins
-/// * `fields` - HashSet of fully qualified (table.col) fields in subquery
schema
+/// * `schema` - HashSet of fully qualified (table.col) fields in subquery
schema
///
/// # Return value
///
@@ -191,7 +191,7 @@ pub fn find_join_exprs(
/// # Arguments
///
/// * `exprs` - List of expressions that correlate a subquery to an outer scope
-/// * `fields` - HashSet of fully qualified (table.col) fields in subquery
schema
+/// * `schema` - subquery schema
/// * `include_negated` - true if `NotEq` counts as a join operator
///
/// # Return value