This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 90d302aeb788 [SPARK-48557][SQL] Support scalar subquery with group-by 
on column equal to constant
90d302aeb788 is described below

commit 90d302aeb7887a86730948b29b32dd82ec62586b
Author: Jack Chen <[email protected]>
AuthorDate: Mon Jun 17 14:28:22 2024 +0800

    [SPARK-48557][SQL] Support scalar subquery with group-by on column equal to 
constant
    
    ### What changes were proposed in this pull request?
    We can enable scalar subqueries that have `group by a` if there's a 
predicate `a = 1`, because these predicates guarantee the group-by produces at 
most one row. (This builds on top of https://github.com/apache/spark/pull/46839 
and enables shapes there were unsupported prior to that PR as well.)
    
    ### Why are the changes needed?
    Support valid subquery shapes.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, support subquery shapes.
    
    ### How was this patch tested?
    Unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46902 from jchen5/subq-gby-eq.
    
    Authored-by: Jack Chen <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/expressions/subquery.scala  | 12 +++--
 .../org/apache/spark/sql/internal/SQLConf.scala    |  9 ++++
 .../scalar-subquery-group-by.sql.out               | 52 +++++++++++++---------
 .../scalar-subquery/scalar-subquery-group-by.sql   |  8 ++--
 .../scalar-subquery-group-by.sql.out               | 40 ++++++++---------
 5 files changed, 73 insertions(+), 48 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 9f914865b3a2..75ca4930cf8c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -257,12 +257,16 @@ object SubExprUtils extends PredicateHelper {
    * We can derive these from correlated equality predicates, though we need 
to take care about
    * propagating this through operators like OUTER JOIN or UNION.
    *
-   * Positive examples: x = outer(a) AND y = outer(b)
+   * Positive examples:
+   * - x = outer(a) AND y = outer(b)
+   * - x = 1
+   * - x = outer(a) + 1
+   *
    * Negative examples:
    * - x <= outer(a)
    * - x + y = outer(a)
    * - x = outer(a) OR y = outer(b)
-   * - y = outer(b) + 1 (this and similar expressions could be supported, but 
very carefully)
+   * - y + outer(b) = 1 (this and similar expressions could be supported, but 
very carefully)
    * - An equality under the right side of a LEFT OUTER JOIN, e.g.
    *   select *, (select count(*) from y left join
    *     (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x;
@@ -274,7 +278,9 @@ object SubExprUtils extends PredicateHelper {
     plan match {
       case Filter(cond, child) =>
         val correlated = AttributeSet(splitConjunctivePredicates(cond)
-          .filter(containsOuter) // TODO: can remove this line to allow e.g. 
where x = 1 group by x
+          .filter(
+            
SQLConf.get.getConf(SQLConf.SCALAR_SUBQUERY_ALLOW_GROUP_BY_COLUMN_EQUAL_TO_CONSTANT)
+            || containsOuter(_))
           .filter(DecorrelateInnerQuery.canPullUpOverAgg)
           .flatMap(_.references))
         correlated ++ getCorrelatedEquivalentInnerColumns(child)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index ade0ba52cf9e..25a2441e05fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4940,6 +4940,15 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val SCALAR_SUBQUERY_ALLOW_GROUP_BY_COLUMN_EQUAL_TO_CONSTANT =
+    
buildConf("spark.sql.analyzer.scalarSubqueryAllowGroupByColumnEqualToConstant")
+      .internal()
+      .doc("When set to true, allow scalar subqueries with group-by on a 
column that also " +
+        " has an equality filter with a constant (SPARK-48557).")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val ALLOW_SUBQUERY_EXPRESSIONS_IN_LAMBDAS_AND_HIGHER_ORDER_FUNCTIONS =
     
buildConf("spark.sql.analyzer.allowSubqueryExpressionsInLambdasOrHigherOrderFunctions")
       .internal()
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out
index d9eff3459235..671557aa3956 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out
@@ -77,6 +77,38 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
 }
 
 
+-- !query
+select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x
+-- !query analysis
+Project [x1#x, x2#x, scalar-subquery#x [x1#x] AS scalarsubquery(x1)#xL]
+:  +- Aggregate [y2#x], [count(1) AS count(1)#xL]
+:     +- Filter ((outer(x1#x) = y1#x) AND (y2#x = 1))
+:        +- SubqueryAlias y
+:           +- View (`y`, [y1#x, y2#x])
+:              +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS 
y2#x]
+:                 +- LocalRelation [col1#x, col2#x]
++- SubqueryAlias x
+   +- View (`x`, [x1#x, x2#x])
+      +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x]
+         +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+select *, (select count(*) from y where x1 = y1 and y2 = x1 + 1 group by y2) 
from x
+-- !query analysis
+Project [x1#x, x2#x, scalar-subquery#x [x1#x && x1#x] AS scalarsubquery(x1, 
x1)#xL]
+:  +- Aggregate [y2#x], [count(1) AS count(1)#xL]
+:     +- Filter ((outer(x1#x) = y1#x) AND (y2#x = (outer(x1#x) + 1)))
+:        +- SubqueryAlias y
+:           +- View (`y`, [y1#x, y2#x])
+:              +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS 
y2#x]
+:                 +- LocalRelation [col1#x, col2#x]
++- SubqueryAlias x
+   +- View (`x`, [x1#x, x2#x])
+      +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x]
+         +- LocalRelation [col1#x, col2#x]
+
+
 -- !query
 select * from x where (select count(*) from y where y1 > x1 group by y1) = 1
 -- !query analysis
@@ -117,26 +149,6 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
 }
 
 
--- !query
-select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x
--- !query analysis
-org.apache.spark.sql.catalyst.ExtendedAnalysisException
-{
-  "errorClass" : 
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY",
-  "sqlState" : "0A000",
-  "messageParameters" : {
-    "value" : "y2"
-  },
-  "queryContext" : [ {
-    "objectType" : "",
-    "objectName" : "",
-    "startIndex" : 11,
-    "stopIndex" : 71,
-    "fragment" : "(select count(*) from y where x1 = y1 and y2 = 1 group by 
y2)"
-  } ]
-}
-
-
 -- !query
 select *, (select count(*) from (select * from y where y1 = x1 union all 
select * from y) sub group by y1) from x
 -- !query analysis
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql
index 627b27ad285b..6787fac75b39 100644
--- 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql
@@ -11,13 +11,15 @@ select * from x where (select count(*) from y where y1 = x1 
group by y1) = 1;
 select * from x where (select count(*) from y where y1 = x1 group by x1) = 1;
 select * from x where (select count(*) from y where y1 > x1 group by x1) = 1;
 
+-- Group-by column equal to constant - legal
+select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x;
+-- Group-by column equal to expression with constants and outer refs - legal
+select *, (select count(*) from y where x1 = y1 and y2 = x1 + 1 group by y2) 
from x;
+
 -- Illegal queries
 select * from x where (select count(*) from y where y1 > x1 group by y1) = 1;
 select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x;
 
--- Equality with literal - disallowed currently but can actually be allowed
-select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x;
-
 -- Certain other operators like OUTER JOIN or UNION between the correlating 
filter and the group-by also can cause the scalar subquery to return multiple 
values and hence make the query illegal.
 select *, (select count(*) from (select * from y where y1 = x1 union all 
select * from y) sub group by y1) from x;
 select *, (select count(*) from y left join (select * from z where z1 = x1) 
sub on y2 = z2 group by z1) from x; -- The correlation below the join is 
unsupported in Spark anyway, but when we do support it this query should still 
be disallowed.
diff --git 
a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out
index c044e59a26fd..85ebd91c28c9 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out
@@ -75,6 +75,24 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
 }
 
 
+-- !query
+select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x
+-- !query schema
+struct<x1:int,x2:int,scalarsubquery(x1):bigint>
+-- !query output
+1      1       NULL
+2      2       NULL
+
+
+-- !query
+select *, (select count(*) from y where x1 = y1 and y2 = x1 + 1 group by y2) 
from x
+-- !query schema
+struct<x1:int,x2:int,scalarsubquery(x1, x1):bigint>
+-- !query output
+1      1       NULL
+2      2       NULL
+
+
 -- !query
 select * from x where (select count(*) from y where y1 > x1 group by y1) = 1
 -- !query schema
@@ -119,28 +137,6 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
 }
 
 
--- !query
-select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x
--- !query schema
-struct<>
--- !query output
-org.apache.spark.sql.catalyst.ExtendedAnalysisException
-{
-  "errorClass" : 
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY",
-  "sqlState" : "0A000",
-  "messageParameters" : {
-    "value" : "y2"
-  },
-  "queryContext" : [ {
-    "objectType" : "",
-    "objectName" : "",
-    "startIndex" : 11,
-    "stopIndex" : 71,
-    "fragment" : "(select count(*) from y where x1 = y1 and y2 = 1 group by 
y2)"
-  } ]
-}
-
-
 -- !query
 select *, (select count(*) from (select * from y where y1 = x1 union all 
select * from y) sub group by y1) from x
 -- !query schema


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to