This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8a972c2fe87 [SPARK-45507][SQL] Correctness fix for nested correlated 
scalar subqueries with COUNT aggregates
8a972c2fe87 is described below

commit 8a972c2fe8730e41193986a273aa92b234e5beb8
Author: Andy Lam <andy....@databricks.com>
AuthorDate: Thu Oct 19 10:36:32 2023 +0800

    [SPARK-45507][SQL] Correctness fix for nested correlated scalar subqueries 
with COUNT aggregates
    
    ### What changes were proposed in this pull request?
    
    We want to use the count bug handling in `DecorrelateInnerQuery` to detect 
potential count bugs in scalar subqueries.  it It is always safe to use 
`DecorrelateInnerQuery` to handle count bugs, but for efficiency reasons, like 
for the common case of COUNT on top of the scalar subquery, we would like to 
avoid an extra left outer join. This PR therefore introduces a simple check to 
detect such cases before `decorrelate()` - if true, then don't do count bug 
handling in `decorrelate()`, an [...]
    
    ### Why are the changes needed?
    
    This PR fixes correctness issues for correlated scalar subqueries 
pertaining to the COUNT bug. Examples can be found in the JIRA ticket.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, results will change.
    
    ### How was this patch tested?
    
    Added SQL end-to-end tests in `count.sql`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43341 from andylam-db/multiple-count-bug.
    
    Authored-by: Andy Lam <andy....@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/optimizer/subquery.scala    |  44 +++++-
 .../org/apache/spark/sql/internal/SQLConf.scala    |   9 ++
 .../nested-scalar-subquery-count-bug.sql.out       | 166 +++++++++++++++++++++
 .../nested-scalar-subquery-count-bug.sql           |  34 +++++
 .../nested-scalar-subquery-count-bug.sql.out       | 125 ++++++++++++++++
 5 files changed, 372 insertions(+), 6 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 5b95ee1df1b..1f1a16e9093 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -426,17 +426,49 @@ object PullupCorrelatedPredicates extends 
Rule[LogicalPlan] with PredicateHelper
     plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
       case ScalarSubquery(sub, children, exprId, conditions, hint, 
mayHaveCountBugOld)
         if children.nonEmpty =>
-        val (newPlan, newCond) = decorrelate(sub, plan)
-        val mayHaveCountBug = if (mayHaveCountBugOld.isEmpty) {
+
+        def mayHaveCountBugAgg(a: Aggregate): Boolean = {
+          a.groupingExpressions.isEmpty && 
a.aggregateExpressions.exists(_.exists {
+            case a: AggregateExpression => 
a.aggregateFunction.defaultResult.isDefined
+            case _ => false
+          })
+        }
+
+        // The below logic controls handling count bug for scalar subqueries in
+        // [[DecorrelateInnerQuery]], and if we don't handle it here, we 
handle it in
+        // [[RewriteCorrelatedScalarSubquery#constructLeftJoins]]. Note that 
handling it in
+        // [[DecorrelateInnerQuery]] is always correct, and turning it off to 
handle it in
+        // constructLeftJoins is an optimization, so that additional, 
redundant left outer joins are
+        // not introduced.
+        val handleCountBugInDecorrelate = 
SQLConf.get.decorrelateInnerQueryEnabled &&
+          !conf.getConf(SQLConf.LEGACY_SCALAR_SUBQUERY_COUNT_BUG_HANDLING) && 
!(sub match {
+          // Handle count bug only if there exists lower level Aggs with count 
bugs. It does not
+          // matter if the top level agg is count bug vulnerable or not, 
because:
+          // 1. If the top level agg is count bug vulnerable, it can be 
handled in
+          // constructLeftJoins, unless there are lower aggs that are count 
bug vulnerable.
+          // E.g. COUNT(COUNT + COUNT)
+          // 2. If the top level agg is not count bug vulnerable, it can be 
count bug vulnerable if
+          // there are lower aggs that are count bug vulnerable. E.g. 
SUM(COUNT)
+          case agg: Aggregate => !agg.child.exists {
+            case lowerAgg: Aggregate => mayHaveCountBugAgg(lowerAgg)
+            case _ => false
+          }
+          case _ => false
+        })
+        val (newPlan, newCond) = decorrelate(sub, plan, 
handleCountBugInDecorrelate)
+        val mayHaveCountBug = if (mayHaveCountBugOld.isDefined) {
+          // For idempotency, we must save this variable the first time this 
rule is run, because
+          // decorrelation introduces a GROUP BY is if one wasn't already 
present.
+          mayHaveCountBugOld.get
+        } else if (handleCountBugInDecorrelate) {
+          // Count bug was already handled in the above decorrelate function 
call.
+          false
+        } else {
           // Check whether the pre-rewrite subquery had empty 
groupingExpressions. If yes, it may
           // be subject to the COUNT bug. If it has non-empty 
groupingExpressions, there is
           // no COUNT bug.
           val (topPart, havingNode, aggNode) = splitSubquery(sub)
           (aggNode.isDefined && aggNode.get.groupingExpressions.isEmpty)
-        } else {
-          // For idempotency, we must save this variable the first time this 
rule is run, because
-          // decorrelation introduces a GROUP BY is if one wasn't already 
present.
-          mayHaveCountBugOld.get
         }
         ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, 
conditions),
           hint, Some(mayHaveCountBug))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index a2401c4917c..e66eadaa914 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4507,6 +4507,15 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val LEGACY_SCALAR_SUBQUERY_COUNT_BUG_HANDLING =
+    buildConf("spark.sql.legacy.scalarSubqueryCountBugBehavior")
+      .internal()
+      .doc("When set to true, restores legacy behavior of potential incorrect 
count bug " +
+        "handling for scalar subqueries.")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(false)
+
   /**
    * Holds information about keys that have been deprecated.
    *
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out
new file mode 100644
index 00000000000..aec952887db
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out
@@ -0,0 +1,166 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2)
+-- !query analysis
+CreateViewCommand `spark_catalog`.`default`.`t1`, [(a1,None), (a2,None)], 
values (0, 1), (1, 2), false, true, PersistedView, true
+   +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3)
+-- !query analysis
+CreateViewCommand `spark_catalog`.`default`.`t2`, [(b1,None), (b2,None)], 
values (0, 2), (0, 3), false, true, PersistedView, true
+   +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3)
+-- !query analysis
+CreateViewCommand `spark_catalog`.`default`.`t3`, [(c1,None), (c2,None)], 
values (0, 2), (0, 3), false, true, PersistedView, true
+   +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+set spark.sql.optimizer.decorrelateInnerQuery.enabled=true
+-- !query analysis
+SetCommand (spark.sql.optimizer.decorrelateInnerQuery.enabled,Some(true))
+
+
+-- !query
+set spark.sql.legacy.scalarSubqueryCountBugBehavior=false
+-- !query analysis
+SetCommand (spark.sql.legacy.scalarSubqueryCountBugBehavior,Some(false))
+
+
+-- !query
+select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = 
t2.b1) ) a from t1 order by a desc
+-- !query analysis
+Sort [a#xL DESC NULLS LAST], true
++- Project [scalar-subquery#x [a1#x] AS a#xL]
+   :  +- Aggregate [sum(cnt#xL) AS sum(cnt)#xL]
+   :     +- SubqueryAlias __auto_generated_subquery_name
+   :        +- Aggregate [count(1) AS cnt#xL]
+   :           +- Filter (outer(a1#x) = b1#x)
+   :              +- SubqueryAlias spark_catalog.default.t2
+   :                 +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x])
+   :                    +- Project [cast(col1#x as int) AS b1#x, cast(col2#x 
as int) AS b2#x]
+   :                       +- LocalRelation [col1#x, col2#x]
+   +- SubqueryAlias spark_catalog.default.t1
+      +- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+         +- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+            +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = 
t2.b1) ) a from t1 order by a desc
+-- !query analysis
+Sort [a#xL DESC NULLS LAST], true
++- Project [scalar-subquery#x [a1#x] AS a#xL]
+   :  +- Aggregate [count(1) AS count(1)#xL]
+   :     +- SubqueryAlias __auto_generated_subquery_name
+   :        +- Aggregate [count(1) AS cnt#xL]
+   :           +- Filter (outer(a1#x) = b1#x)
+   :              +- SubqueryAlias spark_catalog.default.t2
+   :                 +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x])
+   :                    +- Project [cast(col1#x as int) AS b1#x, cast(col2#x 
as int) AS b2#x]
+   :                       +- LocalRelation [col1#x, col2#x]
+   +- SubqueryAlias spark_catalog.default.t1
+      +- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+         +- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+            +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+select (
+  select SUM(l.cnt + r.cnt)
+  from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l
+  join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r
+  on l.cnt = r.cnt
+) a from t1 order by a desc
+-- !query analysis
+Sort [a#xL DESC NULLS LAST], true
++- Project [scalar-subquery#x [a1#x && a1#x] AS a#xL]
+   :  +- Aggregate [sum((cnt#xL + cnt#xL)) AS sum((cnt + cnt))#xL]
+   :     +- Join Inner, (cnt#xL = cnt#xL)
+   :        :- SubqueryAlias l
+   :        :  +- Filter (cnt#xL = cast(0 as bigint))
+   :        :     +- Aggregate [count(1) AS cnt#xL]
+   :        :        +- Filter (outer(a1#x) = b1#x)
+   :        :           +- SubqueryAlias spark_catalog.default.t2
+   :        :              +- View (`spark_catalog`.`default`.`t2`, 
[b1#x,b2#x])
+   :        :                 +- Project [cast(col1#x as int) AS b1#x, 
cast(col2#x as int) AS b2#x]
+   :        :                    +- LocalRelation [col1#x, col2#x]
+   :        +- SubqueryAlias r
+   :           +- Filter (cnt#xL = cast(0 as bigint))
+   :              +- Aggregate [count(1) AS cnt#xL]
+   :                 +- Filter (outer(a1#x) = c1#x)
+   :                    +- SubqueryAlias spark_catalog.default.t3
+   :                       +- View (`spark_catalog`.`default`.`t3`, 
[c1#x,c2#x])
+   :                          +- Project [cast(col1#x as int) AS c1#x, 
cast(col2#x as int) AS c2#x]
+   :                             +- LocalRelation [col1#x, col2#x]
+   +- SubqueryAlias spark_catalog.default.t1
+      +- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+         +- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+            +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+select (
+  select sum(l.cnt + r.cnt)
+  from (select count(*) cnt from t2 where t1.a1 = t2.b1) l
+  join (select count(*) cnt from t3 where t1.a1 = t3.c1) r
+  on l.cnt = r.cnt
+) a from t1 order by a desc
+-- !query analysis
+Sort [a#xL DESC NULLS LAST], true
++- Project [scalar-subquery#x [a1#x && a1#x] AS a#xL]
+   :  +- Aggregate [sum((cnt#xL + cnt#xL)) AS sum((cnt + cnt))#xL]
+   :     +- Join Inner, (cnt#xL = cnt#xL)
+   :        :- SubqueryAlias l
+   :        :  +- Aggregate [count(1) AS cnt#xL]
+   :        :     +- Filter (outer(a1#x) = b1#x)
+   :        :        +- SubqueryAlias spark_catalog.default.t2
+   :        :           +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x])
+   :        :              +- Project [cast(col1#x as int) AS b1#x, 
cast(col2#x as int) AS b2#x]
+   :        :                 +- LocalRelation [col1#x, col2#x]
+   :        +- SubqueryAlias r
+   :           +- Aggregate [count(1) AS cnt#xL]
+   :              +- Filter (outer(a1#x) = c1#x)
+   :                 +- SubqueryAlias spark_catalog.default.t3
+   :                    +- View (`spark_catalog`.`default`.`t3`, [c1#x,c2#x])
+   :                       +- Project [cast(col1#x as int) AS c1#x, 
cast(col2#x as int) AS c2#x]
+   :                          +- LocalRelation [col1#x, col2#x]
+   +- SubqueryAlias spark_catalog.default.t1
+      +- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+         +- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+            +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+reset spark.sql.optimizer.decorrelateInnerQuery.enabled
+-- !query analysis
+ResetCommand spark.sql.optimizer.decorrelateInnerQuery.enabled
+
+
+-- !query
+reset spark.sql.legacy.scalarSubqueryCountBugBehavior
+-- !query analysis
+ResetCommand spark.sql.legacy.scalarSubqueryCountBugBehavior
+
+
+-- !query
+DROP VIEW t1
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`t1`, false, true, false
+
+
+-- !query
+DROP VIEW t2
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`t2`, false, true, false
+
+
+-- !query
+DROP VIEW t3
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`t3`, false, true, false
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql
new file mode 100644
index 00000000000..86476389a85
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql
@@ -0,0 +1,34 @@
+CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2);
+CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3);
+CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3);
+
+set spark.sql.optimizer.decorrelateInnerQuery.enabled=true;
+set spark.sql.legacy.scalarSubqueryCountBugBehavior=false;
+
+-- test for count bug in nested aggregates in correlated scalar subqueries
+select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = 
t2.b1) ) a from t1 order by a desc;
+
+-- test for count bug in nested counts in correlated scalar subqueries
+select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = 
t2.b1) ) a from t1 order by a desc;
+
+-- test for count bug in correlated scalar subqueries with nested aggregates 
with multiple counts
+select (
+  select SUM(l.cnt + r.cnt)
+  from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l
+  join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r
+  on l.cnt = r.cnt
+) a from t1 order by a desc;
+
+-- same as above, without HAVING clause
+select (
+  select sum(l.cnt + r.cnt)
+  from (select count(*) cnt from t2 where t1.a1 = t2.b1) l
+  join (select count(*) cnt from t3 where t1.a1 = t3.c1) r
+  on l.cnt = r.cnt
+) a from t1 order by a desc;
+
+reset spark.sql.optimizer.decorrelateInnerQuery.enabled;
+reset spark.sql.legacy.scalarSubqueryCountBugBehavior;
+DROP VIEW t1;
+DROP VIEW t2;
+DROP VIEW t3;
diff --git 
a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out
new file mode 100644
index 00000000000..c524d315baf
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out
@@ -0,0 +1,125 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+set spark.sql.optimizer.decorrelateInnerQuery.enabled=true
+-- !query schema
+struct<key:string,value:string>
+-- !query output
+spark.sql.optimizer.decorrelateInnerQuery.enabled      true
+
+
+-- !query
+set spark.sql.legacy.scalarSubqueryCountBugBehavior=false
+-- !query schema
+struct<key:string,value:string>
+-- !query output
+spark.sql.legacy.scalarSubqueryCountBugBehavior        false
+
+
+-- !query
+select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = 
t2.b1) ) a from t1 order by a desc
+-- !query schema
+struct<a:bigint>
+-- !query output
+2
+0
+
+
+-- !query
+select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = 
t2.b1) ) a from t1 order by a desc
+-- !query schema
+struct<a:bigint>
+-- !query output
+1
+1
+
+
+-- !query
+select (
+  select SUM(l.cnt + r.cnt)
+  from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l
+  join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r
+  on l.cnt = r.cnt
+) a from t1 order by a desc
+-- !query schema
+struct<a:bigint>
+-- !query output
+0
+NULL
+
+
+-- !query
+select (
+  select sum(l.cnt + r.cnt)
+  from (select count(*) cnt from t2 where t1.a1 = t2.b1) l
+  join (select count(*) cnt from t3 where t1.a1 = t3.c1) r
+  on l.cnt = r.cnt
+) a from t1 order by a desc
+-- !query schema
+struct<a:bigint>
+-- !query output
+4
+0
+
+
+-- !query
+reset spark.sql.optimizer.decorrelateInnerQuery.enabled
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+reset spark.sql.legacy.scalarSubqueryCountBugBehavior
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW t1
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW t2
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW t3
+-- !query schema
+struct<>
+-- !query output
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to