This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new fb158760 feat(rust/sedona-expr): Add GroupsAccumulator to framework
and implementation for ST_Envelope_Agg (#510)
fb158760 is described below
commit fb158760fa1630ae7cb3e48759884bbedeb8508a
Author: Dewey Dunnington <[email protected]>
AuthorDate: Thu Jan 15 20:54:30 2026 -0600
feat(rust/sedona-expr): Add GroupsAccumulator to framework and
implementation for ST_Envelope_Agg (#510)
Co-authored-by: Copilot <[email protected]>
---
python/sedonadb/tests/functions/test_aggregate.py | 114 ++++++++++-
rust/sedona-expr/src/aggregate_udf.rs | 61 +++++-
rust/sedona-functions/src/st_envelope_agg.rs | 232 +++++++++++++++++++++-
rust/sedona-testing/src/testers.rs | 147 +++++++++++---
4 files changed, 506 insertions(+), 48 deletions(-)
diff --git a/python/sedonadb/tests/functions/test_aggregate.py
b/python/sedonadb/tests/functions/test_aggregate.py
index 6f6344d1..fe345dca 100644
--- a/python/sedonadb/tests/functions/test_aggregate.py
+++ b/python/sedonadb/tests/functions/test_aggregate.py
@@ -16,14 +16,126 @@
# under the License.
import pytest
+import shapely
from sedonadb.testing import PostGIS, SedonaDB
+# Aggregate functions don't have a suffix in PostGIS
def agg_fn_suffix(eng):
- """Return the appropriate suffix for the aggregate function for the given
engine."""
return "" if isinstance(eng, PostGIS) else "_Agg"
+# ST_Envelope is not an aggregate function in PostGIS but we can check
+# behaviour using ST_Envelope(ST_Collect(...))
+def call_st_envelope_agg(eng, arg):
+ if isinstance(eng, PostGIS):
+ return f"ST_Envelope(ST_Collect({arg}))"
+ else:
+ return f"ST_Envelope_Agg({arg})"
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_points(eng):
+ eng = eng.create_or_skip()
+
+ eng.assert_query_result(
+ f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
+ VALUES
+ ('POINT (1 2)'),
+ ('POINT (3 4)'),
+ (NULL)
+ ) AS t(geom)""",
+ "POLYGON ((1 2, 1 4, 3 4, 3 2, 1 2))",
+ )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_all_null(eng):
+ eng = eng.create_or_skip()
+
+ eng.assert_query_result(
+ f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
+ VALUES
+ (NULL),
+ (NULL),
+ (NULL)
+ ) AS t(geom)""",
+ None,
+ )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_zero_input(eng):
+ eng = eng.create_or_skip()
+
+ eng.assert_query_result(
+ f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} AS
empty FROM (
+ VALUES
+ ('POINT (1 2)')
+ ) AS t(geom) WHERE false""",
+ None,
+ )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_single_point(eng):
+ eng = eng.create_or_skip()
+
+ eng.assert_query_result(
+ f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
+ VALUES ('POINT (5 5)')
+ ) AS t(geom)""",
+ "POINT (5 5)",
+ )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_collinear_points(eng):
+ eng = eng.create_or_skip()
+
+ eng.assert_query_result(
+ f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
+ VALUES
+ ('POINT (0 0)'),
+ ('POINT (0 1)'),
+ ('POINT (0 2)')
+ ) AS t(geom)""",
+ "LINESTRING (0 0, 0 2)",
+ )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_many_groups(eng, con):
+ eng = eng.create_or_skip()
+ num_groups = 1000
+
+ df_points = con.sql("""
+ SELECT id, geometry FROM sd_random_geometry('{"target_rows": 100000,
"seed": 9728}')
+ """)
+ eng.create_table_arrow("df_points", df_points.to_arrow_table())
+
+ result = eng.execute_and_collect(
+ f"""
+ SELECT
+ (id % {num_groups})::INTEGER AS id_mod,
+ {call_st_envelope_agg(eng, "geometry")} AS envelope
+ FROM df_points
+ GROUP BY id_mod
+ ORDER BY id_mod
+ """,
+ )
+
+ df_points_geopandas = df_points.to_pandas()
+ expected = (
+ df_points_geopandas.groupby(df_points_geopandas["id"] %
num_groups)["geometry"]
+ .apply(lambda group: shapely.box(*group.total_bounds))
+ .reset_index(name="envelope")
+ .rename(columns={"id": "id_mod"})
+ )
+
+ eng.assert_result(result, expected)
+
+
@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
def test_st_collect_points(eng):
eng = eng.create_or_skip()
diff --git a/rust/sedona-expr/src/aggregate_udf.rs
b/rust/sedona-expr/src/aggregate_udf.rs
index 3ec999d0..b8625e54 100644
--- a/rust/sedona-expr/src/aggregate_udf.rs
+++ b/rust/sedona-expr/src/aggregate_udf.rs
@@ -20,7 +20,7 @@ use arrow_schema::{DataType, FieldRef};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
- Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
+ Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator,
Signature, Volatility,
};
use sedona_common::sedona_internal_err;
use sedona_schema::datatypes::SedonaType;
@@ -102,6 +102,18 @@ impl SedonaAggregateUDF {
&self.kernels
}
+ fn accumulator_arg_types(args: &AccumulatorArgs) ->
Result<Vec<SedonaType>> {
+ let arg_fields = args
+ .exprs
+ .iter()
+ .map(|expr| expr.return_field(args.schema))
+ .collect::<Result<Vec<_>>>()?;
+ arg_fields
+ .iter()
+ .map(|field| SedonaType::from_storage_field(field))
+ .collect()
+ }
+
fn dispatch_impl(&self, args: &[SedonaType]) -> Result<(&dyn
SedonaAccumulator, SedonaType)> {
// Resolve kernels in reverse so that more recently added ones are
resolved first
for kernel in self.kernels.iter().rev() {
@@ -154,16 +166,27 @@ impl AggregateUDFImpl for SedonaAggregateUDF {
sedona_internal_err!("return_type() should not be called (use
return_field())")
}
+ fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+ if let Ok(arg_types) = Self::accumulator_arg_types(&args) {
+ if let Ok((accumulator, _)) = self.dispatch_impl(&arg_types) {
+ return accumulator.groups_accumulator_supported(&arg_types);
+ }
+ }
+
+ false
+ }
+
+ fn create_groups_accumulator(
+ &self,
+ args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
+ let arg_types = Self::accumulator_arg_types(&args)?;
+ let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
+ accumulator.groups_accumulator(&arg_types, &output_type)
+ }
+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
- let arg_fields = acc_args
- .exprs
- .iter()
- .map(|expr| expr.return_field(acc_args.schema))
- .collect::<Result<Vec<_>>>()?;
- let arg_types = arg_fields
- .iter()
- .map(|field| SedonaType::from_storage_field(field))
- .collect::<Result<Vec<_>>>()?;
+ let arg_types = Self::accumulator_arg_types(&acc_args)?;
let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
accumulator.accumulator(&arg_types, &output_type)
}
@@ -190,6 +213,24 @@ pub trait SedonaAccumulator: Debug {
output_type: &SedonaType,
) -> Result<Box<dyn Accumulator>>;
+ /// Given input data types, check if this implementation supports
GroupsAccumulator
+ fn groups_accumulator_supported(&self, _args: &[SedonaType]) -> bool {
+ false
+ }
+
+ /// Given input data types, resolve a [GroupsAccumulator]
+ ///
+ /// A GroupsAccumulator is an important optimization for aggregating many
small groups,
+ /// particularly when such an aggregation is cheap. See the DataFusion
documentation
+ /// for details.
+ fn groups_accumulator(
+ &self,
+ _args: &[SedonaType],
+ _output_type: &SedonaType,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
+ sedona_internal_err!("groups_accumulator not supported for {self:?}")
+ }
+
/// The fields representing the underlying serialized state of the
Accumulator
fn state_fields(&self, args: &[SedonaType]) -> Result<Vec<FieldRef>>;
}
diff --git a/rust/sedona-functions/src/st_envelope_agg.rs
b/rust/sedona-functions/src/st_envelope_agg.rs
index c077ac01..a5692001 100644
--- a/rust/sedona-functions/src/st_envelope_agg.rs
+++ b/rust/sedona-functions/src/st_envelope_agg.rs
@@ -14,24 +14,26 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
-use std::{sync::Arc, vec};
+use std::{iter::zip, sync::Arc, vec};
use crate::executor::WkbExecutor;
use crate::st_envelope::write_envelope;
-use arrow_array::ArrayRef;
+use arrow_array::{builder::BinaryBuilder, Array, ArrayRef, BooleanArray};
use arrow_schema::FieldRef;
use datafusion_common::{
error::{DataFusionError, Result},
ScalarValue,
};
use datafusion_expr::{
- scalar_doc_sections::DOC_SECTION_OTHER, Accumulator, ColumnarValue,
Documentation, Volatility,
+ scalar_doc_sections::DOC_SECTION_OTHER, Accumulator, ColumnarValue,
Documentation, EmitTo,
+ GroupsAccumulator, Volatility,
};
use sedona_common::sedona_internal_err;
use sedona_expr::aggregate_udf::{SedonaAccumulator, SedonaAggregateUDF};
use sedona_geometry::{
bounds::geo_traits_update_xy_bounds,
interval::{Interval, IntervalTrait},
+ wkb_factory::WKB_MIN_PROBABLE_BYTES,
};
use sedona_schema::{
datatypes::{SedonaType, WKB_GEOMETRY},
@@ -70,6 +72,18 @@ impl SedonaAccumulator for STEnvelopeAgg {
matcher.match_args(args)
}
+ fn groups_accumulator_supported(&self, _args: &[SedonaType]) -> bool {
+ true
+ }
+
+ fn groups_accumulator(
+ &self,
+ args: &[SedonaType],
+ _output_type: &SedonaType,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
+ Ok(Box::new(BoundsGroupsAccumulator2D::new(args[0].clone())))
+ }
+
fn accumulator(
&self,
args: &[SedonaType],
@@ -178,12 +192,167 @@ impl Accumulator for BoundsAccumulator2D {
}
}
+#[derive(Debug)]
+struct BoundsGroupsAccumulator2D {
+ input_type: SedonaType,
+ xs: Vec<Interval>,
+ ys: Vec<Interval>,
+ offset: usize,
+}
+
+impl BoundsGroupsAccumulator2D {
+ pub fn new(input_type: SedonaType) -> Self {
+ Self {
+ input_type,
+ xs: Vec::new(),
+ ys: Vec::new(),
+ offset: 0,
+ }
+ }
+
+ fn execute_update(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ // Check some of our assumptions about how this will be called
+ debug_assert_eq!(self.offset, 0);
+ debug_assert_eq!(values.len(), 1);
+ debug_assert_eq!(values[0].len(), group_indices.len());
+ if let Some(filter) = opt_filter {
+ debug_assert_eq!(values[0].len(), filter.len());
+ }
+
+ let arg_types = [self.input_type.clone()];
+ let args = [ColumnarValue::Array(values[0].clone())];
+ let executor = WkbExecutor::new(&arg_types, &args);
+ self.xs.resize(total_num_groups, Interval::empty());
+ self.ys.resize(total_num_groups, Interval::empty());
+ let mut i = 0;
+
+ if let Some(filter) = opt_filter {
+ let mut filter_iter = filter.iter();
+ executor.execute_wkb_void(|maybe_item| {
+ if filter_iter.next().unwrap().unwrap_or(false) {
+ let group_id = group_indices[i];
+ i += 1;
+ if let Some(item) = maybe_item {
+ geo_traits_update_xy_bounds(
+ item,
+ &mut self.xs[group_id],
+ &mut self.ys[group_id],
+ )
+ .map_err(|e| DataFusionError::External(Box::new(e)))?;
+ }
+ } else {
+ i += 1;
+ }
+
+ Ok(())
+ })?;
+ } else {
+ executor.execute_wkb_void(|maybe_item| {
+ let group_id = group_indices[i];
+ i += 1;
+ if let Some(item) = maybe_item {
+ geo_traits_update_xy_bounds(
+ item,
+ &mut self.xs[group_id],
+ &mut self.ys[group_id],
+ )
+ .map_err(|e| DataFusionError::External(Box::new(e)))?;
+ }
+
+ Ok(())
+ })?;
+ }
+
+ Ok(())
+ }
+
+ fn emit_wkb_result(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
+ let emit_size = match emit_to {
+ EmitTo::All => self.xs.len(),
+ EmitTo::First(n) => n,
+ };
+
+ let mut builder =
+ BinaryBuilder::with_capacity(emit_size, emit_size *
WKB_MIN_PROBABLE_BYTES);
+
+ let emit_range = self.offset..(self.offset + emit_size);
+ for (x, y) in zip(&self.xs[emit_range.clone()],
&self.ys[emit_range.clone()]) {
+ let written = write_envelope(&(*x).into(), y, &mut builder)?;
+ if written {
+ builder.append_value([]);
+ } else {
+ builder.append_null();
+ }
+ }
+
+ match emit_to {
+ EmitTo::All => {
+ self.xs = Vec::new();
+ self.ys = Vec::new();
+ self.offset = 0;
+ }
+ EmitTo::First(n) => {
+ self.offset += n;
+ }
+ }
+
+ Ok(Arc::new(builder.finish()))
+ }
+}
+
+impl GroupsAccumulator for BoundsGroupsAccumulator2D {
+ fn update_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ self.execute_update(values, group_indices, opt_filter,
total_num_groups)
+ }
+
+ fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+ Ok(vec![self.emit_wkb_result(emit_to)?])
+ }
+
+ fn merge_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&arrow_array::BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ // In this case, our state is identical to our input values
+ self.execute_update(values, group_indices, opt_filter,
total_num_groups)
+ }
+
+ fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
+ self.emit_wkb_result(emit_to)
+ }
+
+ fn size(&self) -> usize {
+ size_of::<BoundsGroupsAccumulator2D>()
+ + self.xs.capacity() * size_of::<Interval>()
+ + self.ys.capacity() * size_of::<Interval>()
+ }
+}
+
#[cfg(test)]
mod test {
use datafusion_expr::AggregateUDF;
use rstest::rstest;
use sedona_schema::datatypes::WKB_VIEW_GEOMETRY;
- use sedona_testing::{compare::assert_scalar_equal_wkb_geometry,
testers::AggregateUdfTester};
+ use sedona_testing::{
+ compare::{assert_array_equal, assert_scalar_equal_wkb_geometry},
+ create::create_array,
+ testers::AggregateUdfTester,
+ };
use super::*;
@@ -245,4 +414,59 @@ mod test {
Some("LINESTRING (0 1, 1 1)"),
);
}
+
+ #[test]
+ fn udf_grouped_accumulate() {
+ let tester = AggregateUdfTester::new(st_envelope_agg_udf().into(),
vec![WKB_GEOMETRY]);
+ assert_eq!(tester.return_type().unwrap(), WKB_GEOMETRY);
+
+ // Six elements, four groups, with one all null group and one
partially null group
+ let group_indices = vec![0, 3, 1, 1, 0, 2];
+ let array0 = create_array(
+ &[Some("POINT (0 1)"), None, Some("POINT (2 3)")],
+ &WKB_GEOMETRY,
+ );
+ let array1 = create_array(
+ &[Some("POINT (4 5)"), None, Some("POINT (6 7)")],
+ &WKB_GEOMETRY,
+ );
+ let batches = vec![array0, array1];
+
+ let expected = create_array(
+ &[
+ // First element only + a null
+ Some("POINT (0 1)"),
+ // Middle two elements
+ Some("POLYGON((2 3, 2 5, 4 5, 4 3, 2 3))"),
+ // Last element only
+ Some("POINT (6 7)"),
+ // Only null
+ None,
+ ],
+ &WKB_GEOMETRY,
+ );
+ let result = tester
+ .aggregate_groups(&batches, group_indices.clone(), None, vec![])
+ .unwrap();
+ assert_array_equal(&result, &expected);
+
+ // We should get the same answer even with a sequence of partial emits
+ let result = tester
+ .aggregate_groups(&batches, group_indices.clone(), None, vec![1,
1, 1, 1])
+ .unwrap();
+ assert_array_equal(&result, &expected);
+
+ // Also check with a filter (in this case, filter out all values except
+ // the middle two elements).
+ let filter = vec![false, false, true, true, false, false];
+ let expected = create_array(
+ &[None, Some("POLYGON((2 3, 2 5, 4 5, 4 3, 2 3))"), None, None],
+ &WKB_GEOMETRY,
+ );
+
+ let result = tester
+ .aggregate_groups(&batches, group_indices.clone(), Some(&filter),
vec![])
+ .unwrap();
+ assert_array_equal(&result, &expected);
+ }
}
diff --git a/rust/sedona-testing/src/testers.rs
b/rust/sedona-testing/src/testers.rs
index cb56bbb3..3e939b32 100644
--- a/rust/sedona-testing/src/testers.rs
+++ b/rust/sedona-testing/src/testers.rs
@@ -16,13 +16,15 @@
// under the License.
use std::{iter::zip, sync::Arc};
-use arrow_array::{ArrayRef, RecordBatch};
+use arrow_array::{ArrayRef, BooleanArray, RecordBatch};
use arrow_schema::{DataType, FieldRef, Schema};
-use datafusion_common::{config::ConfigOptions, Result, ScalarValue};
+use datafusion_common::{
+ arrow::compute::kernels::concat::concat, config::ConfigOptions, Result,
ScalarValue,
+};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
- Accumulator, AggregateUDF, ColumnarValue, Expr, Literal, ReturnFieldArgs,
ScalarFunctionArgs,
- ScalarUDF,
+ Accumulator, AggregateUDF, ColumnarValue, EmitTo, Expr, GroupsAccumulator,
Literal,
+ ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
};
use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
use sedona_common::sedona_internal_err;
@@ -49,23 +51,34 @@ use crate::{
pub struct AggregateUdfTester {
udf: AggregateUDF,
arg_types: Vec<SedonaType>,
+ mock_schema: Schema,
+ mock_exprs: Vec<Arc<dyn PhysicalExpr>>,
}
impl AggregateUdfTester {
/// Create a new tester
pub fn new(udf: AggregateUDF, arg_types: Vec<SedonaType>) -> Self {
- Self { udf, arg_types }
+ let arg_fields = arg_types
+ .iter()
+ .map(|sedona_type| sedona_type.to_storage_field("",
true).map(Arc::new))
+ .collect::<Result<Vec<_>>>()
+ .unwrap();
+ let mock_schema = Schema::new(arg_fields);
+
+ let mock_exprs = (0..arg_types.len())
+ .map(|i| -> Arc<dyn PhysicalExpr> { Arc::new(Column::new("col",
i)) })
+ .collect::<Vec<_>>();
+ Self {
+ udf,
+ arg_types,
+ mock_schema,
+ mock_exprs,
+ }
}
/// Compute the return type
pub fn return_type(&self) -> Result<SedonaType> {
- let arg_fields = self
- .arg_types
- .iter()
- .map(|arg_type| arg_type.to_storage_field("", true).map(Arc::new))
- .collect::<Result<Vec<_>>>()?;
-
- let out_field = self.udf.return_field(&arg_fields)?;
+ let out_field = self.udf.return_field(&self.mock_schema.fields)?;
SedonaType::from_storage_field(&out_field)
}
@@ -105,44 +118,112 @@ impl AggregateUdfTester {
state_accumulator.evaluate()
}
+ /// Perform a simple grouped aggregation
+ ///
+ /// Each batch in batches is accumulated with its own groups accumulator
+ /// and serialized into its own state, after which the state resulting
+ /// from each batch is merged into the final groups accumulator. This
+ /// has the effect of testing the pieces of a groups accumulator in a
+ /// predictable/debug-friendly (if artificial) way.
+ pub fn aggregate_groups(
+ &self,
+ batches: &Vec<ArrayRef>,
+ group_indices: Vec<usize>,
+ opt_filter: Option<&Vec<bool>>,
+ emit_sizes: Vec<usize>,
+ ) -> Result<ArrayRef> {
+ let state_schema = Arc::new(Schema::new(self.state_fields()?));
+ let mut state_accumulator = self.new_groups_accumulator()?;
+ let total_num_groups = group_indices.iter().max().unwrap_or(&0) + 1;
+
+ // Check input
+ let total_input_rows: usize = batches.iter().map(|a| a.len()).sum();
+ assert_eq!(total_input_rows, group_indices.len());
+ if let Some(filter) = opt_filter {
+ assert_eq!(total_input_rows, filter.len());
+ }
+ if !emit_sizes.is_empty() {
+ assert_eq!(emit_sizes.iter().sum::<usize>(), total_num_groups);
+ }
+
+ let mut offset = 0;
+ for batch in batches {
+ let mut batch_accumulator = self.new_groups_accumulator()?;
+ let opt_filter_array = opt_filter.map(|filter_vec| {
+ filter_vec[offset..(offset + batch.len())]
+ .iter()
+ .collect::<BooleanArray>()
+ });
+ batch_accumulator.update_batch(
+ std::slice::from_ref(batch),
+ &group_indices[offset..(offset + batch.len())],
+ opt_filter_array.as_ref(),
+ total_num_groups,
+ )?;
+ offset += batch.len();
+
+ // For the state accumulator the input is ordered such that
+ // each row is group i for i in (0..total_num_groups)
+ let state_batch = RecordBatch::try_new(
+ state_schema.clone(),
+ batch_accumulator.state(datafusion_expr::EmitTo::All)?,
+ )?;
+ state_accumulator.merge_batch(
+ state_batch.columns(),
+ &(0..total_num_groups).collect::<Vec<_>>(),
+ None,
+ total_num_groups,
+ )?;
+ }
+
+ if emit_sizes.is_empty() {
+ state_accumulator.evaluate(datafusion_expr::EmitTo::All)
+ } else {
+ let arrays = emit_sizes
+ .iter()
+ .map(|emit_size|
state_accumulator.evaluate(EmitTo::First(*emit_size)))
+ .collect::<Result<Vec<_>>>()?;
+ let arrays_ref = arrays.iter().map(|a|
a.as_ref()).collect::<Vec<_>>();
+ Ok(concat(&arrays_ref)?)
+ }
+ }
+
fn new_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- let mock_schema = Schema::new(self.arg_fields());
- let exprs = (0..self.arg_types.len())
- .map(|i| -> Arc<dyn PhysicalExpr> { Arc::new(Column::new("col",
i)) })
- .collect::<Vec<_>>();
- let accumulator_args = AccumulatorArgs {
- return_field: self.udf.return_field(mock_schema.fields())?,
- schema: &mock_schema,
+ let accumulator_args = self.accumulator_args()?;
+ self.udf.accumulator(accumulator_args)
+ }
+
+ fn new_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+ assert!(self
+ .udf
+ .groups_accumulator_supported(self.accumulator_args()?));
+ self.udf.create_groups_accumulator(self.accumulator_args()?)
+ }
+
+ fn accumulator_args(&self) -> Result<AccumulatorArgs<'_>> {
+ Ok(AccumulatorArgs {
+ return_field: self.udf.return_field(self.mock_schema.fields())?,
+ schema: &self.mock_schema,
ignore_nulls: true,
order_bys: &[],
is_reversed: false,
name: "",
is_distinct: false,
- exprs: &exprs,
+ exprs: &self.mock_exprs,
expr_fields: &[],
- };
-
- self.udf.accumulator(accumulator_args)
+ })
}
fn state_fields(&self) -> Result<Vec<FieldRef>> {
let state_field_args = StateFieldsArgs {
name: "",
- input_fields: &self.arg_fields(),
- return_field: self.udf.return_field(&self.arg_fields())?,
+ input_fields: self.mock_schema.fields(),
+ return_field: self.udf.return_field(self.mock_schema.fields())?,
ordering_fields: &[],
is_distinct: false,
};
self.udf.state_fields(state_field_args)
}
-
- fn arg_fields(&self) -> Vec<FieldRef> {
- self.arg_types
- .iter()
- .map(|sedona_type| sedona_type.to_storage_field("",
true).map(Arc::new))
- .collect::<Result<Vec<_>>>()
- .unwrap()
- }
}
/// Low-level tester for scalar functions