This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch hash_agg_spike
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit 90b38b0d467eb5f1aaa0dde138035504a0ff64a9
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri Jun 30 08:59:34 2023 -0400

    Refactor Decimal128 averaging code to be vectorizable (and easier to read)
---
 datafusion/physical-expr/src/aggregate/utils.rs | 130 +++++++++++++++++-------
 1 file changed, 96 insertions(+), 34 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/utils.rs 
b/datafusion/physical-expr/src/aggregate/utils.rs
index f6f0086919..5dfd29ec98 100644
--- a/datafusion/physical-expr/src/aggregate/utils.rs
+++ b/datafusion/physical-expr/src/aggregate/utils.rs
@@ -37,45 +37,107 @@ pub fn get_accum_scalar_values_as_arrays(
         .collect::<Vec<_>>())
 }
 
-pub fn calculate_result_decimal_for_avg(
-    lit_value: i128,
-    count: i128,
-    scale: i8,
-    target_type: &DataType,
-) -> Result<ScalarValue> {
-    match target_type {
-        DataType::Decimal128(p, s) => {
-            // Different precision for decimal128 can store different range of 
value.
-            // For example, the precision is 3, the max of value is `999` and 
the min
-            // value is `-999`
-            let (target_mul, target_min, target_max) = (
-                10_i128.pow(*s as u32),
-                MIN_DECIMAL_FOR_EACH_PRECISION[*p as usize - 1],
-                MAX_DECIMAL_FOR_EACH_PRECISION[*p as usize - 1],
-            );
-            let lit_scale_mul = 10_i128.pow(scale as u32);
-            if target_mul >= lit_scale_mul {
-                if let Some(value) = lit_value.checked_mul(target_mul / 
lit_scale_mul) {
-                    let new_value = value / count;
-                    if new_value >= target_min && new_value <= target_max {
-                        Ok(ScalarValue::Decimal128(Some(new_value), *p, *s))
-                    } else {
-                        Err(DataFusionError::Execution(
-                            "Arithmetic Overflow in 
AvgAccumulator".to_string(),
-                        ))
-                    }
-                } else {
-                    // can't convert the lit decimal to the returned data type
-                    Err(DataFusionError::Execution(
-                        "Arithmetic Overflow in AvgAccumulator".to_string(),
-                    ))
-                }
+/// Computes averages for `Decimal128` values, checking for overflow
+///
+/// This is needed because different precisions for Decimal128 can
+/// store different ranges of values and thus sum/count may not fit in
+/// the target type.
+///
+/// For example, the precision is 3, the max of value is `999` and the min
+/// value is `-999`
+pub(crate) struct Decimal128Averager {
+    /// scale factor for sum values (10^sum_scale)
+    sum_mul: i128,
+    /// scale factor for target (10^target_scale)
+    target_mul: i128,
+    /// The minimum output value possible to represent with the target 
precision
+    target_min: i128,
+    /// The maximum output value possible to represent with the target 
precision
+    target_max: i128,
+}
+
+impl Decimal128Averager {
+    /// Create a new `Decimal128Averager`:
+    ///
+    /// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
+    /// * target_precision: the output precision
+    /// * target_precision: the output scale
+    ///
+    /// Errors if the resulting data can not be stored
+    pub fn try_new(
+        sum_scale: i8,
+        target_precision: u8,
+        target_scale: i8,
+    ) -> Result<Self> {
+        let sum_mul = 10_i128.pow(sum_scale as u32);
+        let target_mul = 10_i128.pow(target_scale as u32);
+        let target_min = MIN_DECIMAL_FOR_EACH_PRECISION[target_precision as 
usize - 1];
+        let target_max = MAX_DECIMAL_FOR_EACH_PRECISION[target_precision as 
usize - 1];
+
+        if target_mul >= sum_mul {
+            Ok(Self {
+                sum_mul,
+                target_mul,
+                target_min,
+                target_max,
+            })
+        } else {
+            // can't convert the lit decimal to the returned data type
+            Err(DataFusionError::Execution(
+                "Arithmetic Overflow in AvgAccumulator".to_string(),
+            ))
+        }
+    }
+
+    /// Returns the `sum`/`count` as a i128 Decimal128 with
+    /// target_scale and target_precision and reporting overflow.
+    ///
+    /// * sum: The total sum value stored as Decimal128 with sum_scale
+    /// (passed to `Self::try_new`)
+    /// * count: total count, stored as a i128 (*NOT* a Decimal128 value)
+    #[inline(always)]
+    pub fn avg(&self, sum: i128, count: i128) -> Result<i128> {
+        if let Some(value) = sum.checked_mul(self.target_mul / self.sum_mul) {
+            let new_value = value / count;
+            if new_value >= self.target_min && new_value <= self.target_max {
+                Ok(new_value)
             } else {
-                // can't convert the lit decimal to the returned data type
                 Err(DataFusionError::Execution(
                     "Arithmetic Overflow in AvgAccumulator".to_string(),
                 ))
             }
+        } else {
+            // can't convert the lit decimal to the returned data type
+            Err(DataFusionError::Execution(
+                "Arithmetic Overflow in AvgAccumulator".to_string(),
+            ))
+        }
+    }
+}
+
+/// Returns `sum`/`count` for decimal values, detecting and reporting overflow.
+///
+/// * sum:  stored as Decimal128 with `sum_scale` scale
+/// * count: stored as a i128 (*NOT* a Decimal128 value)
+/// * sum_scale: the scale of `sum`
+/// * target_type: the output decimal type
+pub fn calculate_result_decimal_for_avg(
+    sum: i128,
+    count: i128,
+    sum_scale: i8,
+    target_type: &DataType,
+) -> Result<ScalarValue> {
+    match target_type {
+        DataType::Decimal128(target_precision, target_scale) => {
+            let new_value =
+                Decimal128Averager::try_new(sum_scale, *target_precision, 
*target_scale)?
+                    .avg(sum, count)?;
+
+            Ok(ScalarValue::Decimal128(
+                Some(new_value),
+                *target_precision,
+                *target_scale,
+            ))
         }
         other => Err(DataFusionError::Internal(format!(
             "Invalid target type in AvgAccumulator {other:?}"

Reply via email to