jayzhan211 commented on code in PR #8849:
URL: https://github.com/apache/arrow-datafusion/pull/8849#discussion_r1455424711
##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -438,6 +443,206 @@ where
}
}
+#[derive(Debug)]
+struct StringDistinctCountAccumulator(SSOStringHashSet);
+impl StringDistinctCountAccumulator {
+ fn new() -> Self {
+ Self(SSOStringHashSet::new())
+ }
+}
+
+impl Accumulator for StringDistinctCountAccumulator {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ let arr = self.0.state();
+ let list = Arc::new(array_into_list_array(Arc::new(arr)));
+ Ok(vec![ScalarValue::List(list)])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ if values.is_empty() {
+ return Ok(());
+ }
+
+ let arr = as_string_array(&values[0])?;
+ arr.iter().for_each(|value| {
+ if let Some(value) = value {
+ self.0.insert(value);
+ }
+ });
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ }
+ assert_eq!(
+ states.len(),
+ 1,
+ "count_distinct states must be single array"
+ );
+
+ let arr = as_list_array(&states[0])?;
+ arr.iter().try_for_each(|maybe_list| {
+ if let Some(list) = maybe_list {
+ let list = as_string_array(&list)?;
+
+ list.iter().for_each(|value| {
+ if let Some(value) = value {
+ self.0.insert(value);
+ }
+ })
+ };
+ Ok(())
+ })
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
+ }
+
+ fn size(&self) -> usize {
+ // Size of accumulator
+ // + SSOStringHashSet size
+ std::mem::size_of_val(self) + self.0.size()
+ }
+}
+
+const SHORT_STRING_LEN: usize = mem::size_of::<usize>();
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
+struct SSOStringHeader {
+ /// hash of the string value (used when resizing table)
+ hash: u64,
+ /// length of the string
+ len: usize,
+ /// short strings are stored inline, long strings are stored in the buffer
+ offset_or_inline: usize,
+}
+
+// Short String Optimizated HashSet for String
+// Equivalent to HashSet<String> but with better memory usage
+#[derive(Default)]
+struct SSOStringHashSet {
+ /// Core of the HashSet, it stores both the short and long string headers
+ header_set: HashSet<SSOStringHeader>,
+ /// Used to check if the long string already exists
+ long_string_map: hashbrown::raw::RawTable<SSOStringHeader>,
+ /// Total size of the map in bytes
+ map_size: usize,
+ /// Buffer containing all long strings
+ buffer: BufferBuilder<u8>,
+ /// The random state used to generate hashes
+ state: RandomState,
+ /// Used for capacity calculation, equivalent to the sum of all string
lengths
+ size_hint: usize,
+}
+
+impl SSOStringHashSet {
+ fn new() -> Self {
+ Self::default()
+ }
+
+ fn insert(&mut self, value: &str) {
+ let value_len = value.len();
+ self.size_hint += value_len;
+ let value_bytes = value.as_bytes();
+
+ if value_len <= SHORT_STRING_LEN {
+ let inline = value_bytes
+ .iter()
+ .fold(0usize, |acc, &x| acc << 8 | x as usize);
+ let short_string_header = SSOStringHeader {
+ hash: 0, // no need for short string cases
+ len: value_len,
+ offset_or_inline: inline,
+ };
+ self.header_set.insert(short_string_header);
+ } else {
+ let hash = self.state.hash_one(value_bytes);
+
+ let entry = self.long_string_map.get_mut(hash, |header| {
+ // if hash matches, check if the bytes match
+ let offset = header.offset_or_inline;
+ let len = header.len;
+
+ // SAFETY: buffer is only appended to, and we correctly
inserted values
+ let existing_value =
+ unsafe {
self.buffer.as_slice().get_unchecked(offset..offset + len) };
+
+ value_bytes == existing_value
+ });
+
+ if entry.is_none() {
+ let offset = self.buffer.len();
+ self.buffer.append_slice(value_bytes);
+ let header = SSOStringHeader {
+ hash,
+ len: value_len,
+ offset_or_inline: offset,
+ };
+ self.long_string_map.insert_accounted(
+ header,
+ |header| header.hash,
+ &mut self.map_size,
+ );
+ self.header_set.insert(header);
+ }
+ }
+ }
+
+ // Returns a StringArray with the current state of the set
+ fn state(&self) -> StringArray {
+ let mut offsets = Vec::with_capacity(self.size_hint + 1);
+ offsets.push(0);
+
+ let mut values = MutableBuffer::new(0);
+ let buffer = self.buffer.as_slice();
+
+ for header in self.header_set.iter() {
+ let s = if header.len <= SHORT_STRING_LEN {
Review Comment:
I removed `evaluate` and calculated here directly to avoid the conversion
from &[u8] to string.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]