This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push: new 292eb954f Support using var/var_pop/stddev/stddev_pop in window expressions with custom frames (#4848) 292eb954f is described below commit 292eb954fc0bad3a1febc597233ba26cb60bda3e Author: Jon Mease <jonmme...@gmail.com> AuthorDate: Tue Jan 10 01:37:41 2023 -0500 Support using var/var_pop/stddev/stddev_pop in window expressions with custom frames (#4848) * Wire up retract_batch for Stddev/StddevPop/Variance/VariancePop to * Add test for Stddev/StddevPop/Variance/VariancePop with window frame --- datafusion/core/tests/sql/window.rs | 28 ++++++++++++++++++++++ datafusion/physical-expr/src/aggregate/stddev.rs | 12 ++++++++++ datafusion/physical-expr/src/aggregate/variance.rs | 10 ++++++++ 3 files changed, 50 insertions(+) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 0c3ecfa59..1167d57a4 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -524,6 +524,34 @@ async fn window_frame_rows_preceding() -> Result<()> { Ok(()) } +#[tokio::test] +async fn window_frame_rows_preceding_stddev_variance() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT \ + VAR(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + VAR_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + STDDEV(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + STDDEV_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\ + FROM aggregate_test_100 \ + ORDER BY c9 \ + LIMIT 5"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+", + "| VARIANCE(aggregate_test_100.c4) | VARIANCEPOP(aggregate_test_100.c4) | STDDEV(aggregate_test_100.c4) | STDDEVPOP(aggregate_test_100.c4) |", + "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+", + "| 46721.33333333174 | 31147.555555554496 | 216.15118166073427 | 176.4867007894773 |", + "| 2639429.333333332 | 1759619.5555555548 | 1624.6320609089714 | 1326.5065229977404 |", + "| 746202.3333333324 | 497468.2222222216 | 863.8300372951455 | 705.3142719541563 |", + "| 768422.9999999981 | 512281.9999999988 | 876.5973990378925 | 715.7387791645767 |", + "| 66526.3333333288 | 44350.88888888587 | 257.9269922542594 | 210.5965073045749 |", + "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index 4c9e46644..dab84b14a 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -73,6 +73,10 @@ impl AggregateExpr for Stddev { Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) } + fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> { + Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) + } + fn state_fields(&self) -> Result<Vec<Field>> { Ok(vec![ Field::new( @@ -128,6 +132,10 @@ impl AggregateExpr for StddevPop { Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) } + fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> { + Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) + } + fn state_fields(&self) -> Result<Vec<Field>> { Ok(vec![ Field::new( @@ -184,6 +192,10 @@ impl Accumulator for StddevAccumulator { self.variance.update_batch(values) } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.retract_batch(values) + } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { self.variance.merge_batch(states) } diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index 289513744..657103e43 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -79,6 +79,10 @@ impl AggregateExpr for Variance { Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) } + fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> { + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) + } + fn state_fields(&self) -> Result<Vec<Field>> { Ok(vec![ Field::new( @@ -136,6 +140,12 @@ impl AggregateExpr for VariancePop { )?)) } + fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> { + Ok(Box::new(VarianceAccumulator::try_new( + StatsType::Population, + )?)) + } + fn state_fields(&self) -> Result<Vec<Field>> { Ok(vec![ Field::new(