This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 978ec2dc8 refactor: make GeometricMean not to have update and merge
(#5469)
978ec2dc8 is described below
commit 978ec2dc80b4956ee0c19850653448aa8dc2be1f
Author: Alex Huang <[email protected]>
AuthorDate: Mon Mar 6 13:09:04 2023 +0100
refactor: make GeometricMean not to have update and merge (#5469)
---
datafusion-examples/examples/simple_udaf.rs | 65 ++++++++++-------------------
1 file changed, 22 insertions(+), 43 deletions(-)
diff --git a/datafusion-examples/examples/simple_udaf.rs
b/datafusion-examples/examples/simple_udaf.rs
index d171f6579..b858ce7eb 100644
--- a/datafusion-examples/examples/simple_udaf.rs
+++ b/datafusion-examples/examples/simple_udaf.rs
@@ -65,40 +65,6 @@ impl GeometricMean {
pub fn new() -> Self {
GeometricMean { n: 0, prod: 1.0 }
}
-
- // this function receives one entry per argument of this accumulator.
- // DataFusion calls this function on every row, and expects this function
to update the accumulator's state.
- fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
- // this is a one-argument UDAF, and thus we use `0`.
- let value = &values[0];
- match value {
- // here we map `ScalarValue` to our internal state. `Float64`
indicates that this function
- // only accepts Float64 as its argument (DataFusion does try to
coerce arguments to this type)
- //
- // Note that `.map` here ensures that we ignore Nulls.
- ScalarValue::Float64(e) => e.map(|value| {
- self.prod *= value;
- self.n += 1;
- }),
- _ => unreachable!(""),
- };
- Ok(())
- }
-
- // this function receives states from other accumulators (Vec<ScalarValue>)
- // and updates the accumulator.
- fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
- let prod = &states[0];
- let n = &states[1];
- match (prod, n) {
- (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n)))
=> {
- self.prod *= prod;
- self.n += n;
- }
- _ => unreachable!(""),
- };
- Ok(())
- }
}
// UDAFs are built using the trait `Accumulator`, that offers DataFusion the
necessary functions
@@ -128,28 +94,41 @@ impl Accumulator for GeometricMean {
if values.is_empty() {
return Ok(());
}
- (0..values[0].len()).try_for_each(|index| {
- let v = values
- .iter()
- .map(|array| ScalarValue::try_from_array(array, index))
- .collect::<Result<Vec<_>>>()?;
- self.update(&v)
+ let arr = &values[0];
+ (0..arr.len()).try_for_each(|index| {
+ let v = ScalarValue::try_from_array(arr, index)?;
+
+ if let ScalarValue::Float64(Some(value)) = v {
+ self.prod *= value;
+ self.n += 1;
+ } else {
+ unreachable!("")
+ }
+ Ok(())
})
}
// Optimization hint: this trait also supports `update_batch` and
`merge_batch`,
// that can be used to perform these operations on arrays instead of
single values.
- // By default, these methods call `update` and `merge` row by row
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
- (0..states[0].len()).try_for_each(|index| {
+ let arr = &states[0];
+ (0..arr.len()).try_for_each(|index| {
let v = states
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
- self.merge(&v)
+ if let (ScalarValue::Float64(Some(prod)),
ScalarValue::UInt32(Some(n))) =
+ (&v[0], &v[1])
+ {
+ self.prod *= prod;
+ self.n += n;
+ } else {
+ unreachable!("")
+ }
+ Ok(())
})
}