martin-g commented on code in PR #18937:
URL: https://github.com/apache/datafusion/pull/18937#discussion_r2565639564
##########
datafusion/functions-aggregate/src/variance.rs:
##########
@@ -583,10 +1169,53 @@ impl GroupsAccumulator for VarianceGroupsAccumulator {
#[cfg(test)]
mod tests {
+ use arrow::array::Decimal128Builder;
use datafusion_expr::EmitTo;
+ use std::sync::Arc;
use super::*;
+ #[test]
+ fn variance_population_accepts_decimal() -> Result<()> {
+ let variance = VariancePopulation::new();
+ variance.return_type(&[DataType::Decimal128(10, 3)])?;
+ Ok(())
+ }
+
+ #[test]
+ fn variance_decimal_input() -> Result<()> {
+ let mut builder = Decimal128Builder::with_capacity(20);
+ for i in 0..10 {
+ builder.append_value(110000 + i);
+ }
+ for i in 0..10 {
+ builder.append_value(-((100000 + i) as i128));
+ }
+ let decimal_array = builder.finish().with_precision_and_scale(10,
3).unwrap();
+ let array: ArrayRef = Arc::new(decimal_array);
+
+ let mut pop_acc = VarianceAccumulator::try_new(StatsType::Population)?;
Review Comment:
Shouldn't this use `DecimalVarianceAccumulator` ?
##########
datafusion/functions-aggregate/src/variance.rs:
##########
@@ -55,6 +67,538 @@ make_udaf_expr_and_func!(
var_pop_udaf
);
+fn variance_signature() -> Signature {
+ Signature::one_of(
+ vec![
+ TypeSignature::Numeric(1),
+ TypeSignature::Coercible(vec![Coercion::new_exact(
+ TypeSignatureClass::Decimal,
+ )]),
+ ],
+ Volatility::Immutable,
+ )
+}
+
+const DECIMAL_VARIANCE_BINARY_SIZE: i32 = 32;
+
+fn decimal_overflow_err() -> DataFusionError {
+ DataFusionError::Execution("Decimal variance overflow".to_string())
+}
+
+fn i256_to_f64_lossy(value: i256) -> f64 {
+ const SCALE: f64 = 18446744073709551616.0; // 2^64
+ let mut abs = value;
+ let negative = abs < i256::ZERO;
+ if negative {
+ abs = abs.neg();
+ }
+ let bytes = abs.to_le_bytes();
+ let mut result = 0f64;
+ for chunk in bytes.chunks_exact(8).rev() {
+ let chunk_val = u64::from_le_bytes(chunk.try_into().unwrap());
+ result = result * SCALE + chunk_val as f64;
+ }
+ if negative {
+ -result
+ } else {
+ result
+ }
+}
+
+fn decimal_scale(dt: &DataType) -> Option<i8> {
+ match dt {
+ DataType::Decimal32(_, scale)
+ | DataType::Decimal64(_, scale)
+ | DataType::Decimal128(_, scale)
+ | DataType::Decimal256(_, scale) => Some(*scale),
+ _ => None,
+ }
+}
+
+fn decimal_variance_state_fields(name: &str) -> Vec<FieldRef> {
+ vec![
+ Field::new(format_state_name(name, "count"), DataType::UInt64, true),
+ Field::new(
+ format_state_name(name, "sum"),
+ DataType::FixedSizeBinary(DECIMAL_VARIANCE_BINARY_SIZE),
+ true,
+ ),
+ Field::new(
+ format_state_name(name, "sum_squares"),
+ DataType::FixedSizeBinary(DECIMAL_VARIANCE_BINARY_SIZE),
+ true,
+ ),
+ ]
+ .into_iter()
+ .map(Arc::new)
+ .collect()
+}
+
+fn is_numeric_or_decimal(data_type: &DataType) -> bool {
+ data_type.is_numeric()
+ || matches!(
+ data_type,
+ DataType::Decimal32(_, _)
+ | DataType::Decimal64(_, _)
+ | DataType::Decimal128(_, _)
+ | DataType::Decimal256(_, _)
+ )
+}
+
+fn i256_from_bytes(bytes: &[u8]) -> Result<i256> {
+ if bytes.len() != DECIMAL_VARIANCE_BINARY_LEN {
+ return exec_err!(
+ "Decimal variance state expected {} bytes got {}",
+ DECIMAL_VARIANCE_BINARY_LEN,
+ bytes.len()
+ );
+ }
+ let mut buffer = [0u8; DECIMAL_VARIANCE_BINARY_LEN];
+ buffer.copy_from_slice(bytes);
+ Ok(i256::from_le_bytes(buffer))
+}
+
+const DECIMAL_VARIANCE_BINARY_LEN: usize = DECIMAL_VARIANCE_BINARY_SIZE as
usize;
+
+fn i256_to_scalar(value: i256) -> ScalarValue {
+ ScalarValue::FixedSizeBinary(
+ DECIMAL_VARIANCE_BINARY_SIZE,
+ Some(value.to_le_bytes().to_vec()),
+ )
+}
+
+fn create_decimal_variance_accumulator(
+ data_type: &DataType,
+ stats_type: StatsType,
+) -> Result<Option<Box<dyn Accumulator>>> {
+ let accumulator = match data_type {
+ DataType::Decimal32(_, scale) =>
Some(Box::new(DecimalVarianceAccumulator::<
+ Decimal32Type,
+ >::try_new(
+ *scale, stats_type
+ )?) as Box<dyn Accumulator>),
+ DataType::Decimal64(_, scale) =>
Some(Box::new(DecimalVarianceAccumulator::<
+ Decimal64Type,
+ >::try_new(
+ *scale, stats_type
+ )?) as Box<dyn Accumulator>),
+ DataType::Decimal128(_, scale) =>
Some(Box::new(DecimalVarianceAccumulator::<
+ Decimal128Type,
+ >::try_new(
+ *scale, stats_type
+ )?) as Box<dyn Accumulator>),
+ DataType::Decimal256(_, scale) =>
Some(Box::new(DecimalVarianceAccumulator::<
+ Decimal256Type,
+ >::try_new(
+ *scale, stats_type
+ )?) as Box<dyn Accumulator>),
+ _ => None,
+ };
+ Ok(accumulator)
+}
+
+fn create_decimal_variance_groups_accumulator(
+ data_type: &DataType,
+ stats_type: StatsType,
+) -> Result<Option<Box<dyn GroupsAccumulator>>> {
+ let accumulator = match data_type {
+ DataType::Decimal32(_, scale) => Some(Box::new(
+ DecimalVarianceGroupsAccumulator::<Decimal32Type>::new(*scale,
stats_type),
+ ) as Box<dyn GroupsAccumulator>),
+ DataType::Decimal64(_, scale) => Some(Box::new(
+ DecimalVarianceGroupsAccumulator::<Decimal64Type>::new(*scale,
stats_type),
+ ) as Box<dyn GroupsAccumulator>),
+ DataType::Decimal128(_, scale) => Some(Box::new(
+ DecimalVarianceGroupsAccumulator::<Decimal128Type>::new(*scale,
stats_type),
+ ) as Box<dyn GroupsAccumulator>),
+ DataType::Decimal256(_, scale) => Some(Box::new(
+ DecimalVarianceGroupsAccumulator::<Decimal256Type>::new(*scale,
stats_type),
+ ) as Box<dyn GroupsAccumulator>),
+ _ => None,
+ };
+ Ok(accumulator)
+}
+
+trait DecimalNative: Copy {
+ fn to_i256(self) -> i256;
+}
+
+impl DecimalNative for i32 {
+ fn to_i256(self) -> i256 {
+ i256::from(self)
+ }
+}
+
+impl DecimalNative for i64 {
+ fn to_i256(self) -> i256 {
+ i256::from(self)
+ }
+}
+
+impl DecimalNative for i128 {
+ fn to_i256(self) -> i256 {
+ i256::from_i128(self)
+ }
+}
+
+impl DecimalNative for i256 {
+ fn to_i256(self) -> i256 {
+ self
+ }
+}
+
+#[derive(Clone, Debug, Default)]
+struct DecimalVarianceState {
+ count: u64,
+ sum: i256,
+ sum_squares: i256,
+}
+
+impl DecimalVarianceState {
+ fn update(&mut self, value: i256) -> Result<()> {
+ self.count =
self.count.checked_add(1).ok_or_else(decimal_overflow_err)?;
+ self.sum = self
+ .sum
+ .checked_add(value)
+ .ok_or_else(decimal_overflow_err)?;
+ let square =
value.checked_mul(value).ok_or_else(decimal_overflow_err)?;
+ self.sum_squares = self
+ .sum_squares
+ .checked_add(square)
+ .ok_or_else(decimal_overflow_err)?;
+ Ok(())
+ }
+
+ fn retract(&mut self, value: i256) -> Result<()> {
+ if self.count == 0 {
+ return exec_err!("Decimal variance retract underflow");
+ }
+ self.count -= 1;
+ self.sum = self
+ .sum
+ .checked_sub(value)
+ .ok_or_else(decimal_overflow_err)?;
+ let square =
value.checked_mul(value).ok_or_else(decimal_overflow_err)?;
+ self.sum_squares = self
+ .sum_squares
+ .checked_sub(square)
+ .ok_or_else(decimal_overflow_err)?;
+ Ok(())
+ }
+
+ fn merge(&mut self, other: &Self) -> Result<()> {
+ self.count = self
+ .count
+ .checked_add(other.count)
+ .ok_or_else(decimal_overflow_err)?;
+ self.sum = self
+ .sum
+ .checked_add(other.sum)
+ .ok_or_else(decimal_overflow_err)?;
+ self.sum_squares = self
+ .sum_squares
+ .checked_add(other.sum_squares)
+ .ok_or_else(decimal_overflow_err)?;
+ Ok(())
+ }
+
+ fn variance(&self, stats_type: StatsType, scale: i8) ->
Result<Option<f64>> {
+ if self.count == 0 {
+ return Ok(None);
+ }
+ if matches!(stats_type, StatsType::Sample) && self.count <= 1 {
+ return Ok(None);
+ }
+
+ let count_i256 = i256::from_i128(self.count as i128);
+ let scaled_sum_squares = self
+ .sum_squares
+ .checked_mul(count_i256)
+ .ok_or_else(decimal_overflow_err)?;
+ let sum_squared = self
+ .sum
+ .checked_mul(self.sum)
+ .ok_or_else(decimal_overflow_err)?;
+ let numerator = scaled_sum_squares
+ .checked_sub(sum_squared)
+ .ok_or_else(decimal_overflow_err)?;
+
+ let numerator = if numerator < i256::ZERO {
+ i256::ZERO
+ } else {
+ numerator
+ };
+
+ let denominator_counts = match stats_type {
+ StatsType::Population => {
+ let count = self.count as f64;
+ count * count
+ }
+ StatsType::Sample => {
+ let count = self.count as f64;
+ count * ((self.count - 1) as f64)
+ }
+ };
+
+ if denominator_counts == 0.0 {
+ return Ok(None);
+ }
+
+ let numerator_f64 = i256_to_f64_lossy(numerator);
+ let scale_factor = 10f64.powi(2 * scale as i32);
+ Ok(Some(numerator_f64 / (denominator_counts * scale_factor)))
+ }
+
+ fn to_scalar_state(&self) -> Vec<ScalarValue> {
+ vec![
+ ScalarValue::from(self.count),
+ i256_to_scalar(self.sum),
+ i256_to_scalar(self.sum_squares),
+ ]
+ }
+}
+
+#[derive(Debug)]
+struct DecimalVarianceAccumulator<T>
+where
+ T: DecimalType + ArrowNumericType + Debug,
+ T::Native: DecimalNative,
+{
+ state: DecimalVarianceState,
+ scale: i8,
+ stats_type: StatsType,
+ _marker: PhantomData<T>,
+}
+
+impl<T> DecimalVarianceAccumulator<T>
+where
+ T: DecimalType + ArrowNumericType + Debug,
+ T::Native: DecimalNative,
+{
+ fn try_new(scale: i8, stats_type: StatsType) -> Result<Self> {
+ if scale > DECIMAL256_MAX_SCALE {
+ return exec_err!(
+ "Decimal variance does not support scale {} greater than {}",
+ scale,
+ DECIMAL256_MAX_SCALE
+ );
+ }
+ Ok(Self {
+ state: DecimalVarianceState::default(),
+ scale,
+ stats_type,
+ _marker: PhantomData,
+ })
+ }
+
+ fn convert_array(values: &ArrayRef) -> &PrimitiveArray<T> {
+ values.as_primitive::<T>()
+ }
+}
+
+impl<T> Accumulator for DecimalVarianceAccumulator<T>
+where
+ T: DecimalType + ArrowNumericType + Debug,
+ T::Native: DecimalNative,
+{
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ Ok(self.state.to_scalar_state())
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let array = Self::convert_array(&values[0]);
+ for value in array.iter().flatten() {
+ self.state.update(value.to_i256())?;
+ }
+ Ok(())
+ }
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let array = Self::convert_array(&values[0]);
+ for value in array.iter().flatten() {
+ self.state.retract(value.to_i256())?;
+ }
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let counts = downcast_value!(states[0], UInt64Array);
+ let sums = downcast_value!(states[1], FixedSizeBinaryArray);
+ let sum_squares = downcast_value!(states[2], FixedSizeBinaryArray);
+
+ for i in 0..counts.len() {
+ if counts.is_null(i) {
+ continue;
+ }
+ let count = counts.value(i);
+ if count == 0 {
+ continue;
+ }
+ let sum = i256_from_bytes(sums.value(i))?;
+ let sum_sq = i256_from_bytes(sum_squares.value(i))?;
+ let other = DecimalVarianceState {
+ count,
+ sum,
+ sum_squares: sum_sq,
+ };
+ self.state.merge(&other)?;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ match self.state.variance(self.stats_type, self.scale)? {
+ Some(v) => Ok(ScalarValue::Float64(Some(v))),
+ None => Ok(ScalarValue::Float64(None)),
+ }
+ }
+
+ fn size(&self) -> usize {
+ size_of_val(self)
+ }
+
+ fn supports_retract_batch(&self) -> bool {
+ true
+ }
+}
+
+#[derive(Debug)]
+struct DecimalVarianceGroupsAccumulator<T>
+where
+ T: DecimalType + ArrowNumericType + Debug,
+ T::Native: DecimalNative,
+{
+ states: Vec<DecimalVarianceState>,
+ scale: i8,
+ stats_type: StatsType,
+ _marker: PhantomData<T>,
+}
+
+impl<T> DecimalVarianceGroupsAccumulator<T>
+where
+ T: DecimalType + ArrowNumericType + Debug,
+ T::Native: DecimalNative,
+{
+ fn new(scale: i8, stats_type: StatsType) -> Self {
+ Self {
+ states: Vec::new(),
+ scale,
+ stats_type,
+ _marker: PhantomData,
+ }
+ }
+
+ fn resize(&mut self, total_num_groups: usize) {
+ if self.states.len() < total_num_groups {
+ self.states
+ .resize(total_num_groups, DecimalVarianceState::default());
+ }
+ }
+}
+
+impl<T> GroupsAccumulator for DecimalVarianceGroupsAccumulator<T>
+where
+ T: DecimalType + ArrowNumericType + Debug,
+ T::Native: DecimalNative,
+{
+ fn update_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ let array = values[0].as_primitive::<T>();
+ self.resize(total_num_groups);
+ for (row, group_index) in group_indices.iter().enumerate() {
+ if let Some(filter) = opt_filter {
+ if !filter.value(row) {
Review Comment:
```suggestion
if !filter.is_valid(row) || !filter.value(row) {
```
##########
datafusion/functions-aggregate/src/variance.rs:
##########
@@ -583,10 +1169,53 @@ impl GroupsAccumulator for VarianceGroupsAccumulator {
#[cfg(test)]
mod tests {
+ use arrow::array::Decimal128Builder;
use datafusion_expr::EmitTo;
+ use std::sync::Arc;
use super::*;
+ #[test]
+ fn variance_population_accepts_decimal() -> Result<()> {
+ let variance = VariancePopulation::new();
+ variance.return_type(&[DataType::Decimal128(10, 3)])?;
+ Ok(())
+ }
+
+ #[test]
+ fn variance_decimal_input() -> Result<()> {
+ let mut builder = Decimal128Builder::with_capacity(20);
+ for i in 0..10 {
+ builder.append_value(110000 + i);
+ }
+ for i in 0..10 {
+ builder.append_value(-((100000 + i) as i128));
+ }
+ let decimal_array = builder.finish().with_precision_and_scale(10,
3).unwrap();
+ let array: ArrayRef = Arc::new(decimal_array);
+
+ let mut pop_acc = VarianceAccumulator::try_new(StatsType::Population)?;
+ let pop_input = [Arc::clone(&array)];
+ pop_acc.update_batch(&pop_input)?;
+ assert_variance(pop_acc.evaluate()?, 11025.9450285);
+
+ let mut sample_acc = VarianceAccumulator::try_new(StatsType::Sample)?;
Review Comment:
Here too
##########
datafusion/functions-aggregate/src/variance.rs:
##########
@@ -55,6 +67,538 @@ make_udaf_expr_and_func!(
var_pop_udaf
);
+fn variance_signature() -> Signature {
+ Signature::one_of(
+ vec![
+ TypeSignature::Numeric(1),
+ TypeSignature::Coercible(vec![Coercion::new_exact(
+ TypeSignatureClass::Decimal,
+ )]),
+ ],
+ Volatility::Immutable,
+ )
+}
+
+const DECIMAL_VARIANCE_BINARY_SIZE: i32 = 32;
+
+fn decimal_overflow_err() -> DataFusionError {
+ DataFusionError::Execution("Decimal variance overflow".to_string())
+}
+
+fn i256_to_f64_lossy(value: i256) -> f64 {
+ const SCALE: f64 = 18446744073709551616.0; // 2^64
+ let mut abs = value;
+ let negative = abs < i256::ZERO;
+ if negative {
+ abs = abs.neg();
+ }
+ let bytes = abs.to_le_bytes();
+ let mut result = 0f64;
+ for chunk in bytes.chunks_exact(8).rev() {
+ let chunk_val = u64::from_le_bytes(chunk.try_into().unwrap());
+ result = result * SCALE + chunk_val as f64;
+ }
+ if negative {
+ -result
+ } else {
+ result
+ }
+}
+
+fn decimal_scale(dt: &DataType) -> Option<i8> {
+ match dt {
+ DataType::Decimal32(_, scale)
+ | DataType::Decimal64(_, scale)
+ | DataType::Decimal128(_, scale)
+ | DataType::Decimal256(_, scale) => Some(*scale),
+ _ => None,
+ }
+}
+
+fn decimal_variance_state_fields(name: &str) -> Vec<FieldRef> {
+ vec![
+ Field::new(format_state_name(name, "count"), DataType::UInt64, true),
+ Field::new(
+ format_state_name(name, "sum"),
+ DataType::FixedSizeBinary(DECIMAL_VARIANCE_BINARY_SIZE),
+ true,
+ ),
+ Field::new(
+ format_state_name(name, "sum_squares"),
+ DataType::FixedSizeBinary(DECIMAL_VARIANCE_BINARY_SIZE),
+ true,
+ ),
+ ]
+ .into_iter()
+ .map(Arc::new)
+ .collect()
+}
+
+fn is_numeric_or_decimal(data_type: &DataType) -> bool {
+ data_type.is_numeric()
+ || matches!(
+ data_type,
+ DataType::Decimal32(_, _)
+ | DataType::Decimal64(_, _)
+ | DataType::Decimal128(_, _)
+ | DataType::Decimal256(_, _)
+ )
+}
+
+fn i256_from_bytes(bytes: &[u8]) -> Result<i256> {
+ if bytes.len() != DECIMAL_VARIANCE_BINARY_LEN {
+ return exec_err!(
+ "Decimal variance state expected {} bytes got {}",
+ DECIMAL_VARIANCE_BINARY_LEN,
+ bytes.len()
+ );
+ }
+ let mut buffer = [0u8; DECIMAL_VARIANCE_BINARY_LEN];
+ buffer.copy_from_slice(bytes);
+ Ok(i256::from_le_bytes(buffer))
+}
+
+const DECIMAL_VARIANCE_BINARY_LEN: usize = DECIMAL_VARIANCE_BINARY_SIZE as
usize;
+
+fn i256_to_scalar(value: i256) -> ScalarValue {
+ ScalarValue::FixedSizeBinary(
+ DECIMAL_VARIANCE_BINARY_SIZE,
+ Some(value.to_le_bytes().to_vec()),
+ )
+}
+
+fn create_decimal_variance_accumulator(
+ data_type: &DataType,
+ stats_type: StatsType,
+) -> Result<Option<Box<dyn Accumulator>>> {
+ let accumulator = match data_type {
+ DataType::Decimal32(_, scale) =>
Some(Box::new(DecimalVarianceAccumulator::<
+ Decimal32Type,
+ >::try_new(
+ *scale, stats_type
+ )?) as Box<dyn Accumulator>),
+ DataType::Decimal64(_, scale) =>
Some(Box::new(DecimalVarianceAccumulator::<
+ Decimal64Type,
+ >::try_new(
+ *scale, stats_type
+ )?) as Box<dyn Accumulator>),
+ DataType::Decimal128(_, scale) =>
Some(Box::new(DecimalVarianceAccumulator::<
+ Decimal128Type,
+ >::try_new(
+ *scale, stats_type
+ )?) as Box<dyn Accumulator>),
+ DataType::Decimal256(_, scale) =>
Some(Box::new(DecimalVarianceAccumulator::<
+ Decimal256Type,
+ >::try_new(
+ *scale, stats_type
+ )?) as Box<dyn Accumulator>),
+ _ => None,
+ };
+ Ok(accumulator)
+}
+
+fn create_decimal_variance_groups_accumulator(
+ data_type: &DataType,
+ stats_type: StatsType,
+) -> Result<Option<Box<dyn GroupsAccumulator>>> {
+ let accumulator = match data_type {
+ DataType::Decimal32(_, scale) => Some(Box::new(
+ DecimalVarianceGroupsAccumulator::<Decimal32Type>::new(*scale,
stats_type),
+ ) as Box<dyn GroupsAccumulator>),
+ DataType::Decimal64(_, scale) => Some(Box::new(
+ DecimalVarianceGroupsAccumulator::<Decimal64Type>::new(*scale,
stats_type),
+ ) as Box<dyn GroupsAccumulator>),
+ DataType::Decimal128(_, scale) => Some(Box::new(
+ DecimalVarianceGroupsAccumulator::<Decimal128Type>::new(*scale,
stats_type),
+ ) as Box<dyn GroupsAccumulator>),
+ DataType::Decimal256(_, scale) => Some(Box::new(
+ DecimalVarianceGroupsAccumulator::<Decimal256Type>::new(*scale,
stats_type),
+ ) as Box<dyn GroupsAccumulator>),
+ _ => None,
+ };
+ Ok(accumulator)
+}
+
+trait DecimalNative: Copy {
+ fn to_i256(self) -> i256;
+}
+
+impl DecimalNative for i32 {
+ fn to_i256(self) -> i256 {
+ i256::from(self)
+ }
+}
+
+impl DecimalNative for i64 {
+ fn to_i256(self) -> i256 {
+ i256::from(self)
+ }
+}
+
+impl DecimalNative for i128 {
+ fn to_i256(self) -> i256 {
+ i256::from_i128(self)
+ }
+}
+
+impl DecimalNative for i256 {
+ fn to_i256(self) -> i256 {
+ self
+ }
+}
+
+#[derive(Clone, Debug, Default)]
+struct DecimalVarianceState {
+ count: u64,
+ sum: i256,
+ sum_squares: i256,
+}
+
+impl DecimalVarianceState {
+ fn update(&mut self, value: i256) -> Result<()> {
+ self.count =
self.count.checked_add(1).ok_or_else(decimal_overflow_err)?;
+ self.sum = self
+ .sum
+ .checked_add(value)
+ .ok_or_else(decimal_overflow_err)?;
+ let square =
value.checked_mul(value).ok_or_else(decimal_overflow_err)?;
+ self.sum_squares = self
+ .sum_squares
+ .checked_add(square)
+ .ok_or_else(decimal_overflow_err)?;
+ Ok(())
+ }
+
+ fn retract(&mut self, value: i256) -> Result<()> {
+ if self.count == 0 {
+ return exec_err!("Decimal variance retract underflow");
+ }
+ self.count -= 1;
+ self.sum = self
+ .sum
+ .checked_sub(value)
+ .ok_or_else(decimal_overflow_err)?;
+ let square =
value.checked_mul(value).ok_or_else(decimal_overflow_err)?;
+ self.sum_squares = self
+ .sum_squares
+ .checked_sub(square)
+ .ok_or_else(decimal_overflow_err)?;
+ Ok(())
+ }
+
+ fn merge(&mut self, other: &Self) -> Result<()> {
+ self.count = self
+ .count
+ .checked_add(other.count)
+ .ok_or_else(decimal_overflow_err)?;
+ self.sum = self
+ .sum
+ .checked_add(other.sum)
+ .ok_or_else(decimal_overflow_err)?;
+ self.sum_squares = self
+ .sum_squares
+ .checked_add(other.sum_squares)
+ .ok_or_else(decimal_overflow_err)?;
+ Ok(())
+ }
+
+ fn variance(&self, stats_type: StatsType, scale: i8) ->
Result<Option<f64>> {
+ if self.count == 0 {
+ return Ok(None);
+ }
+ if matches!(stats_type, StatsType::Sample) && self.count <= 1 {
+ return Ok(None);
+ }
+
+ let count_i256 = i256::from_i128(self.count as i128);
+ let scaled_sum_squares = self
+ .sum_squares
+ .checked_mul(count_i256)
+ .ok_or_else(decimal_overflow_err)?;
+ let sum_squared = self
+ .sum
+ .checked_mul(self.sum)
+ .ok_or_else(decimal_overflow_err)?;
+ let numerator = scaled_sum_squares
+ .checked_sub(sum_squared)
+ .ok_or_else(decimal_overflow_err)?;
+
+ let numerator = if numerator < i256::ZERO {
+ i256::ZERO
+ } else {
+ numerator
+ };
+
+ let denominator_counts = match stats_type {
+ StatsType::Population => {
+ let count = self.count as f64;
+ count * count
+ }
+ StatsType::Sample => {
+ let count = self.count as f64;
+ count * ((self.count - 1) as f64)
+ }
+ };
+
+ if denominator_counts == 0.0 {
+ return Ok(None);
+ }
+
+ let numerator_f64 = i256_to_f64_lossy(numerator);
+ let scale_factor = 10f64.powi(2 * scale as i32);
+ Ok(Some(numerator_f64 / (denominator_counts * scale_factor)))
+ }
+
+ fn to_scalar_state(&self) -> Vec<ScalarValue> {
+ vec![
+ ScalarValue::from(self.count),
+ i256_to_scalar(self.sum),
+ i256_to_scalar(self.sum_squares),
+ ]
+ }
+}
+
+#[derive(Debug)]
+struct DecimalVarianceAccumulator<T>
+where
+ T: DecimalType + ArrowNumericType + Debug,
+ T::Native: DecimalNative,
+{
+ state: DecimalVarianceState,
+ scale: i8,
+ stats_type: StatsType,
+ _marker: PhantomData<T>,
+}
+
+impl<T> DecimalVarianceAccumulator<T>
+where
+ T: DecimalType + ArrowNumericType + Debug,
+ T::Native: DecimalNative,
+{
+ fn try_new(scale: i8, stats_type: StatsType) -> Result<Self> {
+ if scale > DECIMAL256_MAX_SCALE {
+ return exec_err!(
+ "Decimal variance does not support scale {} greater than {}",
+ scale,
+ DECIMAL256_MAX_SCALE
+ );
+ }
+ Ok(Self {
+ state: DecimalVarianceState::default(),
+ scale,
+ stats_type,
+ _marker: PhantomData,
+ })
+ }
+
+ fn convert_array(values: &ArrayRef) -> &PrimitiveArray<T> {
+ values.as_primitive::<T>()
+ }
+}
+
+impl<T> Accumulator for DecimalVarianceAccumulator<T>
+where
+ T: DecimalType + ArrowNumericType + Debug,
+ T::Native: DecimalNative,
+{
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ Ok(self.state.to_scalar_state())
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let array = Self::convert_array(&values[0]);
+ for value in array.iter().flatten() {
+ self.state.update(value.to_i256())?;
+ }
+ Ok(())
+ }
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let array = Self::convert_array(&values[0]);
+ for value in array.iter().flatten() {
+ self.state.retract(value.to_i256())?;
+ }
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let counts = downcast_value!(states[0], UInt64Array);
+ let sums = downcast_value!(states[1], FixedSizeBinaryArray);
+ let sum_squares = downcast_value!(states[2], FixedSizeBinaryArray);
+
+ for i in 0..counts.len() {
+ if counts.is_null(i) {
+ continue;
+ }
+ let count = counts.value(i);
+ if count == 0 {
+ continue;
+ }
+ let sum = i256_from_bytes(sums.value(i))?;
+ let sum_sq = i256_from_bytes(sum_squares.value(i))?;
+ let other = DecimalVarianceState {
+ count,
+ sum,
+ sum_squares: sum_sq,
+ };
+ self.state.merge(&other)?;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ match self.state.variance(self.stats_type, self.scale)? {
+ Some(v) => Ok(ScalarValue::Float64(Some(v))),
+ None => Ok(ScalarValue::Float64(None)),
+ }
+ }
+
+ fn size(&self) -> usize {
+ size_of_val(self)
+ }
+
+ fn supports_retract_batch(&self) -> bool {
+ true
+ }
+}
+
+#[derive(Debug)]
+struct DecimalVarianceGroupsAccumulator<T>
+where
+ T: DecimalType + ArrowNumericType + Debug,
+ T::Native: DecimalNative,
+{
+ states: Vec<DecimalVarianceState>,
+ scale: i8,
+ stats_type: StatsType,
+ _marker: PhantomData<T>,
+}
+
+impl<T> DecimalVarianceGroupsAccumulator<T>
+where
+ T: DecimalType + ArrowNumericType + Debug,
+ T::Native: DecimalNative,
+{
+ fn new(scale: i8, stats_type: StatsType) -> Self {
Review Comment:
Why here the scale is not checked (`scale > DECIMAL256_MAX_SCALE`) ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]