This is an automated email from the ASF dual-hosted git repository.
yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 04b821c69e85 [SPARK-55256][SQL] Support IGNORE NULLS / RESPECT NULLS
for array_agg and collect_list
04b821c69e85 is described below
commit 04b821c69e85be5f51a1270b3a9a4155afdb5334
Author: Kent Yao <[email protected]>
AuthorDate: Fri Jan 30 22:33:51 2026 +0800
[SPARK-55256][SQL] Support IGNORE NULLS / RESPECT NULLS for array_agg and
collect_list
### What changes were proposed in this pull request?
This PR adds support for the IGNORE NULLS and RESPECT NULLS clauses for
`array_agg` and `collect_list` aggregate functions.
### Why are the changes needed?
The SQL standard and many databases (PostgreSQL, Snowflake, DuckDB, etc.)
support the IGNORE NULLS / RESPECT NULLS syntax for aggregate functions.
Currently, Spark only supports this syntax for window functions like `first`,
`last`, `lead`, `lag`, and `nth_value`.
By adding this support to `array_agg` and `collect_list`, users can
explicitly control whether null values should be included in the resulting
array:
- `array_agg(col) IGNORE NULLS` - skips null values (default behavior)
- `array_agg(col) RESPECT NULLS` - includes null values in the result
### Implementation Details
1. Added `ignoreNulls: Boolean = true` parameter to `CollectList` class
2. `array_agg` now uses `CollectList` as they have identical behavior
3. Changed `UnresolvedFunction.ignoreNulls` from `Boolean` to
`Option[Boolean]` to distinguish between:
- `None`: no clause specified (use function's default)
- `Some(true)`: IGNORE NULLS explicitly specified
- `Some(false)`: RESPECT NULLS explicitly specified
4. Consolidated ignoreNulls resolution logic in `FunctionResolution` with
shared `resolveIgnoreNulls` and `applyIgnoreNulls` methods
### Does this PR introduce _any_ user-facing change?
Yes. Users can now use IGNORE NULLS / RESPECT NULLS with array_agg and
collect_list:
```sql
SELECT array_agg(col IGNORE NULLS) FROM table;
SELECT collect_list(col RESPECT NULLS) OVER (PARTITION BY id) FROM table;
```
### How was this patch tested?
Added unit tests in DataFrameAggregateSuite:
- `array_agg and collect_list skip nulls by default`
- `array_agg with IGNORE NULLS explicitly skips nulls`
- `array_agg with RESPECT NULLS preserves nulls`
### Was this patch authored or co-authored using generative AI tooling?
Yes, GitHub Copilot was used to assist with this implementation.
Closes #54034 from yaooqinn/SPARK-55256-ignore-nulls-array-agg.
Authored-by: Kent Yao <[email protected]>
Signed-off-by: Kent Yao <[email protected]>
---
.../sql/catalyst/analysis/FunctionResolution.scala | 83 ++++++++++++----------
.../spark/sql/catalyst/analysis/unresolved.scala | 2 +-
.../catalyst/expressions/aggregate/collect.scala | 38 ++++++++--
.../spark/sql/catalyst/expressions/literals.scala | 2 +-
.../spark/sql/catalyst/parser/AstBuilder.scala | 2 +-
.../explain-results/function_array_agg.explain | 2 +-
.../explain-results/function_collect_list.explain | 2 +-
.../spark/sql/classic/DataFrameWriterV2.scala | 2 +-
.../generators-resolution-edge-cases.sql.out | 6 +-
.../sql-tests/analyzer-results/group-by.sql.out | 4 +-
.../scalar-subquery/scalar-subquery-select.sql.out | 2 +-
.../analyzer-results/udf/udf-window.sql.out | 2 +-
.../sql-tests/analyzer-results/window.sql.out | 2 +-
.../apache/spark/sql/DataFrameAggregateSuite.scala | 25 +++++++
14 files changed, 119 insertions(+), 55 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
index 29f4db65def0..63a75a8aa2b8 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
@@ -152,18 +152,8 @@ class FunctionResolution(
wf.prettyName,
"FILTER clause"
)
- } else if (u.ignoreNulls) {
- wf match {
- case nthValue: NthValue =>
- nthValue.copy(ignoreNulls = u.ignoreNulls)
- case _ =>
- throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
- wf.prettyName,
- "IGNORE NULLS"
- )
- }
} else {
- wf
+ resolveIgnoreNulls(wf, u.ignoreNulls)
}
case owf: FrameLessOffsetWindowFunction =>
if (u.isDistinct) {
@@ -176,15 +166,8 @@ class FunctionResolution(
owf.prettyName,
"FILTER clause"
)
- } else if (u.ignoreNulls) {
- owf match {
- case lead: Lead =>
- lead.copy(ignoreNulls = u.ignoreNulls)
- case lag: Lag =>
- lag.copy(ignoreNulls = u.ignoreNulls)
- }
} else {
- owf
+ resolveIgnoreNulls(owf, u.ignoreNulls)
}
// We get an aggregate function, we need to wrap it in an
AggregateExpression.
case agg: AggregateFunction =>
@@ -216,21 +199,8 @@ class FunctionResolution(
)
case _ =>
}
- if (u.ignoreNulls) {
- val aggFunc = newAgg match {
- case first: First => first.copy(ignoreNulls = u.ignoreNulls)
- case last: Last => last.copy(ignoreNulls = u.ignoreNulls)
- case any_value: AnyValue => any_value.copy(ignoreNulls =
u.ignoreNulls)
- case _ =>
- throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
- newAgg.prettyName,
- "IGNORE NULLS"
- )
- }
- aggFunc.toAggregateExpression(u.isDistinct, u.filter)
- } else {
- newAgg.toAggregateExpression(u.isDistinct, u.filter)
- }
+ val aggFunc = resolveIgnoreNulls(newAgg, u.ignoreNulls)
+ aggFunc.toAggregateExpression(u.isDistinct, u.filter)
// This function is not an aggregate function, just return the resolved
one.
case other =>
checkUnsupportedAggregateClause(other, u)
@@ -258,7 +228,8 @@ class FunctionResolution(
"FILTER clause"
)
}
- if (u.ignoreNulls) {
+ // Only fail for IGNORE NULLS; RESPECT NULLS is the default behavior
+ if (u.ignoreNulls.contains(true)) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
func.prettyName,
"IGNORE NULLS"
@@ -266,6 +237,42 @@ class FunctionResolution(
}
}
+ /**
+ * Resolves the IGNORE NULLS / RESPECT NULLS clause for a function.
+ * If ignoreNulls is defined, applies it to the function; otherwise returns
unchanged.
+ */
+ private def resolveIgnoreNulls[T <: Expression](func: T, ignoreNulls:
Option[Boolean]): T = {
+ ignoreNulls.map(applyIgnoreNulls(func, _)).getOrElse(func)
+ }
+
+ /**
+ * Applies the IGNORE NULLS / RESPECT NULLS clause to functions that support
it.
+ * Returns the modified function if supported, throws error otherwise.
+ */
+ private def applyIgnoreNulls[T <: Expression](func: T, ignoreNulls:
Boolean): T = {
+ val result = func match {
+ // Window functions
+ case nthValue: NthValue => nthValue.copy(ignoreNulls = ignoreNulls)
+ case lead: Lead => lead.copy(ignoreNulls = ignoreNulls)
+ case lag: Lag => lag.copy(ignoreNulls = ignoreNulls)
+ // Aggregate functions
+ case first: First => first.copy(ignoreNulls = ignoreNulls)
+ case last: Last => last.copy(ignoreNulls = ignoreNulls)
+ case anyValue: AnyValue => anyValue.copy(ignoreNulls = ignoreNulls)
+ case collectList: CollectList => collectList.copy(ignoreNulls =
ignoreNulls)
+ case _ if ignoreNulls =>
+ // Only fail for IGNORE NULLS; RESPECT NULLS is the default behavior
+ throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+ func.prettyName,
+ "IGNORE NULLS"
+ )
+ case _ =>
+ // RESPECT NULLS is the default, silently return unchanged
+ func
+ }
+ result.asInstanceOf[T]
+ }
+
private def resolveV2Function(
unbound: UnboundFunction,
arguments: Seq[Expression],
@@ -312,7 +319,8 @@ class FunctionResolution(
scalarFunc.name(),
"FILTER clause"
)
- } else if (u.ignoreNulls) {
+ } else if (u.ignoreNulls.contains(true)) {
+ // Only fail for IGNORE NULLS; RESPECT NULLS is the default behavior
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(),
"IGNORE NULLS"
@@ -326,7 +334,8 @@ class FunctionResolution(
aggFunc: V2AggregateFunction[_, _],
arguments: Seq[Expression],
u: UnresolvedFunction): Expression = {
- if (u.ignoreNulls) {
+ // Only fail for IGNORE NULLS; RESPECT NULLS is the default behavior
+ if (u.ignoreNulls.contains(true)) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
aggFunc.name(),
"IGNORE NULLS"
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 3bdeb6d71884..fffbc7511a1d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -369,7 +369,7 @@ case class UnresolvedFunction(
arguments: Seq[Expression],
isDistinct: Boolean,
filter: Option[Expression] = None,
- ignoreNulls: Boolean = false,
+ ignoreNulls: Option[Boolean] = None,
orderingWithinGroup: Seq[SortOrder] = Seq.empty,
isInternal: Boolean = false)
extends Expression with Unevaluable {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index 015bd1e3e142..29163d08297d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -48,12 +48,17 @@ abstract class Collect[T <: Growable[Any] with
Iterable[Any]] extends TypedImper
override def nullable: Boolean = false
- override def dataType: DataType = ArrayType(child.dataType, false)
+ // Subclasses can override bufferContainsNull to indicate if the result
array contains nulls
+ override def dataType: DataType = ArrayType(child.dataType,
bufferContainsNull)
override def defaultResult: Option[Literal] = Option(Literal.create(Array(),
dataType))
protected def convertToBufferElement(value: Any): Any
+ // Subclasses can override this to allow nulls in buffer
+ // (e.g., CollectList with ignoreNulls=false)
+ protected def bufferContainsNull: Boolean = false
+
override def update(buffer: T, input: InternalRow): T = {
val value = child.eval(input)
@@ -72,7 +77,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]]
extends TypedImper
protected val bufferElementType: DataType
private lazy val projection = UnsafeProjection.create(
- Array[DataType](ArrayType(elementType = bufferElementType, containsNull =
false)))
+ Array[DataType](ArrayType(elementType = bufferElementType, containsNull =
bufferContainsNull)))
private lazy val row = new UnsafeRow(1)
override def serialize(obj: T): Array[Byte] = {
@@ -90,6 +95,9 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]]
extends TypedImper
/**
* Collect a list of elements.
+ *
+ * @param ignoreNulls when true (IGNORE NULLS), null values are excluded from
the result array.
+ * When false (RESPECT NULLS), null values are included in
the result array.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.",
@@ -107,15 +115,32 @@ abstract class Collect[T <: Growable[Any] with
Iterable[Any]] extends TypedImper
case class CollectList(
child: Expression,
mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]]
+ inputAggBufferOffset: Int = 0,
+ ignoreNulls: Boolean = true) extends Collect[mutable.ArrayBuffer[Any]]
with UnaryLike[Expression] {
- def this(child: Expression) = this(child, 0, 0)
+ def this(child: Expression) = this(child, 0, 0, true)
+
+ // Buffer can contain nulls when ignoreNulls is false (RESPECT NULLS)
+ override protected def bufferContainsNull: Boolean = !ignoreNulls
override lazy val bufferElementType = child.dataType
override def convertToBufferElement(value: Any): Any =
InternalRow.copyValue(value)
+ override def update(
+ buffer: mutable.ArrayBuffer[Any],
+ input: InternalRow): mutable.ArrayBuffer[Any] = {
+ val value = child.eval(input)
+ if (value != null) {
+ buffer += convertToBufferElement(value)
+ } else if (!ignoreNulls) {
+ // RESPECT NULLS: preserve null values in result
+ buffer += null
+ }
+ buffer
+ }
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -130,6 +155,11 @@ case class CollectList(
new GenericArrayData(buffer.toArray)
}
+ override def toString: String = {
+ val ignoreNullsStr = if (ignoreNulls) "" else " respect nulls"
+ s"$prettyName($child)$ignoreNullsStr"
+ }
+
override protected def withNewChildInternal(newChild: Expression):
CollectList =
copy(child = newChild)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 94dbce19d4f0..6448194f9705 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -281,7 +281,7 @@ object Literal {
assert(u.nameParts.length == 1)
assert(!u.isDistinct)
assert(u.filter.isEmpty)
- assert(!u.ignoreNulls)
+ assert(u.ignoreNulls.isEmpty)
assert(u.orderingWithinGroup.isEmpty)
assert(!u.isInternal)
FunctionRegistry.builtin.lookupFunction(FunctionIdentifier(u.nameParts.head),
u.arguments)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index fbecd1933ec5..0e9844d7f1a2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -3253,7 +3253,7 @@ class AstBuilder extends DataTypeAstBuilder
val order = ctx.sortItem.asScala.map(visitSortItem)
val filter = Option(ctx.where).map(expression(_))
val ignoreNulls =
- Option(ctx.nullsOption).map(_.getType ==
SqlBaseParser.IGNORE).getOrElse(false)
+ Option(ctx.nullsOption).map(_.getType == SqlBaseParser.IGNORE)
// Is this an IDENTIFIER clause instead of a function call?
if (ctx.functionName.identFunc != null &&
diff --git
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain
index 102f736c62ef..27da00bd0c86 100644
---
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain
+++
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain
@@ -1,2 +1,2 @@
-Aggregate [collect_list(a#0, 0, 0) AS collect_list(a)#0]
+Aggregate [collect_list(a#0, 0, 0, true) AS collect_list(a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_list.explain
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_list.explain
index 102f736c62ef..27da00bd0c86 100644
---
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_list.explain
+++
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_list.explain
@@ -1,2 +1,2 @@
-Aggregate [collect_list(a#0, 0, 0) AS collect_list(a)#0]
+Aggregate [collect_list(a#0, 0, 0, true) AS collect_list(a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
index 19cc9b76beae..1309b346ffa5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
@@ -253,7 +253,7 @@ private object PartitionTransform {
private val NAMES = Seq(name)
def unapply(e: Expression): Option[Seq[Expression]] = e match {
- case UnresolvedFunction(NAMES, children, false, None, false, Nil, true)
=> Option(children)
+ case UnresolvedFunction(NAMES, children, false, None, None, Nil, true)
=> Option(children)
case _ => None
}
}
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/generators-resolution-edge-cases.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/generators-resolution-edge-cases.sql.out
index 2cbb1fb5d382..5a87586b150a 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/generators-resolution-edge-cases.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/generators-resolution-edge-cases.sql.out
@@ -71,7 +71,7 @@ SELECT explode(collect_list(id)) AS val FROM range(5)
-- !query analysis
Project [val#xL]
+- Generate explode(_gen_input_0#x), false, [val#xL]
- +- Aggregate [collect_list(id#xL, 0, 0) AS _gen_input_0#x]
+ +- Aggregate [collect_list(id#xL, 0, 0, true) AS _gen_input_0#x]
+- Range (0, 5, step=1)
@@ -80,7 +80,7 @@ SELECT explode(collect_list(id)) AS val, count(*) AS cnt FROM
range(3)
-- !query analysis
Project [val#xL, cnt#xL]
+- Generate explode(_gen_input_0#x), false, [val#xL]
- +- Aggregate [collect_list(id#xL, 0, 0) AS _gen_input_0#x, count(1) AS
cnt#xL]
+ +- Aggregate [collect_list(id#xL, 0, 0, true) AS _gen_input_0#x, count(1)
AS cnt#xL]
+- Range (0, 3, step=1)
@@ -204,7 +204,7 @@ GROUP BY ALL
-- !query analysis
Project [val#x]
+- Generate explode(_gen_input_0#x), false, [val#x]
- +- Aggregate [collect_list(a#x, 0, 0) AS _gen_input_0#x]
+ +- Aggregate [collect_list(a#x, 0, 0, true) AS _gen_input_0#x]
+- SubqueryAlias t
+- Project [col1#x AS a#x]
+- LocalRelation [col1#x]
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
index a2d3f4cdb016..80d63901af4c 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
@@ -1044,7 +1044,7 @@ SELECT
FROM VALUES
(1), (2), (1) AS tab(col)
-- !query analysis
-Aggregate [collect_list(col#x, 0, 0) AS collect_list(col)#x,
collect_list(col#x, 0, 0) AS collect_list(col)#x]
+Aggregate [collect_list(col#x, 0, 0, true) AS collect_list(col)#x,
collect_list(col#x, 0, 0, true) AS collect_list(col)#x]
+- SubqueryAlias tab
+- LocalRelation [col#x]
@@ -1058,7 +1058,7 @@ FROM VALUES
(1,4),(2,3),(1,4),(2,4) AS v(a,b)
GROUP BY a
-- !query analysis
-Aggregate [a#x], [a#x, collect_list(b#x, 0, 0) AS collect_list(b)#x,
collect_list(b#x, 0, 0) AS collect_list(b)#x]
+Aggregate [a#x], [a#x, collect_list(b#x, 0, 0, true) AS collect_list(b)#x,
collect_list(b#x, 0, 0, true) AS collect_list(b)#x]
+- SubqueryAlias v
+- LocalRelation [a#x, b#x]
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
index f64b3736b552..9f3552e6d6e2 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
@@ -480,7 +480,7 @@ Project [t1a#x, scalar-subquery#x [t1a#x] AS count_t2#xL,
scalar-subquery#x [t1a
: : +- Project [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x,
t2g#x, t2h#x, t2i#x]
: : +- SubqueryAlias t2
: : +- LocalRelation [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x,
t2f#x, t2g#x, t2h#x, t2i#x]
-: :- Aggregate [collect_list(t2d#xL, 0, 0) AS collect_list(t2d)#x]
+: :- Aggregate [collect_list(t2d#xL, 0, 0, true) AS collect_list(t2d)#x]
: : +- Filter (t2a#x = outer(t1a#x))
: : +- SubqueryAlias t2
: : +- View (`t2`, [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x,
t2h#x, t2i#x])
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
index aa3f40b62ccd..11164ececc93 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
@@ -389,7 +389,7 @@ Project [udf(val)#x, cate#x, max#x, min#x, min#x, count#xL,
sum#xL, avg#x, stdde
+- Sort [cate#x ASC NULLS FIRST, cast(udf(cast(val#x as string)) as int) ASC
NULLS FIRST], true
+- Project [udf(val)#x, cate#x, max#x, min#x, min#x, count#xL, sum#xL,
avg#x, stddev#x, first_value#x, first_value_ignore_null#x,
first_value_contain_null#x, any_value#x, any_value_ignore_null#x,
any_value_contain_null#x, last_value#x, last_value_ignore_null#x,
last_value_contain_null#x, rank#x, dense_rank#x, cume_dist#x, percent_rank#x,
ntile#x, row_number#x, var_pop#x, var_samp#x, approx_count_distinct#xL,
covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x, collect_list#x, collect_set
[...]
+- Project [udf(val)#x, cate#x, _w0#x, _w1#x, _w2#x, _w3#x, _w4#x,
max#x, min#x, min#x, count#xL, sum#xL, avg#x, stddev#x, first_value#x,
first_value_ignore_null#x, first_value_contain_null#x, any_value#x,
any_value_ignore_null#x, any_value_contain_null#x, last_value#x,
last_value_ignore_null#x, last_value_contain_null#x, rank#x, dense_rank#x,
cume_dist#x, percent_rank#x, ntile#x, row_number#x, var_pop#x, var_samp#x,
approx_count_distinct#xL, covar_pop#x, corr#x, stddev_samp#x, std [...]
- +- Window [max(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$()))
AS max#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, count(_w0#x) windowspecdefinition(_w1#x, [...]
+ +- Window [max(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$()))
AS max#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, count(_w0#x) windowspecdefinition(_w1#x, [...]
+- Project [cast(udf(cast(val#x as string)) as int) AS udf(val)#x,
cate#x, cast(udf(cast(val#x as string)) as int) AS _w0#x, cast(udf(cast(cate#x
as string)) as string) AS _w1#x, cast(cast(udf(cast(val#x as string)) as int)
as double) AS _w2#x, cast(cast(udf(cast(val_long#xL as string)) as bigint) as
double) AS _w3#x, cast(udf(cast(val_double#x as string)) as double) AS _w4#x,
val#x]
+- SubqueryAlias testdata
+- View (`testData`, [val#x, val_long#xL, val_double#x,
val_date#x, val_timestamp#x, cate#x])
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
index 00cf492574cf..b585e01a75de 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
@@ -585,7 +585,7 @@ ORDER BY cate, val
Sort [cate#x ASC NULLS FIRST, val#x ASC NULLS FIRST], true
+- Project [val#x, cate#x, max#x, min#x, min#x, count#xL, sum#xL, avg#x,
stddev#x, first_value#x, first_value_ignore_null#x, first_value_contain_null#x,
any_value#x, any_value_ignore_null#x, any_value_contain_null#x, last_value#x,
last_value_ignore_null#x, last_value_contain_null#x, rank#x, dense_rank#x,
cume_dist#x, percent_rank#x, ntile#x, row_number#x, var_pop#x, var_samp#x,
approx_count_distinct#xL, covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x,
collect_list#x, collect_set#x, skew [...]
+- Project [val#x, cate#x, _w0#x, _w1#x, val_double#x, max#x, min#x, min#x,
count#xL, sum#xL, avg#x, stddev#x, first_value#x, first_value_ignore_null#x,
first_value_contain_null#x, any_value#x, any_value_ignore_null#x,
any_value_contain_null#x, last_value#x, last_value_ignore_null#x,
last_value_contain_null#x, rank#x, dense_rank#x, cume_dist#x, percent_rank#x,
ntile#x, row_number#x, var_pop#x, var_samp#x, approx_count_distinct#xL,
covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x, coll [...]
- +- Window [max(val#x) windowspecdefinition(cate#x, val#x ASC NULLS
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$()))
AS max#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, count(val#x) windowspecdefinition(cate#x [...]
+ +- Window [max(val#x) windowspecdefinition(cate#x, val#x ASC NULLS
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$()))
AS max#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, count(val#x) windowspecdefinition(cate#x [...]
+- Project [val#x, cate#x, cast(val#x as double) AS _w0#x,
cast(val_long#xL as double) AS _w1#x, val_double#x]
+- SubqueryAlias testdata
+- View (`testData`, [val#x, val_long#xL, val_double#x,
val_date#x, val_timestamp#x, cate#x])
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 1a56b01a851d..0a9e08f5c57e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -610,6 +610,31 @@ class DataFrameAggregateSuite extends QueryTest
)
}
+ test("SPARK-55256: array_agg and collect_list skip nulls by default") {
+ val df = Seq((1, Some(2)), (2, None), (3, Some(4))).toDF("a", "b")
+
+ // Both functions skip nulls by default
+ checkAnswer(df.selectExpr("array_agg(b)"), Seq(Row(Seq(2, 4))))
+ checkAnswer(df.select(array_agg($"b")), Seq(Row(Seq(2, 4))))
+ checkAnswer(df.selectExpr("collect_list(b)"), Seq(Row(Seq(2, 4))))
+ checkAnswer(df.select(collect_list($"b")), Seq(Row(Seq(2, 4))))
+ }
+
+ test("SPARK-55256: array_agg with IGNORE NULLS explicitly skips nulls") {
+ val df = Seq((1, Some(2)), (2, None), (3, Some(4))).toDF("a", "b")
+
+ checkAnswer(df.selectExpr("array_agg(b) IGNORE NULLS"), Seq(Row(Seq(2,
4))))
+ checkAnswer(df.selectExpr("collect_list(b) IGNORE NULLS"), Seq(Row(Seq(2,
4))))
+ }
+
+ test("SPARK-55256: array_agg with RESPECT NULLS preserves nulls") {
+ val df = Seq((1, Some(2)), (2, None), (3, Some(4))).toDF("a", "b")
+
+ // RESPECT NULLS preserves null values in the result
+ checkAnswer(df.selectExpr("array_agg(b) RESPECT NULLS"), Seq(Row(Seq(2,
null, 4))))
+ checkAnswer(df.selectExpr("collect_list(b) RESPECT NULLS"), Seq(Row(Seq(2,
null, 4))))
+ }
+
test("collect functions structs") {
val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1))
.toDF("a", "x", "y")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]