This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 910029d40a fix: Queries similar to `count-bug` produce incorrect
results (#15281)
910029d40a is described below
commit 910029d40a86aa2040f535ea3c49d746ff58eaed
Author: suibianwanwan <[email protected]>
AuthorDate: Fri Apr 4 08:18:40 2025 +0800
fix: Queries similar to `count-bug` produce incorrect results (#15281)
* fix: Queries similar to `count-bug` produce incorrect results
* Add more test
---
datafusion/optimizer/src/decorrelate.rs | 5 +-
.../optimizer/src/scalar_subquery_to_join.rs | 216 +++++++++++++--------
datafusion/optimizer/src/utils.rs | 72 ++++---
datafusion/sqllogictest/test_files/subquery.slt | 44 ++++-
4 files changed, 224 insertions(+), 113 deletions(-)
diff --git a/datafusion/optimizer/src/decorrelate.rs
b/datafusion/optimizer/src/decorrelate.rs
index 71ff863b51..418619c839 100644
--- a/datafusion/optimizer/src/decorrelate.rs
+++ b/datafusion/optimizer/src/decorrelate.rs
@@ -501,10 +501,7 @@ fn agg_exprs_evaluation_result_on_empty_batch(
let info =
SimplifyContext::new(&props).with_schema(Arc::clone(schema));
let simplifier = ExprSimplifier::new(info);
let result_expr = simplifier.simplify(result_expr)?;
- if matches!(result_expr, Expr::Literal(ScalarValue::Int64(_))) {
- expr_result_map_for_count_bug
- .insert(e.schema_name().to_string(), result_expr);
- }
+ expr_result_map_for_count_bug.insert(e.schema_name().to_string(),
result_expr);
}
Ok(())
}
diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs
b/datafusion/optimizer/src/scalar_subquery_to_join.rs
index 499447861a..5c89bc29a5 100644
--- a/datafusion/optimizer/src/scalar_subquery_to_join.rs
+++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs
@@ -22,9 +22,10 @@ use std::sync::Arc;
use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
use crate::optimizer::ApplyOrder;
-use crate::utils::replace_qualified_name;
+use crate::utils::{evaluates_to_null, replace_qualified_name};
use crate::{OptimizerConfig, OptimizerRule};
+use crate::analyzer::type_coercion::TypeCoercionRewriter;
use datafusion_common::alias::AliasGenerator;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
TreeNodeRewriter,
@@ -348,6 +349,10 @@ fn build_join(
let mut computation_project_expr = HashMap::new();
if let Some(expr_map) = collected_count_expr_map {
for (name, result) in expr_map {
+ if evaluates_to_null(result.clone(), result.column_refs())? {
+ // If expr always returns null when column is null, skip
processing
+ continue;
+ }
let computer_expr = if let Some(filter) =
&pull_up.pull_up_having_expr {
Expr::Case(expr::Case {
expr: None,
@@ -381,7 +386,11 @@ fn build_join(
)))),
})
};
- computation_project_expr.insert(name, computer_expr);
+ let mut expr_rewrite = TypeCoercionRewriter {
+ schema: new_plan.schema(),
+ };
+ computation_project_expr
+ .insert(name, computer_expr.rewrite(&mut
expr_rewrite).data()?);
}
}
@@ -425,18 +434,18 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND
Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
- \n Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N,
o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
- \n Left Join: Filter: __scalar_sq_1.o_custkey =
customer.c_custkey [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N,
o_custkey:Int64]\
- \n Projection: max(orders.o_custkey), orders.o_custkey
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
- \n Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N,
o_custkey:Int64]\
- \n Projection: max(orders.o_custkey), orders.o_custkey
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
- \n Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+ \n Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND
Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n Left Join: Filter: __scalar_sq_2.o_custkey =
customer.c_custkey [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n Left Join: Filter: __scalar_sq_1.o_custkey =
customer.c_custkey [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __scalar_sq_1
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+ \n Projection: max(orders.o_custkey), orders.o_custkey,
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64,
__always_true:Boolean]\
+ \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true)
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64,
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n SubqueryAlias: __scalar_sq_2
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+ \n Projection: max(orders.o_custkey), orders.o_custkey,
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64,
__always_true:Boolean]\
+ \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true)
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64,
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
@@ -480,19 +489,19 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n Filter: customer.c_acctbal <
__scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8,
sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\
- \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey
[c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N,
o_custkey:Int64;N]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n SubqueryAlias: __scalar_sq_1
[sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\
- \n Projection: sum(orders.o_totalprice), orders.o_custkey
[sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\
- \n Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64,
sum(orders.o_totalprice):Float64;N]\
- \n Filter: orders.o_totalprice <
__scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N,
sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\
- \n Left Join: Filter: __scalar_sq_2.l_orderkey =
orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8,
o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N,
l_orderkey:Int64;N]\
- \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n SubqueryAlias: __scalar_sq_2
[sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\
- \n Projection: sum(lineitem.l_extendedprice),
lineitem.l_orderkey [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\
- \n Aggregate: groupBy=[[lineitem.l_orderkey]],
aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64,
sum(lineitem.l_extendedprice):Float64;N]\
- \n TableScan: lineitem [l_orderkey:Int64,
l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64,
l_extendedprice:Float64]";
+ \n Filter: customer.c_acctbal <
__scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8,
sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n Left Join: Filter: __scalar_sq_1.o_custkey =
customer.c_custkey [c_custkey:Int64, c_name:Utf8,
sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __scalar_sq_1
[sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]\
+ \n Projection: sum(orders.o_totalprice), orders.o_custkey,
__always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64,
__always_true:Boolean]\
+ \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true)
AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64,
__always_true:Boolean, sum(orders.o_totalprice):Float64;N]\
+ \n Filter: orders.o_totalprice <
__scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N,
sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N,
__always_true:Boolean;N]\
+ \n Left Join: Filter: __scalar_sq_2.l_orderkey =
orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8,
o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N,
l_orderkey:Int64;N, __always_true:Boolean;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n SubqueryAlias: __scalar_sq_2
[sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64,
__always_true:Boolean]\
+ \n Projection: sum(lineitem.l_extendedprice),
lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N,
l_orderkey:Int64, __always_true:Boolean]\
+ \n Aggregate: groupBy=[[lineitem.l_orderkey,
Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]]
[l_orderkey:Int64, __always_true:Boolean,
sum(lineitem.l_extendedprice):Float64;N]\
+ \n TableScan: lineitem [l_orderkey:Int64,
l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64,
l_extendedprice:Float64]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
@@ -522,14 +531,14 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey)
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N,
o_custkey:Int64;N]\
- \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N,
o_custkey:Int64;N]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N,
o_custkey:Int64]\
- \n Projection: max(orders.o_custkey), orders.o_custkey
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
- \n Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
- \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+ \n Filter: customer.c_custkey =
__scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n Left Join: Filter: customer.c_custkey =
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __scalar_sq_1
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+ \n Projection: max(orders.o_custkey), orders.o_custkey,
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64,
__always_true:Boolean]\
+ \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true)
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64,
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+ \n Filter: orders.o_orderkey = Int32(1)
[o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -760,13 +769,56 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) +
Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) +
Int32(1):Int64;N, o_custkey:Int64;N]\
- \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N,
o_custkey:Int64;N]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) +
Int32(1):Int64;N, o_custkey:Int64]\
- \n Projection: max(orders.o_custkey) + Int32(1),
orders.o_custkey [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\
- \n Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+ \n Filter: customer.c_custkey =
__scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N,
__always_true:Boolean;N]\
+ \n Left Join: Filter: customer.c_custkey =
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) +
Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) +
Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+ \n Projection: max(orders.o_custkey) + Int32(1),
orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N,
o_custkey:Int64, __always_true:Boolean]\
+ \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true)
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64,
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_multi_rules_optimized_plan_eq_display_indent(
+ vec![Arc::new(ScalarSubqueryToJoin::new())],
+ plan,
+ expected,
+ );
+ Ok(())
+ }
+
+ /// Test for correlated scalar subquery with non-strong project
+ #[test]
+ fn scalar_subquery_with_non_strong_project() -> Result<()> {
+ let case = Expr::Case(expr::Case {
+ expr: None,
+ when_then_expr: vec![(
+ Box::new(col("max(orders.o_totalprice)")),
+ Box::new(lit("a")),
+ )],
+ else_expr: Some(Box::new(lit("b"))),
+ });
+
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ out_ref_col(DataType::Int64, "customer.c_custkey")
+ .eq(col("orders.o_custkey")),
+ )?
+ .aggregate(Vec::<Expr>::new(),
vec![max(col("orders.o_totalprice"))])?
+ .project(vec![case])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])?
+ .build()?;
+
+ let expected = "Projection: customer.c_custkey, CASE WHEN
__scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN
Utf8(\"a\") ELSE Utf8(\"b\") END ELSE __scalar_sq_1.CASE WHEN
max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END END AS CASE WHEN
max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END
[c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE
Utf8(\"b\") END:Utf8;N]\
+ \n Left Join: Filter: customer.c_custkey =
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN
max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8;N,
o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __scalar_sq_1 [CASE WHEN
max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8,
o_custkey:Int64, __always_true:Boolean]\
+ \n Projection: CASE WHEN max(orders.o_totalprice) THEN
Utf8(\"a\") ELSE Utf8(\"b\") END, orders.o_custkey, __always_true [CASE WHEN
max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8,
o_custkey:Int64, __always_true:Boolean]\
+ \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS
__always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64,
__always_true:Boolean, max(orders.o_totalprice):Float64;N]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -824,13 +876,13 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey)
AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
- \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N,
o_custkey:Int64;N]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N,
o_custkey:Int64]\
- \n Projection: max(orders.o_custkey), orders.o_custkey
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
- \n Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+ \n Filter: customer.c_custkey >=
__scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1)
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N,
o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n Left Join: Filter: customer.c_custkey =
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __scalar_sq_1
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+ \n Projection: max(orders.o_custkey), orders.o_custkey,
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64,
__always_true:Boolean]\
+ \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true)
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64,
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -863,13 +915,13 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey)
AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
- \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N,
o_custkey:Int64;N]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N,
o_custkey:Int64]\
- \n Projection: max(orders.o_custkey), orders.o_custkey
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
- \n Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+ \n Filter: customer.c_custkey =
__scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1)
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N,
o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n Left Join: Filter: customer.c_custkey =
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __scalar_sq_1
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+ \n Projection: max(orders.o_custkey), orders.o_custkey,
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64,
__always_true:Boolean]\
+ \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true)
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64,
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -903,13 +955,13 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey)
OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
- \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N,
o_custkey:Int64;N]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N,
o_custkey:Int64]\
- \n Projection: max(orders.o_custkey), orders.o_custkey
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
- \n Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+ \n Filter: customer.c_custkey =
__scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1)
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N,
o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n Left Join: Filter: customer.c_custkey =
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __scalar_sq_1
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+ \n Projection: max(orders.o_custkey), orders.o_custkey,
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64,
__always_true:Boolean]\
+ \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true)
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64,
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -936,13 +988,13 @@ mod tests {
.build()?;
let expected = "Projection: test.c [c:UInt32]\
- \n Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32,
c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\
- \n Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32,
b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
- \n SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32]\
- \n Projection: min(sq.c), sq.a [min(sq.c):UInt32;N, a:UInt32]\
- \n Aggregate: groupBy=[[sq.a]], aggr=[[min(sq.c)]] [a:UInt32,
min(sq.c):UInt32;N]\
- \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+ \n Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32,
c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]\
+ \n Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32,
b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N,
a:UInt32, __always_true:Boolean]\
+ \n Projection: min(sq.c), sq.a, __always_true
[min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]\
+ \n Aggregate: groupBy=[[sq.a, Boolean(true) AS
__always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean,
min(sq.c):UInt32;N]\
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -1051,18 +1103,18 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n Filter: customer.c_custkey BETWEEN
__scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey)
[c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N,
o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
- \n Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey
[c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N,
o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
- \n Left Join: Filter: customer.c_custkey =
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8,
min(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N,
o_custkey:Int64]\
- \n Projection: min(orders.o_custkey), orders.o_custkey
[min(orders.o_custkey):Int64;N, o_custkey:Int64]\
- \n Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, min(orders.o_custkey):Int64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N,
o_custkey:Int64]\
- \n Projection: max(orders.o_custkey), orders.o_custkey
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
- \n Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+ \n Filter: customer.c_custkey BETWEEN
__scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey)
[c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N,
o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N,
o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n Left Join: Filter: customer.c_custkey =
__scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8,
min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N,
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n Left Join: Filter: customer.c_custkey =
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8,
min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __scalar_sq_1
[min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+ \n Projection: min(orders.o_custkey), orders.o_custkey,
__always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64,
__always_true:Boolean]\
+ \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true)
AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64,
__always_true:Boolean, min(orders.o_custkey):Int64;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n SubqueryAlias: __scalar_sq_2
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+ \n Projection: max(orders.o_custkey), orders.o_custkey,
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64,
__always_true:Boolean]\
+ \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true)
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64,
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
diff --git a/datafusion/optimizer/src/utils.rs
b/datafusion/optimizer/src/utils.rs
index c734d908f6..41c40ec06d 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -79,6 +79,50 @@ pub fn is_restrict_null_predicate<'a>(
return Ok(true);
}
+ // If result is single `true`, return false;
+ // If result is single `NULL` or `false`, return true;
+ Ok(
+ match evaluate_expr_with_null_column(predicate,
join_cols_of_predicate)? {
+ ColumnarValue::Array(array) => {
+ if array.len() == 1 {
+ let boolean_array = as_boolean_array(&array)?;
+ boolean_array.is_null(0) || !boolean_array.value(0)
+ } else {
+ false
+ }
+ }
+ ColumnarValue::Scalar(scalar) => matches!(
+ scalar,
+ ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
+ ),
+ },
+ )
+}
+
+/// Determines if an expression will always evaluate to null.
+/// `c0 + 8` return true
+/// `c0 IS NULL` return false
+/// `CASE WHEN c0 > 1 then 0 else 1` return false
+pub fn evaluates_to_null<'a>(
+ predicate: Expr,
+ null_columns: impl IntoIterator<Item = &'a Column>,
+) -> Result<bool> {
+ if matches!(predicate, Expr::Column(_)) {
+ return Ok(true);
+ }
+
+ Ok(
+ match evaluate_expr_with_null_column(predicate, null_columns)? {
+ ColumnarValue::Array(_) => false,
+ ColumnarValue::Scalar(scalar) => scalar.is_null(),
+ },
+ )
+}
+
+fn evaluate_expr_with_null_column<'a>(
+ predicate: Expr,
+ null_columns: impl IntoIterator<Item = &'a Column>,
+) -> Result<ColumnarValue> {
static DUMMY_COL_NAME: &str = "?";
let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null,
true)]);
let input_schema = DFSchema::try_from(schema.clone())?;
@@ -87,37 +131,15 @@ pub fn is_restrict_null_predicate<'a>(
let execution_props = ExecutionProps::default();
let null_column = Column::from_name(DUMMY_COL_NAME);
- let join_cols_to_replace = join_cols_of_predicate
+ let join_cols_to_replace = null_columns
.into_iter()
.map(|column| (column, &null_column))
.collect::<HashMap<_, _>>();
let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
- let phys_expr =
- create_physical_expr(&coerced_predicate, &input_schema,
&execution_props)?;
-
- let result_type = phys_expr.data_type(&schema)?;
- if !matches!(&result_type, DataType::Boolean) {
- return Ok(false);
- }
-
- // If result is single `true`, return false;
- // If result is single `NULL` or `false`, return true;
- Ok(match phys_expr.evaluate(&input_batch)? {
- ColumnarValue::Array(array) => {
- if array.len() == 1 {
- let boolean_array = as_boolean_array(&array)?;
- boolean_array.is_null(0) || !boolean_array.value(0)
- } else {
- false
- }
- }
- ColumnarValue::Scalar(scalar) => matches!(
- scalar,
- ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
- ),
- })
+ create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?
+ .evaluate(&input_batch)
}
fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
diff --git a/datafusion/sqllogictest/test_files/subquery.slt
b/datafusion/sqllogictest/test_files/subquery.slt
index aaccaaa43c..a0ac15b740 100644
--- a/datafusion/sqllogictest/test_files/subquery.slt
+++ b/datafusion/sqllogictest/test_files/subquery.slt
@@ -921,7 +921,7 @@ query TT
explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE
t2.t2_int = t1.t1_int having count(*) = 0) from t1
----
logical_plan
-01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN
Int64(2) WHEN __scalar_sq_1.count(Int64(1)) != Int64(0) THEN NULL ELSE
__scalar_sq_1.cnt_plus_2 END AS cnt_plus_2
+01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN
Int64(2) WHEN __scalar_sq_1.count(Int64(1)) != Int64(0) THEN Int64(NULL) ELSE
__scalar_sq_1.cnt_plus_2 END AS cnt_plus_2
02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
03)----TableScan: t1 projection=[t1_id, t1_int]
04)----SubqueryAlias: __scalar_sq_1
@@ -995,7 +995,7 @@ select t1.t1_int from t1 where (
----
logical_plan
01)Projection: t1.t1_int
-02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN
__scalar_sq_1.count(Int64(1)) != Int64(0) THEN NULL ELSE
__scalar_sq_1.cnt_plus_two END = Int64(2)
+02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN
__scalar_sq_1.count(Int64(1)) != Int64(0) THEN Int64(NULL) ELSE
__scalar_sq_1.cnt_plus_two END = Int64(2)
03)----Projection: t1.t1_int, __scalar_sq_1.cnt_plus_two,
__scalar_sq_1.count(Int64(1)), __scalar_sq_1.__always_true
04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int
05)--------TableScan: t1 projection=[t1_int]
@@ -1049,6 +1049,46 @@ false
true
true
+query IT rowsort
+SELECT t1_id, (SELECT case when max(t2.t2_id) > 1 then 'a' else 'b' end FROM
t2 WHERE t2.t2_int = t1.t1_int) x from t1
+----
+11 a
+22 b
+33 a
+44 b
+
+query IB rowsort
+SELECT t1_id, (SELECT max(t2.t2_id) is null FROM t2 WHERE t2.t2_int =
t1.t1_int) x from t1
+----
+11 false
+22 true
+33 false
+44 true
+
+query TT
+explain SELECT t1_id, (SELECT max(t2.t2_id) is null FROM t2 WHERE t2.t2_int =
t1.t1_int) x from t1
+----
+logical_plan
+01)Projection: t1.t1_id, __scalar_sq_1.__always_true IS NULL OR
__scalar_sq_1.__always_true IS NOT NULL AND __scalar_sq_1.max(t2.t2_id) IS NULL
AS x
+02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
+03)----TableScan: t1 projection=[t1_id, t1_int]
+04)----SubqueryAlias: __scalar_sq_1
+05)------Projection: max(t2.t2_id) IS NULL, t2.t2_int, Boolean(true) AS
__always_true
+06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[max(t2.t2_id)]]
+07)----------TableScan: t2 projection=[t2_id, t2_int]
+
+query TT
+explain SELECT t1_id, (SELECT max(t2.t2_id) FROM t2 WHERE t2.t2_int =
t1.t1_int) x from t1
+----
+logical_plan
+01)Projection: t1.t1_id, __scalar_sq_1.max(t2.t2_id) AS x
+02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
+03)----TableScan: t1 projection=[t1_id, t1_int]
+04)----SubqueryAlias: __scalar_sq_1
+05)------Projection: max(t2.t2_id), t2.t2_int
+06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[max(t2.t2_id)]]
+07)----------TableScan: t2 projection=[t2_id, t2_int]
+
# in_subquery_to_join_with_correlated_outer_filter_disjunction
query TT
explain select t1.t1_id,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]