alamb commented on code in PR #15468:
URL: https://github.com/apache/datafusion/pull/15468#discussion_r2048964601


##########
datafusion/functions-aggregate/src/average.rs:
##########
@@ -399,6 +410,105 @@ impl<T: DecimalType + ArrowNumericType + Debug> 
Accumulator for DecimalAvgAccumu
     }
 }
 
+/// An accumulator to compute the average for duration values
+#[derive(Debug)]
+struct DurationAvgAccumulator {
+    sum: Option<i64>,
+    count: u64,
+    time_unit: TimeUnit,
+    result_unit: TimeUnit,
+}
+
+impl Accumulator for DurationAvgAccumulator {
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let array = &values[0];
+        self.count += (array.len() - array.null_count()) as u64;
+
+        let sum_value = match self.time_unit {
+            TimeUnit::Second => 
sum(array.as_primitive::<DurationSecondType>()),
+            TimeUnit::Millisecond => 
sum(array.as_primitive::<DurationMillisecondType>()),
+            TimeUnit::Microsecond => 
sum(array.as_primitive::<DurationMicrosecondType>()),
+            TimeUnit::Nanosecond => 
sum(array.as_primitive::<DurationNanosecondType>()),
+        };
+
+        if let Some(x) = sum_value {
+            let v = self.sum.get_or_insert(0);
+            *v += x;
+        }
+        Ok(())
+    }
+
+    fn evaluate(&mut self) -> Result<ScalarValue> {
+        let avg = self.sum.map(|sum| sum / self.count as i64);
+
+        match self.result_unit {
+            TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
+            TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
+            TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
+            TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
+        }
+    }
+
+    fn size(&self) -> usize {
+        size_of_val(self)
+    }
+
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        let duration_value = match self.time_unit {
+            TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
+            TimeUnit::Millisecond => 
ScalarValue::DurationMillisecond(self.sum),
+            TimeUnit::Microsecond => 
ScalarValue::DurationMicrosecond(self.sum),
+            TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
+        };
+
+        Ok(vec![ScalarValue::from(self.count), duration_value])
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        self.count += 
sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
+
+        let sum_value = match self.time_unit {
+            TimeUnit::Second => 
sum(states[1].as_primitive::<DurationSecondType>()),
+            TimeUnit::Millisecond => {
+                sum(states[1].as_primitive::<DurationMillisecondType>())
+            }
+            TimeUnit::Microsecond => {
+                sum(states[1].as_primitive::<DurationMicrosecondType>())
+            }
+            TimeUnit::Nanosecond => {
+                sum(states[1].as_primitive::<DurationNanosecondType>())
+            }
+        };
+
+        if let Some(x) = sum_value {
+            let v = self.sum.get_or_insert(0);
+            *v += x;
+        }
+        Ok(())
+    }
+
+    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {

Review Comment:
   I think retract batch is used as part of window functions so we probably 
need to add some tests



##########
datafusion/sqllogictest/test_files/aggregate.slt:
##########
@@ -4969,6 +4969,25 @@ select count(distinct column1), count(distinct column2) 
from dict_test group by
 statement ok
 drop table dict_test;
 
+# avg_duartion
+
+statement ok
+create table d as values 
+  (arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), 
arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 
1),
+  (arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 
'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), 
arrow_cast(44, 'Duration(Nanosecond)'), 1);

Review Comment:
   I think we need a few more tests:
   1. At least one other group (maybe add another few rows with column5 and an 
id of `2`)
   1. Test for data with nulls (add a row with all null values to d -- the 
nulls should be ignored and the output should not be NULL)
   2. Test for use of AVG as a window function (to test the retract batch code)
   
   
   For example to test null you might do (the code works correctly i think):
   
   ```sql
   > create table d as values
     (arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 
'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 
'Duration(Nanosecond)'), 1),
     (arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 
'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), 
arrow_cast(44, 'Duration(Nanosecond)'), 1), (null, null, null, null, 2);
   0 row(s) fetched.
   Elapsed 0.013 seconds.
   
   > select * from d;
   
+-------------------------------+----------------------------------+-------------------------------------+----------------------------------------+---------+
   | column1                       | column2                          | column3 
                            | column4                                | column5 |
   
+-------------------------------+----------------------------------+-------------------------------------+----------------------------------------+---------+
   | 0 days 0 hours 0 mins 1 secs  | 0 days 0 hours 0 mins 0.002 secs | 0 days 
0 hours 0 mins 0.000003 secs | 0 days 0 hours 0 mins 0.000000004 secs | 1       
|
   | 0 days 0 hours 0 mins 11 secs | 0 days 0 hours 0 mins 0.022 secs | 0 days 
0 hours 0 mins 0.000033 secs | 0 days 0 hours 0 mins 0.000000044 secs | 1       
|
   | NULL                          | NULL                             | NULL    
                            | NULL                                   | 2       |
   
+-------------------------------+----------------------------------+-------------------------------------+----------------------------------------+---------+
   3 row(s) fetched.
   Elapsed 0.009 seconds.
   
   > select avg(column1) from d;
   +------------------------------+
   | avg(d.column1)               |
   +------------------------------+
   | 0 days 0 hours 0 mins 6 secs |
   +------------------------------+
   1 row(s) fetched.
   Elapsed 0.011 seconds.
   
   > select avg(column1) from d GROUP BY column5;
   +------------------------------+
   | avg(d.column1)               |
   +------------------------------+
   | 0 days 0 hours 0 mins 6 secs |
   | NULL                         |
   +------------------------------+
   2 row(s) fetched.
   Elapsed 0.005 seconds.
   ```
   
   To test window functions something like this (you'll probably have to add 
more than 3 rows)
   
   ```sql
   > select column1, avg(column1) OVER (ORDER BY column5,column1 ROWS BETWEEN 1 
PRECEDING AND CURRENT ROW) from d;
   
+-------------------------------+-----------------------------------------------------------------------------------------------------------------------+
   | column1                       | avg(d.column1) ORDER BY [d.column5 ASC 
NULLS LAST, d.column1 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW |
   
+-------------------------------+-----------------------------------------------------------------------------------------------------------------------+
   | 0 days 0 hours 0 mins 1 secs  | 0 days 0 hours 0 mins 1 secs               
                                                                           |
   | 0 days 0 hours 0 mins 11 secs | 0 days 0 hours 0 mins 6 secs               
                                                                           |
   | NULL                          | 0 days 0 hours 0 mins 11 secs              
                                                                           |
   
+-------------------------------+-----------------------------------------------------------------------------------------------------------------------+
   3 row(s) fetched.
   Elapsed 0.010 seconds.
   ```



##########
datafusion/sqllogictest/test_files/aggregate.slt:
##########
@@ -4969,6 +4969,25 @@ select count(distinct column1), count(distinct column2) 
from dict_test group by
 statement ok
 drop table dict_test;
 
+# avg_duartion

Review Comment:
   ```suggestion
   # avg_duration
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to