alamb commented on code in PR #8849:
URL: https://github.com/apache/arrow-datafusion/pull/8849#discussion_r1460682982
##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -438,11 +446,292 @@ where
}
}
+#[derive(Debug)]
+struct StringDistinctCountAccumulator(Mutex<SSOStringHashSet>);
+impl StringDistinctCountAccumulator {
+ fn new() -> Self {
+ Self(Mutex::new(SSOStringHashSet::new()))
+ }
+}
+
+impl Accumulator for StringDistinctCountAccumulator {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ // TODO this should not need a lock/clone (should make
+ // `Accumulator::state` take a mutable reference)
+ let mut lk = self.0.lock().unwrap();
+ let set: &mut SSOStringHashSet = &mut lk;
+ // take the state out of the string set and replace with default
+ let set = std::mem::take(set);
+ let arr = set.into_state();
+ let list = Arc::new(array_into_list_array(Arc::new(arr)));
+ Ok(vec![ScalarValue::List(list)])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ if values.is_empty() {
+ return Ok(());
+ }
+
+ self.0.lock().unwrap().insert(values[0].clone());
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ }
+ assert_eq!(
+ states.len(),
+ 1,
+ "count_distinct states must be single array"
+ );
+
+ let arr = as_list_array(&states[0])?;
+ arr.iter().try_for_each(|maybe_list| {
+ if let Some(list) = maybe_list {
+ self.0.lock().unwrap().insert(list);
+ };
+ Ok(())
+ })
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::Int64(
+ Some(self.0.lock().unwrap().len() as i64),
+ ))
+ }
+
+ fn size(&self) -> usize {
+ // Size of accumulator
+ // + SSOStringHashSet size
+ std::mem::size_of_val(self) + self.0.lock().unwrap().size()
+ }
+}
+
+/// Maximum size of a string that can be inlined in the hash table
+const SHORT_STRING_LEN: usize = mem::size_of::<usize>();
+
+/// Entry that is stored in the actual hash table
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
+struct SSOStringHeader {
+ /// hash of the string value (stored to avoid recomputing it when checking)
+ /// TODO can we simply recreate when needed
+ hash: u64,
+ /// length of the string, in bytes
+ len: usize,
+ /// if len =< SHORT_STRING_LEN: the string data inlined
+ /// if len > SHORT_STRING_LEN, the offset
+ offset_or_inline: usize,
+}
+
+impl SSOStringHeader {}
+
+impl SSOStringHeader {
+ /// returns self.offset..self.offset + self.len
+ fn range(&self) -> Range<usize> {
+ self.offset_or_inline..self.offset_or_inline + self.len
+ }
+}
+
+// Short String Optimized HashSet for String
+// Equivalent to HashSet<String> but with better memory usage
+#[derive(Default)]
+struct SSOStringHashSet {
+ /// Store entries for each distinct string
+ map: hashbrown::raw::RawTable<SSOStringHeader>,
+ /// Total size of the map in bytes (TODO)
+ map_size: usize,
+ /// Buffer containing all long strings
+ buffer: BufferBuilder<u8>,
+ /// The random state used to generate hashes
+ random_state: RandomState,
+ // buffer to be reused to store hashes
+ hashes_buffer: Vec<u64>,
+}
+
+impl SSOStringHashSet {
+ fn new() -> Self {
+ Self::default()
+ }
+
+ fn insert(&mut self, values: ArrayRef) {
+ // step 1: compute hashes for the strings
+ let batch_hashes = &mut self.hashes_buffer;
+ batch_hashes.clear();
+ batch_hashes.resize(values.len(), 0);
+ create_hashes(&[values.clone()], &self.random_state, batch_hashes)
+ // hash is supported for all string types and create_hashes only
+ // returns errors for unsupported types
+ .unwrap();
+
+ // TODO make this generic (to support large strings)
+ let values = values.as_string::<i32>();
+
+ // step 2: insert each string into the set, if not already present
+
+ // Assert for unsafe values call
+ assert_eq!(values.len(), batch_hashes.len());
+
+ for (value, &hash) in values.iter().zip(batch_hashes.iter()) {
+ // count distinct ignores nulls
+ let Some(value) = value else {
+ continue;
+ };
+
+ // from here on only use bytes (not str/chars) for value
+ let value = value.as_bytes();
+
+ if value.len() <= SHORT_STRING_LEN {
+ let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x
as usize);
+
+ // Check if the value is already present in the set
+ let entry = self.map.get_mut(hash, |header| {
+ // if hash matches, must also compare the values
+ if header.len != value.len() {
+ return false;
+ }
+ inline == header.offset_or_inline
+ });
+
+ // Insert an entry for this value if it is not present
+ if entry.is_none() {
+ let new_header = SSOStringHeader {
+ hash,
+ len: value.len(),
+ offset_or_inline: inline,
+ };
+ self.map.insert_accounted(
+ new_header,
+ |header| header.hash,
+ &mut self.map_size,
+ );
+ }
+ }
+ // handle large strings
+ else {
+ // Check if the value is already present in the set
+ let entry = self.map.get_mut(hash, |header| {
+ // if hash matches, must also compare the values
+ if header.len != value.len() {
+ return false;
+ }
+ // SAFETY: buffer is only appended to, and we correctly
inserted values
+ let existing_value =
+ unsafe {
self.buffer.as_slice().get_unchecked(header.range()) };
+ value == existing_value
+ });
+
+ // Insert the value if it is not present
+ if entry.is_none() {
+ // long strings are stored as a length/offset into the
buffer
+ let offset = self.buffer.len();
+ self.buffer.append_slice(value);
+ let new_header = SSOStringHeader {
+ hash,
+ len: value.len(),
+ offset_or_inline: offset,
+ };
+ self.map.insert_accounted(
+ new_header,
+ |header| header.hash,
+ &mut self.map_size,
+ );
+ }
+ }
+ }
+ }
+
+ /// Converts this set into a StringArray of the distinct string values
+ fn into_state(self) -> StringArray {
+ // The map contains entries that have offsets in some arbitrary order
+ // but the buffer contains the actual strings in the order they were
inserted
+ // so we need to build offsets for the strings in the buffer in order
+ // then append short strings, if any, and then build the StringArray
+ // TODO a picture would be nice here
+ let Self {
+ map,
+ map_size: _,
+ mut buffer,
+ random_state: _,
+ hashes_buffer: _,
+ } = self;
+
+ // Sort all headers so that long strings come first, in offset order
+ // followed by short strings ordered by value
+ let mut headers = map.into_iter().collect::<Vec<_>>();
Review Comment:
fixed in
[a101b62](https://github.com/apache/arrow-datafusion/pull/8849/commits/a101b62557d0c0aaa0402256cc4ca249995d5b04)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]