This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 978ec2dc8 refactor: make GeometricMean not to have update and merge 
(#5469)
978ec2dc8 is described below

commit 978ec2dc80b4956ee0c19850653448aa8dc2be1f
Author: Alex Huang <[email protected]>
AuthorDate: Mon Mar 6 13:09:04 2023 +0100

    refactor: make GeometricMean not to have update and merge (#5469)
---
 datafusion-examples/examples/simple_udaf.rs | 65 ++++++++++-------------------
 1 file changed, 22 insertions(+), 43 deletions(-)

diff --git a/datafusion-examples/examples/simple_udaf.rs 
b/datafusion-examples/examples/simple_udaf.rs
index d171f6579..b858ce7eb 100644
--- a/datafusion-examples/examples/simple_udaf.rs
+++ b/datafusion-examples/examples/simple_udaf.rs
@@ -65,40 +65,6 @@ impl GeometricMean {
     pub fn new() -> Self {
         GeometricMean { n: 0, prod: 1.0 }
     }
-
-    // this function receives one entry per argument of this accumulator.
-    // DataFusion calls this function on every row, and expects this function 
to update the accumulator's state.
-    fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
-        // this is a one-argument UDAF, and thus we use `0`.
-        let value = &values[0];
-        match value {
-            // here we map `ScalarValue` to our internal state. `Float64` 
indicates that this function
-            // only accepts Float64 as its argument (DataFusion does try to 
coerce arguments to this type)
-            //
-            // Note that `.map` here ensures that we ignore Nulls.
-            ScalarValue::Float64(e) => e.map(|value| {
-                self.prod *= value;
-                self.n += 1;
-            }),
-            _ => unreachable!(""),
-        };
-        Ok(())
-    }
-
-    // this function receives states from other accumulators (Vec<ScalarValue>)
-    // and updates the accumulator.
-    fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
-        let prod = &states[0];
-        let n = &states[1];
-        match (prod, n) {
-            (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) 
=> {
-                self.prod *= prod;
-                self.n += n;
-            }
-            _ => unreachable!(""),
-        };
-        Ok(())
-    }
 }
 
 // UDAFs are built using the trait `Accumulator`, that offers DataFusion the 
necessary functions
@@ -128,28 +94,41 @@ impl Accumulator for GeometricMean {
         if values.is_empty() {
             return Ok(());
         }
-        (0..values[0].len()).try_for_each(|index| {
-            let v = values
-                .iter()
-                .map(|array| ScalarValue::try_from_array(array, index))
-                .collect::<Result<Vec<_>>>()?;
-            self.update(&v)
+        let arr = &values[0];
+        (0..arr.len()).try_for_each(|index| {
+            let v = ScalarValue::try_from_array(arr, index)?;
+
+            if let ScalarValue::Float64(Some(value)) = v {
+                self.prod *= value;
+                self.n += 1;
+            } else {
+                unreachable!("")
+            }
+            Ok(())
         })
     }
 
     // Optimization hint: this trait also supports `update_batch` and 
`merge_batch`,
     // that can be used to perform these operations on arrays instead of 
single values.
-    // By default, these methods call `update` and `merge` row by row
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
         if states.is_empty() {
             return Ok(());
         }
-        (0..states[0].len()).try_for_each(|index| {
+        let arr = &states[0];
+        (0..arr.len()).try_for_each(|index| {
             let v = states
                 .iter()
                 .map(|array| ScalarValue::try_from_array(array, index))
                 .collect::<Result<Vec<_>>>()?;
-            self.merge(&v)
+            if let (ScalarValue::Float64(Some(prod)), 
ScalarValue::UInt32(Some(n))) =
+                (&v[0], &v[1])
+            {
+                self.prod *= prod;
+                self.n += n;
+            } else {
+                unreachable!("")
+            }
+            Ok(())
         })
     }
 

Reply via email to