This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 821d410fc0 feat(spark): Implement collect_list/collect_set aggregate 
functions (#19699)
821d410fc0 is described below

commit 821d410fc0da0770a834bcb83bf486dee28ec0b2
Author: cht42 <[email protected]>
AuthorDate: Sat Jan 10 06:17:45 2026 +0400

    feat(spark): Implement collect_list/collect_set aggregate functions (#19699)
    
    ## Which issue does this PR close?
    
    - Part of #15914
    - Closes #17923
    - Close #17924
    
    ## Rationale for this change
    
    ## What changes are included in this PR?
    
    Implementation of spark `collect_list` and `collect_set` aggregate
    functions.
    
    ## Are these changes tested?
    
    yes
    
    ## Are there any user-facing changes?
    
    yes
---
 Cargo.lock                                         |   1 +
 datafusion/functions-aggregate/src/array_agg.rs    |   2 +-
 datafusion/spark/Cargo.toml                        |   1 +
 datafusion/spark/src/function/aggregate/collect.rs | 200 +++++++++++++++++++++
 datafusion/spark/src/function/aggregate/mod.rs     |  19 +-
 .../test_files/spark/aggregate/collect.slt         |  93 ++++++++++
 6 files changed, 314 insertions(+), 2 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index aefd6a63f5..ad347e1072 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2570,6 +2570,7 @@ dependencies = [
  "datafusion-execution",
  "datafusion-expr",
  "datafusion-functions",
+ "datafusion-functions-aggregate",
  "datafusion-functions-nested",
  "log",
  "percent-encoding",
diff --git a/datafusion/functions-aggregate/src/array_agg.rs 
b/datafusion/functions-aggregate/src/array_agg.rs
index 9b2e7429ab..c07958a858 100644
--- a/datafusion/functions-aggregate/src/array_agg.rs
+++ b/datafusion/functions-aggregate/src/array_agg.rs
@@ -415,7 +415,7 @@ impl Accumulator for ArrayAggAccumulator {
 }
 
 #[derive(Debug)]
-struct DistinctArrayAggAccumulator {
+pub struct DistinctArrayAggAccumulator {
     values: HashSet<ScalarValue>,
     datatype: DataType,
     sort_options: Option<SortOptions>,
diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml
index 673b62c5c3..0dc35f4a87 100644
--- a/datafusion/spark/Cargo.toml
+++ b/datafusion/spark/Cargo.toml
@@ -48,6 +48,7 @@ datafusion-common = { workspace = true }
 datafusion-execution = { workspace = true }
 datafusion-expr = { workspace = true }
 datafusion-functions = { workspace = true, features = ["crypto_expressions"] }
+datafusion-functions-aggregate = { workspace = true }
 datafusion-functions-nested = { workspace = true }
 log = { workspace = true }
 percent-encoding = "2.3.2"
diff --git a/datafusion/spark/src/function/aggregate/collect.rs 
b/datafusion/spark/src/function/aggregate/collect.rs
new file mode 100644
index 0000000000..50497e2826
--- /dev/null
+++ b/datafusion/spark/src/function/aggregate/collect.rs
@@ -0,0 +1,200 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::ArrayRef;
+use arrow::datatypes::{DataType, Field, FieldRef};
+use datafusion_common::utils::SingleRowListArrayBuilder;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
+use datafusion_expr::utils::format_state_name;
+use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
+use datafusion_functions_aggregate::array_agg::{
+    ArrayAggAccumulator, DistinctArrayAggAccumulator,
+};
+use std::{any::Any, sync::Arc};
+
+// Spark implementation of collect_list/collect_set aggregate function.
+// Differs from DataFusion ArrayAgg in the following ways:
+// - ignores NULL inputs
+// - returns an empty list when all inputs are NULL
+// - does not support ordering
+
+// <https://spark.apache.org/docs/latest/api/sql/index.html#collect_list>
+#[derive(Debug, PartialEq, Eq, Hash)]
+pub struct SparkCollectList {
+    signature: Signature,
+}
+
+impl Default for SparkCollectList {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl SparkCollectList {
+    pub fn new() -> Self {
+        Self {
+            signature: Signature::any(1, Volatility::Immutable),
+        }
+    }
+}
+
+impl AggregateUDFImpl for SparkCollectList {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "collect_list"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::List(Arc::new(Field::new_list_field(
+            arg_types[0].clone(),
+            true,
+        ))))
+    }
+
+    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
+        Ok(vec![
+            Field::new_list(
+                format_state_name(args.name, "collect_list"),
+                
Field::new_list_field(args.input_fields[0].data_type().clone(), true),
+                true,
+            )
+            .into(),
+        ])
+    }
+
+    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        let field = &acc_args.expr_fields[0];
+        let data_type = field.data_type().clone();
+        let ignore_nulls = true;
+        Ok(Box::new(NullToEmptyListAccumulator::new(
+            ArrayAggAccumulator::try_new(&data_type, ignore_nulls)?,
+            data_type,
+        )))
+    }
+}
+
+// <https://spark.apache.org/docs/latest/api/sql/index.html#collect_set>
+#[derive(Debug, PartialEq, Eq, Hash)]
+pub struct SparkCollectSet {
+    signature: Signature,
+}
+
+impl Default for SparkCollectSet {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl SparkCollectSet {
+    pub fn new() -> Self {
+        Self {
+            signature: Signature::any(1, Volatility::Immutable),
+        }
+    }
+}
+
+impl AggregateUDFImpl for SparkCollectSet {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "collect_set"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::List(Arc::new(Field::new_list_field(
+            arg_types[0].clone(),
+            true,
+        ))))
+    }
+
+    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
+        Ok(vec![
+            Field::new_list(
+                format_state_name(args.name, "collect_set"),
+                
Field::new_list_field(args.input_fields[0].data_type().clone(), true),
+                true,
+            )
+            .into(),
+        ])
+    }
+
+    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        let field = &acc_args.expr_fields[0];
+        let data_type = field.data_type().clone();
+        let ignore_nulls = true;
+        Ok(Box::new(NullToEmptyListAccumulator::new(
+            DistinctArrayAggAccumulator::try_new(&data_type, None, 
ignore_nulls)?,
+            data_type,
+        )))
+    }
+}
+
+/// Wrapper accumulator that returns an empty list instead of NULL when all 
inputs are NULL.
+/// This implements Spark's behavior for collect_list and collect_set.
+#[derive(Debug)]
+struct NullToEmptyListAccumulator<T: Accumulator> {
+    inner: T,
+    data_type: DataType,
+}
+
+impl<T: Accumulator> NullToEmptyListAccumulator<T> {
+    pub fn new(inner: T, data_type: DataType) -> Self {
+        Self { inner, data_type }
+    }
+}
+
+impl<T: Accumulator> Accumulator for NullToEmptyListAccumulator<T> {
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        self.inner.update_batch(values)
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        self.inner.merge_batch(states)
+    }
+
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        self.inner.state()
+    }
+
+    fn evaluate(&mut self) -> Result<ScalarValue> {
+        let result = self.inner.evaluate()?;
+        if result.is_null() {
+            let empty_array = arrow::array::new_empty_array(&self.data_type);
+            Ok(SingleRowListArrayBuilder::new(empty_array).build_list_scalar())
+        } else {
+            Ok(result)
+        }
+    }
+
+    fn size(&self) -> usize {
+        self.inner.size() + self.data_type.size()
+    }
+}
diff --git a/datafusion/spark/src/function/aggregate/mod.rs 
b/datafusion/spark/src/function/aggregate/mod.rs
index 3db72669d4..d6a2fe7a85 100644
--- a/datafusion/spark/src/function/aggregate/mod.rs
+++ b/datafusion/spark/src/function/aggregate/mod.rs
@@ -19,6 +19,7 @@ use datafusion_expr::AggregateUDF;
 use std::sync::Arc;
 
 pub mod avg;
+pub mod collect;
 pub mod try_sum;
 
 pub mod expr_fn {
@@ -30,6 +31,16 @@ pub mod expr_fn {
         "Returns the sum of values for a column, or NULL if overflow occurs",
         arg1
     ));
+    export_functions!((
+        collect_list,
+        "Returns a list created from the values in a column",
+        arg1
+    ));
+    export_functions!((
+        collect_set,
+        "Returns a set created from the values in a column",
+        arg1
+    ));
 }
 
 // TODO: try use something like datafusion_functions_aggregate::create_func!()
@@ -39,7 +50,13 @@ pub fn avg() -> Arc<AggregateUDF> {
 pub fn try_sum() -> Arc<AggregateUDF> {
     Arc::new(AggregateUDF::new_from_impl(try_sum::SparkTrySum::new()))
 }
+pub fn collect_list() -> Arc<AggregateUDF> {
+    Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectList::new()))
+}
+pub fn collect_set() -> Arc<AggregateUDF> {
+    Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectSet::new()))
+}
 
 pub fn functions() -> Vec<Arc<AggregateUDF>> {
-    vec![avg(), try_sum()]
+    vec![avg(), try_sum(), collect_list(), collect_set()]
 }
diff --git a/datafusion/sqllogictest/test_files/spark/aggregate/collect.slt 
b/datafusion/sqllogictest/test_files/spark/aggregate/collect.slt
new file mode 100644
index 0000000000..2bd80e2e13
--- /dev/null
+++ b/datafusion/sqllogictest/test_files/spark/aggregate/collect.slt
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+
+#   http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+query ?
+SELECT collect_list(a) FROM (VALUES (1), (2), (3)) AS t(a);
+----
+[1, 2, 3]
+
+query ?
+SELECT collect_list(a) FROM (VALUES (1), (2), (2), (3), (1)) AS t(a);
+----
+[1, 2, 2, 3, 1]
+
+query ?
+SELECT collect_list(a) FROM (VALUES (1), (NULL), (3)) AS t(a);
+----
+[1, 3]
+
+query ?
+SELECT collect_list(a) FROM (VALUES (CAST(NULL AS INT)), (NULL), (NULL)) AS 
t(a);
+----
+[]
+
+query I?
+SELECT g, collect_list(a)
+FROM (VALUES (1, 10), (1, 20), (2, 30), (2, 30), (1, 10)) AS t(g, a)
+GROUP BY g
+ORDER BY g;
+----
+1 [10, 20, 10]
+2 [30, 30]
+
+query I?
+SELECT g, collect_list(a)
+FROM (VALUES (1, 10), (1, NULL), (2, 20), (2, NULL)) AS t(g, a)
+GROUP BY g
+ORDER BY g;
+----
+1 [10]
+2 [20]
+
+# we need to wrap collect_set with array_sort to have consistent outputs
+query ?
+SELECT array_sort(collect_set(a)) FROM (VALUES (1), (2), (3)) AS t(a);
+----
+[1, 2, 3]
+
+query ?
+SELECT array_sort(collect_set(a)) FROM (VALUES (1), (2), (2), (3), (1)) AS 
t(a);
+----
+[1, 2, 3]
+
+query ?
+SELECT array_sort(collect_set(a)) FROM (VALUES (1), (NULL), (3)) AS t(a);
+----
+[1, 3]
+
+query ?
+SELECT array_sort(collect_set(a)) FROM (VALUES (CAST(NULL AS INT)), (NULL), 
(NULL)) AS t(a);
+----
+[]
+
+query I?
+SELECT g, array_sort(collect_set(a))
+FROM (VALUES (1, 10), (1, 20), (2, 30), (2, 30), (1, 10)) AS t(g, a)
+GROUP BY g
+ORDER BY g;
+----
+1 [10, 20]
+2 [30]
+
+query I?
+SELECT g, array_sort(collect_set(a))
+FROM (VALUES (1, 10), (1, NULL), (1, NULL), (2, 20), (2, NULL)) AS t(g, a)
+GROUP BY g
+ORDER BY g;
+----
+1 [10]
+2 [20]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to