Dandandan commented on code in PR #8849:
URL: https://github.com/apache/arrow-datafusion/pull/8849#discussion_r1459253196
##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -438,6 +443,212 @@ where
}
}
+#[derive(Debug)]
+struct StringDistinctCountAccumulator(SSOStringHashSet);
+impl StringDistinctCountAccumulator {
+ fn new() -> Self {
+ Self(SSOStringHashSet::new())
+ }
+}
+
+impl Accumulator for StringDistinctCountAccumulator {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ let arr = self.0.state();
+ let list = Arc::new(array_into_list_array(Arc::new(arr)));
+ Ok(vec![ScalarValue::List(list)])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ if values.is_empty() {
+ return Ok(());
+ }
+
+ let arr = as_string_array(&values[0])?;
+ arr.iter().for_each(|value| {
+ if let Some(value) = value {
+ self.0.insert(value);
+ }
+ });
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ }
+ assert_eq!(
+ states.len(),
+ 1,
+ "count_distinct states must be single array"
+ );
+
+ let arr = as_list_array(&states[0])?;
+ arr.iter().try_for_each(|maybe_list| {
+ if let Some(list) = maybe_list {
+ let list = as_string_array(&list)?;
+
+ list.iter().for_each(|value| {
+ if let Some(value) = value {
+ self.0.insert(value);
+ }
+ })
+ };
+ Ok(())
+ })
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
+ }
+
+ fn size(&self) -> usize {
+ // Size of accumulator
+ // + SSOStringHashSet size
+ std::mem::size_of_val(self) + self.0.size()
+ }
+}
+
+const SHORT_STRING_LEN: usize = mem::size_of::<usize>();
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
+struct SSOStringHeader {
+ /// hash of the string value (used when resizing table)
+ hash: u64,
+ /// length of the string
+ len: usize,
+ /// short strings are stored inline, long strings are stored in the buffer
+ offset_or_inline: usize,
+}
+
+impl SSOStringHeader {
+ fn evaluate(&self, buffer: &[u8]) -> String {
+ if self.len <= SHORT_STRING_LEN {
+ self.offset_or_inline.to_string()
+ } else {
+ let offset = self.offset_or_inline;
+ // SAFETY: buffer is only appended to, and we correctly inserted
values
+ unsafe {
+ std::str::from_utf8_unchecked(
+ buffer.get_unchecked(offset..offset + self.len),
+ )
+ }
+ .to_string()
+ }
+ }
+}
+
+// Short String Optimizated HashSet for String
+// Equivalent to HashSet<String> but with better memory usage
+#[derive(Default)]
+struct SSOStringHashSet {
+ /// Core of the HashSet, it stores both the short and long string headers
+ header_set: HashSet<SSOStringHeader>,
+ /// Used to check if the long string already exists
+ long_string_map: hashbrown::raw::RawTable<SSOStringHeader>,
+ /// Total size of the map in bytes
+ map_size: usize,
+ /// Buffer containing all long strings
+ buffer: BufferBuilder<u8>,
+ /// The random state used to generate hashes
+ state: RandomState,
+ /// Used for capacity calculation, equivalent to the sum of all string
lengths
+ size_hint: usize,
+}
+
+impl SSOStringHashSet {
+ fn new() -> Self {
+ Self::default()
+ }
+
+ fn insert(&mut self, value: &str) {
+ let value_len = value.len();
+ self.size_hint += value_len;
+ let value_bytes = value.as_bytes();
+
+ if value_len <= SHORT_STRING_LEN {
Review Comment:
Is this significantly faster than hashing the bytes and using one `RawTable`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]