This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3432fd8dba5 [SPARK-46468][SQL] Handle COUNT bug for EXISTS subqueries 
with Aggregate without grouping keys
3432fd8dba5 is described below

commit 3432fd8dba5bec623b14a4ec4306290eced6c93c
Author: Andrey Gubichev <andrey.gubic...@databricks.com>
AuthorDate: Fri Dec 22 09:32:22 2023 +0800

    [SPARK-46468][SQL] Handle COUNT bug for EXISTS subqueries with Aggregate 
without grouping keys
    
    ### What changes were proposed in this pull request?
    
    As Aggregates with no grouping keys always return 1 row (can be NULL), an 
EXISTs over such subquery should always return true.
    This reverts some changes done when we migrated EXISTS/IN to 
DecorrelateInnerQuery framework, in particular the static detection of 
potential count bug aggregates is removed (just having an empty grouping key 
should trigger the count bug treatment now; scalar subqueries still have extra 
checks that are evaluating the aggregate on an empty input). I suspect the same 
correctness problem was present in the legacy framework (added one test in the 
legacy section of exists-count-bug.sql)
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Query tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44451 from agubichev/SPARK-46468_count.
    
    Authored-by: Andrey Gubichev <andrey.gubic...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/optimizer/DecorrelateInnerQuery.scala | 22 +-------------
 .../exists-subquery/exists-aggregate.sql.out       | 29 ++++++++++++++++++
 .../exists-subquery/exists-count-bug.sql.out       | 34 ++++++++++++++++++++++
 .../subquery/exists-subquery/exists-aggregate.sql  |  9 ++++++
 .../subquery/exists-subquery/exists-count-bug.sql  |  5 ++++
 .../sql-tests/results/join-lateral.sql.out         |  1 +
 .../exists-subquery/exists-aggregate.sql.out       | 22 ++++++++++++++
 .../exists-subquery/exists-count-bug.sql.out       | 17 +++++++++++
 8 files changed, 118 insertions(+), 21 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
index feb01d1ce3f..eca392fd84c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
@@ -22,7 +22,6 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
@@ -462,22 +461,6 @@ object DecorrelateInnerQuery extends PredicateHelper {
       p.mapChildren(rewriteDomainJoins(outerPlan, _, conditions))
   }
 
-  private def isCountBugFree(aggregateExpressions: Seq[NamedExpression]): 
Boolean = {
-    // The COUNT bug only appears if an aggregate expression returns a 
non-NULL result on an empty
-    // input.
-    // Typical example (hence the name) is COUNT(*) that returns 0 from an 
empty result.
-    // However, SUM(x) IS NULL is another case that returns 0, and in general 
any IS/NOT IS and CASE
-    // expressions are suspect (and the combination of those).
-    // For now we conservatively accept only those expressions that are 
guaranteed to be safe.
-    aggregateExpressions.forall {
-      case _ : AttributeReference => true
-      case Alias(_: AttributeReference, _) => true
-      case Alias(_: Literal, _) => true
-      case Alias(a: AggregateExpression, _) if 
a.aggregateFunction.defaultResult == None => true
-      case _ => false
-    }
-  }
-
   def apply(
       innerPlan: LogicalPlan,
       outerPlan: LogicalPlan,
@@ -727,8 +710,6 @@ object DecorrelateInnerQuery extends PredicateHelper {
           case a @ Aggregate(groupingExpressions, aggregateExpressions, child) 
=>
             val outerReferences = collectOuterReferences(a.expressions)
             val newOuterReferences = parentOuterReferences ++ outerReferences
-            val countBugSusceptible = groupingExpressions.isEmpty &&
-              !isCountBugFree(aggregateExpressions)
             val (newChild, joinCond, outerReferenceMap) =
               decorrelate(child, newOuterReferences, aggregated = true, 
underSetOp)
             // Replace all outer references in grouping and aggregate 
expressions, and keep
@@ -791,8 +772,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
             // | 0 | 2    | true       | 2                              |
             // | 0 | null | null       | 0                              |  
<--- correct result
             // +---+------+------------+--------------------------------+
-            // TODO(a.gubichev): retire the 'handleCountBug' parameter.
-            if (countBugSusceptible && handleCountBug) {
+            if (groupingExpressions.isEmpty && handleCountBug) {
               // Evaluate the aggregate expressions with zero tuples.
               val resultMap = 
RewriteCorrelatedScalarSubquery.evalAggregateOnZeroTups(newAggregate)
               val alwaysTrue = Alias(Literal.TrueLiteral, "alwaysTrue")()
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-aggregate.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-aggregate.sql.out
index f026a330773..d486ff4fb03 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-aggregate.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-aggregate.sql.out
@@ -345,3 +345,32 @@ Project [emp_name#x, bonus_amt#x]
             +- Project [emp_name#x, bonus_amt#x]
                +- SubqueryAlias BONUS
                   +- LocalRelation [emp_name#x, bonus_amt#x]
+
+
+-- !query
+SELECT tt1.emp_name
+FROM EMP as tt1
+WHERE EXISTS (
+  select max(tt2.id)
+  from EMP as tt2
+  where tt1.emp_name is null
+)
+-- !query analysis
+Project [emp_name#x]
++- Filter exists#x [emp_name#x]
+   :  +- Aggregate [max(id#x) AS max(id)#x]
+   :     +- Filter isnull(outer(emp_name#x))
+   :        +- SubqueryAlias tt2
+   :           +- SubqueryAlias emp
+   :              +- View (`EMP`, 
[id#x,emp_name#x,hiredate#x,salary#x,dept_id#x])
+   :                 +- Project [cast(id#x as int) AS id#x, cast(emp_name#x as 
string) AS emp_name#x, cast(hiredate#x as date) AS hiredate#x, cast(salary#x as 
double) AS salary#x, cast(dept_id#x as int) AS dept_id#x]
+   :                    +- Project [id#x, emp_name#x, hiredate#x, salary#x, 
dept_id#x]
+   :                       +- SubqueryAlias EMP
+   :                          +- LocalRelation [id#x, emp_name#x, hiredate#x, 
salary#x, dept_id#x]
+   +- SubqueryAlias tt1
+      +- SubqueryAlias emp
+         +- View (`EMP`, [id#x,emp_name#x,hiredate#x,salary#x,dept_id#x])
+            +- Project [cast(id#x as int) AS id#x, cast(emp_name#x as string) 
AS emp_name#x, cast(hiredate#x as date) AS hiredate#x, cast(salary#x as double) 
AS salary#x, cast(dept_id#x as int) AS dept_id#x]
+               +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+                  +- SubqueryAlias EMP
+                     +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, 
dept_id#x]
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-count-bug.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-count-bug.sql.out
index b5e609dedd7..a4dc454572f 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-count-bug.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-count-bug.sql.out
@@ -143,6 +143,23 @@ Project [c1#x, c2#x]
             +- LocalRelation [col1#x, col2#x]
 
 
+-- !query
+select * from t1 where exists (select count(*) from t2 where t1.c1 = 100)
+-- !query analysis
+Project [c1#x, c2#x]
++- Filter exists#x [c1#x]
+   :  +- Aggregate [count(1) AS count(1)#xL]
+   :     +- Filter (outer(c1#x) = 100)
+   :        +- SubqueryAlias t2
+   :           +- View (`t2`, [c1#x,c2#x])
+   :              +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) 
AS c2#x]
+   :                 +- LocalRelation [col1#x, col2#x]
+   +- SubqueryAlias t1
+      +- View (`t1`, [c1#x,c2#x])
+         +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+            +- LocalRelation [col1#x, col2#x]
+
+
 -- !query
 set 
spark.sql.optimizer.decorrelateExistsSubqueryLegacyIncorrectCountHandling.enabled
 = true
 -- !query analysis
@@ -240,6 +257,23 @@ Project [c1#x, c2#x]
             +- LocalRelation [col1#x, col2#x]
 
 
+-- !query
+select * from t1 where exists (select count(*) from t2 where t1.c1 = 100)
+-- !query analysis
+Project [c1#x, c2#x]
++- Filter exists#x [c1#x]
+   :  +- Aggregate [count(1) AS count(1)#xL]
+   :     +- Filter (outer(c1#x) = 100)
+   :        +- SubqueryAlias t2
+   :           +- View (`t2`, [c1#x,c2#x])
+   :              +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) 
AS c2#x]
+   :                 +- LocalRelation [col1#x, col2#x]
+   +- SubqueryAlias t1
+      +- View (`t1`, [c1#x,c2#x])
+         +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+            +- LocalRelation [col1#x, col2#x]
+
+
 -- !query
 set 
spark.sql.optimizer.decorrelateExistsSubqueryLegacyIncorrectCountHandling.enabled
 = false
 -- !query analysis
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-aggregate.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-aggregate.sql
index 1c4ef982c66..17672f9738f 100644
--- 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-aggregate.sql
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-aggregate.sql
@@ -125,3 +125,12 @@ FROM BONUS
 WHERE EXISTS(SELECT RANK() OVER (PARTITION BY hiredate ORDER BY salary) AS s
                     FROM EMP, DEPT where EMP.dept_id = DEPT.dept_id
                         AND DEPT.dept_name < BONUS.emp_name);
+
+-- SPARK-46468: Aggregate always returns 1 row, so EXISTS is always true.
+SELECT tt1.emp_name
+FROM EMP as tt1
+WHERE EXISTS (
+  select max(tt2.id)
+  from EMP as tt2
+  where tt1.emp_name is null
+);
\ No newline at end of file
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-count-bug.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-count-bug.sql
index 1af7a235683..3075fef70ad 100644
--- 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-count-bug.sql
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-count-bug.sql
@@ -20,6 +20,9 @@ select * from t1 where
  not exists(select count(*) - 1 from t2 where t2.c1 = t1.c1)) AND
  exists(select count(*) from t2 where t2.c1 = t1.c2);
 
+select * from t1 where exists (select count(*) from t2 where t1.c1 = 100);
+
+
 -- With legacy behavior flag set, some answers are not correct.
 set 
spark.sql.optimizer.decorrelateExistsSubqueryLegacyIncorrectCountHandling.enabled
 = true;
 select * from t1 where exists (select count(*) from t2 where t2.c1 = t1.c1);
@@ -34,4 +37,6 @@ select * from t1 where
  exists(select count(*) + 1 from t2 where t2.c1 = t1.c1) OR
  not exists (select count(*) - 1 from t2 where t2.c1 = t1.c1);
 
+select * from t1 where exists (select count(*) from t2 where t1.c1 = 100);
+
 set 
spark.sql.optimizer.decorrelateExistsSubqueryLegacyIncorrectCountHandling.enabled
 = false;
diff --git a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out 
b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
index 5c50d91a5a8..0cbfe9ef081 100644
--- a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
@@ -759,6 +759,7 @@ SELECT * FROM t1, LATERAL (SELECT SUM(cnt) FROM (SELECT 
COUNT(*) cnt FROM t2 WHE
 struct<c1:int,c2:int,sum(cnt):bigint>
 -- !query output
 0      1       2
+1      2       NULL
 
 
 -- !query
diff --git 
a/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-aggregate.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-aggregate.sql.out
index ef720d927d8..af907b67df2 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-aggregate.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-aggregate.sql.out
@@ -197,3 +197,25 @@ emp 3      300.0
 emp 4  100.0
 emp 5  1000.0
 emp 6 - no dept        500.0
+
+
+-- !query
+SELECT tt1.emp_name
+FROM EMP as tt1
+WHERE EXISTS (
+  select max(tt2.id)
+  from EMP as tt2
+  where tt1.emp_name is null
+)
+-- !query schema
+struct<emp_name:string>
+-- !query output
+emp 1
+emp 1
+emp 2
+emp 3
+emp 4
+emp 5
+emp 6 - no dept
+emp 7
+emp 8
diff --git 
a/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-count-bug.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-count-bug.sql.out
index 2e5b31747cb..ee9deab84d2 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-count-bug.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-count-bug.sql.out
@@ -81,6 +81,15 @@ struct<c1:int,c2:int>
 1      2
 
 
+-- !query
+select * from t1 where exists (select count(*) from t2 where t1.c1 = 100)
+-- !query schema
+struct<c1:int,c2:int>
+-- !query output
+0      1
+1      2
+
+
 -- !query
 set 
spark.sql.optimizer.decorrelateExistsSubqueryLegacyIncorrectCountHandling.enabled
 = true
 -- !query schema
@@ -134,6 +143,14 @@ struct<c1:int,c2:int>
 1      2
 
 
+-- !query
+select * from t1 where exists (select count(*) from t2 where t1.c1 = 100)
+-- !query schema
+struct<c1:int,c2:int>
+-- !query output
+
+
+
 -- !query
 set 
spark.sql.optimizer.decorrelateExistsSubqueryLegacyIncorrectCountHandling.enabled
 = false
 -- !query schema


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to