This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1469e0eea526 [SPARK-55533][SQL] Support IGNORE NULLS / RESPECT NULLS
for collect_set
1469e0eea526 is described below
commit 1469e0eea526bea54acd59684cdf5615af8d71da
Author: Kent Yao <[email protected]>
AuthorDate: Sun Feb 15 11:16:47 2026 -0800
[SPARK-55533][SQL] Support IGNORE NULLS / RESPECT NULLS for collect_set
### What changes were proposed in this pull request?
This PR adds `IGNORE NULLS` / `RESPECT NULLS` support to `collect_set`,
mirroring the existing `collect_list` behavior added in SPARK-55256.
- `collect_set(expr)` — default, skips nulls (unchanged behavior)
- `collect_set(expr) IGNORE NULLS` — explicitly skips nulls
- `collect_set(expr) RESPECT NULLS` — includes null in the result set
### Why are the changes needed?
For consistency: `collect_list`/`array_agg` already supports this syntax,
but `collect_set` does not. Users who want to include null in collected sets
currently have no way to do so.
### Does this PR introduce _any_ user-facing change?
Yes. `collect_set` now accepts `IGNORE NULLS` and `RESPECT NULLS` clauses
in SQL.
### How was this patch tested?
Added 3 new tests in `DataFrameAggregateSuite`:
- `collect_set skips nulls by default`
- `collect_set with IGNORE NULLS explicitly skips nulls`
- `collect_set with RESPECT NULLS preserves null in set`
### Was this patch authored or co-authored using generative AI tooling?
Yes, co-authored with GitHub Copilot.
Closes #54329 from yaooqinn/SPARK-55533.
Authored-by: Kent Yao <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../sql/catalyst/analysis/FunctionResolution.scala | 1 +
.../catalyst/expressions/aggregate/collect.scala | 34 ++++++++++++++++++++--
.../explain-results/function_collect_set.explain | 2 +-
.../scalar-subquery/scalar-subquery-select.sql.out | 2 +-
.../analyzer-results/udf/udf-window.sql.out | 2 +-
.../sql-tests/analyzer-results/window.sql.out | 2 +-
.../sql-tests/results/explain-aqe.sql.out | 8 ++---
.../resources/sql-tests/results/explain.sql.out | 8 ++---
.../apache/spark/sql/DataFrameAggregateSuite.scala | 22 ++++++++++++++
9 files changed, 66 insertions(+), 15 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
index 63a75a8aa2b8..19638a8e5966 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
@@ -260,6 +260,7 @@ class FunctionResolution(
case last: Last => last.copy(ignoreNulls = ignoreNulls)
case anyValue: AnyValue => anyValue.copy(ignoreNulls = ignoreNulls)
case collectList: CollectList => collectList.copy(ignoreNulls =
ignoreNulls)
+ case collectSet: CollectSet => collectSet.copy(ignoreNulls = ignoreNulls)
case _ if ignoreNulls =>
// Only fail for IGNORE NULLS; RESPECT NULLS is the default behavior
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index 29163d08297d..a864bf108436 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -166,6 +166,9 @@ case class CollectList(
/**
* Collect a set of unique elements.
+ *
+ * @param ignoreNulls when true (IGNORE NULLS), null values are excluded from
the result array.
+ * When false (RESPECT NULLS), null values are included in
the result array.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Collects and returns a set of unique elements.",
@@ -184,16 +187,33 @@ case class CollectList(
case class CollectSet(
child: Expression,
mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
+ inputAggBufferOffset: Int = 0,
+ ignoreNulls: Boolean = true)
extends Collect[mutable.HashSet[Any]] with QueryErrorsBase with
UnaryLike[Expression] {
- def this(child: Expression) = this(child, 0, 0)
+ def this(child: Expression) = this(child, 0, 0, true)
+
+ // Buffer can contain nulls when ignoreNulls is false (RESPECT NULLS)
+ override protected def bufferContainsNull: Boolean = !ignoreNulls
override lazy val bufferElementType = child.dataType match {
case BinaryType => ArrayType(ByteType)
case other => other
}
+ override def update(
+ buffer: mutable.HashSet[Any],
+ input: InternalRow): mutable.HashSet[Any] = {
+ val value = child.eval(input)
+ if (value != null) {
+ buffer += convertToBufferElement(value)
+ } else if (!ignoreNulls) {
+ // RESPECT NULLS: preserve null value in result
+ buffer += null
+ }
+ buffer
+ }
+
override def convertToBufferElement(value: Any): Any = child.dataType match {
/*
* collect_set() of BinaryType should not return duplicate elements,
@@ -207,7 +227,10 @@ case class CollectSet(
override def eval(buffer: mutable.HashSet[Any]): Any = {
val array = child.dataType match {
case BinaryType =>
-
buffer.iterator.map(_.asInstanceOf[ArrayData].toByteArray()).toArray[Any]
+ buffer.iterator.map {
+ case null => null
+ case v => v.asInstanceOf[ArrayData].toByteArray()
+ }.toArray[Any]
case _ => buffer.toArray
}
new GenericArrayData(array)
@@ -238,6 +261,11 @@ case class CollectSet(
override def createAggregationBuffer(): mutable.HashSet[Any] =
mutable.HashSet.empty
+ override def toString: String = {
+ val ignoreNullsStr = if (ignoreNulls) "" else " respect nulls"
+ s"$prettyName($child)$ignoreNullsStr"
+ }
+
override protected def withNewChildInternal(newChild: Expression):
CollectSet =
copy(child = newChild)
}
diff --git
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_set.explain
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_set.explain
index 18246a74ccc9..ca3cd174afc6 100644
---
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_set.explain
+++
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_set.explain
@@ -1,2 +1,2 @@
-Aggregate [collect_set(a#0, 0, 0) AS collect_set(a)#0]
+Aggregate [collect_set(a#0, 0, 0, true) AS collect_set(a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
index 9f3552e6d6e2..44512caf1def 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
@@ -488,7 +488,7 @@ Project [t1a#x, scalar-subquery#x [t1a#x] AS count_t2#xL,
scalar-subquery#x [t1a
: : +- Project [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x,
t2g#x, t2h#x, t2i#x]
: : +- SubqueryAlias t2
: : +- LocalRelation [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x,
t2f#x, t2g#x, t2h#x, t2i#x]
-: :- Aggregate [sort_array(collect_set(t2d#xL, 0, 0), true) AS
sort_array(collect_set(t2d), true)#x]
+: :- Aggregate [sort_array(collect_set(t2d#xL, 0, 0, true), true) AS
sort_array(collect_set(t2d), true)#x]
: : +- Filter (t2a#x = outer(t1a#x))
: : +- SubqueryAlias t2
: : +- View (`t2`, [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x,
t2h#x, t2i#x])
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
index 11164ececc93..1ceb67f58d69 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
@@ -389,7 +389,7 @@ Project [udf(val)#x, cate#x, max#x, min#x, min#x, count#xL,
sum#xL, avg#x, stdde
+- Sort [cate#x ASC NULLS FIRST, cast(udf(cast(val#x as string)) as int) ASC
NULLS FIRST], true
+- Project [udf(val)#x, cate#x, max#x, min#x, min#x, count#xL, sum#xL,
avg#x, stddev#x, first_value#x, first_value_ignore_null#x,
first_value_contain_null#x, any_value#x, any_value_ignore_null#x,
any_value_contain_null#x, last_value#x, last_value_ignore_null#x,
last_value_contain_null#x, rank#x, dense_rank#x, cume_dist#x, percent_rank#x,
ntile#x, row_number#x, var_pop#x, var_samp#x, approx_count_distinct#xL,
covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x, collect_list#x, collect_set
[...]
+- Project [udf(val)#x, cate#x, _w0#x, _w1#x, _w2#x, _w3#x, _w4#x,
max#x, min#x, min#x, count#xL, sum#xL, avg#x, stddev#x, first_value#x,
first_value_ignore_null#x, first_value_contain_null#x, any_value#x,
any_value_ignore_null#x, any_value_contain_null#x, last_value#x,
last_value_ignore_null#x, last_value_contain_null#x, rank#x, dense_rank#x,
cume_dist#x, percent_rank#x, ntile#x, row_number#x, var_pop#x, var_samp#x,
approx_count_distinct#xL, covar_pop#x, corr#x, stddev_samp#x, std [...]
- +- Window [max(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$()))
AS max#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, count(_w0#x) windowspecdefinition(_w1#x, [...]
+ +- Window [max(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$()))
AS max#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, count(_w0#x) windowspecdefinition(_w1#x, [...]
+- Project [cast(udf(cast(val#x as string)) as int) AS udf(val)#x,
cate#x, cast(udf(cast(val#x as string)) as int) AS _w0#x, cast(udf(cast(cate#x
as string)) as string) AS _w1#x, cast(cast(udf(cast(val#x as string)) as int)
as double) AS _w2#x, cast(cast(udf(cast(val_long#xL as string)) as bigint) as
double) AS _w3#x, cast(udf(cast(val_double#x as string)) as double) AS _w4#x,
val#x]
+- SubqueryAlias testdata
+- View (`testData`, [val#x, val_long#xL, val_double#x,
val_date#x, val_timestamp#x, cate#x])
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
index b585e01a75de..a441256d3bf0 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
@@ -585,7 +585,7 @@ ORDER BY cate, val
Sort [cate#x ASC NULLS FIRST, val#x ASC NULLS FIRST], true
+- Project [val#x, cate#x, max#x, min#x, min#x, count#xL, sum#xL, avg#x,
stddev#x, first_value#x, first_value_ignore_null#x, first_value_contain_null#x,
any_value#x, any_value_ignore_null#x, any_value_contain_null#x, last_value#x,
last_value_ignore_null#x, last_value_contain_null#x, rank#x, dense_rank#x,
cume_dist#x, percent_rank#x, ntile#x, row_number#x, var_pop#x, var_samp#x,
approx_count_distinct#xL, covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x,
collect_list#x, collect_set#x, skew [...]
+- Project [val#x, cate#x, _w0#x, _w1#x, val_double#x, max#x, min#x, min#x,
count#xL, sum#xL, avg#x, stddev#x, first_value#x, first_value_ignore_null#x,
first_value_contain_null#x, any_value#x, any_value_ignore_null#x,
any_value_contain_null#x, last_value#x, last_value_ignore_null#x,
last_value_contain_null#x, rank#x, dense_rank#x, cume_dist#x, percent_rank#x,
ntile#x, row_number#x, var_pop#x, var_samp#x, approx_count_distinct#xL,
covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x, coll [...]
- +- Window [max(val#x) windowspecdefinition(cate#x, val#x ASC NULLS
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$()))
AS max#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, count(val#x) windowspecdefinition(cate#x [...]
+ +- Window [max(val#x) windowspecdefinition(cate#x, val#x ASC NULLS
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$()))
AS max#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST,
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS
min#x, count(val#x) windowspecdefinition(cate#x [...]
+- Project [val#x, cate#x, cast(val#x as double) AS _w0#x,
cast(val_long#xL as double) AS _w1#x, val_double#x]
+- SubqueryAlias testdata
+- View (`testData`, [val#x, val_long#xL, val_double#x,
val_date#x, val_timestamp#x, cate#x])
diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
index be22c74f43b0..69ad095bc1a5 100644
--- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
@@ -1102,7 +1102,7 @@ ReadSchema: struct<key:int,val:string>
(2) ObjectHashAggregate
Input [2]: [key#x, val#x]
Keys [1]: [key#x]
-Functions [1]: [partial_collect_set(val#x, 0, 0)]
+Functions [1]: [partial_collect_set(val#x, 0, 0, true)]
Aggregate Attributes [1]: [buf#x]
Results [2]: [key#x, buf#x]
@@ -1113,9 +1113,9 @@ Arguments: hashpartitioning(key#x, 4),
ENSURE_REQUIREMENTS, [plan_id=x]
(4) ObjectHashAggregate
Input [2]: [key#x, buf#x]
Keys [1]: [key#x]
-Functions [1]: [collect_set(val#x, 0, 0)]
-Aggregate Attributes [1]: [collect_set(val#x, 0, 0)#x]
-Results [2]: [key#x, sort_array(collect_set(val#x, 0, 0)#x, true)[0] AS
sort_array(collect_set(val), true)[0]#x]
+Functions [1]: [collect_set(val#x, 0, 0, true)]
+Aggregate Attributes [1]: [collect_set(val#x)#x]
+Results [2]: [key#x, sort_array(collect_set(val#x)#x, true)[0] AS
sort_array(collect_set(val), true)[0]#x]
(5) AdaptiveSparkPlan
Output [2]: [key#x, sort_array(collect_set(val), true)[0]#x]
diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out
b/sql/core/src/test/resources/sql-tests/results/explain.sql.out
index e6db39f7913c..4d640fb2ad14 100644
--- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out
@@ -999,7 +999,7 @@ Input [2]: [key#x, val#x]
(3) ObjectHashAggregate
Input [2]: [key#x, val#x]
Keys [1]: [key#x]
-Functions [1]: [partial_collect_set(val#x, 0, 0)]
+Functions [1]: [partial_collect_set(val#x, 0, 0, true)]
Aggregate Attributes [1]: [buf#x]
Results [2]: [key#x, buf#x]
@@ -1010,9 +1010,9 @@ Arguments: hashpartitioning(key#x, 4),
ENSURE_REQUIREMENTS, [plan_id=x]
(5) ObjectHashAggregate
Input [2]: [key#x, buf#x]
Keys [1]: [key#x]
-Functions [1]: [collect_set(val#x, 0, 0)]
-Aggregate Attributes [1]: [collect_set(val#x, 0, 0)#x]
-Results [2]: [key#x, sort_array(collect_set(val#x, 0, 0)#x, true)[0] AS
sort_array(collect_set(val), true)[0]#x]
+Functions [1]: [collect_set(val#x, 0, 0, true)]
+Aggregate Attributes [1]: [collect_set(val#x)#x]
+Results [2]: [key#x, sort_array(collect_set(val#x)#x, true)[0] AS
sort_array(collect_set(val), true)[0]#x]
-- !query
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index ffc85fef1d10..6c14a869f201 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -635,6 +635,28 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(df.selectExpr("collect_list(b) RESPECT NULLS"), Seq(Row(Seq(2,
null, 4))))
}
+ test("collect_set skips nulls by default") {
+ val df = Seq((1, Some(2)), (2, None), (3, Some(2))).toDF("a", "b")
+
+ checkAnswer(df.selectExpr("sort_array(collect_set(b))"), Seq(Row(Seq(2))))
+ checkAnswer(df.select(sort_array(collect_set($"b"))), Seq(Row(Seq(2))))
+ }
+
+ test("collect_set with IGNORE NULLS explicitly skips nulls") {
+ val df = Seq((1, Some(2)), (2, None), (3, Some(4))).toDF("a", "b")
+
+ checkAnswer(
+ df.selectExpr("sort_array(collect_set(b) IGNORE NULLS)"), Seq(Row(Seq(2,
4))))
+ }
+
+ test("collect_set with RESPECT NULLS preserves null in set") {
+ val df = Seq((1, Some(2)), (2, None), (3, Some(2))).toDF("a", "b")
+
+ // RESPECT NULLS preserves null value in the set
+ checkAnswer(
+ df.selectExpr("sort_array(collect_set(b) RESPECT NULLS)"),
Seq(Row(Seq(null, 2))))
+ }
+
test("collect functions structs") {
val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1))
.toDF("a", "x", "y")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]