jayzhan211 commented on code in PR #8849:
URL: https://github.com/apache/arrow-datafusion/pull/8849#discussion_r1451371442


##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -438,6 +443,207 @@ where
     }
 }
 
+#[derive(Debug)]
+struct StringDistinctCountAccumulator(SSOStringHashSet);
+impl StringDistinctCountAccumulator {
+    fn new() -> Self {
+        Self(SSOStringHashSet::new())
+    }
+}
+
+impl Accumulator for StringDistinctCountAccumulator {
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        let arr = StringArray::from_iter_values(self.0.iter());
+        let list = Arc::new(array_into_list_array(Arc::new(arr)));
+        Ok(vec![ScalarValue::List(list)])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        if values.is_empty() {
+            return Ok(());
+        }
+
+        let arr = as_string_array(&values[0])?;
+        arr.iter().for_each(|value| {
+            if let Some(value) = value {
+                self.0.insert(value);
+            }
+        });
+
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        if states.is_empty() {
+            return Ok(());
+        }
+        assert_eq!(
+            states.len(),
+            1,
+            "count_distinct states must be single array"
+        );
+
+        let arr = as_list_array(&states[0])?;
+        arr.iter().try_for_each(|maybe_list| {
+            if let Some(list) = maybe_list {
+                let list = as_string_array(&list)?;
+
+                list.iter().for_each(|value| {
+                    if let Some(value) = value {
+                        self.0.insert(value);
+                    }
+                })
+            };
+            Ok(())
+        })
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
+    }
+
+    fn size(&self) -> usize {
+        // Size of accumulator
+        // + SSOStringHashSet size
+        std::mem::size_of_val(self) + self.0.size()
+    }
+}
+
+const SHORT_STRING_LEN: usize = mem::size_of::<usize>();
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
+struct SSOStringHeader {
+    /// hash of the string value (used when resizing table)
+    hash: u64,
+
+    len: usize,
+    offset_or_inline: usize,
+}
+
+impl SSOStringHeader {
+    fn evaluate(&self) -> (bool, usize) {
+        if self.len <= SHORT_STRING_LEN {
+            (true, self.offset_or_inline)
+        } else {
+            (false, self.offset_or_inline)
+        }
+    }
+}
+
+// Short String Optimizated HashSet for String
+// Equivalent to HashSet<String> but with better memory usage
+#[derive(Default)]
+struct SSOStringHashSet {
+    header_set: HashSet<SSOStringHeader>,
+    long_string_map: hashbrown::raw::RawTable<SSOStringHeader>,
+    map_size: usize,
+    buffer: BufferBuilder<u8>,
+    state: RandomState,
+}
+
+impl SSOStringHashSet {
+    fn new() -> Self {
+        Self::default()
+    }
+
+    fn insert(&mut self, value: &str) {
+        let value_len = value.len();
+        let value_bytes = value.as_bytes();
+
+        if value_len <= SHORT_STRING_LEN {
+            let inline = value_bytes
+                .iter()
+                .fold(0usize, |acc, &x| acc << 8 | x as usize);
+            let short_string_header = SSOStringHeader {
+                // no need for short string cases
+                hash: 0,
+                len: value_len,
+                offset_or_inline: inline,
+            };
+            self.header_set.insert(short_string_header);
+        } else {
+            let hash = self.state.hash_one(value_bytes);
+
+            let entry = self.long_string_map.get_mut(hash, |header| {
+                // if hash matches, check if the bytes match
+                let offset = header.offset_or_inline;
+                let len = header.len;
+
+                // SAFETY: buffer is only appended to, and we correctly 
inserted values
+                let existing_value =
+                    unsafe { 
self.buffer.as_slice().get_unchecked(offset..offset + len) };
+
+                value_bytes == existing_value
+            });
+
+            if entry.is_none() {
+                let offset = self.buffer.len();
+                self.buffer.append_slice(value_bytes);
+                let header = SSOStringHeader {
+                    hash,
+                    len: value_len,
+                    offset_or_inline: offset,
+                };
+                self.long_string_map.insert_accounted(
+                    header,
+                    |header| header.hash,
+                    &mut self.map_size,
+                );
+                self.header_set.insert(header);
+            }
+        }
+    }
+
+    fn iter(&self) -> impl Iterator<Item = String> + '_ {

Review Comment:
   Did not find out way to return `&str` to avoid `to_string`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to