This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 910029d40a fix: Queries similar to `count-bug` produce incorrect 
results (#15281)
910029d40a is described below

commit 910029d40a86aa2040f535ea3c49d746ff58eaed
Author: suibianwanwan <[email protected]>
AuthorDate: Fri Apr 4 08:18:40 2025 +0800

    fix: Queries similar to `count-bug` produce incorrect results (#15281)
    
    * fix: Queries similar to `count-bug` produce incorrect results
    
    * Add more test
---
 datafusion/optimizer/src/decorrelate.rs            |   5 +-
 .../optimizer/src/scalar_subquery_to_join.rs       | 216 +++++++++++++--------
 datafusion/optimizer/src/utils.rs                  |  72 ++++---
 datafusion/sqllogictest/test_files/subquery.slt    |  44 ++++-
 4 files changed, 224 insertions(+), 113 deletions(-)

diff --git a/datafusion/optimizer/src/decorrelate.rs 
b/datafusion/optimizer/src/decorrelate.rs
index 71ff863b51..418619c839 100644
--- a/datafusion/optimizer/src/decorrelate.rs
+++ b/datafusion/optimizer/src/decorrelate.rs
@@ -501,10 +501,7 @@ fn agg_exprs_evaluation_result_on_empty_batch(
         let info = 
SimplifyContext::new(&props).with_schema(Arc::clone(schema));
         let simplifier = ExprSimplifier::new(info);
         let result_expr = simplifier.simplify(result_expr)?;
-        if matches!(result_expr, Expr::Literal(ScalarValue::Int64(_))) {
-            expr_result_map_for_count_bug
-                .insert(e.schema_name().to_string(), result_expr);
-        }
+        expr_result_map_for_count_bug.insert(e.schema_name().to_string(), 
result_expr);
     }
     Ok(())
 }
diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs 
b/datafusion/optimizer/src/scalar_subquery_to_join.rs
index 499447861a..5c89bc29a5 100644
--- a/datafusion/optimizer/src/scalar_subquery_to_join.rs
+++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs
@@ -22,9 +22,10 @@ use std::sync::Arc;
 
 use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
 use crate::optimizer::ApplyOrder;
-use crate::utils::replace_qualified_name;
+use crate::utils::{evaluates_to_null, replace_qualified_name};
 use crate::{OptimizerConfig, OptimizerRule};
 
+use crate::analyzer::type_coercion::TypeCoercionRewriter;
 use datafusion_common::alias::AliasGenerator;
 use datafusion_common::tree_node::{
     Transformed, TransformedResult, TreeNode, TreeNodeRecursion, 
TreeNodeRewriter,
@@ -348,6 +349,10 @@ fn build_join(
     let mut computation_project_expr = HashMap::new();
     if let Some(expr_map) = collected_count_expr_map {
         for (name, result) in expr_map {
+            if evaluates_to_null(result.clone(), result.column_refs())? {
+                // If expr always returns null when column is null, skip 
processing
+                continue;
+            }
             let computer_expr = if let Some(filter) = 
&pull_up.pull_up_having_expr {
                 Expr::Case(expr::Case {
                     expr: None,
@@ -381,7 +386,11 @@ fn build_join(
                     )))),
                 })
             };
-            computation_project_expr.insert(name, computer_expr);
+            let mut expr_rewrite = TypeCoercionRewriter {
+                schema: new_plan.schema(),
+            };
+            computation_project_expr
+                .insert(name, computer_expr.rewrite(&mut 
expr_rewrite).data()?);
         }
     }
 
@@ -425,18 +434,18 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND 
Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
-        \n    Left Join:  Filter: __scalar_sq_2.o_custkey = customer.c_custkey 
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, 
o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
-        \n      Left Join:  Filter: __scalar_sq_1.o_custkey = 
customer.c_custkey [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
-        \n        TableScan: customer [c_custkey:Int64, c_name:Utf8]\
-        \n        SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, 
o_custkey:Int64]\
-        \n          Projection: max(orders.o_custkey), orders.o_custkey 
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
-        \n            Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
-        \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]\
-        \n      SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, 
o_custkey:Int64]\
-        \n        Projection: max(orders.o_custkey), orders.o_custkey 
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
-        \n          Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
-        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+            \n  Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND 
Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n    Left Join:  Filter: __scalar_sq_2.o_custkey = 
customer.c_custkey [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n      Left Join:  Filter: __scalar_sq_1.o_custkey = 
customer.c_custkey [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n        TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+            \n        SubqueryAlias: __scalar_sq_1 
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+            \n          Projection: max(orders.o_custkey), orders.o_custkey, 
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, 
__always_true:Boolean]\
+            \n            Aggregate: groupBy=[[orders.o_custkey, Boolean(true) 
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, 
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+            \n              TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+            \n      SubqueryAlias: __scalar_sq_2 
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+            \n        Projection: max(orders.o_custkey), orders.o_custkey, 
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, 
__always_true:Boolean]\
+            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) 
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, 
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+            \n            TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
         assert_multi_rules_optimized_plan_eq_display_indent(
             vec![Arc::new(ScalarSubqueryToJoin::new())],
             plan,
@@ -480,19 +489,19 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  Filter: customer.c_acctbal < 
__scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, 
sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\
-        \n    Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey 
[c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, 
o_custkey:Int64;N]\
-        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
-        \n      SubqueryAlias: __scalar_sq_1 
[sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\
-        \n        Projection: sum(orders.o_totalprice), orders.o_custkey 
[sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\
-        \n          Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, 
sum(orders.o_totalprice):Float64;N]\
-        \n            Filter: orders.o_totalprice < 
__scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N, 
sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\
-        \n              Left Join:  Filter: __scalar_sq_2.l_orderkey = 
orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, 
o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, 
l_orderkey:Int64;N]\
-        \n                TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
-        \n                SubqueryAlias: __scalar_sq_2 
[sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\
-        \n                  Projection: sum(lineitem.l_extendedprice), 
lineitem.l_orderkey [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\
-        \n                    Aggregate: groupBy=[[lineitem.l_orderkey]], 
aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, 
sum(lineitem.l_extendedprice):Float64;N]\
-        \n                      TableScan: lineitem [l_orderkey:Int64, 
l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, 
l_extendedprice:Float64]";
+            \n  Filter: customer.c_acctbal < 
__scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, 
sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n    Left Join:  Filter: __scalar_sq_1.o_custkey = 
customer.c_custkey [c_custkey:Int64, c_name:Utf8, 
sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+            \n      SubqueryAlias: __scalar_sq_1 
[sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]\
+            \n        Projection: sum(orders.o_totalprice), orders.o_custkey, 
__always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, 
__always_true:Boolean]\
+            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) 
AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, 
__always_true:Boolean, sum(orders.o_totalprice):Float64;N]\
+            \n            Filter: orders.o_totalprice < 
__scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N, 
sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, 
__always_true:Boolean;N]\
+            \n              Left Join:  Filter: __scalar_sq_2.l_orderkey = 
orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, 
o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, 
l_orderkey:Int64;N, __always_true:Boolean;N]\
+            \n                TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+            \n                SubqueryAlias: __scalar_sq_2 
[sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, 
__always_true:Boolean]\
+            \n                  Projection: sum(lineitem.l_extendedprice), 
lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, 
l_orderkey:Int64, __always_true:Boolean]\
+            \n                    Aggregate: groupBy=[[lineitem.l_orderkey, 
Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] 
[l_orderkey:Int64, __always_true:Boolean, 
sum(lineitem.l_extendedprice):Float64;N]\
+            \n                      TableScan: lineitem [l_orderkey:Int64, 
l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, 
l_extendedprice:Float64]";
         assert_multi_rules_optimized_plan_eq_display_indent(
             vec![Arc::new(ScalarSubqueryToJoin::new())],
             plan,
@@ -522,14 +531,14 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) 
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, 
o_custkey:Int64;N]\
-        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey 
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, 
o_custkey:Int64;N]\
-        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
-        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, 
o_custkey:Int64]\
-        \n        Projection: max(orders.o_custkey), orders.o_custkey 
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
-        \n          Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
-        \n            Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
-        \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+            \n  Filter: customer.c_custkey = 
__scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n    Left Join:  Filter: customer.c_custkey = 
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+            \n      SubqueryAlias: __scalar_sq_1 
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+            \n        Projection: max(orders.o_custkey), orders.o_custkey, 
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, 
__always_true:Boolean]\
+            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) 
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, 
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+            \n            Filter: orders.o_orderkey = Int32(1) 
[o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+            \n              TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
         assert_multi_rules_optimized_plan_eq_display_indent(
             vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -760,13 +769,56 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + 
Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + 
Int32(1):Int64;N, o_custkey:Int64;N]\
-        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey 
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, 
o_custkey:Int64;N]\
-        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
-        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + 
Int32(1):Int64;N, o_custkey:Int64]\
-        \n        Projection: max(orders.o_custkey) + Int32(1), 
orders.o_custkey [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\
-        \n          Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
-        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+            \n  Filter: customer.c_custkey = 
__scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, 
__always_true:Boolean;N]\
+            \n    Left Join:  Filter: customer.c_custkey = 
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + 
Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+            \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + 
Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+            \n        Projection: max(orders.o_custkey) + Int32(1), 
orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, 
o_custkey:Int64, __always_true:Boolean]\
+            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) 
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, 
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+            \n            TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+        assert_multi_rules_optimized_plan_eq_display_indent(
+            vec![Arc::new(ScalarSubqueryToJoin::new())],
+            plan,
+            expected,
+        );
+        Ok(())
+    }
+
+    /// Test for correlated scalar subquery with non-strong project
+    #[test]
+    fn scalar_subquery_with_non_strong_project() -> Result<()> {
+        let case = Expr::Case(expr::Case {
+            expr: None,
+            when_then_expr: vec![(
+                Box::new(col("max(orders.o_totalprice)")),
+                Box::new(lit("a")),
+            )],
+            else_expr: Some(Box::new(lit("b"))),
+        });
+
+        let sq = Arc::new(
+            LogicalPlanBuilder::from(scan_tpch_table("orders"))
+                .filter(
+                    out_ref_col(DataType::Int64, "customer.c_custkey")
+                        .eq(col("orders.o_custkey")),
+                )?
+                .aggregate(Vec::<Expr>::new(), 
vec![max(col("orders.o_totalprice"))])?
+                .project(vec![case])?
+                .build()?,
+        );
+
+        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+            .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])?
+            .build()?;
+
+        let expected = "Projection: customer.c_custkey, CASE WHEN 
__scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN 
Utf8(\"a\") ELSE Utf8(\"b\") END ELSE __scalar_sq_1.CASE WHEN 
max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END END AS CASE WHEN 
max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END 
[c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE 
Utf8(\"b\") END:Utf8;N]\
+            \n  Left Join:  Filter: customer.c_custkey = 
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN 
max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8;N, 
o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+            \n    SubqueryAlias: __scalar_sq_1 [CASE WHEN 
max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8, 
o_custkey:Int64, __always_true:Boolean]\
+            \n      Projection: CASE WHEN max(orders.o_totalprice) THEN 
Utf8(\"a\") ELSE Utf8(\"b\") END, orders.o_custkey, __always_true [CASE WHEN 
max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8, 
o_custkey:Int64, __always_true:Boolean]\
+            \n        Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS 
__always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, 
__always_true:Boolean, max(orders.o_totalprice):Float64;N]\
+            \n          TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
         assert_multi_rules_optimized_plan_eq_display_indent(
             vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -824,13 +876,13 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) 
AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
-        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey 
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, 
o_custkey:Int64;N]\
-        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
-        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, 
o_custkey:Int64]\
-        \n        Projection: max(orders.o_custkey), orders.o_custkey 
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
-        \n          Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
-        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+            \n  Filter: customer.c_custkey >= 
__scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) 
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, 
o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n    Left Join:  Filter: customer.c_custkey = 
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+            \n      SubqueryAlias: __scalar_sq_1 
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+            \n        Projection: max(orders.o_custkey), orders.o_custkey, 
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, 
__always_true:Boolean]\
+            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) 
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, 
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+            \n            TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
         assert_multi_rules_optimized_plan_eq_display_indent(
             vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -863,13 +915,13 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) 
AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
-        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey 
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, 
o_custkey:Int64;N]\
-        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
-        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, 
o_custkey:Int64]\
-        \n        Projection: max(orders.o_custkey), orders.o_custkey 
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
-        \n          Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
-        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+            \n  Filter: customer.c_custkey = 
__scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) 
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, 
o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n    Left Join:  Filter: customer.c_custkey = 
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+            \n      SubqueryAlias: __scalar_sq_1 
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+            \n        Projection: max(orders.o_custkey), orders.o_custkey, 
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, 
__always_true:Boolean]\
+            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) 
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, 
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+            \n            TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
         assert_multi_rules_optimized_plan_eq_display_indent(
             vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -903,13 +955,13 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) 
OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
-        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey 
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, 
o_custkey:Int64;N]\
-        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
-        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, 
o_custkey:Int64]\
-        \n        Projection: max(orders.o_custkey), orders.o_custkey 
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
-        \n          Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
-        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+            \n  Filter: customer.c_custkey = 
__scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) 
[c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, 
o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n    Left Join:  Filter: customer.c_custkey = 
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+            \n      SubqueryAlias: __scalar_sq_1 
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+            \n        Projection: max(orders.o_custkey), orders.o_custkey, 
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, 
__always_true:Boolean]\
+            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) 
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, 
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+            \n            TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
         assert_multi_rules_optimized_plan_eq_display_indent(
             vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -936,13 +988,13 @@ mod tests {
             .build()?;
 
         let expected = "Projection: test.c [c:UInt32]\
-        \n  Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, 
c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\
-        \n    Left Join:  Filter: test.a = __scalar_sq_1.a [a:UInt32, 
b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\
-        \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
-        \n      SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32]\
-        \n        Projection: min(sq.c), sq.a [min(sq.c):UInt32;N, a:UInt32]\
-        \n          Aggregate: groupBy=[[sq.a]], aggr=[[min(sq.c)]] [a:UInt32, 
min(sq.c):UInt32;N]\
-        \n            TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+            \n  Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, 
c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]\
+            \n    Left Join:  Filter: test.a = __scalar_sq_1.a [a:UInt32, 
b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]\
+            \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+            \n      SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, 
a:UInt32, __always_true:Boolean]\
+            \n        Projection: min(sq.c), sq.a, __always_true 
[min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]\
+            \n          Aggregate: groupBy=[[sq.a, Boolean(true) AS 
__always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, 
min(sq.c):UInt32;N]\
+            \n            TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
 
         assert_multi_rules_optimized_plan_eq_display_indent(
             vec![Arc::new(ScalarSubqueryToJoin::new())],
@@ -1051,18 +1103,18 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  Filter: customer.c_custkey BETWEEN 
__scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) 
[c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, 
o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
-        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_2.o_custkey 
[c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, 
o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
-        \n      Left Join:  Filter: customer.c_custkey = 
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, 
min(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
-        \n        TableScan: customer [c_custkey:Int64, c_name:Utf8]\
-        \n        SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, 
o_custkey:Int64]\
-        \n          Projection: min(orders.o_custkey), orders.o_custkey 
[min(orders.o_custkey):Int64;N, o_custkey:Int64]\
-        \n            Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, min(orders.o_custkey):Int64;N]\
-        \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]\
-        \n      SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, 
o_custkey:Int64]\
-        \n        Projection: max(orders.o_custkey), orders.o_custkey 
[max(orders.o_custkey):Int64;N, o_custkey:Int64]\
-        \n          Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
-        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]";
+            \n  Filter: customer.c_custkey BETWEEN 
__scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) 
[c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, 
o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, 
o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n    Left Join:  Filter: customer.c_custkey = 
__scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, 
min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, 
max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n      Left Join:  Filter: customer.c_custkey = 
__scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, 
min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
+            \n        TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+            \n        SubqueryAlias: __scalar_sq_1 
[min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+            \n          Projection: min(orders.o_custkey), orders.o_custkey, 
__always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, 
__always_true:Boolean]\
+            \n            Aggregate: groupBy=[[orders.o_custkey, Boolean(true) 
AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, 
__always_true:Boolean, min(orders.o_custkey):Int64;N]\
+            \n              TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+            \n      SubqueryAlias: __scalar_sq_2 
[max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
+            \n        Projection: max(orders.o_custkey), orders.o_custkey, 
__always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, 
__always_true:Boolean]\
+            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) 
AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, 
__always_true:Boolean, max(orders.o_custkey):Int64;N]\
+            \n            TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
         assert_multi_rules_optimized_plan_eq_display_indent(
             vec![Arc::new(ScalarSubqueryToJoin::new())],
diff --git a/datafusion/optimizer/src/utils.rs 
b/datafusion/optimizer/src/utils.rs
index c734d908f6..41c40ec06d 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -79,6 +79,50 @@ pub fn is_restrict_null_predicate<'a>(
         return Ok(true);
     }
 
+    // If result is single `true`, return false;
+    // If result is single `NULL` or `false`, return true;
+    Ok(
+        match evaluate_expr_with_null_column(predicate, 
join_cols_of_predicate)? {
+            ColumnarValue::Array(array) => {
+                if array.len() == 1 {
+                    let boolean_array = as_boolean_array(&array)?;
+                    boolean_array.is_null(0) || !boolean_array.value(0)
+                } else {
+                    false
+                }
+            }
+            ColumnarValue::Scalar(scalar) => matches!(
+                scalar,
+                ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
+            ),
+        },
+    )
+}
+
+/// Determines if an expression will always evaluate to null.
+/// `c0 + 8` return true
+/// `c0 IS NULL` return false
+/// `CASE WHEN c0 > 1 then 0 else 1` return false
+pub fn evaluates_to_null<'a>(
+    predicate: Expr,
+    null_columns: impl IntoIterator<Item = &'a Column>,
+) -> Result<bool> {
+    if matches!(predicate, Expr::Column(_)) {
+        return Ok(true);
+    }
+
+    Ok(
+        match evaluate_expr_with_null_column(predicate, null_columns)? {
+            ColumnarValue::Array(_) => false,
+            ColumnarValue::Scalar(scalar) => scalar.is_null(),
+        },
+    )
+}
+
+fn evaluate_expr_with_null_column<'a>(
+    predicate: Expr,
+    null_columns: impl IntoIterator<Item = &'a Column>,
+) -> Result<ColumnarValue> {
     static DUMMY_COL_NAME: &str = "?";
     let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, 
true)]);
     let input_schema = DFSchema::try_from(schema.clone())?;
@@ -87,37 +131,15 @@ pub fn is_restrict_null_predicate<'a>(
     let execution_props = ExecutionProps::default();
     let null_column = Column::from_name(DUMMY_COL_NAME);
 
-    let join_cols_to_replace = join_cols_of_predicate
+    let join_cols_to_replace = null_columns
         .into_iter()
         .map(|column| (column, &null_column))
         .collect::<HashMap<_, _>>();
 
     let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
     let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
-    let phys_expr =
-        create_physical_expr(&coerced_predicate, &input_schema, 
&execution_props)?;
-
-    let result_type = phys_expr.data_type(&schema)?;
-    if !matches!(&result_type, DataType::Boolean) {
-        return Ok(false);
-    }
-
-    // If result is single `true`, return false;
-    // If result is single `NULL` or `false`, return true;
-    Ok(match phys_expr.evaluate(&input_batch)? {
-        ColumnarValue::Array(array) => {
-            if array.len() == 1 {
-                let boolean_array = as_boolean_array(&array)?;
-                boolean_array.is_null(0) || !boolean_array.value(0)
-            } else {
-                false
-            }
-        }
-        ColumnarValue::Scalar(scalar) => matches!(
-            scalar,
-            ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
-        ),
-    })
+    create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?
+        .evaluate(&input_batch)
 }
 
 fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
diff --git a/datafusion/sqllogictest/test_files/subquery.slt 
b/datafusion/sqllogictest/test_files/subquery.slt
index aaccaaa43c..a0ac15b740 100644
--- a/datafusion/sqllogictest/test_files/subquery.slt
+++ b/datafusion/sqllogictest/test_files/subquery.slt
@@ -921,7 +921,7 @@ query TT
 explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE 
t2.t2_int = t1.t1_int having count(*) = 0) from t1
 ----
 logical_plan
-01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN 
Int64(2) WHEN __scalar_sq_1.count(Int64(1)) != Int64(0) THEN NULL ELSE 
__scalar_sq_1.cnt_plus_2 END AS cnt_plus_2
+01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN 
Int64(2) WHEN __scalar_sq_1.count(Int64(1)) != Int64(0) THEN Int64(NULL) ELSE 
__scalar_sq_1.cnt_plus_2 END AS cnt_plus_2
 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
 03)----TableScan: t1 projection=[t1_id, t1_int]
 04)----SubqueryAlias: __scalar_sq_1
@@ -995,7 +995,7 @@ select t1.t1_int from t1 where (
 ----
 logical_plan
 01)Projection: t1.t1_int
-02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN 
__scalar_sq_1.count(Int64(1)) != Int64(0) THEN NULL ELSE 
__scalar_sq_1.cnt_plus_two END = Int64(2)
+02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN 
__scalar_sq_1.count(Int64(1)) != Int64(0) THEN Int64(NULL) ELSE 
__scalar_sq_1.cnt_plus_two END = Int64(2)
 03)----Projection: t1.t1_int, __scalar_sq_1.cnt_plus_two, 
__scalar_sq_1.count(Int64(1)), __scalar_sq_1.__always_true
 04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int
 05)--------TableScan: t1 projection=[t1_int]
@@ -1049,6 +1049,46 @@ false
 true
 true
 
+query IT rowsort
+SELECT t1_id, (SELECT case when max(t2.t2_id) > 1 then 'a' else 'b' end FROM 
t2 WHERE t2.t2_int = t1.t1_int) x from t1
+----
+11 a
+22 b
+33 a
+44 b
+
+query IB rowsort
+SELECT t1_id, (SELECT max(t2.t2_id) is null FROM t2 WHERE t2.t2_int = 
t1.t1_int) x from t1
+----
+11 false
+22 true
+33 false
+44 true
+
+query TT
+explain SELECT t1_id, (SELECT max(t2.t2_id) is null FROM t2 WHERE t2.t2_int = 
t1.t1_int) x from t1
+----
+logical_plan
+01)Projection: t1.t1_id, __scalar_sq_1.__always_true IS NULL OR 
__scalar_sq_1.__always_true IS NOT NULL AND __scalar_sq_1.max(t2.t2_id) IS NULL 
AS x
+02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
+03)----TableScan: t1 projection=[t1_id, t1_int]
+04)----SubqueryAlias: __scalar_sq_1
+05)------Projection: max(t2.t2_id) IS NULL, t2.t2_int, Boolean(true) AS 
__always_true
+06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[max(t2.t2_id)]]
+07)----------TableScan: t2 projection=[t2_id, t2_int]
+
+query TT
+explain SELECT t1_id, (SELECT max(t2.t2_id) FROM t2 WHERE t2.t2_int = 
t1.t1_int) x from t1
+----
+logical_plan
+01)Projection: t1.t1_id, __scalar_sq_1.max(t2.t2_id) AS x
+02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
+03)----TableScan: t1 projection=[t1_id, t1_int]
+04)----SubqueryAlias: __scalar_sq_1
+05)------Projection: max(t2.t2_id), t2.t2_int
+06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[max(t2.t2_id)]]
+07)----------TableScan: t2 projection=[t2_id, t2_int]
+
 # in_subquery_to_join_with_correlated_outer_filter_disjunction
 query TT
 explain select t1.t1_id,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to