This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1469e0eea526 [SPARK-55533][SQL] Support IGNORE NULLS / RESPECT NULLS 
for collect_set
1469e0eea526 is described below

commit 1469e0eea526bea54acd59684cdf5615af8d71da
Author: Kent Yao <[email protected]>
AuthorDate: Sun Feb 15 11:16:47 2026 -0800

    [SPARK-55533][SQL] Support IGNORE NULLS / RESPECT NULLS for collect_set
    
    ### What changes were proposed in this pull request?
    
    This PR adds `IGNORE NULLS` / `RESPECT NULLS` support to `collect_set`, 
mirroring the existing `collect_list` behavior added in SPARK-55256.
    
    - `collect_set(expr)` — default, skips nulls (unchanged behavior)
    - `collect_set(expr) IGNORE NULLS` — explicitly skips nulls
    - `collect_set(expr) RESPECT NULLS` — includes null in the result set
    
    ### Why are the changes needed?
    
    For consistency: `collect_list`/`array_agg` already supports this syntax, 
but `collect_set` does not. Users who want to include null in collected sets 
currently have no way to do so.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. `collect_set` now accepts `IGNORE NULLS` and `RESPECT NULLS` clauses 
in SQL.
    
    ### How was this patch tested?
    
    Added 3 new tests in `DataFrameAggregateSuite`:
    - `collect_set skips nulls by default`
    - `collect_set with IGNORE NULLS explicitly skips nulls`
    - `collect_set with RESPECT NULLS preserves null in set`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, co-authored with GitHub Copilot.
    
    Closes #54329 from yaooqinn/SPARK-55533.
    
    Authored-by: Kent Yao <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../sql/catalyst/analysis/FunctionResolution.scala |  1 +
 .../catalyst/expressions/aggregate/collect.scala   | 34 ++++++++++++++++++++--
 .../explain-results/function_collect_set.explain   |  2 +-
 .../scalar-subquery/scalar-subquery-select.sql.out |  2 +-
 .../analyzer-results/udf/udf-window.sql.out        |  2 +-
 .../sql-tests/analyzer-results/window.sql.out      |  2 +-
 .../sql-tests/results/explain-aqe.sql.out          |  8 ++---
 .../resources/sql-tests/results/explain.sql.out    |  8 ++---
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 22 ++++++++++++++
 9 files changed, 66 insertions(+), 15 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
index 63a75a8aa2b8..19638a8e5966 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
@@ -260,6 +260,7 @@ class FunctionResolution(
       case last: Last => last.copy(ignoreNulls = ignoreNulls)
       case anyValue: AnyValue => anyValue.copy(ignoreNulls = ignoreNulls)
       case collectList: CollectList => collectList.copy(ignoreNulls = 
ignoreNulls)
+      case collectSet: CollectSet => collectSet.copy(ignoreNulls = ignoreNulls)
       case _ if ignoreNulls =>
         // Only fail for IGNORE NULLS; RESPECT NULLS is the default behavior
         throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index 29163d08297d..a864bf108436 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -166,6 +166,9 @@ case class CollectList(
 
 /**
  * Collect a set of unique elements.
+ *
+ * @param ignoreNulls when true (IGNORE NULLS), null values are excluded from 
the result array.
+ *                    When false (RESPECT NULLS), null values are included in 
the result array.
  */
 @ExpressionDescription(
   usage = "_FUNC_(expr) - Collects and returns a set of unique elements.",
@@ -184,16 +187,33 @@ case class CollectList(
 case class CollectSet(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0)
+    inputAggBufferOffset: Int = 0,
+    ignoreNulls: Boolean = true)
   extends Collect[mutable.HashSet[Any]] with QueryErrorsBase with 
UnaryLike[Expression] {
 
-  def this(child: Expression) = this(child, 0, 0)
+  def this(child: Expression) = this(child, 0, 0, true)
+
+  // Buffer can contain nulls when ignoreNulls is false (RESPECT NULLS)
+  override protected def bufferContainsNull: Boolean = !ignoreNulls
 
   override lazy val bufferElementType = child.dataType match {
     case BinaryType => ArrayType(ByteType)
     case other => other
   }
 
+  override def update(
+      buffer: mutable.HashSet[Any],
+      input: InternalRow): mutable.HashSet[Any] = {
+    val value = child.eval(input)
+    if (value != null) {
+      buffer += convertToBufferElement(value)
+    } else if (!ignoreNulls) {
+      // RESPECT NULLS: preserve null value in result
+      buffer += null
+    }
+    buffer
+  }
+
   override def convertToBufferElement(value: Any): Any = child.dataType match {
     /*
      * collect_set() of BinaryType should not return duplicate elements,
@@ -207,7 +227,10 @@ case class CollectSet(
   override def eval(buffer: mutable.HashSet[Any]): Any = {
     val array = child.dataType match {
       case BinaryType =>
-        
buffer.iterator.map(_.asInstanceOf[ArrayData].toByteArray()).toArray[Any]
+        buffer.iterator.map {
+          case null => null
+          case v => v.asInstanceOf[ArrayData].toByteArray()
+        }.toArray[Any]
       case _ => buffer.toArray
     }
     new GenericArrayData(array)
@@ -238,6 +261,11 @@ case class CollectSet(
 
   override def createAggregationBuffer(): mutable.HashSet[Any] = 
mutable.HashSet.empty
 
+  override def toString: String = {
+    val ignoreNullsStr = if (ignoreNulls) "" else " respect nulls"
+    s"$prettyName($child)$ignoreNullsStr"
+  }
+
   override protected def withNewChildInternal(newChild: Expression): 
CollectSet =
     copy(child = newChild)
 }
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_set.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_set.explain
index 18246a74ccc9..ca3cd174afc6 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_set.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_set.explain
@@ -1,2 +1,2 @@
-Aggregate [collect_set(a#0, 0, 0) AS collect_set(a)#0]
+Aggregate [collect_set(a#0, 0, 0, true) AS collect_set(a)#0]
 +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
index 9f3552e6d6e2..44512caf1def 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
@@ -488,7 +488,7 @@ Project [t1a#x, scalar-subquery#x [t1a#x] AS count_t2#xL, 
scalar-subquery#x [t1a
 :  :              +- Project [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, 
t2g#x, t2h#x, t2i#x]
 :  :                 +- SubqueryAlias t2
 :  :                    +- LocalRelation [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, 
t2f#x, t2g#x, t2h#x, t2i#x]
-:  :- Aggregate [sort_array(collect_set(t2d#xL, 0, 0), true) AS 
sort_array(collect_set(t2d), true)#x]
+:  :- Aggregate [sort_array(collect_set(t2d#xL, 0, 0, true), true) AS 
sort_array(collect_set(t2d), true)#x]
 :  :  +- Filter (t2a#x = outer(t1a#x))
 :  :     +- SubqueryAlias t2
 :  :        +- View (`t2`, [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, 
t2h#x, t2i#x])
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
index 11164ececc93..1ceb67f58d69 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
@@ -389,7 +389,7 @@ Project [udf(val)#x, cate#x, max#x, min#x, min#x, count#xL, 
sum#xL, avg#x, stdde
 +- Sort [cate#x ASC NULLS FIRST, cast(udf(cast(val#x as string)) as int) ASC 
NULLS FIRST], true
    +- Project [udf(val)#x, cate#x, max#x, min#x, min#x, count#xL, sum#xL, 
avg#x, stddev#x, first_value#x, first_value_ignore_null#x, 
first_value_contain_null#x, any_value#x, any_value_ignore_null#x, 
any_value_contain_null#x, last_value#x, last_value_ignore_null#x, 
last_value_contain_null#x, rank#x, dense_rank#x, cume_dist#x, percent_rank#x, 
ntile#x, row_number#x, var_pop#x, var_samp#x, approx_count_distinct#xL, 
covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x, collect_list#x, collect_set 
[...]
       +- Project [udf(val)#x, cate#x, _w0#x, _w1#x, _w2#x, _w3#x, _w4#x, 
max#x, min#x, min#x, count#xL, sum#xL, avg#x, stddev#x, first_value#x, 
first_value_ignore_null#x, first_value_contain_null#x, any_value#x, 
any_value_ignore_null#x, any_value_contain_null#x, last_value#x, 
last_value_ignore_null#x, last_value_contain_null#x, rank#x, dense_rank#x, 
cume_dist#x, percent_rank#x, ntile#x, row_number#x, var_pop#x, var_samp#x, 
approx_count_distinct#xL, covar_pop#x, corr#x, stddev_samp#x, std [...]
-         +- Window [max(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS 
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) 
AS max#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, count(_w0#x) windowspecdefinition(_w1#x, [...]
+         +- Window [max(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS 
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) 
AS max#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, count(_w0#x) windowspecdefinition(_w1#x, [...]
             +- Project [cast(udf(cast(val#x as string)) as int) AS udf(val)#x, 
cate#x, cast(udf(cast(val#x as string)) as int) AS _w0#x, cast(udf(cast(cate#x 
as string)) as string) AS _w1#x, cast(cast(udf(cast(val#x as string)) as int) 
as double) AS _w2#x, cast(cast(udf(cast(val_long#xL as string)) as bigint) as 
double) AS _w3#x, cast(udf(cast(val_double#x as string)) as double) AS _w4#x, 
val#x]
                +- SubqueryAlias testdata
                   +- View (`testData`, [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x])
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
index b585e01a75de..a441256d3bf0 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
@@ -585,7 +585,7 @@ ORDER BY cate, val
 Sort [cate#x ASC NULLS FIRST, val#x ASC NULLS FIRST], true
 +- Project [val#x, cate#x, max#x, min#x, min#x, count#xL, sum#xL, avg#x, 
stddev#x, first_value#x, first_value_ignore_null#x, first_value_contain_null#x, 
any_value#x, any_value_ignore_null#x, any_value_contain_null#x, last_value#x, 
last_value_ignore_null#x, last_value_contain_null#x, rank#x, dense_rank#x, 
cume_dist#x, percent_rank#x, ntile#x, row_number#x, var_pop#x, var_samp#x, 
approx_count_distinct#xL, covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x, 
collect_list#x, collect_set#x, skew [...]
    +- Project [val#x, cate#x, _w0#x, _w1#x, val_double#x, max#x, min#x, min#x, 
count#xL, sum#xL, avg#x, stddev#x, first_value#x, first_value_ignore_null#x, 
first_value_contain_null#x, any_value#x, any_value_ignore_null#x, 
any_value_contain_null#x, last_value#x, last_value_ignore_null#x, 
last_value_contain_null#x, rank#x, dense_rank#x, cume_dist#x, percent_rank#x, 
ntile#x, row_number#x, var_pop#x, var_samp#x, approx_count_distinct#xL, 
covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x, coll [...]
-      +- Window [max(val#x) windowspecdefinition(cate#x, val#x ASC NULLS 
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) 
AS max#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, count(val#x) windowspecdefinition(cate#x [...]
+      +- Window [max(val#x) windowspecdefinition(cate#x, val#x ASC NULLS 
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) 
AS max#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, count(val#x) windowspecdefinition(cate#x [...]
          +- Project [val#x, cate#x, cast(val#x as double) AS _w0#x, 
cast(val_long#xL as double) AS _w1#x, val_double#x]
             +- SubqueryAlias testdata
                +- View (`testData`, [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x])
diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out 
b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
index be22c74f43b0..69ad095bc1a5 100644
--- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
@@ -1102,7 +1102,7 @@ ReadSchema: struct<key:int,val:string>
 (2) ObjectHashAggregate
 Input [2]: [key#x, val#x]
 Keys [1]: [key#x]
-Functions [1]: [partial_collect_set(val#x, 0, 0)]
+Functions [1]: [partial_collect_set(val#x, 0, 0, true)]
 Aggregate Attributes [1]: [buf#x]
 Results [2]: [key#x, buf#x]
 
@@ -1113,9 +1113,9 @@ Arguments: hashpartitioning(key#x, 4), 
ENSURE_REQUIREMENTS, [plan_id=x]
 (4) ObjectHashAggregate
 Input [2]: [key#x, buf#x]
 Keys [1]: [key#x]
-Functions [1]: [collect_set(val#x, 0, 0)]
-Aggregate Attributes [1]: [collect_set(val#x, 0, 0)#x]
-Results [2]: [key#x, sort_array(collect_set(val#x, 0, 0)#x, true)[0] AS 
sort_array(collect_set(val), true)[0]#x]
+Functions [1]: [collect_set(val#x, 0, 0, true)]
+Aggregate Attributes [1]: [collect_set(val#x)#x]
+Results [2]: [key#x, sort_array(collect_set(val#x)#x, true)[0] AS 
sort_array(collect_set(val), true)[0]#x]
 
 (5) AdaptiveSparkPlan
 Output [2]: [key#x, sort_array(collect_set(val), true)[0]#x]
diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out 
b/sql/core/src/test/resources/sql-tests/results/explain.sql.out
index e6db39f7913c..4d640fb2ad14 100644
--- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out
@@ -999,7 +999,7 @@ Input [2]: [key#x, val#x]
 (3) ObjectHashAggregate
 Input [2]: [key#x, val#x]
 Keys [1]: [key#x]
-Functions [1]: [partial_collect_set(val#x, 0, 0)]
+Functions [1]: [partial_collect_set(val#x, 0, 0, true)]
 Aggregate Attributes [1]: [buf#x]
 Results [2]: [key#x, buf#x]
 
@@ -1010,9 +1010,9 @@ Arguments: hashpartitioning(key#x, 4), 
ENSURE_REQUIREMENTS, [plan_id=x]
 (5) ObjectHashAggregate
 Input [2]: [key#x, buf#x]
 Keys [1]: [key#x]
-Functions [1]: [collect_set(val#x, 0, 0)]
-Aggregate Attributes [1]: [collect_set(val#x, 0, 0)#x]
-Results [2]: [key#x, sort_array(collect_set(val#x, 0, 0)#x, true)[0] AS 
sort_array(collect_set(val), true)[0]#x]
+Functions [1]: [collect_set(val#x, 0, 0, true)]
+Aggregate Attributes [1]: [collect_set(val#x)#x]
+Results [2]: [key#x, sort_array(collect_set(val#x)#x, true)[0] AS 
sort_array(collect_set(val), true)[0]#x]
 
 
 -- !query
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index ffc85fef1d10..6c14a869f201 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -635,6 +635,28 @@ class DataFrameAggregateSuite extends QueryTest
     checkAnswer(df.selectExpr("collect_list(b) RESPECT NULLS"), Seq(Row(Seq(2, 
null, 4))))
   }
 
+  test("collect_set skips nulls by default") {
+    val df = Seq((1, Some(2)), (2, None), (3, Some(2))).toDF("a", "b")
+
+    checkAnswer(df.selectExpr("sort_array(collect_set(b))"), Seq(Row(Seq(2))))
+    checkAnswer(df.select(sort_array(collect_set($"b"))), Seq(Row(Seq(2))))
+  }
+
+  test("collect_set with IGNORE NULLS explicitly skips nulls") {
+    val df = Seq((1, Some(2)), (2, None), (3, Some(4))).toDF("a", "b")
+
+    checkAnswer(
+      df.selectExpr("sort_array(collect_set(b) IGNORE NULLS)"), Seq(Row(Seq(2, 
4))))
+  }
+
+  test("collect_set with RESPECT NULLS preserves null in set") {
+    val df = Seq((1, Some(2)), (2, None), (3, Some(2))).toDF("a", "b")
+
+    // RESPECT NULLS preserves null value in the set
+    checkAnswer(
+      df.selectExpr("sort_array(collect_set(b) RESPECT NULLS)"), 
Seq(Row(Seq(null, 2))))
+  }
+
   test("collect functions structs") {
     val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1))
       .toDF("a", "x", "y")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to