This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 835553ea5f Deprecate ScalarValue bitor, bitand, and bitxor (#6842) 
(#7351)
835553ea5f is described below

commit 835553ea5fa3216f1911221c212742d8d0b0a621
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Mon Aug 21 22:04:22 2023 +0100

    Deprecate ScalarValue bitor, bitand, and bitxor (#6842) (#7351)
    
    * Deprecate ScalarValue bitor, bitand, and bitxor (#6842)
    
    * Fixes
    
    * Format
    
    * Fix BitAndAccumulator
---
 datafusion/common/src/scalar.rs                    |  25 +-
 .../physical-expr/src/aggregate/bit_and_or_xor.rs  | 484 +++++++++------------
 2 files changed, 231 insertions(+), 278 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 715725a34d..73b71722f9 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -46,7 +46,8 @@ use arrow::{
         DECIMAL128_MAX_PRECISION,
     },
 };
-use arrow_array::{timezone::Tz, ArrowNativeTypeOp};
+use arrow_array::timezone::Tz;
+use arrow_array::ArrowNativeTypeOp;
 use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime};
 
 // Constants we use throughout this file:
@@ -1779,6 +1780,25 @@ macro_rules! eq_array_primitive {
 }
 
 impl ScalarValue {
+    /// Create a [`ScalarValue`] with the provided value and datatype
+    ///
+    /// # Panics
+    ///
+    /// Panics if d is not compatible with T
+    pub fn new_primitive<T: ArrowPrimitiveType>(
+        a: Option<T::Native>,
+        d: &DataType,
+    ) -> Self {
+        match a {
+            None => d.try_into().unwrap(),
+            Some(v) => {
+                let array = PrimitiveArray::<T>::new(vec![v].into(), None)
+                    .with_data_type(d.clone());
+                Self::try_from_array(&array, 0).unwrap()
+            }
+        }
+    }
+
     /// Create a decimal Scalar from value/precision and scale.
     pub fn try_new_decimal128(value: i128, precision: u8, scale: i8) -> 
Result<Self> {
         // make sure the precision and scale is valid
@@ -2089,11 +2109,13 @@ impl ScalarValue {
         impl_op!(self, rhs, &)
     }
 
+    #[deprecated(note = "Use arrow kernels or specialization (#6842)")]
     pub fn bitor<T: Borrow<ScalarValue>>(&self, other: T) -> 
Result<ScalarValue> {
         let rhs = other.borrow();
         impl_op!(self, rhs, |)
     }
 
+    #[deprecated(note = "Use arrow kernels or specialization (#6842)")]
     pub fn bitxor<T: Borrow<ScalarValue>>(&self, other: T) -> 
Result<ScalarValue> {
         let rhs = other.borrow();
         impl_op!(self, rhs, ^)
@@ -4059,6 +4081,7 @@ mod tests {
     use arrow::datatypes::ArrowPrimitiveType;
     use arrow::util::pretty::pretty_format_columns;
     use arrow_array::ArrowNumericType;
+    use chrono::NaiveDate;
     use rand::Rng;
 
     use crate::cast::{as_string_array, as_uint32_array, as_uint64_array};
diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs 
b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
index 55ca02a147..93b911c939 100644
--- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
+++ b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
@@ -19,24 +19,12 @@
 
 use ahash::RandomState;
 use std::any::Any;
-use std::convert::TryFrom;
 use std::sync::Arc;
 
 use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr};
-use arrow::datatypes::{
-    DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, 
UInt32Type,
-    UInt64Type, UInt8Type,
-};
-use arrow::{
-    array::{
-        ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array,
-        UInt32Array, UInt64Array, UInt8Array,
-    },
-    datatypes::Field,
-};
-use datafusion_common::{
-    downcast_value, internal_err, not_impl_err, DataFusionError, Result, 
ScalarValue,
-};
+use arrow::datatypes::DataType;
+use arrow::{array::ArrayRef, datatypes::Field};
+use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
 use datafusion_expr::Accumulator;
 use std::collections::HashSet;
 
@@ -45,81 +33,9 @@ use crate::aggregate::utils::down_cast_any_ref;
 use crate::expressions::format_state_name;
 use arrow::array::Array;
 use arrow::compute::{bit_and, bit_or, bit_xor};
-
-/// Creates a [`PrimitiveGroupsAccumulator`] with the specified
-/// [`ArrowPrimitiveType`] that initailizes each accumulator to $START
-/// and applies `$FN` to each element
-///
-/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
-macro_rules! instantiate_accumulator {
-    ($SELF:expr, $START:expr, $PRIMTYPE:ident, $FN:expr) => {{
-        Ok(Box::new(
-            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$SELF.data_type, 
$FN)
-                .with_starting_value($START),
-        ))
-    }};
-}
-
-// returns the new value after bit_and/bit_or/bit_xor with the new values, 
taking nullability into account
-macro_rules! typed_bit_and_or_xor_batch {
-    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
-        let array = downcast_value!($VALUES, $ARRAYTYPE);
-        let delta = $OP(array);
-        Ok(ScalarValue::$SCALAR(delta))
-    }};
-}
-
-// bit_and/bit_or/bit_xor the array and returns a ScalarValue of its 
corresponding type.
-macro_rules! bit_and_or_xor_batch {
-    ($VALUES:expr, $OP:ident) => {{
-        match $VALUES.data_type() {
-            DataType::Int64 => {
-                typed_bit_and_or_xor_batch!($VALUES, Int64Array, Int64, $OP)
-            }
-            DataType::Int32 => {
-                typed_bit_and_or_xor_batch!($VALUES, Int32Array, Int32, $OP)
-            }
-            DataType::Int16 => {
-                typed_bit_and_or_xor_batch!($VALUES, Int16Array, Int16, $OP)
-            }
-            DataType::Int8 => {
-                typed_bit_and_or_xor_batch!($VALUES, Int8Array, Int8, $OP)
-            }
-            DataType::UInt64 => {
-                typed_bit_and_or_xor_batch!($VALUES, UInt64Array, UInt64, $OP)
-            }
-            DataType::UInt32 => {
-                typed_bit_and_or_xor_batch!($VALUES, UInt32Array, UInt32, $OP)
-            }
-            DataType::UInt16 => {
-                typed_bit_and_or_xor_batch!($VALUES, UInt16Array, UInt16, $OP)
-            }
-            DataType::UInt8 => {
-                typed_bit_and_or_xor_batch!($VALUES, UInt8Array, UInt8, $OP)
-            }
-            e => {
-                return internal_err!(
-                    "Bit and/Bit or/Bit xor is not expected to receive the 
type {e:?}"
-                );
-            }
-        }
-    }};
-}
-
-/// dynamically-typed bit_and(array) -> ScalarValue
-fn bit_and_batch(values: &ArrayRef) -> Result<ScalarValue> {
-    bit_and_or_xor_batch!(values, bit_and)
-}
-
-/// dynamically-typed bit_or(array) -> ScalarValue
-fn bit_or_batch(values: &ArrayRef) -> Result<ScalarValue> {
-    bit_and_or_xor_batch!(values, bit_or)
-}
-
-/// dynamically-typed bit_xor(array) -> ScalarValue
-fn bit_xor_batch(values: &ArrayRef) -> Result<ScalarValue> {
-    bit_and_or_xor_batch!(values, bit_xor)
-}
+use arrow_array::cast::AsArray;
+use arrow_array::{downcast_integer, ArrowNumericType};
+use arrow_buffer::ArrowNativeType;
 
 /// BIT_AND aggregate expression
 #[derive(Debug, Clone)]
@@ -161,7 +77,19 @@ impl AggregateExpr for BitAnd {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(BitAndAccumulator::try_new(&self.data_type)?))
+        macro_rules! helper {
+            ($t:ty) => {
+                Ok(Box::<BitAndAccumulator<$t>>::default())
+            };
+        }
+        downcast_integer! {
+            &self.data_type => (helper),
+            _ => Err(DataFusionError::NotImplemented(format!(
+                "BitAndAccumulator not supported for {} with {}",
+                self.name(),
+                self.data_type
+            ))),
+        }
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
@@ -186,39 +114,22 @@ impl AggregateExpr for BitAnd {
 
     fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
         use std::ops::BitAndAssign;
-        // Note the default value for BitAnd should be all set
-        // (e.g. `0b11...111`) use MAX / -1 here to get appropriate
-        // bit pattern for each type
-        match self.data_type {
-            DataType::Int8 => {
-                instantiate_accumulator!(self, -1, Int8Type, |x, y| 
x.bitand_assign(y))
-            }
-            DataType::Int16 => {
-                instantiate_accumulator!(self, -1, Int16Type, |x, y| 
x.bitand_assign(y))
-            }
-            DataType::Int32 => {
-                instantiate_accumulator!(self, -1, Int32Type, |x, y| 
x.bitand_assign(y))
-            }
-            DataType::Int64 => {
-                instantiate_accumulator!(self, -1, Int64Type, |x, y| 
x.bitand_assign(y))
-            }
-            DataType::UInt8 => {
-                instantiate_accumulator!(self, u8::MAX, UInt8Type, |x, y| x
-                    .bitand_assign(y))
-            }
-            DataType::UInt16 => {
-                instantiate_accumulator!(self, u16::MAX, UInt16Type, |x, y| x
-                    .bitand_assign(y))
-            }
-            DataType::UInt32 => {
-                instantiate_accumulator!(self, u32::MAX, UInt32Type, |x, y| x
-                    .bitand_assign(y))
-            }
-            DataType::UInt64 => {
-                instantiate_accumulator!(self, u64::MAX, UInt64Type, |x, y| x
-                    .bitand_assign(y))
-            }
 
+        // Note the default value for BitAnd should be all set, i.e. `!0`
+        macro_rules! helper {
+            ($t:ty, $dt:expr) => {
+                Ok(Box::new(
+                    PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| {
+                        x.bitand_assign(y)
+                    })
+                    .with_starting_value(!0),
+                ))
+            };
+        }
+
+        let data_type = &self.data_type;
+        downcast_integer! {
+            data_type => (helper, data_type),
             _ => not_impl_err!(
                 "GroupsAccumulator not supported for {} with {}",
                 self.name(),
@@ -246,25 +157,31 @@ impl PartialEq<dyn Any> for BitAnd {
     }
 }
 
-#[derive(Debug)]
-struct BitAndAccumulator {
-    bit_and: ScalarValue,
+struct BitAndAccumulator<T: ArrowNumericType> {
+    value: Option<T::Native>,
+}
+
+impl<T: ArrowNumericType> std::fmt::Debug for BitAndAccumulator<T> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "BitAndAccumulator({})", T::DATA_TYPE)
+    }
 }
 
-impl BitAndAccumulator {
-    /// new bit_and accumulator
-    pub fn try_new(data_type: &DataType) -> Result<Self> {
-        Ok(Self {
-            bit_and: ScalarValue::try_from(data_type)?,
-        })
+impl<T: ArrowNumericType> Default for BitAndAccumulator<T> {
+    fn default() -> Self {
+        Self { value: None }
     }
 }
 
-impl Accumulator for BitAndAccumulator {
+impl<T: ArrowNumericType> Accumulator for BitAndAccumulator<T>
+where
+    T::Native: std::ops::BitAnd<Output = T::Native>,
+{
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let values = &values[0];
-        let delta = &bit_and_batch(values)?;
-        self.bit_and = self.bit_and.bitand(delta)?;
+        if let Some(x) = bit_and(values[0].as_primitive::<T>()) {
+            let v = self.value.get_or_insert(x);
+            *v = *v & x;
+        }
         Ok(())
     }
 
@@ -273,16 +190,15 @@ impl Accumulator for BitAndAccumulator {
     }
 
     fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![self.bit_and.clone()])
+        Ok(vec![self.evaluate()?])
     }
 
     fn evaluate(&self) -> Result<ScalarValue> {
-        Ok(self.bit_and.clone())
+        Ok(ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE))
     }
 
     fn size(&self) -> usize {
-        std::mem::size_of_val(self) - std::mem::size_of_val(&self.bit_and)
-            + self.bit_and.size()
+        std::mem::size_of_val(self)
     }
 }
 
@@ -326,7 +242,19 @@ impl AggregateExpr for BitOr {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(BitOrAccumulator::try_new(&self.data_type)?))
+        macro_rules! helper {
+            ($t:ty) => {
+                Ok(Box::<BitOrAccumulator<$t>>::default())
+            };
+        }
+        downcast_integer! {
+            &self.data_type => (helper),
+            _ => Err(DataFusionError::NotImplemented(format!(
+                "BitOrAccumulator not supported for {} with {}",
+                self.name(),
+                self.data_type
+            ))),
+        }
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
@@ -351,32 +279,18 @@ impl AggregateExpr for BitOr {
 
     fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
         use std::ops::BitOrAssign;
-        match self.data_type {
-            DataType::Int8 => {
-                instantiate_accumulator!(self, 0, Int8Type, |x, y| 
x.bitor_assign(y))
-            }
-            DataType::Int16 => {
-                instantiate_accumulator!(self, 0, Int16Type, |x, y| 
x.bitor_assign(y))
-            }
-            DataType::Int32 => {
-                instantiate_accumulator!(self, 0, Int32Type, |x, y| 
x.bitor_assign(y))
-            }
-            DataType::Int64 => {
-                instantiate_accumulator!(self, 0, Int64Type, |x, y| 
x.bitor_assign(y))
-            }
-            DataType::UInt8 => {
-                instantiate_accumulator!(self, 0, UInt8Type, |x, y| 
x.bitor_assign(y))
-            }
-            DataType::UInt16 => {
-                instantiate_accumulator!(self, 0, UInt16Type, |x, y| 
x.bitor_assign(y))
-            }
-            DataType::UInt32 => {
-                instantiate_accumulator!(self, 0, UInt32Type, |x, y| 
x.bitor_assign(y))
-            }
-            DataType::UInt64 => {
-                instantiate_accumulator!(self, 0, UInt64Type, |x, y| 
x.bitor_assign(y))
-            }
+        macro_rules! helper {
+            ($t:ty, $dt:expr) => {
+                Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
+                    $dt,
+                    |x, y| x.bitor_assign(y),
+                )))
+            };
+        }
 
+        let data_type = &self.data_type;
+        downcast_integer! {
+            data_type => (helper, data_type),
             _ => not_impl_err!(
                 "GroupsAccumulator not supported for {} with {}",
                 self.name(),
@@ -404,29 +318,35 @@ impl PartialEq<dyn Any> for BitOr {
     }
 }
 
-#[derive(Debug)]
-struct BitOrAccumulator {
-    bit_or: ScalarValue,
+struct BitOrAccumulator<T: ArrowNumericType> {
+    value: Option<T::Native>,
+}
+
+impl<T: ArrowNumericType> std::fmt::Debug for BitOrAccumulator<T> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "BitOrAccumulator({})", T::DATA_TYPE)
+    }
 }
 
-impl BitOrAccumulator {
-    /// new bit_or accumulator
-    pub fn try_new(data_type: &DataType) -> Result<Self> {
-        Ok(Self {
-            bit_or: ScalarValue::try_from(data_type)?,
-        })
+impl<T: ArrowNumericType> Default for BitOrAccumulator<T> {
+    fn default() -> Self {
+        Self { value: None }
     }
 }
 
-impl Accumulator for BitOrAccumulator {
+impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T>
+where
+    T::Native: std::ops::BitOr<Output = T::Native>,
+{
     fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![self.bit_or.clone()])
+        Ok(vec![self.evaluate()?])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let values = &values[0];
-        let delta = &bit_or_batch(values)?;
-        self.bit_or = self.bit_or.bitor(delta)?;
+        if let Some(x) = bit_or(values[0].as_primitive::<T>()) {
+            let v = self.value.get_or_insert(T::Native::usize_as(0));
+            *v = *v | x;
+        }
         Ok(())
     }
 
@@ -435,12 +355,11 @@ impl Accumulator for BitOrAccumulator {
     }
 
     fn evaluate(&self) -> Result<ScalarValue> {
-        Ok(self.bit_or.clone())
+        Ok(ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE))
     }
 
     fn size(&self) -> usize {
-        std::mem::size_of_val(self) - std::mem::size_of_val(&self.bit_or)
-            + self.bit_or.size()
+        std::mem::size_of_val(self)
     }
 }
 
@@ -484,7 +403,19 @@ impl AggregateExpr for BitXor {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(BitXorAccumulator::try_new(&self.data_type)?))
+        macro_rules! helper {
+            ($t:ty) => {
+                Ok(Box::<BitXorAccumulator<$t>>::default())
+            };
+        }
+        downcast_integer! {
+            &self.data_type => (helper),
+            _ => Err(DataFusionError::NotImplemented(format!(
+                "BitXor not supported for {} with {}",
+                self.name(),
+                self.data_type
+            ))),
+        }
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
@@ -509,32 +440,18 @@ impl AggregateExpr for BitXor {
 
     fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
         use std::ops::BitXorAssign;
-        match self.data_type {
-            DataType::Int8 => {
-                instantiate_accumulator!(self, 0, Int8Type, |x, y| 
x.bitxor_assign(y))
-            }
-            DataType::Int16 => {
-                instantiate_accumulator!(self, 0, Int16Type, |x, y| 
x.bitxor_assign(y))
-            }
-            DataType::Int32 => {
-                instantiate_accumulator!(self, 0, Int32Type, |x, y| 
x.bitxor_assign(y))
-            }
-            DataType::Int64 => {
-                instantiate_accumulator!(self, 0, Int64Type, |x, y| 
x.bitxor_assign(y))
-            }
-            DataType::UInt8 => {
-                instantiate_accumulator!(self, 0, UInt8Type, |x, y| 
x.bitxor_assign(y))
-            }
-            DataType::UInt16 => {
-                instantiate_accumulator!(self, 0, UInt16Type, |x, y| 
x.bitxor_assign(y))
-            }
-            DataType::UInt32 => {
-                instantiate_accumulator!(self, 0, UInt32Type, |x, y| 
x.bitxor_assign(y))
-            }
-            DataType::UInt64 => {
-                instantiate_accumulator!(self, 0, UInt64Type, |x, y| 
x.bitxor_assign(y))
-            }
+        macro_rules! helper {
+            ($t:ty, $dt:expr) => {
+                Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
+                    $dt,
+                    |x, y| x.bitxor_assign(y),
+                )))
+            };
+        }
 
+        let data_type = &self.data_type;
+        downcast_integer! {
+            data_type => (helper, data_type),
             _ => not_impl_err!(
                 "GroupsAccumulator not supported for {} with {}",
                 self.name(),
@@ -562,29 +479,35 @@ impl PartialEq<dyn Any> for BitXor {
     }
 }
 
-#[derive(Debug)]
-struct BitXorAccumulator {
-    bit_xor: ScalarValue,
+struct BitXorAccumulator<T: ArrowNumericType> {
+    value: Option<T::Native>,
+}
+
+impl<T: ArrowNumericType> std::fmt::Debug for BitXorAccumulator<T> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "BitXorAccumulator({})", T::DATA_TYPE)
+    }
 }
 
-impl BitXorAccumulator {
-    /// new bit_xor accumulator
-    pub fn try_new(data_type: &DataType) -> Result<Self> {
-        Ok(Self {
-            bit_xor: ScalarValue::try_from(data_type)?,
-        })
+impl<T: ArrowNumericType> Default for BitXorAccumulator<T> {
+    fn default() -> Self {
+        Self { value: None }
     }
 }
 
-impl Accumulator for BitXorAccumulator {
+impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T>
+where
+    T::Native: std::ops::BitXor<Output = T::Native>,
+{
     fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![self.bit_xor.clone()])
+        Ok(vec![self.evaluate()?])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let values = &values[0];
-        let delta = &bit_xor_batch(values)?;
-        self.bit_xor = self.bit_xor.bitxor(delta)?;
+        if let Some(x) = bit_xor(values[0].as_primitive::<T>()) {
+            let v = self.value.get_or_insert(T::Native::usize_as(0));
+            *v = *v ^ x;
+        }
         Ok(())
     }
 
@@ -593,12 +516,11 @@ impl Accumulator for BitXorAccumulator {
     }
 
     fn evaluate(&self) -> Result<ScalarValue> {
-        Ok(self.bit_xor.clone())
+        Ok(ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE))
     }
 
     fn size(&self) -> usize {
-        std::mem::size_of_val(self) - std::mem::size_of_val(&self.bit_xor)
-            + self.bit_xor.size()
+        std::mem::size_of_val(self)
     }
 }
 
@@ -642,9 +564,19 @@ impl AggregateExpr for DistinctBitXor {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(DistinctBitXorAccumulator::try_new(
-            &self.data_type,
-        )?))
+        macro_rules! helper {
+            ($t:ty) => {
+                Ok(Box::<DistinctBitXorAccumulator<$t>>::default())
+            };
+        }
+        downcast_integer! {
+            &self.data_type => (helper),
+            _ => Err(DataFusionError::NotImplemented(format!(
+                "DistinctBitXorAccumulator not supported for {} with {}",
+                self.name(),
+                self.data_type
+            ))),
+        }
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
@@ -679,34 +611,39 @@ impl PartialEq<dyn Any> for DistinctBitXor {
     }
 }
 
-#[derive(Debug)]
-struct DistinctBitXorAccumulator {
-    hash_values: HashSet<ScalarValue, RandomState>,
-    data_type: DataType,
+struct DistinctBitXorAccumulator<T: ArrowNumericType> {
+    values: HashSet<T::Native, RandomState>,
 }
 
-impl DistinctBitXorAccumulator {
-    pub fn try_new(data_type: &DataType) -> Result<Self> {
-        Ok(Self {
-            hash_values: HashSet::default(),
-            data_type: data_type.clone(),
-        })
+impl<T: ArrowNumericType> std::fmt::Debug for DistinctBitXorAccumulator<T> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE)
     }
 }
 
-impl Accumulator for DistinctBitXorAccumulator {
+impl<T: ArrowNumericType> Default for DistinctBitXorAccumulator<T> {
+    fn default() -> Self {
+        Self {
+            values: HashSet::default(),
+        }
+    }
+}
+
+impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T>
+where
+    T::Native: std::ops::BitXor<Output = T::Native> + std::hash::Hash + Eq,
+{
     fn state(&self) -> Result<Vec<ScalarValue>> {
         // 1. Stores aggregate state in `ScalarValue::List`
         // 2. Constructs `ScalarValue::List` state from distinct numeric 
stored in hash set
         let state_out = {
-            let mut distinct_values = Vec::new();
-            self.hash_values
+            let values = self
+                .values
                 .iter()
-                .for_each(|distinct_value| 
distinct_values.push(distinct_value.clone()));
-            vec![ScalarValue::new_list(
-                Some(distinct_values),
-                self.data_type.clone(),
-            )]
+                .map(|x| ScalarValue::new_primitive::<T>(Some(*x), 
&T::DATA_TYPE))
+                .collect();
+
+            vec![ScalarValue::new_list(Some(values), T::DATA_TYPE)]
         };
         Ok(state_out)
     }
@@ -716,14 +653,18 @@ impl Accumulator for DistinctBitXorAccumulator {
             return Ok(());
         }
 
-        let arr = &values[0];
-        (0..values[0].len()).try_for_each(|index| {
-            if !arr.is_null(index) {
-                let v = ScalarValue::try_from_array(arr, index)?;
-                self.hash_values.insert(v);
+        let array = values[0].as_primitive::<T>();
+        match array.nulls().filter(|x| x.null_count() > 0) {
+            Some(n) => {
+                for idx in n.valid_indices() {
+                    self.values.insert(array.value(idx));
+                }
             }
-            Ok(())
-        })
+            None => array.values().iter().for_each(|x| {
+                self.values.insert(*x);
+            }),
+        }
+        Ok(())
     }
 
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
@@ -731,36 +672,24 @@ impl Accumulator for DistinctBitXorAccumulator {
             return Ok(());
         }
 
-        let arr = &states[0];
-        (0..arr.len()).try_for_each(|index| {
-            let scalar = ScalarValue::try_from_array(arr, index)?;
-
-            if let ScalarValue::List(Some(scalar), _) = scalar {
-                scalar.iter().for_each(|scalar| {
-                    if !ScalarValue::is_null(scalar) {
-                        self.hash_values.insert(scalar.clone());
-                    }
-                });
-            } else {
-                return internal_err!("Unexpected accumulator state");
-            }
-            Ok(())
-        })
+        for x in states[0].as_list::<i32>().iter().flatten() {
+            self.update_batch(&[x])?
+        }
+        Ok(())
     }
 
     fn evaluate(&self) -> Result<ScalarValue> {
-        let mut bit_xor_value = ScalarValue::try_from(&self.data_type)?;
-        for distinct_value in self.hash_values.iter() {
-            bit_xor_value = bit_xor_value.bitxor(distinct_value)?;
+        let mut acc = T::Native::usize_as(0);
+        for distinct_value in self.values.iter() {
+            acc = acc ^ *distinct_value;
         }
-        Ok(bit_xor_value)
+        let v = (!self.values.is_empty()).then_some(acc);
+        Ok(ScalarValue::new_primitive::<T>(v, &T::DATA_TYPE))
     }
 
     fn size(&self) -> usize {
-        std::mem::size_of_val(self) + 
ScalarValue::size_of_hashset(&self.hash_values)
-            - std::mem::size_of_val(&self.hash_values)
-            + self.data_type.size()
-            - std::mem::size_of_val(&self.data_type)
+        std::mem::size_of_val(self)
+            + self.values.capacity() * std::mem::size_of::<T::Native>()
     }
 }
 
@@ -770,6 +699,7 @@ mod tests {
     use crate::expressions::col;
     use crate::expressions::tests::aggregate;
     use crate::generic_test_op;
+    use arrow::array::*;
     use arrow::datatypes::*;
     use arrow::record_batch::RecordBatch;
     use datafusion_common::Result;

Reply via email to