alamb commented on code in PR #15468: URL: https://github.com/apache/datafusion/pull/15468#discussion_r2048964601
########## datafusion/functions-aggregate/src/average.rs: ########## @@ -399,6 +410,105 @@ impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumu } } +/// An accumulator to compute the average for duration values +#[derive(Debug)] +struct DurationAvgAccumulator { + sum: Option<i64>, + count: u64, + time_unit: TimeUnit, + result_unit: TimeUnit, +} + +impl Accumulator for DurationAvgAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count += (array.len() - array.null_count()) as u64; + + let sum_value = match self.time_unit { + TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()), + TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()), + TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()), + TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()), + }; + + if let Some(x) = sum_value { + let v = self.sum.get_or_insert(0); + *v += x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result<ScalarValue> { + let avg = self.sum.map(|sum| sum / self.count as i64); + + match self.result_unit { + TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)), + TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)), + TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)), + TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)), + } + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result<Vec<ScalarValue>> { + let duration_value = match self.time_unit { + TimeUnit::Second => ScalarValue::DurationSecond(self.sum), + TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum), + TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum), + TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum), + }; + + Ok(vec![ScalarValue::from(self.count), duration_value]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default(); + + let sum_value = match self.time_unit { + TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()), + TimeUnit::Millisecond => { + sum(states[1].as_primitive::<DurationMillisecondType>()) + } + TimeUnit::Microsecond => { + sum(states[1].as_primitive::<DurationMicrosecondType>()) + } + TimeUnit::Nanosecond => { + sum(states[1].as_primitive::<DurationNanosecondType>()) + } + }; + + if let Some(x) = sum_value { + let v = self.sum.get_or_insert(0); + *v += x; + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { Review Comment: I think retract batch is used as part of window functions so we probably need to add some tests ########## datafusion/sqllogictest/test_files/aggregate.slt: ########## @@ -4969,6 +4969,25 @@ select count(distinct column1), count(distinct column2) from dict_test group by statement ok drop table dict_test; +# avg_duartion + +statement ok +create table d as values + (arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1), + (arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1); Review Comment: I think we need a few more tests: 1. At least one other group (maybe add another few rows with column5 and an id of `2`) 1. Test for data with nulls (add a row with all null values to d -- the nulls should be ignored and the output should not be NULL) 2. Test for use of AVG as a window function (to test the retract batch code) For example to test null you might do (the code works correctly i think): ```sql > create table d as values (arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1), (arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1), (null, null, null, null, 2); 0 row(s) fetched. Elapsed 0.013 seconds. > select * from d; +-------------------------------+----------------------------------+-------------------------------------+----------------------------------------+---------+ | column1 | column2 | column3 | column4 | column5 | +-------------------------------+----------------------------------+-------------------------------------+----------------------------------------+---------+ | 0 days 0 hours 0 mins 1 secs | 0 days 0 hours 0 mins 0.002 secs | 0 days 0 hours 0 mins 0.000003 secs | 0 days 0 hours 0 mins 0.000000004 secs | 1 | | 0 days 0 hours 0 mins 11 secs | 0 days 0 hours 0 mins 0.022 secs | 0 days 0 hours 0 mins 0.000033 secs | 0 days 0 hours 0 mins 0.000000044 secs | 1 | | NULL | NULL | NULL | NULL | 2 | +-------------------------------+----------------------------------+-------------------------------------+----------------------------------------+---------+ 3 row(s) fetched. Elapsed 0.009 seconds. > select avg(column1) from d; +------------------------------+ | avg(d.column1) | +------------------------------+ | 0 days 0 hours 0 mins 6 secs | +------------------------------+ 1 row(s) fetched. Elapsed 0.011 seconds. > select avg(column1) from d GROUP BY column5; +------------------------------+ | avg(d.column1) | +------------------------------+ | 0 days 0 hours 0 mins 6 secs | | NULL | +------------------------------+ 2 row(s) fetched. Elapsed 0.005 seconds. ``` To test window functions something like this (you'll probably have to add more than 3 rows) ```sql > select column1, avg(column1) OVER (ORDER BY column5,column1 ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) from d; +-------------------------------+-----------------------------------------------------------------------------------------------------------------------+ | column1 | avg(d.column1) ORDER BY [d.column5 ASC NULLS LAST, d.column1 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW | +-------------------------------+-----------------------------------------------------------------------------------------------------------------------+ | 0 days 0 hours 0 mins 1 secs | 0 days 0 hours 0 mins 1 secs | | 0 days 0 hours 0 mins 11 secs | 0 days 0 hours 0 mins 6 secs | | NULL | 0 days 0 hours 0 mins 11 secs | +-------------------------------+-----------------------------------------------------------------------------------------------------------------------+ 3 row(s) fetched. Elapsed 0.010 seconds. ``` ########## datafusion/sqllogictest/test_files/aggregate.slt: ########## @@ -4969,6 +4969,25 @@ select count(distinct column1), count(distinct column2) from dict_test group by statement ok drop table dict_test; +# avg_duartion Review Comment: ```suggestion # avg_duration ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org