neilconway commented on code in PR #21363:
URL: https://github.com/apache/datafusion/pull/21363#discussion_r3047256748


##########
datafusion/optimizer/src/decorrelate_predicate_subquery.rs:
##########
@@ -69,53 +70,113 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
             })?
             .data;
 
-        let LogicalPlan::Filter(filter) = plan else {
-            return Ok(Transformed::no(plan));
-        };
-
-        if !has_subquery(&filter.predicate) {
-            return Ok(Transformed::no(LogicalPlan::Filter(filter)));
-        }
+        match plan {
+            LogicalPlan::Filter(filter) => {
+                if !has_subquery(&filter.predicate) {
+                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
+                }
 
-        let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
-            split_conjunction_owned(filter.predicate)
-                .into_iter()
-                .partition(has_subquery);
+                let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
+                    split_conjunction_owned(filter.predicate)
+                        .into_iter()
+                        .partition(has_subquery);
+
+                assert_or_internal_err!(
+                    !with_subqueries.is_empty(),
+                    "can not find expected subqueries in 
DecorrelatePredicateSubquery"
+                );
+
+                // iterate through all exists clauses in predicate, turning 
each into a join
+                let mut cur_input = Arc::unwrap_or_clone(filter.input);
+                for subquery_expr in with_subqueries {
+                    match extract_subquery_info(subquery_expr) {
+                        // The subquery expression is at the top level of the 
filter
+                        SubqueryPredicate::Top(subquery) => {
+                            match build_join_top(
+                                &subquery,
+                                &cur_input,
+                                config.alias_generator(),
+                            )? {
+                                Some(plan) => cur_input = plan,
+                                // If the subquery can not be converted to a 
Join, reconstruct the subquery expression and add it to the Filter
+                                None => other_exprs.push(subquery.expr()),
+                            }
+                        }
+                        // The subquery expression is embedded within another 
expression
+                        SubqueryPredicate::Embedded(expr) => {
+                            let (plan, expr_without_subqueries) =
+                                rewrite_inner_subqueries(cur_input, expr, 
config)?;
+                            cur_input = plan;
+                            other_exprs.push(expr_without_subqueries);
+                        }
+                    }
+                }
 
-        assert_or_internal_err!(
-            !with_subqueries.is_empty(),
-            "can not find expected subqueries in DecorrelatePredicateSubquery"
-        );
+                let expr = conjunction(other_exprs);
+                if let Some(expr) = expr {
+                    let new_filter = Filter::try_new(expr, 
Arc::new(cur_input))?;
+                    cur_input = LogicalPlan::Filter(new_filter);
+                }
+                Ok(Transformed::yes(cur_input))
+            }
+            LogicalPlan::Projection(projection) => {
+                // Skip if no predicate subqueries in any projection expression
+                if !projection.expr.iter().any(has_subquery) {
+                    return 
Ok(Transformed::no(LogicalPlan::Projection(projection)));
+                }
 
-        // iterate through all exists clauses in predicate, turning each into 
a join
-        let mut cur_input = Arc::unwrap_or_clone(filter.input);
-        for subquery_expr in with_subqueries {
-            match extract_subquery_info(subquery_expr) {
-                // The subquery expression is at the top level of the filter
-                SubqueryPredicate::Top(subquery) => {
-                    match build_join_top(&subquery, &cur_input, 
config.alias_generator())?
-                    {
-                        Some(plan) => cur_input = plan,
-                        // If the subquery can not be converted to a Join, 
reconstruct the subquery expression and add it to the Filter
-                        None => other_exprs.push(subquery.expr()),
+                // Keep an Arc clone of the original input so we can 
reconstruct
+                // the Projection if decorrelation fails for any expression.
+                let original_input = Arc::clone(&projection.input);
+                let mut cur_input = Arc::unwrap_or_clone(projection.input);
+                let mut new_exprs = Vec::with_capacity(projection.expr.len());
+
+                for expr in &projection.expr {
+                    if has_subquery(expr) {
+                        let (plan, rewritten) =
+                            rewrite_inner_subqueries(cur_input, expr.clone(), 
config)?;
+                        cur_input = plan;
+                        new_exprs.push(rewritten);
+                    } else {
+                        new_exprs.push(expr.clone());
                     }
                 }
-                // The subquery expression is embedded within another 
expression
-                SubqueryPredicate::Embedded(expr) => {
-                    let (plan, expr_without_subqueries) =
-                        rewrite_inner_subqueries(cur_input, expr, config)?;
-                    cur_input = plan;
-                    other_exprs.push(expr_without_subqueries);
+
+                // If any expression still contains a subquery after rewriting,
+                // decorrelation failed — bail out and return the original plan
+                // unchanged (same pattern as ScalarSubqueryToJoin).
+                if new_exprs.iter().any(has_subquery) {
+                    let original = Projection::try_new_with_schema(
+                        projection.expr,
+                        original_input,
+                        projection.schema,
+                    )?;
+                    return 
Ok(Transformed::no(LogicalPlan::Projection(original)));
                 }

Review Comment:
   I wonder if it's a bit cleaner to check whether rewriting failed on a 
per-expr basis, so we can bail out early if some subqueries can be decorrelated 
but others can't. i.e.,
   
   ```
   for expr in &projection.expr {
       if has_subquery(expr) {
         let (plan, rewritten) = rewrite_inner_subqueries(...);
         if has_subquery(rewritten) { bail out }
       }
   }
   ```



##########
datafusion/optimizer/src/decorrelate_predicate_subquery.rs:
##########
@@ -2114,4 +2177,249 @@ mod tests {
         "
         )
     }
+
+    // -----------------------------------------------------------------------
+    // Tests for InSubquery / Exists in Projection expressions
+    // -----------------------------------------------------------------------
+
+    /// IN subquery inside CASE WHEN in a projection expression
+    #[test]
+    fn in_subquery_in_case_projection() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let sq = test_subquery_with_name("sq")?;
+
+        let case_expr =
+            when(in_subquery(col("c"), sq), lit("yes")).otherwise(lit("no"))?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a"), case_expr])?
+            .build()?;
+
+        assert_optimized_plan_equal!(
+            plan,
+            @r#"
+        Projection: test.a, CASE WHEN __correlated_sq_1.mark THEN Utf8("yes") 
ELSE Utf8("no") END AS CASE WHEN IN THEN Utf8("yes") ELSE Utf8("no") END 
[a:UInt32, CASE WHEN IN THEN Utf8("yes") ELSE Utf8("no") END:Utf8]
+          LeftMark Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, 
b:UInt32, c:UInt32, mark:Boolean]
+            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
+            Projection: __correlated_sq_1.c [c:UInt32]
+              SubqueryAlias: __correlated_sq_1 [c:UInt32]
+                Projection: sq.c [c:UInt32]
+                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
+        "#
+        )
+    }
+
+    /// EXISTS subquery inside CASE WHEN in a projection expression
+    #[test]
+    fn exists_in_case_projection() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let sq = Arc::new(
+            LogicalPlanBuilder::from(scan_tpch_table("orders"))
+                .filter(
+                    col("orders.o_custkey").eq(out_ref_col(DataType::UInt32, 
"test.a")),
+                )?
+                .project(vec![lit(1)])?
+                .build()?,
+        );
+
+        let case_expr =
+            when(exists(sq), lit("has_orders")).otherwise(lit("no_orders"))?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a"), case_expr])?
+            .build()?;
+
+        assert_optimized_plan_equal!(
+            plan,
+            @r#"
+        Projection: test.a, CASE WHEN __correlated_sq_1.mark THEN 
Utf8("has_orders") ELSE Utf8("no_orders") END AS CASE WHEN EXISTS THEN 
Utf8("has_orders") ELSE Utf8("no_orders") END [a:UInt32, CASE WHEN EXISTS THEN 
Utf8("has_orders") ELSE Utf8("no_orders") END:Utf8]
+          LeftMark Join:  Filter: __correlated_sq_1.o_custkey = test.a 
[a:UInt32, b:UInt32, c:UInt32, mark:Boolean]
+            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
+            Projection: __correlated_sq_1.o_custkey [o_custkey:Int64]
+              SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, 
o_custkey:Int64]
+                Projection: Int32(1), orders.o_custkey [Int32(1):Int32, 
o_custkey:Int64]
+                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]
+        "#
+        )
+    }
+
+    /// NOT IN subquery inside CASE WHEN in a projection expression
+    #[test]
+    fn not_in_subquery_in_case_projection() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let sq = test_subquery_with_name("sq")?;
+
+        let case_expr = when(not_in_subquery(col("c"), sq), lit("excluded"))
+            .otherwise(lit("included"))?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a"), case_expr])?
+            .build()?;
+
+        assert_optimized_plan_equal!(
+            plan,
+            @r#"
+        Projection: test.a, CASE WHEN NOT __correlated_sq_1.mark THEN 
Utf8("excluded") ELSE Utf8("included") END AS CASE WHEN NOT IN THEN 
Utf8("excluded") ELSE Utf8("included") END [a:UInt32, CASE WHEN NOT IN THEN 
Utf8("excluded") ELSE Utf8("included") END:Utf8]
+          LeftMark Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, 
b:UInt32, c:UInt32, mark:Boolean]
+            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
+            Projection: __correlated_sq_1.c [c:UInt32]
+              SubqueryAlias: __correlated_sq_1 [c:UInt32]
+                Projection: sq.c [c:UInt32]
+                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
+        "#
+        )
+    }
+
+    /// IN subquery as bare boolean in SELECT (no CASE wrapper)
+    #[test]
+    fn in_subquery_bare_bool_projection() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let sq = test_subquery_with_name("sq")?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a"), in_subquery(col("c"), sq)])?
+            .build()?;
+
+        assert_optimized_plan_equal!(
+            plan,
+            @r"
+        Projection: test.a, __correlated_sq_1.mark AS IN [a:UInt32, IN:Boolean]
+          LeftMark Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, 
b:UInt32, c:UInt32, mark:Boolean]
+            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
+            Projection: __correlated_sq_1.c [c:UInt32]
+              SubqueryAlias: __correlated_sq_1 [c:UInt32]
+                Projection: sq.c [c:UInt32]
+                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
+        "
+        )
+    }
+
+    /// Correlated IN subquery inside CASE WHEN in a projection expression
+    #[test]
+    fn correlated_in_subquery_in_case_projection() -> Result<()> {
+        let orders = Arc::new(
+            LogicalPlanBuilder::from(scan_tpch_table("orders"))
+                .filter(
+                    col("orders.o_custkey")
+                        .eq(out_ref_col(DataType::Int64, 
"customer.c_custkey")),
+                )?
+                .project(vec![col("orders.o_custkey")])?
+                .build()?,
+        );
+
+        let case_expr = when(
+            in_subquery(col("customer.c_custkey"), orders),
+            lit("active"),
+        )
+        .otherwise(lit("inactive"))?;
+
+        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+            .project(vec![col("customer.c_custkey"), case_expr])?
+            .build()?;
+
+        assert_optimized_plan_equal!(
+            plan,
+            @r#"
+        Projection: customer.c_custkey, CASE WHEN __correlated_sq_1.mark THEN 
Utf8("active") ELSE Utf8("inactive") END AS CASE WHEN IN THEN Utf8("active") 
ELSE Utf8("inactive") END [c_custkey:Int64, CASE WHEN IN THEN Utf8("active") 
ELSE Utf8("inactive") END:Utf8]
+          LeftMark Join:  Filter: customer.c_custkey = 
__correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, mark:Boolean]
+            TableScan: customer [c_custkey:Int64, c_name:Utf8]
+            Projection: __correlated_sq_1.o_custkey [o_custkey:Int64]
+              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
+                Projection: orders.o_custkey [o_custkey:Int64]
+                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]
+        "#
+        )
+    }
+
+    /// Multiple subqueries in one projection expression
+    #[test]
+    fn multiple_subqueries_in_one_projection_expr() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let sq1 = test_subquery_with_name("sq_1")?;
+        let sq2 = test_subquery_with_name("sq_2")?;
+
+        // CASE WHEN a IN (sq1) THEN 'a_match'
+        //      WHEN b IN (sq2) THEN 'b_match'
+        //      ELSE 'none' END
+        let case_expr = when(in_subquery(col("a"), sq1), lit("a_match"))
+            .when(in_subquery(col("b"), sq2), lit("b_match"))
+            .otherwise(lit("none"))?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a"), case_expr])?
+            .build()?;
+
+        assert_optimized_plan_equal!(
+            plan,
+            @r#"
+        Projection: test.a, CASE WHEN __correlated_sq_1.mark THEN 
Utf8("a_match") WHEN __correlated_sq_2.mark THEN Utf8("b_match") ELSE 
Utf8("none") END AS CASE WHEN IN THEN Utf8("a_match") WHEN IN THEN 
Utf8("b_match") ELSE Utf8("none") END [a:UInt32, CASE WHEN IN THEN 
Utf8("a_match") WHEN IN THEN Utf8("b_match") ELSE Utf8("none") END:Utf8]
+          LeftMark Join:  Filter: test.b = __correlated_sq_2.c [a:UInt32, 
b:UInt32, c:UInt32, mark:Boolean, mark:Boolean]
+            LeftMark Join:  Filter: test.a = __correlated_sq_1.c [a:UInt32, 
b:UInt32, c:UInt32, mark:Boolean]
+              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
+              Projection: __correlated_sq_1.c [c:UInt32]
+                SubqueryAlias: __correlated_sq_1 [c:UInt32]
+                  Projection: sq_1.c [c:UInt32]
+                    TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32]
+            Projection: __correlated_sq_2.c [c:UInt32]
+              SubqueryAlias: __correlated_sq_2 [c:UInt32]
+                Projection: sq_2.c [c:UInt32]
+                  TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]
+        "#
+        )
+    }
+
+    /// Projection with no subquery is not modified
+    #[test]
+    fn projection_without_subquery_unchanged() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a"), col("b")])?
+            .build()?;
+
+        assert_optimized_plan_equal!(
+            plan,
+            @r"
+        Projection: test.a, test.b [a:UInt32, b:UInt32]
+          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
+        "
+        )
+    }
+
+    /// When a correlated IN subquery inside a projection cannot be 
decorrelated

Review Comment:
   Might be useful to enhance this to include one decorrelatable-subquery and 
one that can't be decorrelated, to ensure we are checking the former isn't 
rewritten.



##########
datafusion/optimizer/src/decorrelate_predicate_subquery.rs:
##########
@@ -69,53 +70,113 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
             })?
             .data;
 
-        let LogicalPlan::Filter(filter) = plan else {
-            return Ok(Transformed::no(plan));
-        };
-
-        if !has_subquery(&filter.predicate) {
-            return Ok(Transformed::no(LogicalPlan::Filter(filter)));
-        }
+        match plan {
+            LogicalPlan::Filter(filter) => {
+                if !has_subquery(&filter.predicate) {
+                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
+                }
 
-        let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
-            split_conjunction_owned(filter.predicate)
-                .into_iter()
-                .partition(has_subquery);
+                let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
+                    split_conjunction_owned(filter.predicate)
+                        .into_iter()
+                        .partition(has_subquery);
+
+                assert_or_internal_err!(
+                    !with_subqueries.is_empty(),
+                    "can not find expected subqueries in 
DecorrelatePredicateSubquery"
+                );
+
+                // iterate through all exists clauses in predicate, turning 
each into a join
+                let mut cur_input = Arc::unwrap_or_clone(filter.input);
+                for subquery_expr in with_subqueries {
+                    match extract_subquery_info(subquery_expr) {
+                        // The subquery expression is at the top level of the 
filter
+                        SubqueryPredicate::Top(subquery) => {
+                            match build_join_top(
+                                &subquery,
+                                &cur_input,
+                                config.alias_generator(),
+                            )? {
+                                Some(plan) => cur_input = plan,
+                                // If the subquery can not be converted to a 
Join, reconstruct the subquery expression and add it to the Filter
+                                None => other_exprs.push(subquery.expr()),
+                            }
+                        }
+                        // The subquery expression is embedded within another 
expression
+                        SubqueryPredicate::Embedded(expr) => {
+                            let (plan, expr_without_subqueries) =
+                                rewrite_inner_subqueries(cur_input, expr, 
config)?;
+                            cur_input = plan;
+                            other_exprs.push(expr_without_subqueries);
+                        }
+                    }
+                }
 
-        assert_or_internal_err!(
-            !with_subqueries.is_empty(),
-            "can not find expected subqueries in DecorrelatePredicateSubquery"
-        );
+                let expr = conjunction(other_exprs);
+                if let Some(expr) = expr {
+                    let new_filter = Filter::try_new(expr, 
Arc::new(cur_input))?;
+                    cur_input = LogicalPlan::Filter(new_filter);
+                }
+                Ok(Transformed::yes(cur_input))
+            }
+            LogicalPlan::Projection(projection) => {
+                // Skip if no predicate subqueries in any projection expression

Review Comment:
   Nit: Add "Optimization: ..." to make it clear this isn't necessary for 
correctness.



##########
datafusion/optimizer/src/decorrelate_predicate_subquery.rs:
##########
@@ -69,53 +70,113 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
             })?
             .data;
 
-        let LogicalPlan::Filter(filter) = plan else {
-            return Ok(Transformed::no(plan));
-        };
-
-        if !has_subquery(&filter.predicate) {
-            return Ok(Transformed::no(LogicalPlan::Filter(filter)));
-        }
+        match plan {
+            LogicalPlan::Filter(filter) => {
+                if !has_subquery(&filter.predicate) {
+                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
+                }
 
-        let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
-            split_conjunction_owned(filter.predicate)
-                .into_iter()
-                .partition(has_subquery);
+                let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
+                    split_conjunction_owned(filter.predicate)
+                        .into_iter()
+                        .partition(has_subquery);
+
+                assert_or_internal_err!(
+                    !with_subqueries.is_empty(),
+                    "can not find expected subqueries in 
DecorrelatePredicateSubquery"
+                );
+
+                // iterate through all exists clauses in predicate, turning 
each into a join
+                let mut cur_input = Arc::unwrap_or_clone(filter.input);
+                for subquery_expr in with_subqueries {
+                    match extract_subquery_info(subquery_expr) {
+                        // The subquery expression is at the top level of the 
filter
+                        SubqueryPredicate::Top(subquery) => {
+                            match build_join_top(
+                                &subquery,
+                                &cur_input,
+                                config.alias_generator(),
+                            )? {
+                                Some(plan) => cur_input = plan,
+                                // If the subquery can not be converted to a 
Join, reconstruct the subquery expression and add it to the Filter
+                                None => other_exprs.push(subquery.expr()),
+                            }
+                        }
+                        // The subquery expression is embedded within another 
expression
+                        SubqueryPredicate::Embedded(expr) => {
+                            let (plan, expr_without_subqueries) =
+                                rewrite_inner_subqueries(cur_input, expr, 
config)?;
+                            cur_input = plan;
+                            other_exprs.push(expr_without_subqueries);
+                        }
+                    }
+                }
 
-        assert_or_internal_err!(
-            !with_subqueries.is_empty(),
-            "can not find expected subqueries in DecorrelatePredicateSubquery"
-        );
+                let expr = conjunction(other_exprs);
+                if let Some(expr) = expr {
+                    let new_filter = Filter::try_new(expr, 
Arc::new(cur_input))?;
+                    cur_input = LogicalPlan::Filter(new_filter);
+                }
+                Ok(Transformed::yes(cur_input))
+            }
+            LogicalPlan::Projection(projection) => {
+                // Skip if no predicate subqueries in any projection expression
+                if !projection.expr.iter().any(has_subquery) {
+                    return 
Ok(Transformed::no(LogicalPlan::Projection(projection)));
+                }
 
-        // iterate through all exists clauses in predicate, turning each into 
a join
-        let mut cur_input = Arc::unwrap_or_clone(filter.input);
-        for subquery_expr in with_subqueries {
-            match extract_subquery_info(subquery_expr) {
-                // The subquery expression is at the top level of the filter
-                SubqueryPredicate::Top(subquery) => {
-                    match build_join_top(&subquery, &cur_input, 
config.alias_generator())?
-                    {
-                        Some(plan) => cur_input = plan,
-                        // If the subquery can not be converted to a Join, 
reconstruct the subquery expression and add it to the Filter
-                        None => other_exprs.push(subquery.expr()),
+                // Keep an Arc clone of the original input so we can 
reconstruct
+                // the Projection if decorrelation fails for any expression.
+                let original_input = Arc::clone(&projection.input);
+                let mut cur_input = Arc::unwrap_or_clone(projection.input);

Review Comment:
   I think if we do
   
   ```
   let mut cur_input = projection.input.as_ref().clone();
   ```
   
   We can get rid of `original_input` and simplify the early-return 
(`ScalarSubqueryToJoin` code path does this).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to