jayzhan211 commented on code in PR #8849:
URL: https://github.com/apache/arrow-datafusion/pull/8849#discussion_r1451368176


##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -438,6 +443,207 @@ where
     }
 }
 
+#[derive(Debug)]
+struct StringDistinctCountAccumulator(SSOStringHashSet);
+impl StringDistinctCountAccumulator {
+    fn new() -> Self {
+        Self(SSOStringHashSet::new())
+    }
+}
+
+impl Accumulator for StringDistinctCountAccumulator {
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        let arr = StringArray::from_iter_values(self.0.iter());
+        let list = Arc::new(array_into_list_array(Arc::new(arr)));
+        Ok(vec![ScalarValue::List(list)])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        if values.is_empty() {
+            return Ok(());
+        }
+
+        let arr = as_string_array(&values[0])?;
+        arr.iter().for_each(|value| {
+            if let Some(value) = value {
+                self.0.insert(value);
+            }
+        });
+
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        if states.is_empty() {
+            return Ok(());
+        }
+        assert_eq!(
+            states.len(),
+            1,
+            "count_distinct states must be single array"
+        );
+
+        let arr = as_list_array(&states[0])?;
+        arr.iter().try_for_each(|maybe_list| {
+            if let Some(list) = maybe_list {
+                let list = as_string_array(&list)?;
+
+                list.iter().for_each(|value| {
+                    if let Some(value) = value {
+                        self.0.insert(value);
+                    }
+                })
+            };
+            Ok(())
+        })
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
+    }
+
+    fn size(&self) -> usize {
+        // Size of accumulator
+        // + SSOStringHashSet size
+        std::mem::size_of_val(self) + self.0.size()
+    }
+}
+
+const SHORT_STRING_LEN: usize = mem::size_of::<usize>();
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]

Review Comment:
   Allow Copy since they are all native types



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to