This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 04b821c69e85 [SPARK-55256][SQL] Support IGNORE NULLS / RESPECT NULLS 
for array_agg and collect_list
04b821c69e85 is described below

commit 04b821c69e85be5f51a1270b3a9a4155afdb5334
Author: Kent Yao <[email protected]>
AuthorDate: Fri Jan 30 22:33:51 2026 +0800

    [SPARK-55256][SQL] Support IGNORE NULLS / RESPECT NULLS for array_agg and 
collect_list
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for the IGNORE NULLS and RESPECT NULLS clauses for 
`array_agg` and `collect_list` aggregate functions.
    
    ### Why are the changes needed?
    
    The SQL standard and many databases (PostgreSQL, Snowflake, DuckDB, etc.) 
support the IGNORE NULLS / RESPECT NULLS syntax for aggregate functions. 
Currently, Spark only supports this syntax for window functions like `first`, 
`last`, `lead`, `lag`, and `nth_value`.
    
    By adding this support to `array_agg` and `collect_list`, users can 
explicitly control whether null values should be included in the resulting 
array:
    - `array_agg(col) IGNORE NULLS` - skips null values (default behavior)
    - `array_agg(col) RESPECT NULLS` - includes null values in the result
    
    ### Implementation Details
    
    1. Added `ignoreNulls: Boolean = true` parameter to `CollectList` class
    2. `array_agg` now uses `CollectList` as they have identical behavior
    3. Changed `UnresolvedFunction.ignoreNulls` from `Boolean` to 
`Option[Boolean]` to distinguish between:
       - `None`: no clause specified (use function's default)
       - `Some(true)`: IGNORE NULLS explicitly specified
       - `Some(false)`: RESPECT NULLS explicitly specified
    4. Consolidated ignoreNulls resolution logic in `FunctionResolution` with 
shared `resolveIgnoreNulls` and `applyIgnoreNulls` methods
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Users can now use IGNORE NULLS / RESPECT NULLS with array_agg and 
collect_list:
    ```sql
    SELECT array_agg(col IGNORE NULLS) FROM table;
    SELECT collect_list(col RESPECT NULLS) OVER (PARTITION BY id) FROM table;
    ```
    
    ### How was this patch tested?
    
    Added unit tests in DataFrameAggregateSuite:
    - `array_agg and collect_list skip nulls by default`
    - `array_agg with IGNORE NULLS explicitly skips nulls`
    - `array_agg with RESPECT NULLS preserves nulls`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, GitHub Copilot was used to assist with this implementation.
    
    Closes #54034 from yaooqinn/SPARK-55256-ignore-nulls-array-agg.
    
    Authored-by: Kent Yao <[email protected]>
    Signed-off-by: Kent Yao <[email protected]>
---
 .../sql/catalyst/analysis/FunctionResolution.scala | 83 ++++++++++++----------
 .../spark/sql/catalyst/analysis/unresolved.scala   |  2 +-
 .../catalyst/expressions/aggregate/collect.scala   | 38 ++++++++--
 .../spark/sql/catalyst/expressions/literals.scala  |  2 +-
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  2 +-
 .../explain-results/function_array_agg.explain     |  2 +-
 .../explain-results/function_collect_list.explain  |  2 +-
 .../spark/sql/classic/DataFrameWriterV2.scala      |  2 +-
 .../generators-resolution-edge-cases.sql.out       |  6 +-
 .../sql-tests/analyzer-results/group-by.sql.out    |  4 +-
 .../scalar-subquery/scalar-subquery-select.sql.out |  2 +-
 .../analyzer-results/udf/udf-window.sql.out        |  2 +-
 .../sql-tests/analyzer-results/window.sql.out      |  2 +-
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 25 +++++++
 14 files changed, 119 insertions(+), 55 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
index 29f4db65def0..63a75a8aa2b8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
@@ -152,18 +152,8 @@ class FunctionResolution(
             wf.prettyName,
             "FILTER clause"
           )
-        } else if (u.ignoreNulls) {
-          wf match {
-            case nthValue: NthValue =>
-              nthValue.copy(ignoreNulls = u.ignoreNulls)
-            case _ =>
-              throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
-                wf.prettyName,
-                "IGNORE NULLS"
-              )
-          }
         } else {
-          wf
+          resolveIgnoreNulls(wf, u.ignoreNulls)
         }
       case owf: FrameLessOffsetWindowFunction =>
         if (u.isDistinct) {
@@ -176,15 +166,8 @@ class FunctionResolution(
             owf.prettyName,
             "FILTER clause"
           )
-        } else if (u.ignoreNulls) {
-          owf match {
-            case lead: Lead =>
-              lead.copy(ignoreNulls = u.ignoreNulls)
-            case lag: Lag =>
-              lag.copy(ignoreNulls = u.ignoreNulls)
-          }
         } else {
-          owf
+          resolveIgnoreNulls(owf, u.ignoreNulls)
         }
       // We get an aggregate function, we need to wrap it in an 
AggregateExpression.
       case agg: AggregateFunction =>
@@ -216,21 +199,8 @@ class FunctionResolution(
             )
           case _ =>
         }
-        if (u.ignoreNulls) {
-          val aggFunc = newAgg match {
-            case first: First => first.copy(ignoreNulls = u.ignoreNulls)
-            case last: Last => last.copy(ignoreNulls = u.ignoreNulls)
-            case any_value: AnyValue => any_value.copy(ignoreNulls = 
u.ignoreNulls)
-            case _ =>
-              throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
-                newAgg.prettyName,
-                "IGNORE NULLS"
-              )
-          }
-          aggFunc.toAggregateExpression(u.isDistinct, u.filter)
-        } else {
-          newAgg.toAggregateExpression(u.isDistinct, u.filter)
-        }
+        val aggFunc = resolveIgnoreNulls(newAgg, u.ignoreNulls)
+        aggFunc.toAggregateExpression(u.isDistinct, u.filter)
       // This function is not an aggregate function, just return the resolved 
one.
       case other =>
         checkUnsupportedAggregateClause(other, u)
@@ -258,7 +228,8 @@ class FunctionResolution(
         "FILTER clause"
       )
     }
-    if (u.ignoreNulls) {
+    // Only fail for IGNORE NULLS; RESPECT NULLS is the default behavior
+    if (u.ignoreNulls.contains(true)) {
       throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
         func.prettyName,
         "IGNORE NULLS"
@@ -266,6 +237,42 @@ class FunctionResolution(
     }
   }
 
+  /**
+   * Resolves the IGNORE NULLS / RESPECT NULLS clause for a function.
+   * If ignoreNulls is defined, applies it to the function; otherwise returns 
unchanged.
+   */
+  private def resolveIgnoreNulls[T <: Expression](func: T, ignoreNulls: 
Option[Boolean]): T = {
+    ignoreNulls.map(applyIgnoreNulls(func, _)).getOrElse(func)
+  }
+
+  /**
+   * Applies the IGNORE NULLS / RESPECT NULLS clause to functions that support 
it.
+   * Returns the modified function if supported, throws error otherwise.
+   */
+  private def applyIgnoreNulls[T <: Expression](func: T, ignoreNulls: 
Boolean): T = {
+    val result = func match {
+      // Window functions
+      case nthValue: NthValue => nthValue.copy(ignoreNulls = ignoreNulls)
+      case lead: Lead => lead.copy(ignoreNulls = ignoreNulls)
+      case lag: Lag => lag.copy(ignoreNulls = ignoreNulls)
+      // Aggregate functions
+      case first: First => first.copy(ignoreNulls = ignoreNulls)
+      case last: Last => last.copy(ignoreNulls = ignoreNulls)
+      case anyValue: AnyValue => anyValue.copy(ignoreNulls = ignoreNulls)
+      case collectList: CollectList => collectList.copy(ignoreNulls = 
ignoreNulls)
+      case _ if ignoreNulls =>
+        // Only fail for IGNORE NULLS; RESPECT NULLS is the default behavior
+        throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+          func.prettyName,
+          "IGNORE NULLS"
+        )
+      case _ =>
+        // RESPECT NULLS is the default, silently return unchanged
+        func
+    }
+    result.asInstanceOf[T]
+  }
+
   private def resolveV2Function(
       unbound: UnboundFunction,
       arguments: Seq[Expression],
@@ -312,7 +319,8 @@ class FunctionResolution(
         scalarFunc.name(),
         "FILTER clause"
       )
-    } else if (u.ignoreNulls) {
+    } else if (u.ignoreNulls.contains(true)) {
+      // Only fail for IGNORE NULLS; RESPECT NULLS is the default behavior
       throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
         scalarFunc.name(),
         "IGNORE NULLS"
@@ -326,7 +334,8 @@ class FunctionResolution(
       aggFunc: V2AggregateFunction[_, _],
       arguments: Seq[Expression],
       u: UnresolvedFunction): Expression = {
-    if (u.ignoreNulls) {
+    // Only fail for IGNORE NULLS; RESPECT NULLS is the default behavior
+    if (u.ignoreNulls.contains(true)) {
       throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
         aggFunc.name(),
         "IGNORE NULLS"
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 3bdeb6d71884..fffbc7511a1d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -369,7 +369,7 @@ case class UnresolvedFunction(
     arguments: Seq[Expression],
     isDistinct: Boolean,
     filter: Option[Expression] = None,
-    ignoreNulls: Boolean = false,
+    ignoreNulls: Option[Boolean] = None,
     orderingWithinGroup: Seq[SortOrder] = Seq.empty,
     isInternal: Boolean = false)
   extends Expression with Unevaluable {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index 015bd1e3e142..29163d08297d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -48,12 +48,17 @@ abstract class Collect[T <: Growable[Any] with 
Iterable[Any]] extends TypedImper
 
   override def nullable: Boolean = false
 
-  override def dataType: DataType = ArrayType(child.dataType, false)
+  // Subclasses can override bufferContainsNull to indicate if the result 
array contains nulls
+  override def dataType: DataType = ArrayType(child.dataType, 
bufferContainsNull)
 
   override def defaultResult: Option[Literal] = Option(Literal.create(Array(), 
dataType))
 
   protected def convertToBufferElement(value: Any): Any
 
+  // Subclasses can override this to allow nulls in buffer
+  // (e.g., CollectList with ignoreNulls=false)
+  protected def bufferContainsNull: Boolean = false
+
   override def update(buffer: T, input: InternalRow): T = {
     val value = child.eval(input)
 
@@ -72,7 +77,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] 
extends TypedImper
   protected val bufferElementType: DataType
 
   private lazy val projection = UnsafeProjection.create(
-    Array[DataType](ArrayType(elementType = bufferElementType, containsNull = 
false)))
+    Array[DataType](ArrayType(elementType = bufferElementType, containsNull = 
bufferContainsNull)))
   private lazy val row = new UnsafeRow(1)
 
   override def serialize(obj: T): Array[Byte] = {
@@ -90,6 +95,9 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] 
extends TypedImper
 
 /**
  * Collect a list of elements.
+ *
+ * @param ignoreNulls when true (IGNORE NULLS), null values are excluded from 
the result array.
+ *                    When false (RESPECT NULLS), null values are included in 
the result array.
  */
 @ExpressionDescription(
   usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.",
@@ -107,15 +115,32 @@ abstract class Collect[T <: Growable[Any] with 
Iterable[Any]] extends TypedImper
 case class CollectList(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]]
+    inputAggBufferOffset: Int = 0,
+    ignoreNulls: Boolean = true) extends Collect[mutable.ArrayBuffer[Any]]
   with UnaryLike[Expression] {
 
-  def this(child: Expression) = this(child, 0, 0)
+  def this(child: Expression) = this(child, 0, 0, true)
+
+  // Buffer can contain nulls when ignoreNulls is false (RESPECT NULLS)
+  override protected def bufferContainsNull: Boolean = !ignoreNulls
 
   override lazy val bufferElementType = child.dataType
 
   override def convertToBufferElement(value: Any): Any = 
InternalRow.copyValue(value)
 
+  override def update(
+      buffer: mutable.ArrayBuffer[Any],
+      input: InternalRow): mutable.ArrayBuffer[Any] = {
+    val value = child.eval(input)
+    if (value != null) {
+      buffer += convertToBufferElement(value)
+    } else if (!ignoreNulls) {
+      // RESPECT NULLS: preserve null values in result
+      buffer += null
+    }
+    buffer
+  }
+
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 
@@ -130,6 +155,11 @@ case class CollectList(
     new GenericArrayData(buffer.toArray)
   }
 
+  override def toString: String = {
+    val ignoreNullsStr = if (ignoreNulls) "" else " respect nulls"
+    s"$prettyName($child)$ignoreNullsStr"
+  }
+
   override protected def withNewChildInternal(newChild: Expression): 
CollectList =
     copy(child = newChild)
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 94dbce19d4f0..6448194f9705 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -281,7 +281,7 @@ object Literal {
         assert(u.nameParts.length == 1)
         assert(!u.isDistinct)
         assert(u.filter.isEmpty)
-        assert(!u.ignoreNulls)
+        assert(u.ignoreNulls.isEmpty)
         assert(u.orderingWithinGroup.isEmpty)
         assert(!u.isInternal)
         
FunctionRegistry.builtin.lookupFunction(FunctionIdentifier(u.nameParts.head), 
u.arguments)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index fbecd1933ec5..0e9844d7f1a2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -3253,7 +3253,7 @@ class AstBuilder extends DataTypeAstBuilder
     val order = ctx.sortItem.asScala.map(visitSortItem)
     val filter = Option(ctx.where).map(expression(_))
     val ignoreNulls =
-      Option(ctx.nullsOption).map(_.getType == 
SqlBaseParser.IGNORE).getOrElse(false)
+      Option(ctx.nullsOption).map(_.getType == SqlBaseParser.IGNORE)
 
     // Is this an IDENTIFIER clause instead of a function call?
     if (ctx.functionName.identFunc != null &&
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain
index 102f736c62ef..27da00bd0c86 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain
@@ -1,2 +1,2 @@
-Aggregate [collect_list(a#0, 0, 0) AS collect_list(a)#0]
+Aggregate [collect_list(a#0, 0, 0, true) AS collect_list(a)#0]
 +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_list.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_list.explain
index 102f736c62ef..27da00bd0c86 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_list.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_collect_list.explain
@@ -1,2 +1,2 @@
-Aggregate [collect_list(a#0, 0, 0) AS collect_list(a)#0]
+Aggregate [collect_list(a#0, 0, 0, true) AS collect_list(a)#0]
 +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
index 19cc9b76beae..1309b346ffa5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
@@ -253,7 +253,7 @@ private object PartitionTransform {
     private val NAMES = Seq(name)
 
     def unapply(e: Expression): Option[Seq[Expression]] = e match {
-      case UnresolvedFunction(NAMES, children, false, None, false, Nil, true) 
=> Option(children)
+      case UnresolvedFunction(NAMES, children, false, None, None, Nil, true) 
=> Option(children)
       case _ => None
     }
   }
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/generators-resolution-edge-cases.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/generators-resolution-edge-cases.sql.out
index 2cbb1fb5d382..5a87586b150a 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/generators-resolution-edge-cases.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/generators-resolution-edge-cases.sql.out
@@ -71,7 +71,7 @@ SELECT explode(collect_list(id)) AS val FROM range(5)
 -- !query analysis
 Project [val#xL]
 +- Generate explode(_gen_input_0#x), false, [val#xL]
-   +- Aggregate [collect_list(id#xL, 0, 0) AS _gen_input_0#x]
+   +- Aggregate [collect_list(id#xL, 0, 0, true) AS _gen_input_0#x]
       +- Range (0, 5, step=1)
 
 
@@ -80,7 +80,7 @@ SELECT explode(collect_list(id)) AS val, count(*) AS cnt FROM 
range(3)
 -- !query analysis
 Project [val#xL, cnt#xL]
 +- Generate explode(_gen_input_0#x), false, [val#xL]
-   +- Aggregate [collect_list(id#xL, 0, 0) AS _gen_input_0#x, count(1) AS 
cnt#xL]
+   +- Aggregate [collect_list(id#xL, 0, 0, true) AS _gen_input_0#x, count(1) 
AS cnt#xL]
       +- Range (0, 3, step=1)
 
 
@@ -204,7 +204,7 @@ GROUP BY ALL
 -- !query analysis
 Project [val#x]
 +- Generate explode(_gen_input_0#x), false, [val#x]
-   +- Aggregate [collect_list(a#x, 0, 0) AS _gen_input_0#x]
+   +- Aggregate [collect_list(a#x, 0, 0, true) AS _gen_input_0#x]
       +- SubqueryAlias t
          +- Project [col1#x AS a#x]
             +- LocalRelation [col1#x]
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
index a2d3f4cdb016..80d63901af4c 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
@@ -1044,7 +1044,7 @@ SELECT
 FROM VALUES
   (1), (2), (1) AS tab(col)
 -- !query analysis
-Aggregate [collect_list(col#x, 0, 0) AS collect_list(col)#x, 
collect_list(col#x, 0, 0) AS collect_list(col)#x]
+Aggregate [collect_list(col#x, 0, 0, true) AS collect_list(col)#x, 
collect_list(col#x, 0, 0, true) AS collect_list(col)#x]
 +- SubqueryAlias tab
    +- LocalRelation [col#x]
 
@@ -1058,7 +1058,7 @@ FROM VALUES
   (1,4),(2,3),(1,4),(2,4) AS v(a,b)
 GROUP BY a
 -- !query analysis
-Aggregate [a#x], [a#x, collect_list(b#x, 0, 0) AS collect_list(b)#x, 
collect_list(b#x, 0, 0) AS collect_list(b)#x]
+Aggregate [a#x], [a#x, collect_list(b#x, 0, 0, true) AS collect_list(b)#x, 
collect_list(b#x, 0, 0, true) AS collect_list(b)#x]
 +- SubqueryAlias v
    +- LocalRelation [a#x, b#x]
 
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
index f64b3736b552..9f3552e6d6e2 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out
@@ -480,7 +480,7 @@ Project [t1a#x, scalar-subquery#x [t1a#x] AS count_t2#xL, 
scalar-subquery#x [t1a
 :  :              +- Project [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, 
t2g#x, t2h#x, t2i#x]
 :  :                 +- SubqueryAlias t2
 :  :                    +- LocalRelation [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, 
t2f#x, t2g#x, t2h#x, t2i#x]
-:  :- Aggregate [collect_list(t2d#xL, 0, 0) AS collect_list(t2d)#x]
+:  :- Aggregate [collect_list(t2d#xL, 0, 0, true) AS collect_list(t2d)#x]
 :  :  +- Filter (t2a#x = outer(t1a#x))
 :  :     +- SubqueryAlias t2
 :  :        +- View (`t2`, [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, 
t2h#x, t2i#x])
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
index aa3f40b62ccd..11164ececc93 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-window.sql.out
@@ -389,7 +389,7 @@ Project [udf(val)#x, cate#x, max#x, min#x, min#x, count#xL, 
sum#xL, avg#x, stdde
 +- Sort [cate#x ASC NULLS FIRST, cast(udf(cast(val#x as string)) as int) ASC 
NULLS FIRST], true
    +- Project [udf(val)#x, cate#x, max#x, min#x, min#x, count#xL, sum#xL, 
avg#x, stddev#x, first_value#x, first_value_ignore_null#x, 
first_value_contain_null#x, any_value#x, any_value_ignore_null#x, 
any_value_contain_null#x, last_value#x, last_value_ignore_null#x, 
last_value_contain_null#x, rank#x, dense_rank#x, cume_dist#x, percent_rank#x, 
ntile#x, row_number#x, var_pop#x, var_samp#x, approx_count_distinct#xL, 
covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x, collect_list#x, collect_set 
[...]
       +- Project [udf(val)#x, cate#x, _w0#x, _w1#x, _w2#x, _w3#x, _w4#x, 
max#x, min#x, min#x, count#xL, sum#xL, avg#x, stddev#x, first_value#x, 
first_value_ignore_null#x, first_value_contain_null#x, any_value#x, 
any_value_ignore_null#x, any_value_contain_null#x, last_value#x, 
last_value_ignore_null#x, last_value_contain_null#x, rank#x, dense_rank#x, 
cume_dist#x, percent_rank#x, ntile#x, row_number#x, var_pop#x, var_samp#x, 
approx_count_distinct#xL, covar_pop#x, corr#x, stddev_samp#x, std [...]
-         +- Window [max(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS 
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) 
AS max#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, count(_w0#x) windowspecdefinition(_w1#x, [...]
+         +- Window [max(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS 
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) 
AS max#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, min(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, count(_w0#x) windowspecdefinition(_w1#x, [...]
             +- Project [cast(udf(cast(val#x as string)) as int) AS udf(val)#x, 
cate#x, cast(udf(cast(val#x as string)) as int) AS _w0#x, cast(udf(cast(cate#x 
as string)) as string) AS _w1#x, cast(cast(udf(cast(val#x as string)) as int) 
as double) AS _w2#x, cast(cast(udf(cast(val_long#xL as string)) as bigint) as 
double) AS _w3#x, cast(udf(cast(val_double#x as string)) as double) AS _w4#x, 
val#x]
                +- SubqueryAlias testdata
                   +- View (`testData`, [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x])
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
index 00cf492574cf..b585e01a75de 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
@@ -585,7 +585,7 @@ ORDER BY cate, val
 Sort [cate#x ASC NULLS FIRST, val#x ASC NULLS FIRST], true
 +- Project [val#x, cate#x, max#x, min#x, min#x, count#xL, sum#xL, avg#x, 
stddev#x, first_value#x, first_value_ignore_null#x, first_value_contain_null#x, 
any_value#x, any_value_ignore_null#x, any_value_contain_null#x, last_value#x, 
last_value_ignore_null#x, last_value_contain_null#x, rank#x, dense_rank#x, 
cume_dist#x, percent_rank#x, ntile#x, row_number#x, var_pop#x, var_samp#x, 
approx_count_distinct#xL, covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x, 
collect_list#x, collect_set#x, skew [...]
    +- Project [val#x, cate#x, _w0#x, _w1#x, val_double#x, max#x, min#x, min#x, 
count#xL, sum#xL, avg#x, stddev#x, first_value#x, first_value_ignore_null#x, 
first_value_contain_null#x, any_value#x, any_value_ignore_null#x, 
any_value_contain_null#x, last_value#x, last_value_ignore_null#x, 
last_value_contain_null#x, rank#x, dense_rank#x, cume_dist#x, percent_rank#x, 
ntile#x, row_number#x, var_pop#x, var_samp#x, approx_count_distinct#xL, 
covar_pop#x, corr#x, stddev_samp#x, stddev_pop#x, coll [...]
-      +- Window [max(val#x) windowspecdefinition(cate#x, val#x ASC NULLS 
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) 
AS max#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, count(val#x) windowspecdefinition(cate#x [...]
+      +- Window [max(val#x) windowspecdefinition(cate#x, val#x ASC NULLS 
FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) 
AS max#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, min(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, 
specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS 
min#x, count(val#x) windowspecdefinition(cate#x [...]
          +- Project [val#x, cate#x, cast(val#x as double) AS _w0#x, 
cast(val_long#xL as double) AS _w1#x, val_double#x]
             +- SubqueryAlias testdata
                +- View (`testData`, [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x])
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 1a56b01a851d..0a9e08f5c57e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -610,6 +610,31 @@ class DataFrameAggregateSuite extends QueryTest
     )
   }
 
+  test("SPARK-55256: array_agg and collect_list skip nulls by default") {
+    val df = Seq((1, Some(2)), (2, None), (3, Some(4))).toDF("a", "b")
+
+    // Both functions skip nulls by default
+    checkAnswer(df.selectExpr("array_agg(b)"), Seq(Row(Seq(2, 4))))
+    checkAnswer(df.select(array_agg($"b")), Seq(Row(Seq(2, 4))))
+    checkAnswer(df.selectExpr("collect_list(b)"), Seq(Row(Seq(2, 4))))
+    checkAnswer(df.select(collect_list($"b")), Seq(Row(Seq(2, 4))))
+  }
+
+  test("SPARK-55256: array_agg with IGNORE NULLS explicitly skips nulls") {
+    val df = Seq((1, Some(2)), (2, None), (3, Some(4))).toDF("a", "b")
+
+    checkAnswer(df.selectExpr("array_agg(b) IGNORE NULLS"), Seq(Row(Seq(2, 
4))))
+    checkAnswer(df.selectExpr("collect_list(b) IGNORE NULLS"), Seq(Row(Seq(2, 
4))))
+  }
+
+  test("SPARK-55256: array_agg with RESPECT NULLS preserves nulls") {
+    val df = Seq((1, Some(2)), (2, None), (3, Some(4))).toDF("a", "b")
+
+    // RESPECT NULLS preserves null values in the result
+    checkAnswer(df.selectExpr("array_agg(b) RESPECT NULLS"), Seq(Row(Seq(2, 
null, 4))))
+    checkAnswer(df.selectExpr("collect_list(b) RESPECT NULLS"), Seq(Row(Seq(2, 
null, 4))))
+  }
+
   test("collect functions structs") {
     val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1))
       .toDF("a", "x", "y")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to