This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 44c344d  [SPARK-34876][SQL] Fill defaultResult of non-nullable 
aggregates
44c344d is described below

commit 44c344db190976a6375ad5a263dbf20c5135c023
Author: Tanel Kiis <tanel.k...@gmail.com>
AuthorDate: Mon Mar 29 11:47:08 2021 +0900

    [SPARK-34876][SQL] Fill defaultResult of non-nullable aggregates
    
    ### What changes were proposed in this pull request?
    
    Filled the `defaultResult` field on non-nullable aggregates
    
    ### Why are the changes needed?
    
    The `defaultResult` defaults to `None` and in some situations (like 
correlated scalar subqueries) it is used for the value of the aggregation.
    
    The UT result before the fix:
    ```
    -- !query
    SELECT t1a,
       (SELECT count(t2d) FROM t2 WHERE t2a = t1a) count_t2,
       (SELECT count_if(t2d > 0) FROM t2 WHERE t2a = t1a) count_if_t2,
       (SELECT approx_count_distinct(t2d) FROM t2 WHERE t2a = t1a) 
approx_count_distinct_t2,
       (SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2,
       (SELECT collect_set(t2d) FROM t2 WHERE t2a = t1a) collect_set_t2,
        (SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = 
t1a) collect_set_t2
    FROM t1
    -- !query schema
    
struct<t1a:string,count_t2:bigint,count_if_t2:bigint,approx_count_distinct_t2:bigint,collect_list_t2:array<bigint>,collect_set_t2:array<bigint>,collect_set_t2:string>
    -- !query output
    val1a       0       0       NULL    NULL    NULL    NULL
    val1a       0       0       NULL    NULL    NULL    NULL
    val1a       0       0       NULL    NULL    NULL    NULL
    val1a       0       0       NULL    NULL    NULL    NULL
    val1b       6       6       3       [19,119,319,19,19,19]   [19,119,319]    
0000000100000000000000060000000100000004000000005D8D6AB90000000000000000000000000000000400000000000000010000000000000001
    val1c       2       2       2       [219,19]        [219,19]        
0000000100000000000000020000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000001
    val1d       0       0       NULL    NULL    NULL    NULL
    val1d       0       0       NULL    NULL    NULL    NULL
    val1d       0       0       NULL    NULL    NULL    NULL
    val1e       1       1       1       [19]    [19]    
0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
    val1e       1       1       1       [19]    [19]    
0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
    val1e       1       1       1       [19]    [19]    
0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Bugfix
    
    ### How was this patch tested?
    
    UT
    
    Closes #31973 from tanelk/SPARK-34876_non_nullable_agg_subquery.
    
    Authored-by: Tanel Kiis <tanel.k...@gmail.com>
    Signed-off-by: HyukjinKwon <gurwls...@apache.org>
    (cherry picked from commit 4b9e94c44412f399ba19e0ea90525d346942bf71)
    Signed-off-by: HyukjinKwon <gurwls...@apache.org>
---
 .../expressions/aggregate/CountMinSketchAgg.scala  |  5 +++-
 .../aggregate/HyperLogLogPlusPlus.scala            |  2 ++
 .../catalyst/expressions/aggregate/collect.scala   |  2 ++
 .../expressions/aggregate/interfaces.scala         |  3 +--
 .../scalar-subquery/scalar-subquery-select.sql     | 10 ++++++++
 .../scalar-subquery/scalar-subquery-select.sql.out | 28 +++++++++++++++++++++-
 6 files changed, 46 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
index 787b218..7e8abe0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, 
TypeCheckSuccess}
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionDescription}
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionDescription, Literal}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.sketch.CountMinSketch
@@ -135,6 +135,9 @@ case class CountMinSketchAgg(
 
   override def dataType: DataType = BinaryType
 
+  override def defaultResult: Option[Literal] =
+    Option(Literal.create(eval(createAggregationBuffer()), dataType))
+
   override def children: Seq[Expression] =
     Seq(child, epsExpression, confidenceExpression, seedExpression)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index b3cc9a3..b6a4d11 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -90,6 +90,8 @@ case class HyperLogLogPlusPlus(
 
   override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
 
+  override def defaultResult: Option[Literal] = Option(Literal.create(0L, 
dataType))
+
   val hllppHelper = new HyperLogLogPlusPlusHelper(relativeSD)
 
   /** Allocate enough words to store all registers. */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index 0a3d876..68fbfa3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -47,6 +47,8 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] 
extends TypedImper
   // actual order of input rows.
   override lazy val deterministic: Boolean = false
 
+  override def defaultResult: Option[Literal] = Option(Literal.create(Array(), 
dataType))
+
   protected def convertToBufferElement(value: Any): Any
 
   override def update(buffer: T, input: InternalRow): T = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 222ad6f..5f091cd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -203,8 +203,7 @@ abstract class AggregateFunction extends Expression {
   def inputAggBufferAttributes: Seq[AttributeReference]
 
   /**
-   * Result of the aggregate function when the input is empty. This is 
currently only used for the
-   * proper rewriting of distinct aggregate functions.
+   * Result of the aggregate function when the input is empty.
    */
   def defaultResult: Option[Literal] = None
 
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql
index eabbd0a..81712bf 100644
--- 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql
@@ -128,3 +128,13 @@ WHERE  NOT EXISTS (SELECT (SELECT max(t2b)
                                  ON     t2a = t1a
                                  WHERE  t2c = t3c)
                    AND    t3a = t1a);
+
+-- SPARK-34876: Non-nullable aggregates should not return NULL in a correlated 
subquery
+SELECT t1a,
+    (SELECT count(t2d) FROM t2 WHERE t2a = t1a) count_t2,
+    (SELECT count_if(t2d > 0) FROM t2 WHERE t2a = t1a) count_if_t2,
+    (SELECT approx_count_distinct(t2d) FROM t2 WHERE t2a = t1a) 
approx_count_distinct_t2,
+    (SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2,
+    (SELECT collect_set(t2d) FROM t2 WHERE t2a = t1a) collect_set_t2,
+    (SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) 
collect_set_t2
+FROM t1;
\ No newline at end of file
diff --git 
a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out
index 184b8da..16570c6 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 11
+-- Number of queries: 12
 
 
 -- !query
@@ -196,3 +196,29 @@ val1d      NULL
 val1e  10
 val1e  10
 val1e  10
+
+
+-- !query
+SELECT t1a,
+    (SELECT count(t2d) FROM t2 WHERE t2a = t1a) count_t2,
+    (SELECT count_if(t2d > 0) FROM t2 WHERE t2a = t1a) count_if_t2,
+    (SELECT approx_count_distinct(t2d) FROM t2 WHERE t2a = t1a) 
approx_count_distinct_t2,
+    (SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2,
+    (SELECT collect_set(t2d) FROM t2 WHERE t2a = t1a) collect_set_t2,
+    (SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) 
collect_set_t2
+FROM t1
+-- !query schema
+struct<t1a:string,count_t2:bigint,count_if_t2:bigint,approx_count_distinct_t2:bigint,collect_list_t2:array<bigint>,collect_set_t2:array<bigint>,collect_set_t2:string>
+-- !query output
+val1a  0       0       0       []      []      
0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
+val1a  0       0       0       []      []      
0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
+val1a  0       0       0       []      []      
0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
+val1a  0       0       0       []      []      
0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
+val1b  6       6       3       [19,119,319,19,19,19]   [19,119,319]    
0000000100000000000000060000000100000004000000005D8D6AB90000000000000000000000000000000400000000000000010000000000000001
+val1c  2       2       2       [219,19]        [219,19]        
0000000100000000000000020000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000001
+val1d  0       0       0       []      []      
0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
+val1d  0       0       0       []      []      
0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
+val1d  0       0       0       []      []      
0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
+val1e  1       1       1       [19]    [19]    
0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
+val1e  1       1       1       [19]    [19]    
0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
+val1e  1       1       1       [19]    [19]    
0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to