This is an automated email from the ASF dual-hosted git repository.

paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new fb158760 feat(rust/sedona-expr): Add GroupsAccumulator to framework 
and implementation for ST_Envelope_Agg (#510)
fb158760 is described below

commit fb158760fa1630ae7cb3e48759884bbedeb8508a
Author: Dewey Dunnington <[email protected]>
AuthorDate: Thu Jan 15 20:54:30 2026 -0600

    feat(rust/sedona-expr): Add GroupsAccumulator to framework and 
implementation for ST_Envelope_Agg (#510)
    
    Co-authored-by: Copilot <[email protected]>
---
 python/sedonadb/tests/functions/test_aggregate.py | 114 ++++++++++-
 rust/sedona-expr/src/aggregate_udf.rs             |  61 +++++-
 rust/sedona-functions/src/st_envelope_agg.rs      | 232 +++++++++++++++++++++-
 rust/sedona-testing/src/testers.rs                | 147 +++++++++++---
 4 files changed, 506 insertions(+), 48 deletions(-)

diff --git a/python/sedonadb/tests/functions/test_aggregate.py 
b/python/sedonadb/tests/functions/test_aggregate.py
index 6f6344d1..fe345dca 100644
--- a/python/sedonadb/tests/functions/test_aggregate.py
+++ b/python/sedonadb/tests/functions/test_aggregate.py
@@ -16,14 +16,126 @@
 # under the License.
 
 import pytest
+import shapely
 from sedonadb.testing import PostGIS, SedonaDB
 
 
+# Aggregate functions don't have a suffix in PostGIS
 def agg_fn_suffix(eng):
-    """Return the appropriate suffix for the aggregate function for the given 
engine."""
     return "" if isinstance(eng, PostGIS) else "_Agg"
 
 
+# ST_Envelope is not an aggregate function in PostGIS but we can check
+# behaviour using ST_Envelope(ST_Collect(...))
+def call_st_envelope_agg(eng, arg):
+    if isinstance(eng, PostGIS):
+        return f"ST_Envelope(ST_Collect({arg}))"
+    else:
+        return f"ST_Envelope_Agg({arg})"
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_points(eng):
+    eng = eng.create_or_skip()
+
+    eng.assert_query_result(
+        f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
+            VALUES
+                ('POINT (1 2)'),
+                ('POINT (3 4)'),
+                (NULL)
+        ) AS t(geom)""",
+        "POLYGON ((1 2, 1 4, 3 4, 3 2, 1 2))",
+    )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_all_null(eng):
+    eng = eng.create_or_skip()
+
+    eng.assert_query_result(
+        f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
+            VALUES
+                (NULL),
+                (NULL),
+                (NULL)
+        ) AS t(geom)""",
+        None,
+    )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_zero_input(eng):
+    eng = eng.create_or_skip()
+
+    eng.assert_query_result(
+        f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} AS 
empty FROM (
+            VALUES
+                ('POINT (1 2)')
+        ) AS t(geom) WHERE false""",
+        None,
+    )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_single_point(eng):
+    eng = eng.create_or_skip()
+
+    eng.assert_query_result(
+        f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
+            VALUES ('POINT (5 5)')
+        ) AS t(geom)""",
+        "POINT (5 5)",
+    )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_collinear_points(eng):
+    eng = eng.create_or_skip()
+
+    eng.assert_query_result(
+        f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
+            VALUES
+                ('POINT (0 0)'),
+                ('POINT (0 1)'),
+                ('POINT (0 2)')
+        ) AS t(geom)""",
+        "LINESTRING (0 0, 0 2)",
+    )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
+def test_st_envelope_agg_many_groups(eng, con):
+    eng = eng.create_or_skip()
+    num_groups = 1000
+
+    df_points = con.sql("""
+        SELECT id, geometry FROM sd_random_geometry('{"target_rows": 100000, 
"seed": 9728}')
+    """)
+    eng.create_table_arrow("df_points", df_points.to_arrow_table())
+
+    result = eng.execute_and_collect(
+        f"""
+        SELECT
+            (id % {num_groups})::INTEGER AS id_mod,
+            {call_st_envelope_agg(eng, "geometry")} AS envelope
+        FROM df_points
+        GROUP BY id_mod
+        ORDER BY id_mod
+        """,
+    )
+
+    df_points_geopandas = df_points.to_pandas()
+    expected = (
+        df_points_geopandas.groupby(df_points_geopandas["id"] % 
num_groups)["geometry"]
+        .apply(lambda group: shapely.box(*group.total_bounds))
+        .reset_index(name="envelope")
+        .rename(columns={"id": "id_mod"})
+    )
+
+    eng.assert_result(result, expected)
+
+
 @pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
 def test_st_collect_points(eng):
     eng = eng.create_or_skip()
diff --git a/rust/sedona-expr/src/aggregate_udf.rs 
b/rust/sedona-expr/src/aggregate_udf.rs
index 3ec999d0..b8625e54 100644
--- a/rust/sedona-expr/src/aggregate_udf.rs
+++ b/rust/sedona-expr/src/aggregate_udf.rs
@@ -20,7 +20,7 @@ use arrow_schema::{DataType, FieldRef};
 use datafusion_common::{not_impl_err, Result};
 use datafusion_expr::{
     function::{AccumulatorArgs, StateFieldsArgs},
-    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
+    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, 
Signature, Volatility,
 };
 use sedona_common::sedona_internal_err;
 use sedona_schema::datatypes::SedonaType;
@@ -102,6 +102,18 @@ impl SedonaAggregateUDF {
         &self.kernels
     }
 
+    fn accumulator_arg_types(args: &AccumulatorArgs) -> 
Result<Vec<SedonaType>> {
+        let arg_fields = args
+            .exprs
+            .iter()
+            .map(|expr| expr.return_field(args.schema))
+            .collect::<Result<Vec<_>>>()?;
+        arg_fields
+            .iter()
+            .map(|field| SedonaType::from_storage_field(field))
+            .collect()
+    }
+
     fn dispatch_impl(&self, args: &[SedonaType]) -> Result<(&dyn 
SedonaAccumulator, SedonaType)> {
         // Resolve kernels in reverse so that more recently added ones are 
resolved first
         for kernel in self.kernels.iter().rev() {
@@ -154,16 +166,27 @@ impl AggregateUDFImpl for SedonaAggregateUDF {
         sedona_internal_err!("return_type() should not be called (use 
return_field())")
     }
 
+    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+        if let Ok(arg_types) = Self::accumulator_arg_types(&args) {
+            if let Ok((accumulator, _)) = self.dispatch_impl(&arg_types) {
+                return accumulator.groups_accumulator_supported(&arg_types);
+            }
+        }
+
+        false
+    }
+
+    fn create_groups_accumulator(
+        &self,
+        args: AccumulatorArgs,
+    ) -> Result<Box<dyn GroupsAccumulator>> {
+        let arg_types = Self::accumulator_arg_types(&args)?;
+        let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
+        accumulator.groups_accumulator(&arg_types, &output_type)
+    }
+
     fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
-        let arg_fields = acc_args
-            .exprs
-            .iter()
-            .map(|expr| expr.return_field(acc_args.schema))
-            .collect::<Result<Vec<_>>>()?;
-        let arg_types = arg_fields
-            .iter()
-            .map(|field| SedonaType::from_storage_field(field))
-            .collect::<Result<Vec<_>>>()?;
+        let arg_types = Self::accumulator_arg_types(&acc_args)?;
         let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
         accumulator.accumulator(&arg_types, &output_type)
     }
@@ -190,6 +213,24 @@ pub trait SedonaAccumulator: Debug {
         output_type: &SedonaType,
     ) -> Result<Box<dyn Accumulator>>;
 
+    /// Given input data types, check if this implementation supports 
GroupsAccumulator
+    fn groups_accumulator_supported(&self, _args: &[SedonaType]) -> bool {
+        false
+    }
+
+    /// Given input data types, resolve a [GroupsAccumulator]
+    ///
+    /// A GroupsAccumulator is an important optimization for aggregating many 
small groups,
+    /// particularly when such an aggregation is cheap. See the DataFusion 
documentation
+    /// for details.
+    fn groups_accumulator(
+        &self,
+        _args: &[SedonaType],
+        _output_type: &SedonaType,
+    ) -> Result<Box<dyn GroupsAccumulator>> {
+        sedona_internal_err!("groups_accumulator not supported for {self:?}")
+    }
+
     /// The fields representing the underlying serialized state of the 
Accumulator
     fn state_fields(&self, args: &[SedonaType]) -> Result<Vec<FieldRef>>;
 }
diff --git a/rust/sedona-functions/src/st_envelope_agg.rs 
b/rust/sedona-functions/src/st_envelope_agg.rs
index c077ac01..a5692001 100644
--- a/rust/sedona-functions/src/st_envelope_agg.rs
+++ b/rust/sedona-functions/src/st_envelope_agg.rs
@@ -14,24 +14,26 @@
 // KIND, either express or implied.  See the License for the
 // specific language governing permissions and limitations
 // under the License.
-use std::{sync::Arc, vec};
+use std::{iter::zip, sync::Arc, vec};
 
 use crate::executor::WkbExecutor;
 use crate::st_envelope::write_envelope;
-use arrow_array::ArrayRef;
+use arrow_array::{builder::BinaryBuilder, Array, ArrayRef, BooleanArray};
 use arrow_schema::FieldRef;
 use datafusion_common::{
     error::{DataFusionError, Result},
     ScalarValue,
 };
 use datafusion_expr::{
-    scalar_doc_sections::DOC_SECTION_OTHER, Accumulator, ColumnarValue, 
Documentation, Volatility,
+    scalar_doc_sections::DOC_SECTION_OTHER, Accumulator, ColumnarValue, 
Documentation, EmitTo,
+    GroupsAccumulator, Volatility,
 };
 use sedona_common::sedona_internal_err;
 use sedona_expr::aggregate_udf::{SedonaAccumulator, SedonaAggregateUDF};
 use sedona_geometry::{
     bounds::geo_traits_update_xy_bounds,
     interval::{Interval, IntervalTrait},
+    wkb_factory::WKB_MIN_PROBABLE_BYTES,
 };
 use sedona_schema::{
     datatypes::{SedonaType, WKB_GEOMETRY},
@@ -70,6 +72,18 @@ impl SedonaAccumulator for STEnvelopeAgg {
         matcher.match_args(args)
     }
 
+    fn groups_accumulator_supported(&self, _args: &[SedonaType]) -> bool {
+        true
+    }
+
+    fn groups_accumulator(
+        &self,
+        args: &[SedonaType],
+        _output_type: &SedonaType,
+    ) -> Result<Box<dyn GroupsAccumulator>> {
+        Ok(Box::new(BoundsGroupsAccumulator2D::new(args[0].clone())))
+    }
+
     fn accumulator(
         &self,
         args: &[SedonaType],
@@ -178,12 +192,167 @@ impl Accumulator for BoundsAccumulator2D {
     }
 }
 
+#[derive(Debug)]
+struct BoundsGroupsAccumulator2D {
+    input_type: SedonaType,
+    xs: Vec<Interval>,
+    ys: Vec<Interval>,
+    offset: usize,
+}
+
+impl BoundsGroupsAccumulator2D {
+    pub fn new(input_type: SedonaType) -> Self {
+        Self {
+            input_type,
+            xs: Vec::new(),
+            ys: Vec::new(),
+            offset: 0,
+        }
+    }
+
+    fn execute_update(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        // Check some of our assumptions about how this will be called
+        debug_assert_eq!(self.offset, 0);
+        debug_assert_eq!(values.len(), 1);
+        debug_assert_eq!(values[0].len(), group_indices.len());
+        if let Some(filter) = opt_filter {
+            debug_assert_eq!(values[0].len(), filter.len());
+        }
+
+        let arg_types = [self.input_type.clone()];
+        let args = [ColumnarValue::Array(values[0].clone())];
+        let executor = WkbExecutor::new(&arg_types, &args);
+        self.xs.resize(total_num_groups, Interval::empty());
+        self.ys.resize(total_num_groups, Interval::empty());
+        let mut i = 0;
+
+        if let Some(filter) = opt_filter {
+            let mut filter_iter = filter.iter();
+            executor.execute_wkb_void(|maybe_item| {
+                if filter_iter.next().unwrap().unwrap_or(false) {
+                    let group_id = group_indices[i];
+                    i += 1;
+                    if let Some(item) = maybe_item {
+                        geo_traits_update_xy_bounds(
+                            item,
+                            &mut self.xs[group_id],
+                            &mut self.ys[group_id],
+                        )
+                        .map_err(|e| DataFusionError::External(Box::new(e)))?;
+                    }
+                } else {
+                    i += 1;
+                }
+
+                Ok(())
+            })?;
+        } else {
+            executor.execute_wkb_void(|maybe_item| {
+                let group_id = group_indices[i];
+                i += 1;
+                if let Some(item) = maybe_item {
+                    geo_traits_update_xy_bounds(
+                        item,
+                        &mut self.xs[group_id],
+                        &mut self.ys[group_id],
+                    )
+                    .map_err(|e| DataFusionError::External(Box::new(e)))?;
+                }
+
+                Ok(())
+            })?;
+        }
+
+        Ok(())
+    }
+
+    fn emit_wkb_result(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
+        let emit_size = match emit_to {
+            EmitTo::All => self.xs.len(),
+            EmitTo::First(n) => n,
+        };
+
+        let mut builder =
+            BinaryBuilder::with_capacity(emit_size, emit_size * 
WKB_MIN_PROBABLE_BYTES);
+
+        let emit_range = self.offset..(self.offset + emit_size);
+        for (x, y) in zip(&self.xs[emit_range.clone()], 
&self.ys[emit_range.clone()]) {
+            let written = write_envelope(&(*x).into(), y, &mut builder)?;
+            if written {
+                builder.append_value([]);
+            } else {
+                builder.append_null();
+            }
+        }
+
+        match emit_to {
+            EmitTo::All => {
+                self.xs = Vec::new();
+                self.ys = Vec::new();
+                self.offset = 0;
+            }
+            EmitTo::First(n) => {
+                self.offset += n;
+            }
+        }
+
+        Ok(Arc::new(builder.finish()))
+    }
+}
+
+impl GroupsAccumulator for BoundsGroupsAccumulator2D {
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        self.execute_update(values, group_indices, opt_filter, 
total_num_groups)
+    }
+
+    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+        Ok(vec![self.emit_wkb_result(emit_to)?])
+    }
+
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        // In this case, our state is identical to our input values
+        self.execute_update(values, group_indices, opt_filter, 
total_num_groups)
+    }
+
+    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
+        self.emit_wkb_result(emit_to)
+    }
+
+    fn size(&self) -> usize {
+        size_of::<BoundsGroupsAccumulator2D>()
+            + self.xs.capacity() * size_of::<Interval>()
+            + self.ys.capacity() * size_of::<Interval>()
+    }
+}
+
 #[cfg(test)]
 mod test {
     use datafusion_expr::AggregateUDF;
     use rstest::rstest;
     use sedona_schema::datatypes::WKB_VIEW_GEOMETRY;
-    use sedona_testing::{compare::assert_scalar_equal_wkb_geometry, 
testers::AggregateUdfTester};
+    use sedona_testing::{
+        compare::{assert_array_equal, assert_scalar_equal_wkb_geometry},
+        create::create_array,
+        testers::AggregateUdfTester,
+    };
 
     use super::*;
 
@@ -245,4 +414,59 @@ mod test {
             Some("LINESTRING (0 1, 1 1)"),
         );
     }
+
+    #[test]
+    fn udf_grouped_accumulate() {
+        let tester = AggregateUdfTester::new(st_envelope_agg_udf().into(), 
vec![WKB_GEOMETRY]);
+        assert_eq!(tester.return_type().unwrap(), WKB_GEOMETRY);
+
+        // Six elements, four groups, with one all null group and one 
partially null group
+        let group_indices = vec![0, 3, 1, 1, 0, 2];
+        let array0 = create_array(
+            &[Some("POINT (0 1)"), None, Some("POINT (2 3)")],
+            &WKB_GEOMETRY,
+        );
+        let array1 = create_array(
+            &[Some("POINT (4 5)"), None, Some("POINT (6 7)")],
+            &WKB_GEOMETRY,
+        );
+        let batches = vec![array0, array1];
+
+        let expected = create_array(
+            &[
+                // First element only + a null
+                Some("POINT (0 1)"),
+                // Middle two elements
+                Some("POLYGON((2 3, 2 5, 4 5, 4 3, 2 3))"),
+                // Last element only
+                Some("POINT (6 7)"),
+                // Only null
+                None,
+            ],
+            &WKB_GEOMETRY,
+        );
+        let result = tester
+            .aggregate_groups(&batches, group_indices.clone(), None, vec![])
+            .unwrap();
+        assert_array_equal(&result, &expected);
+
+        // We should get the same answer even with a sequence of partial emits
+        let result = tester
+            .aggregate_groups(&batches, group_indices.clone(), None, vec![1, 
1, 1, 1])
+            .unwrap();
+        assert_array_equal(&result, &expected);
+
+        // Also check with a filter (in this case, filter out all values except
+        // the middle two elements).
+        let filter = vec![false, false, true, true, false, false];
+        let expected = create_array(
+            &[None, Some("POLYGON((2 3, 2 5, 4 5, 4 3, 2 3))"), None, None],
+            &WKB_GEOMETRY,
+        );
+
+        let result = tester
+            .aggregate_groups(&batches, group_indices.clone(), Some(&filter), 
vec![])
+            .unwrap();
+        assert_array_equal(&result, &expected);
+    }
 }
diff --git a/rust/sedona-testing/src/testers.rs 
b/rust/sedona-testing/src/testers.rs
index cb56bbb3..3e939b32 100644
--- a/rust/sedona-testing/src/testers.rs
+++ b/rust/sedona-testing/src/testers.rs
@@ -16,13 +16,15 @@
 // under the License.
 use std::{iter::zip, sync::Arc};
 
-use arrow_array::{ArrayRef, RecordBatch};
+use arrow_array::{ArrayRef, BooleanArray, RecordBatch};
 use arrow_schema::{DataType, FieldRef, Schema};
-use datafusion_common::{config::ConfigOptions, Result, ScalarValue};
+use datafusion_common::{
+    arrow::compute::kernels::concat::concat, config::ConfigOptions, Result, 
ScalarValue,
+};
 use datafusion_expr::{
     function::{AccumulatorArgs, StateFieldsArgs},
-    Accumulator, AggregateUDF, ColumnarValue, Expr, Literal, ReturnFieldArgs, 
ScalarFunctionArgs,
-    ScalarUDF,
+    Accumulator, AggregateUDF, ColumnarValue, EmitTo, Expr, GroupsAccumulator, 
Literal,
+    ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
 };
 use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
 use sedona_common::sedona_internal_err;
@@ -49,23 +51,34 @@ use crate::{
 pub struct AggregateUdfTester {
     udf: AggregateUDF,
     arg_types: Vec<SedonaType>,
+    mock_schema: Schema,
+    mock_exprs: Vec<Arc<dyn PhysicalExpr>>,
 }
 
 impl AggregateUdfTester {
     /// Create a new tester
     pub fn new(udf: AggregateUDF, arg_types: Vec<SedonaType>) -> Self {
-        Self { udf, arg_types }
+        let arg_fields = arg_types
+            .iter()
+            .map(|sedona_type| sedona_type.to_storage_field("", 
true).map(Arc::new))
+            .collect::<Result<Vec<_>>>()
+            .unwrap();
+        let mock_schema = Schema::new(arg_fields);
+
+        let mock_exprs = (0..arg_types.len())
+            .map(|i| -> Arc<dyn PhysicalExpr> { Arc::new(Column::new("col", 
i)) })
+            .collect::<Vec<_>>();
+        Self {
+            udf,
+            arg_types,
+            mock_schema,
+            mock_exprs,
+        }
     }
 
     /// Compute the return type
     pub fn return_type(&self) -> Result<SedonaType> {
-        let arg_fields = self
-            .arg_types
-            .iter()
-            .map(|arg_type| arg_type.to_storage_field("", true).map(Arc::new))
-            .collect::<Result<Vec<_>>>()?;
-
-        let out_field = self.udf.return_field(&arg_fields)?;
+        let out_field = self.udf.return_field(&self.mock_schema.fields)?;
         SedonaType::from_storage_field(&out_field)
     }
 
@@ -105,44 +118,112 @@ impl AggregateUdfTester {
         state_accumulator.evaluate()
     }
 
+    /// Perform a simple grouped aggregation
+    ///
+    /// Each batch in batches is accumulated with its own groups accumulator
+    /// and serialized into its own state, after which the state resulting
+    /// from each batch is merged into the final groups accumulator. This
+    /// has the effect of testing the pieces of a groups accumulator in a
+    /// predictable/debug-friendly (if artificial) way.
+    pub fn aggregate_groups(
+        &self,
+        batches: &Vec<ArrayRef>,
+        group_indices: Vec<usize>,
+        opt_filter: Option<&Vec<bool>>,
+        emit_sizes: Vec<usize>,
+    ) -> Result<ArrayRef> {
+        let state_schema = Arc::new(Schema::new(self.state_fields()?));
+        let mut state_accumulator = self.new_groups_accumulator()?;
+        let total_num_groups = group_indices.iter().max().unwrap_or(&0) + 1;
+
+        // Check input
+        let total_input_rows: usize = batches.iter().map(|a| a.len()).sum();
+        assert_eq!(total_input_rows, group_indices.len());
+        if let Some(filter) = opt_filter {
+            assert_eq!(total_input_rows, filter.len());
+        }
+        if !emit_sizes.is_empty() {
+            assert_eq!(emit_sizes.iter().sum::<usize>(), total_num_groups);
+        }
+
+        let mut offset = 0;
+        for batch in batches {
+            let mut batch_accumulator = self.new_groups_accumulator()?;
+            let opt_filter_array = opt_filter.map(|filter_vec| {
+                filter_vec[offset..(offset + batch.len())]
+                    .iter()
+                    .collect::<BooleanArray>()
+            });
+            batch_accumulator.update_batch(
+                std::slice::from_ref(batch),
+                &group_indices[offset..(offset + batch.len())],
+                opt_filter_array.as_ref(),
+                total_num_groups,
+            )?;
+            offset += batch.len();
+
+            // For the state accumulator the input is ordered such that
+            // each row is group i for i in (0..total_num_groups)
+            let state_batch = RecordBatch::try_new(
+                state_schema.clone(),
+                batch_accumulator.state(datafusion_expr::EmitTo::All)?,
+            )?;
+            state_accumulator.merge_batch(
+                state_batch.columns(),
+                &(0..total_num_groups).collect::<Vec<_>>(),
+                None,
+                total_num_groups,
+            )?;
+        }
+
+        if emit_sizes.is_empty() {
+            state_accumulator.evaluate(datafusion_expr::EmitTo::All)
+        } else {
+            let arrays = emit_sizes
+                .iter()
+                .map(|emit_size| 
state_accumulator.evaluate(EmitTo::First(*emit_size)))
+                .collect::<Result<Vec<_>>>()?;
+            let arrays_ref = arrays.iter().map(|a| 
a.as_ref()).collect::<Vec<_>>();
+            Ok(concat(&arrays_ref)?)
+        }
+    }
+
     fn new_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        let mock_schema = Schema::new(self.arg_fields());
-        let exprs = (0..self.arg_types.len())
-            .map(|i| -> Arc<dyn PhysicalExpr> { Arc::new(Column::new("col", 
i)) })
-            .collect::<Vec<_>>();
-        let accumulator_args = AccumulatorArgs {
-            return_field: self.udf.return_field(mock_schema.fields())?,
-            schema: &mock_schema,
+        let accumulator_args = self.accumulator_args()?;
+        self.udf.accumulator(accumulator_args)
+    }
+
+    fn new_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+        assert!(self
+            .udf
+            .groups_accumulator_supported(self.accumulator_args()?));
+        self.udf.create_groups_accumulator(self.accumulator_args()?)
+    }
+
+    fn accumulator_args(&self) -> Result<AccumulatorArgs<'_>> {
+        Ok(AccumulatorArgs {
+            return_field: self.udf.return_field(self.mock_schema.fields())?,
+            schema: &self.mock_schema,
             ignore_nulls: true,
             order_bys: &[],
             is_reversed: false,
             name: "",
             is_distinct: false,
-            exprs: &exprs,
+            exprs: &self.mock_exprs,
             expr_fields: &[],
-        };
-
-        self.udf.accumulator(accumulator_args)
+        })
     }
 
     fn state_fields(&self) -> Result<Vec<FieldRef>> {
         let state_field_args = StateFieldsArgs {
             name: "",
-            input_fields: &self.arg_fields(),
-            return_field: self.udf.return_field(&self.arg_fields())?,
+            input_fields: self.mock_schema.fields(),
+            return_field: self.udf.return_field(self.mock_schema.fields())?,
             ordering_fields: &[],
             is_distinct: false,
         };
         self.udf.state_fields(state_field_args)
     }
-
-    fn arg_fields(&self) -> Vec<FieldRef> {
-        self.arg_types
-            .iter()
-            .map(|sedona_type| sedona_type.to_storage_field("", 
true).map(Arc::new))
-            .collect::<Result<Vec<_>>>()
-            .unwrap()
-    }
 }
 
 /// Low-level tester for scalar functions

Reply via email to