jaylmiller commented on code in PR #5554:
URL: https://github.com/apache/arrow-datafusion/pull/5554#discussion_r1133280240


##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -192,17 +230,150 @@ impl Accumulator for DistinctCountAccumulator {
         }
     }
 }
+/// Special case accumulator for counting distinct values in a dict
+struct CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    /// `K` is required when casting to dict array
+    _dt: core::marker::PhantomData<K>,
+    /// laziliy initialized state that holds a boolean for each index.
+    /// the bool at each index indicates whether the value for that index has 
been seen yet.
+    state: Option<Vec<bool>>,
+}
+
+impl<K> std::fmt::Debug for CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("CountDistinctDictAccumulator")
+            .field("state", &self.state)
+            .finish()
+    }
+}
+impl<K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync>
+    CountDistinctDictAccumulator<K>
+{
+    fn new() -> Self {
+        Self {
+            _dt: core::marker::PhantomData,
+            state: None,
+        }
+    }
+}
+impl<K> Accumulator for CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        if let Some(state) = &self.state {
+            let bools = state
+                .iter()
+                .map(|b| ScalarValue::Boolean(Some(*b)))
+                .collect();
+            Ok(vec![ScalarValue::List(
+                Some(bools),
+                Box::new(Field::new("item", DataType::Boolean, false)),
+            )])
+        } else {
+            // empty state
+            Ok(vec![])
+        }
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        if values.is_empty() {
+            return Ok(());
+        }
+        let arr = as_dictionary_array::<K>(&values[0])?;
+        let nvalues = arr.values().len();
+        if let Some(state) = &self.state {
+            if state.len() != nvalues {
+                return Err(DataFusionError::Internal(
+                    "Accumulator update_batch got invalid value".to_string(),
+                ));
+            }
+        } else {
+            // init state
+            self.state = Some((0..nvalues).map(|_| false).collect());
+        }
+        for idx in arr.keys_iter().flatten() {
+            self.state.as_mut().unwrap()[idx] = true;
+        }

Review Comment:
   I was under the impression that this is the case with how the datafusion 
accumulators work. I'm fairly new to learning datafusion internals though so I 
might be wrong about that! 
   
   I'll look into this more that is a good point and a large assumption on my 
end 😅



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to