jaylmiller commented on code in PR #5554:
URL: https://github.com/apache/arrow-datafusion/pull/5554#discussion_r1135773543


##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -158,38 +219,98 @@ impl Accumulator for DistinctCountAccumulator {
         })
     }
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        if states.is_empty() {
-            return Ok(());
+        merge_values(&mut self.values, states)
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
+    }
+
+    fn size(&self) -> usize {
+        let values_size = match &self.state_data_type {
+            DataType::Boolean | DataType::Null => 
values_fixed_size(&self.values),
+            d if d.is_primitive() => values_fixed_size(&self.values),
+            _ => values_full_size(&self.values),
+        };
+        std::mem::size_of_val(self) + values_size + 
std::mem::size_of::<DataType>()
+    }
+}
+/// Special case accumulator for counting distinct values in a dict
+struct CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    /// `K` is required when casting to dict array
+    _dt: core::marker::PhantomData<K>,
+    values_datatype: DataType,
+    values: ValueSet,
+}
+
+impl<K> std::fmt::Debug for CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("CountDistinctDictAccumulator")
+            .field("values", &self.values)
+            .field("values_datatype", &self.values_datatype)
+            .finish()
+    }
+}
+impl<K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync>
+    CountDistinctDictAccumulator<K>
+{
+    fn new(values_datatype: DataType) -> Self {
+        Self {
+            _dt: core::marker::PhantomData,
+            values: Default::default(),
+            values_datatype,
         }
-        let arr = &states[0];
-        (0..arr.len()).try_for_each(|index| {
-            let scalar = ScalarValue::try_from_array(arr, index)?;
+    }
+}
+impl<K> Accumulator for CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        values_to_state(&self.values, &self.values_datatype)
+    }
 
-            if let ScalarValue::List(Some(scalar), _) = scalar {
-                scalar.iter().for_each(|scalar| {
-                    if !ScalarValue::is_null(scalar) {
-                        self.values.insert(scalar.clone());
-                    }
-                });
-            } else {
-                return Err(DataFusionError::Internal(
-                    "Unexpected accumulator state".into(),
-                ));
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {

Review Comment:
   I sortof assumed that `update_batch` is called frequently enough that the 
extra code was worth but someone more experienced with datafusion would know 
better than me about this



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to