This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 76ced31429 feat: impl the basic `string_agg` function (#8148)
76ced31429 is described below
commit 76ced31429a4e324f9f57cb3e521e75739171e38
Author: Huaijin <[email protected]>
AuthorDate: Sat Nov 18 18:58:17 2023 +0800
feat: impl the basic `string_agg` function (#8148)
* init impl
* add support for larget utf8
* add some test
* support null
* remove redundance code
* remove redundance code
* add more test
* Update datafusion/physical-expr/src/aggregate/string_agg.rs
Co-authored-by: universalmind303 <[email protected]>
* Update datafusion/physical-expr/src/aggregate/string_agg.rs
Co-authored-by: universalmind303 <[email protected]>
* add suggest
* Update datafusion/physical-expr/src/aggregate/string_agg.rs
Co-authored-by: Andrew Lamb <[email protected]>
* Update datafusion/sqllogictest/test_files/aggregate.slt
Co-authored-by: Andrew Lamb <[email protected]>
* Update datafusion/sqllogictest/test_files/aggregate.slt
Co-authored-by: Andrew Lamb <[email protected]>
* fix ci
---------
Co-authored-by: universalmind303 <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/expr/src/aggregate_function.rs | 8 +
datafusion/expr/src/type_coercion/aggregates.rs | 26 +++
datafusion/physical-expr/src/aggregate/build_in.rs | 16 ++
datafusion/physical-expr/src/aggregate/mod.rs | 1 +
.../physical-expr/src/aggregate/string_agg.rs | 246 +++++++++++++++++++++
datafusion/physical-expr/src/expressions/mod.rs | 1 +
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/generated/pbjson.rs | 3 +
datafusion/proto/src/generated/prost.rs | 3 +
datafusion/proto/src/logical_plan/from_proto.rs | 1 +
datafusion/proto/src/logical_plan/to_proto.rs | 4 +
datafusion/sqllogictest/test_files/aggregate.slt | 76 +++++++
12 files changed, 386 insertions(+)
diff --git a/datafusion/expr/src/aggregate_function.rs
b/datafusion/expr/src/aggregate_function.rs
index ea0b018251..4611c7fb10 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -100,6 +100,8 @@ pub enum AggregateFunction {
BoolAnd,
/// Bool Or
BoolOr,
+ /// string_agg
+ StringAgg,
}
impl AggregateFunction {
@@ -141,6 +143,7 @@ impl AggregateFunction {
BitXor => "BIT_XOR",
BoolAnd => "BOOL_AND",
BoolOr => "BOOL_OR",
+ StringAgg => "STRING_AGG",
}
}
}
@@ -171,6 +174,7 @@ impl FromStr for AggregateFunction {
"array_agg" => AggregateFunction::ArrayAgg,
"first_value" => AggregateFunction::FirstValue,
"last_value" => AggregateFunction::LastValue,
+ "string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
"covar" => AggregateFunction::Covariance,
@@ -299,6 +303,7 @@ impl AggregateFunction {
AggregateFunction::FirstValue | AggregateFunction::LastValue => {
Ok(coerced_data_types[0].clone())
}
+ AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
}
}
}
@@ -408,6 +413,9 @@ impl AggregateFunction {
.collect(),
Volatility::Immutable,
),
+ AggregateFunction::StringAgg => {
+ Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable)
+ }
}
}
}
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs
b/datafusion/expr/src/type_coercion/aggregates.rs
index 261c406d5d..7128b57597 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -298,6 +298,23 @@ pub fn coerce_types(
| AggregateFunction::FirstValue
| AggregateFunction::LastValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
+ AggregateFunction::StringAgg => {
+ if !is_string_agg_supported_arg_type(&input_types[0]) {
+ return plan_err!(
+ "The function {:?} does not support inputs of type {:?}",
+ agg_fun,
+ input_types[0]
+ );
+ }
+ if !is_string_agg_supported_arg_type(&input_types[1]) {
+ return plan_err!(
+ "The function {:?} does not support inputs of type {:?}",
+ agg_fun,
+ input_types[1]
+ );
+ }
+ Ok(vec![LargeUtf8, input_types[1].clone()])
+ }
}
}
@@ -565,6 +582,15 @@ pub fn
is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool
)
}
+/// Return `true` if `arg_type` is of a [`DataType`] that the
+/// [`AggregateFunction::StringAgg`] aggregation can operate on.
+pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool {
+ matches!(
+ arg_type,
+ DataType::Utf8 | DataType::LargeUtf8 | DataType::Null
+ )
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 596197b4ee..c40f0db194 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -369,6 +369,22 @@ pub fn create_aggregate_expr(
ordering_req.to_vec(),
ordering_types,
)),
+ (AggregateFunction::StringAgg, false) => {
+ if !ordering_req.is_empty() {
+ return not_impl_err!(
+ "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations
are not available"
+ );
+ }
+ Arc::new(expressions::StringAgg::new(
+ input_phy_exprs[0].clone(),
+ input_phy_exprs[1].clone(),
+ name,
+ data_type,
+ ))
+ }
+ (AggregateFunction::StringAgg, true) => {
+ return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not
available");
+ }
})
}
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs
b/datafusion/physical-expr/src/aggregate/mod.rs
index 442d018b87..329bb1e641 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -43,6 +43,7 @@ pub(crate) mod covariance;
pub(crate) mod first_last;
pub(crate) mod grouping;
pub(crate) mod median;
+pub(crate) mod string_agg;
#[macro_use]
pub(crate) mod min_max;
pub mod build_in;
diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs
b/datafusion/physical-expr/src/aggregate/string_agg.rs
new file mode 100644
index 0000000000..74c083959e
--- /dev/null
+++ b/datafusion/physical-expr/src/aggregate/string_agg.rs
@@ -0,0 +1,246 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the
`string_agg` function
+
+use crate::aggregate::utils::down_cast_any_ref;
+use crate::expressions::{format_state_name, Literal};
+use crate::{AggregateExpr, PhysicalExpr};
+use arrow::array::ArrayRef;
+use arrow::datatypes::{DataType, Field};
+use datafusion_common::cast::as_generic_string_array;
+use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
+use datafusion_expr::Accumulator;
+use std::any::Any;
+use std::sync::Arc;
+
+/// STRING_AGG aggregate expression
+#[derive(Debug)]
+pub struct StringAgg {
+ name: String,
+ data_type: DataType,
+ expr: Arc<dyn PhysicalExpr>,
+ delimiter: Arc<dyn PhysicalExpr>,
+ nullable: bool,
+}
+
+impl StringAgg {
+ /// Create a new StringAgg aggregate function
+ pub fn new(
+ expr: Arc<dyn PhysicalExpr>,
+ delimiter: Arc<dyn PhysicalExpr>,
+ name: impl Into<String>,
+ data_type: DataType,
+ ) -> Self {
+ Self {
+ name: name.into(),
+ data_type,
+ delimiter,
+ expr,
+ nullable: true,
+ }
+ }
+}
+
+impl AggregateExpr for StringAgg {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(
+ &self.name,
+ self.data_type.clone(),
+ self.nullable,
+ ))
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ if let Some(delimiter) =
self.delimiter.as_any().downcast_ref::<Literal>() {
+ match delimiter.value() {
+ ScalarValue::Utf8(Some(delimiter))
+ | ScalarValue::LargeUtf8(Some(delimiter)) => {
+ return Ok(Box::new(StringAggAccumulator::new(delimiter)));
+ }
+ ScalarValue::Null => {
+ return Ok(Box::new(StringAggAccumulator::new("")));
+ }
+ _ => return not_impl_err!("StringAgg not supported for {}",
self.name),
+ }
+ }
+ not_impl_err!("StringAgg not supported for {}", self.name)
+ }
+
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![Field::new(
+ format_state_name(&self.name, "string_agg"),
+ self.data_type.clone(),
+ self.nullable,
+ )])
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr.clone(), self.delimiter.clone()]
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+}
+
+impl PartialEq<dyn Any> for StringAgg {
+ fn eq(&self, other: &dyn Any) -> bool {
+ down_cast_any_ref(other)
+ .downcast_ref::<Self>()
+ .map(|x| {
+ self.name == x.name
+ && self.data_type == x.data_type
+ && self.expr.eq(&x.expr)
+ && self.delimiter.eq(&x.delimiter)
+ })
+ .unwrap_or(false)
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct StringAggAccumulator {
+ values: Option<String>,
+ delimiter: String,
+}
+
+impl StringAggAccumulator {
+ pub fn new(delimiter: &str) -> Self {
+ Self {
+ values: None,
+ delimiter: delimiter.to_string(),
+ }
+ }
+}
+
+impl Accumulator for StringAggAccumulator {
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])?
+ .iter()
+ .filter_map(|v| v.as_ref().map(ToString::to_string))
+ .collect();
+ if !string_array.is_empty() {
+ let s = string_array.join(self.delimiter.as_str());
+ let v = self.values.get_or_insert("".to_string());
+ if !v.is_empty() {
+ v.push_str(self.delimiter.as_str());
+ }
+ v.push_str(s.as_str());
+ }
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ self.update_batch(values)?;
+ Ok(())
+ }
+
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.evaluate()?])
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::LargeUtf8(self.values.clone()))
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0)
+ + self.delimiter.capacity()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::expressions::tests::aggregate;
+ use crate::expressions::{col, create_aggregate_expr, try_cast};
+ use arrow::array::ArrayRef;
+ use arrow::datatypes::*;
+ use arrow::record_batch::RecordBatch;
+ use arrow_array::LargeStringArray;
+ use arrow_array::StringArray;
+ use datafusion_expr::type_coercion::aggregates::coerce_types;
+ use datafusion_expr::AggregateFunction;
+
+ fn assert_string_aggregate(
+ array: ArrayRef,
+ function: AggregateFunction,
+ distinct: bool,
+ expected: ScalarValue,
+ delimiter: String,
+ ) {
+ let data_type = array.data_type();
+ let sig = function.signature();
+ let coerced =
+ coerce_types(&function, &[data_type.clone(), DataType::Utf8],
&sig).unwrap();
+
+ let input_schema = Schema::new(vec![Field::new("a", data_type.clone(),
true)]);
+ let batch =
+ RecordBatch::try_new(Arc::new(input_schema.clone()),
vec![array]).unwrap();
+
+ let input = try_cast(
+ col("a", &input_schema).unwrap(),
+ &input_schema,
+ coerced[0].clone(),
+ )
+ .unwrap();
+
+ let delimiter =
Arc::new(Literal::new(ScalarValue::Utf8(Some(delimiter))));
+ let schema = Schema::new(vec![Field::new("a", coerced[0].clone(),
true)]);
+ let agg = create_aggregate_expr(
+ &function,
+ distinct,
+ &[input, delimiter],
+ &[],
+ &schema,
+ "agg",
+ )
+ .unwrap();
+
+ let result = aggregate(&batch, agg).unwrap();
+ assert_eq!(expected, result);
+ }
+
+ #[test]
+ fn string_agg_utf8() {
+ let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l",
"o"]));
+ assert_string_aggregate(
+ a,
+ AggregateFunction::StringAgg,
+ false,
+ ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())),
+ ",".to_owned(),
+ );
+ }
+
+ #[test]
+ fn string_agg_largeutf8() {
+ let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l",
"l", "o"]));
+ assert_string_aggregate(
+ a,
+ AggregateFunction::StringAgg,
+ false,
+ ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())),
+ "|".to_owned(),
+ );
+ }
+}
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index 1919cac979..b6d0ad5b91 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -63,6 +63,7 @@ pub use crate::aggregate::min_max::{MaxAccumulator,
MinAccumulator};
pub use crate::aggregate::regr::{Regr, RegrType};
pub use crate::aggregate::stats::StatsType;
pub use crate::aggregate::stddev::{Stddev, StddevPop};
+pub use crate::aggregate::string_agg::StringAgg;
pub use crate::aggregate::sum::Sum;
pub use crate::aggregate::sum_distinct::DistinctSum;
pub use crate::aggregate::variance::{Variance, VariancePop};
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 750d12bd77..9d508078c7 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -686,6 +686,7 @@ enum AggregateFunction {
REGR_SXX = 32;
REGR_SYY = 33;
REGR_SXY = 34;
+ STRING_AGG = 35;
}
message AggregateExprNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index af64bd68de..0a8f415e20 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -474,6 +474,7 @@ impl serde::Serialize for AggregateFunction {
Self::RegrSxx => "REGR_SXX",
Self::RegrSyy => "REGR_SYY",
Self::RegrSxy => "REGR_SXY",
+ Self::StringAgg => "STRING_AGG",
};
serializer.serialize_str(variant)
}
@@ -520,6 +521,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"REGR_SXX",
"REGR_SYY",
"REGR_SXY",
+ "STRING_AGG",
];
struct GeneratedVisitor;
@@ -595,6 +597,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"REGR_SXX" => Ok(AggregateFunction::RegrSxx),
"REGR_SYY" => Ok(AggregateFunction::RegrSyy),
"REGR_SXY" => Ok(AggregateFunction::RegrSxy),
+ "STRING_AGG" => Ok(AggregateFunction::StringAgg),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index b23f09e91b..84fb84b948 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2881,6 +2881,7 @@ pub enum AggregateFunction {
RegrSxx = 32,
RegrSyy = 33,
RegrSxy = 34,
+ StringAgg = 35,
}
impl AggregateFunction {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -2926,6 +2927,7 @@ impl AggregateFunction {
AggregateFunction::RegrSxx => "REGR_SXX",
AggregateFunction::RegrSyy => "REGR_SYY",
AggregateFunction::RegrSxy => "REGR_SXY",
+ AggregateFunction::StringAgg => "STRING_AGG",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -2968,6 +2970,7 @@ impl AggregateFunction {
"REGR_SXX" => Some(Self::RegrSxx),
"REGR_SYY" => Some(Self::RegrSyy),
"REGR_SXY" => Some(Self::RegrSxy),
+ "STRING_AGG" => Some(Self::StringAgg),
_ => None,
}
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index f59a59f3c0..4ae45fa521 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -597,6 +597,7 @@ impl From<protobuf::AggregateFunction> for
AggregateFunction {
protobuf::AggregateFunction::Median => Self::Median,
protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue,
protobuf::AggregateFunction::LastValueAgg => Self::LastValue,
+ protobuf::AggregateFunction::StringAgg => Self::StringAgg,
}
}
}
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 8bf4258236..cf66e3ddd5 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -405,6 +405,7 @@ impl From<&AggregateFunction> for
protobuf::AggregateFunction {
AggregateFunction::Median => Self::Median,
AggregateFunction::FirstValue => Self::FirstValueAgg,
AggregateFunction::LastValue => Self::LastValueAgg,
+ AggregateFunction::StringAgg => Self::StringAgg,
}
}
}
@@ -721,6 +722,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
AggregateFunction::LastValue => {
protobuf::AggregateFunction::LastValueAgg
}
+ AggregateFunction::StringAgg => {
+ protobuf::AggregateFunction::StringAgg
+ }
};
let aggregate_expr = protobuf::AggregateExprNode {
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index a1bb93ed53..0a495dd2b0 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -2987,3 +2987,79 @@ NULL NULL 1 NULL 3 6 0 0 0
NULL NULL 1 NULL 5 15 0 0 0
3 0 2 1 5.5 16.5 0.5 4.5 1.5
3 0 3 1 6 18 2 18 6
+
+statement error
+SELECT STRING_AGG()
+
+statement error
+SELECT STRING_AGG(1,2,3)
+
+statement error
+SELECT STRING_AGG(STRING_AGG('a', ','))
+
+query T
+SELECT STRING_AGG('a', ',')
+----
+a
+
+query TTTT
+SELECT STRING_AGG('a',','), STRING_AGG('a', NULL), STRING_AGG(NULL, ','),
STRING_AGG(NULL, NULL)
+----
+a a NULL NULL
+
+query TT
+select string_agg('', '|'), string_agg('a', '');
+----
+(empty) a
+
+query T
+SELECT STRING_AGG(column1, '|') FROM (values (''), (null), (''));
+----
+|
+
+statement ok
+CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR)
+
+query ITT
+INSERT INTO strings VALUES (1,'a','/'), (1,'b','-'), (2,'i','/'),
(2,NULL,'-'), (2,'j','+'), (3,'p','/'), (4,'x','/'), (4,'y','-'), (4,'z','+')
+----
+9
+
+query IT
+SELECT g, STRING_AGG(x,'|') FROM strings GROUP BY g ORDER BY g
+----
+1 a|b
+2 i|j
+3 p
+4 x|y|z
+
+query T
+SELECT STRING_AGG(x,',') FROM strings WHERE g > 100
+----
+NULL
+
+statement ok
+drop table strings
+
+query T
+WITH my_data as (
+SELECT 'text1'::varchar(1000) as my_column union all
+SELECT 'text1'::varchar(1000) as my_column union all
+SELECT 'text1'::varchar(1000) as my_column
+)
+SELECT string_agg(my_column,', ') as my_string_agg
+FROM my_data
+----
+text1, text1, text1
+
+query T
+WITH my_data as (
+SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all
+SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all
+SELECT 1 as dummy, 'text1'::varchar(1000) as my_column
+)
+SELECT string_agg(my_column,', ') as my_string_agg
+FROM my_data
+GROUP BY dummy
+----
+text1, text1, text1