This is an automated email from the ASF dual-hosted git repository.

github-merge-queue[bot] pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new cab69a1d4a Fix correlated subquery empty defaults for regr_count and 
approx_distinct (#22319)
cab69a1d4a is described below

commit cab69a1d4aa8dab980e468e2ec8089ec66988fce
Author: Nathan Bezualem <[email protected]>
AuthorDate: Thu May 28 10:37:08 2026 -0400

    Fix correlated subquery empty defaults for regr_count and approx_distinct 
(#22319)
    
    ## Which issue does this PR close?
    
    - Closes #22317.
    
    ## Rationale for this change
    
    Correlated scalar subqueries with ungrouped aggregates are decorrelated
    into joins. For unmatched outer rows, the rewritten join naturally
    produces NULLs on the right side, so DataFusion has compensation logic
    for aggregates that should return a non-NULL value on empty input.
    
    That compensation previously special-cased `count` by name. As a result,
    other aggregates with non-NULL empty-input results, such as `regr_count`
    and `approx_distinct`, incorrectly returned NULL after decorrelation.
    
    ## What changes are included in this PR?
    
    This PR updates decorrelation to use each aggregate UDF's
    `default_value()` instead of hard-coding `count`.
    
    It also adds empty-input defaults for:
    
    - `regr_count`: `UInt64(0)`
    - `approx_distinct`: `UInt64(0)`
    
    Regression coverage is added for correlated scalar subqueries using
    these aggregates in projection expressions and filters.
    
    ## Are these changes tested?
    
    Yes.
    
    ```bash
    cargo fmt --all
    cargo test -p datafusion-sqllogictest --test sqllogictests -- subquery.slt
    ```
    
    ## Are there any user-facing changes?
    
    Yes. Queries using `regr_count` or `approx_distinct` in correlated
    scalar subqueries now return `0` for unmatched outer rows instead of
    `NULL`, matching the aggregate behavior on empty input.
    
    ---------
    
    Co-authored-by: Nathan Bezualem <[email protected]>
    Co-authored-by: nathanb9 <[email protected]>
---
 datafusion/core/tests/dataframe/mod.rs             |  2 +-
 .../functions-aggregate/src/approx_distinct.rs     |  8 +++
 datafusion/functions-aggregate/src/regr.rs         | 12 +++++
 datafusion/optimizer/src/decorrelate.rs            | 22 +++-----
 .../optimizer/src/scalar_subquery_to_join.rs       |  2 +-
 datafusion/sqllogictest/test_files/subquery.slt    | 62 ++++++++++++++++++++++
 6 files changed, 92 insertions(+), 16 deletions(-)

diff --git a/datafusion/core/tests/dataframe/mod.rs 
b/datafusion/core/tests/dataframe/mod.rs
index 6512d9b432..0ced83f7b9 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -1204,7 +1204,7 @@ async fn window_using_aggregates() -> Result<()> {
     
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
     | first_value | last_val | approx_distinct | approx_median | median | max 
| min  | c2 | c3   |
     
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
-    |             |          |                 |               |        |     
|      | 1  | -85  |
+    |             |          | 0               |               |        |     
|      | 1  | -85  |
     | -85         | -101     | 14              | -12.0         | -12.0  | 83  
| -101 | 4  | -54  |
     | -85         | -101     | 17              | -25.0         | -25.0  | 83  
| -101 | 5  | -31  |
     | -85         | -12      | 10              | -32.75        | -34.0  | 83  
| -85  | 3  | 13   |
diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs 
b/datafusion/functions-aggregate/src/approx_distinct.rs
index cc42b6c22b..306ec074d4 100644
--- a/datafusion/functions-aggregate/src/approx_distinct.rs
+++ b/datafusion/functions-aggregate/src/approx_distinct.rs
@@ -381,6 +381,14 @@ impl AggregateUDFImpl for ApproxDistinct {
         Ok(DataType::UInt64)
     }
 
+    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
+        Ok(ScalarValue::UInt64(Some(0)))
+    }
+
+    fn is_nullable(&self) -> bool {
+        false
+    }
+
     fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
         let data_type = args.input_fields[0].data_type();
         match data_type {
diff --git a/datafusion/functions-aggregate/src/regr.rs 
b/datafusion/functions-aggregate/src/regr.rs
index 3a68672abb..3d5bbf1eda 100644
--- a/datafusion/functions-aggregate/src/regr.rs
+++ b/datafusion/functions-aggregate/src/regr.rs
@@ -457,6 +457,18 @@ impl AggregateUDFImpl for Regr {
         }
     }
 
+    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
+        if self.regr_type == RegrType::Count {
+            Ok(ScalarValue::UInt64(Some(0)))
+        } else {
+            Ok(ScalarValue::Float64(None))
+        }
+    }
+
+    fn is_nullable(&self) -> bool {
+        self.regr_type != RegrType::Count
+    }
+
     fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
         Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
     }
diff --git a/datafusion/optimizer/src/decorrelate.rs 
b/datafusion/optimizer/src/decorrelate.rs
index 2a71205c64..9490af0e59 100644
--- a/datafusion/optimizer/src/decorrelate.rs
+++ b/datafusion/optimizer/src/decorrelate.rs
@@ -35,8 +35,8 @@ use datafusion_expr::utils::{
     collect_subquery_cols, conjunction, find_join_exprs, split_conjunction,
 };
 use datafusion_expr::{
-    BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, 
LogicalPlanBuilder,
-    Operator, expr, lit,
+    BinaryExpr, Cast, EmptyRelation, Expr, ExprSchemable, FetchType, 
LogicalPlan,
+    LogicalPlanBuilder, Operator, expr, lit,
 };
 
 /// This struct rewrite the sub query plan by pull up the correlated
@@ -512,18 +512,12 @@ fn agg_exprs_evaluation_result_on_empty_batch(
         let result_expr = e
             .clone()
             .transform_up(|expr| {
-                let new_expr = match expr {
-                    Expr::AggregateFunction(expr::AggregateFunction { func, .. 
}) => {
-                        if func.name() == "count" {
-                            Transformed::yes(Expr::Literal(
-                                ScalarValue::Int64(Some(0)),
-                                None,
-                            ))
-                        } else {
-                            Transformed::yes(Expr::Literal(ScalarValue::Null, 
None))
-                        }
-                    }
-                    _ => Transformed::no(expr),
+                let new_expr = if let Expr::AggregateFunction(agg) = &expr {
+                    let return_type = expr.get_type(schema.as_ref())?;
+                    let default_value = agg.func.default_value(&return_type)?;
+                    Transformed::yes(Expr::Literal(default_value, None))
+                } else {
+                    Transformed::no(expr)
                 };
                 Ok(new_expr)
             })
diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs 
b/datafusion/optimizer/src/scalar_subquery_to_join.rs
index fee430047a..27da19024c 100644
--- a/datafusion/optimizer/src/scalar_subquery_to_join.rs
+++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs
@@ -819,7 +819,7 @@ mod tests {
         assert_optimized_plan_equal!(
             plan,
             @r#"
-        Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true 
IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END 
ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE 
Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE 
Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN 
Utf8("a") ELSE Utf8("b") END:Utf8;N]
+        Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true 
IS NULL THEN CASE WHEN CAST(Float64(NULL) AS Boolean) THEN Utf8("a") ELSE 
Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN 
Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN 
Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN 
max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N]
           Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey 
[c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN 
Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N]
             TableScan: customer [c_custkey:Int64, c_name:Utf8]
             SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) 
THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
diff --git a/datafusion/sqllogictest/test_files/subquery.slt 
b/datafusion/sqllogictest/test_files/subquery.slt
index 25f124f217..dd195b0ff4 100644
--- a/datafusion/sqllogictest/test_files/subquery.slt
+++ b/datafusion/sqllogictest/test_files/subquery.slt
@@ -888,6 +888,68 @@ SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = 
t1.t1_int) as cnt from
 33 3
 44 0
 
+#correlated_scalar_subquery_non_count_agg_empty_defaults
+query III rowsort
+SELECT
+  t1_id,
+  (
+    SELECT regr_count(1.0, 1.0)
+    FROM t2
+    WHERE t2.t2_int = t1.t1_int
+  ) AS r,
+  (
+    SELECT approx_distinct(t2.t2_id)
+    FROM t2
+    WHERE t2.t2_int = t1.t1_int
+  ) AS d
+FROM t1
+----
+11 1 1
+22 0 0
+33 3 3
+44 0 0
+
+query II rowsort
+SELECT
+  t1_id,
+  (
+    SELECT regr_count(1.0, 1.0) + approx_distinct(t2.t2_id)
+    FROM t2
+    WHERE t2.t2_int = t1.t1_int
+  ) AS combined
+FROM t1
+----
+11 2
+22 0
+33 6
+44 0
+
+query I rowsort
+SELECT t1_id
+FROM t1
+WHERE
+  (
+    SELECT approx_distinct(t2.t2_id)
+    FROM t2
+    WHERE t2.t2_int = t1.t1_int
+  ) = 0
+----
+22
+44
+
+query I rowsort
+SELECT t1_id
+FROM t1
+WHERE
+  (
+    SELECT regr_count(1.0, 1.0)
+    FROM t2
+    WHERE t2.t2_int = t1.t1_int
+  ) = 0
+----
+22
+44
+
 #correlated_scalar_subquery_count_agg_with_alias
 query TT
 explain SELECT t1_id, (SELECT count(*) as _cnt FROM t2 WHERE t2.t2_int = 
t1.t1_int) as cnt from t1


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to