This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new b83ab63 [SPARK-34876][SQL] Fill defaultResult of non-nullable aggregates b83ab63 is described below commit b83ab63384e4a6f235f7d3c7a53ed098b8e1b87e Author: Tanel Kiis <tanel.k...@gmail.com> AuthorDate: Mon Mar 29 11:47:08 2021 +0900 [SPARK-34876][SQL] Fill defaultResult of non-nullable aggregates ### What changes were proposed in this pull request? Filled the `defaultResult` field on non-nullable aggregates ### Why are the changes needed? The `defaultResult` defaults to `None` and in some situations (like correlated scalar subqueries) it is used for the value of the aggregation. The UT result before the fix: ``` -- !query SELECT t1a, (SELECT count(t2d) FROM t2 WHERE t2a = t1a) count_t2, (SELECT count_if(t2d > 0) FROM t2 WHERE t2a = t1a) count_if_t2, (SELECT approx_count_distinct(t2d) FROM t2 WHERE t2a = t1a) approx_count_distinct_t2, (SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2, (SELECT collect_set(t2d) FROM t2 WHERE t2a = t1a) collect_set_t2, (SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2 FROM t1 -- !query schema struct<t1a:string,count_t2:bigint,count_if_t2:bigint,approx_count_distinct_t2:bigint,collect_list_t2:array<bigint>,collect_set_t2:array<bigint>,collect_set_t2:string> -- !query output val1a 0 0 NULL NULL NULL NULL val1a 0 0 NULL NULL NULL NULL val1a 0 0 NULL NULL NULL NULL val1a 0 0 NULL NULL NULL NULL val1b 6 6 3 [19,119,319,19,19,19] [19,119,319] 0000000100000000000000060000000100000004000000005D8D6AB90000000000000000000000000000000400000000000000010000000000000001 val1c 2 2 2 [219,19] [219,19] 0000000100000000000000020000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000001 val1d 0 0 NULL NULL NULL NULL val1d 0 0 NULL NULL NULL NULL val1d 0 0 NULL NULL NULL NULL val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 ``` ### Does this PR introduce _any_ user-facing change? Bugfix ### How was this patch tested? UT Closes #31973 from tanelk/SPARK-34876_non_nullable_agg_subquery. Authored-by: Tanel Kiis <tanel.k...@gmail.com> Signed-off-by: HyukjinKwon <gurwls...@apache.org> (cherry picked from commit 4b9e94c44412f399ba19e0ea90525d346942bf71) Signed-off-by: HyukjinKwon <gurwls...@apache.org> --- .../expressions/aggregate/CountMinSketchAgg.scala | 5 +++- .../aggregate/HyperLogLogPlusPlus.scala | 2 ++ .../catalyst/expressions/aggregate/collect.scala | 2 ++ .../expressions/aggregate/interfaces.scala | 3 +-- .../scalar-subquery/scalar-subquery-select.sql | 10 ++++++++ .../scalar-subquery/scalar-subquery-select.sql.out | 28 +++++++++++++++++++++- 6 files changed, 46 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala index dae88c7..b756de3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.sketch.CountMinSketch @@ -133,6 +133,9 @@ case class CountMinSketchAgg( override def dataType: DataType = BinaryType + override def defaultResult: Option[Literal] = + Option(Literal.create(eval(createAggregationBuffer()), dataType)) + override def children: Seq[Expression] = Seq(child, epsExpression, confidenceExpression, seedExpression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 0b73788..85435f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -84,6 +84,8 @@ case class HyperLogLogPlusPlus( override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override def defaultResult: Option[Literal] = Option(Literal.create(0L, dataType)) + val hllppHelper = new HyperLogLogPlusPlusHelper(relativeSD) /** Allocate enough words to store all registers. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 8dc3171..fed6246 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -47,6 +47,8 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper // actual order of input rows. override lazy val deterministic: Boolean = false + override def defaultResult: Option[Literal] = Option(Literal.create(Array(), dataType)) + protected def convertToBufferElement(value: Any): Any override def update(buffer: T, input: InternalRow): T = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 56c2ee6..ef68587 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -182,8 +182,7 @@ abstract class AggregateFunction extends Expression { def inputAggBufferAttributes: Seq[AttributeReference] /** - * Result of the aggregate function when the input is empty. This is currently only used for the - * proper rewriting of distinct aggregate functions. + * Result of the aggregate function when the input is empty. */ def defaultResult: Option[Literal] = None diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql index eabbd0a..81712bf 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql @@ -128,3 +128,13 @@ WHERE NOT EXISTS (SELECT (SELECT max(t2b) ON t2a = t1a WHERE t2c = t3c) AND t3a = t1a); + +-- SPARK-34876: Non-nullable aggregates should not return NULL in a correlated subquery +SELECT t1a, + (SELECT count(t2d) FROM t2 WHERE t2a = t1a) count_t2, + (SELECT count_if(t2d > 0) FROM t2 WHERE t2a = t1a) count_if_t2, + (SELECT approx_count_distinct(t2d) FROM t2 WHERE t2a = t1a) approx_count_distinct_t2, + (SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2, + (SELECT collect_set(t2d) FROM t2 WHERE t2a = t1a) collect_set_t2, + (SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2 +FROM t1; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out index 807bb47..38b7e8b 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 12 -- !query 0 @@ -196,3 +196,29 @@ val1d NULL val1e 10 val1e 10 val1e 10 + + +-- !query +SELECT t1a, + (SELECT count(t2d) FROM t2 WHERE t2a = t1a) count_t2, + (SELECT count_if(t2d > 0) FROM t2 WHERE t2a = t1a) count_if_t2, + (SELECT approx_count_distinct(t2d) FROM t2 WHERE t2a = t1a) approx_count_distinct_t2, + (SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2, + (SELECT collect_set(t2d) FROM t2 WHERE t2a = t1a) collect_set_t2, + (SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2 +FROM t1 +-- !query schema +struct<t1a:string,count_t2:bigint,count_if_t2:bigint,approx_count_distinct_t2:bigint,collect_list_t2:array<bigint>,collect_set_t2:array<bigint>,collect_set_t2:string> +-- !query output +val1a 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000 +val1a 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000 +val1a 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000 +val1a 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000 +val1b 6 6 3 [19,119,319,19,19,19] [19,119,319] 0000000100000000000000060000000100000004000000005D8D6AB90000000000000000000000000000000400000000000000010000000000000001 +val1c 2 2 2 [219,19] [219,19] 0000000100000000000000020000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000001 +val1d 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000 +val1d 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000 +val1d 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000 +val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 +val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 +val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org