This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new af0e8a95ca Optimize `COUNT( DISTINCT ...)` for strings (up to 9x
faster) (#8849)
af0e8a95ca is described below
commit af0e8a95ca60a231ee4e7665a14645db55a6b97a
Author: Jay Zhan <[email protected]>
AuthorDate: Mon Jan 29 20:28:20 2024 +0800
Optimize `COUNT( DISTINCT ...)` for strings (up to 9x faster) (#8849)
* chkp
Signed-off-by: jayzhan211 <[email protected]>
* chkp
Signed-off-by: jayzhan211 <[email protected]>
* draft
Signed-off-by: jayzhan211 <[email protected]>
* iter done
Signed-off-by: jayzhan211 <[email protected]>
* short string test
Signed-off-by: jayzhan211 <[email protected]>
* add test
Signed-off-by: jayzhan211 <[email protected]>
* remove unused
Signed-off-by: jayzhan211 <[email protected]>
* to_string directly
Signed-off-by: jayzhan211 <[email protected]>
* rewrite evaluate
Signed-off-by: jayzhan211 <[email protected]>
* return Vec<String>
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* add more queries
Signed-off-by: jayzhan211 <[email protected]>
* add group by query and rewrite evalute with state()
Signed-off-by: jayzhan211 <[email protected]>
* move evaluate back
Signed-off-by: jayzhan211 <[email protected]>
* upd test
Signed-off-by: jayzhan211 <[email protected]>
* add row sort
Signed-off-by: jayzhan211 <[email protected]>
* Update benchmarks/queries/clickbench/README.md
* Rework set to avoid copies
* Simplify offset construction
* fmt
* Improve comments
* Improve comments
* add fuzz test
Signed-off-by: jayzhan211 <[email protected]>
* Add support for LargeStringArray
* refine fuzz test
* Add tests for size accounting
* Split into new module
* Remove use of Mutex
* revert changes
* Use reference rather than owned ArrayRef
---------
Signed-off-by: jayzhan211 <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
benchmarks/queries/clickbench/README.md | 3 +-
datafusion-cli/Cargo.lock | 1 +
.../tests/fuzz_cases/distinct_count_string_fuzz.rs | 211 +++++++++
datafusion/core/tests/fuzz_cases/mod.rs | 1 +
datafusion/physical-expr/Cargo.toml | 1 +
.../{count_distinct.rs => count_distinct/mod.rs} | 46 +-
.../src/aggregate/count_distinct/strings.rs | 490 +++++++++++++++++++++
datafusion/sqllogictest/test_files/aggregate.slt | 57 +++
datafusion/sqllogictest/test_files/clickbench.slt | 3 +
9 files changed, 792 insertions(+), 21 deletions(-)
diff --git a/benchmarks/queries/clickbench/README.md
b/benchmarks/queries/clickbench/README.md
index e03b7d519d..ef540ccf9c 100644
--- a/benchmarks/queries/clickbench/README.md
+++ b/benchmarks/queries/clickbench/README.md
@@ -29,7 +29,6 @@ SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT
"MobilePhone"), COUNT(DIST
FROM hits;
```
-
### Q1: Data Exploration
**Question**: "How many distinct "hit color", "browser country" and "language"
are there in the dataset?"
@@ -42,7 +41,7 @@ SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT
"BrowserCountry"), COUNT(DISTI
FROM hits;
```
-### Q2: Top 10 anaylsis
+### Q2: Top 10 analysis
**Question**: "Find the top 10 "browser country" by number of distinct "social
network"s,
including the distinct counts of "hit color", "browser language",
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index a718f7591a..6b881e3105 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -1255,6 +1255,7 @@ dependencies = [
"blake3",
"chrono",
"datafusion-common",
+ "datafusion-execution",
"datafusion-expr",
"half",
"hashbrown 0.14.3",
diff --git a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs
b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs
new file mode 100644
index 0000000000..343a175647
--- /dev/null
+++ b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs
@@ -0,0 +1,211 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Compare DistinctCount for string with naive HashSet and Short String
Optimized HashSet
+
+use std::sync::Arc;
+
+use arrow::array::ArrayRef;
+use arrow::record_batch::RecordBatch;
+use arrow_array::{Array, GenericStringArray, OffsetSizeTrait, UInt32Array};
+
+use arrow_array::cast::AsArray;
+use datafusion::datasource::MemTable;
+use rand::rngs::StdRng;
+use rand::{thread_rng, Rng, SeedableRng};
+use std::collections::HashSet;
+use tokio::task::JoinSet;
+
+use datafusion::prelude::{SessionConfig, SessionContext};
+use test_utils::stagger_batch;
+
+#[tokio::test(flavor = "multi_thread")]
+async fn distinct_count_string_test() {
+ // max length of generated strings
+ let mut join_set = JoinSet::new();
+ let mut rng = thread_rng();
+ for null_pct in [0.0, 0.01, 0.1, 0.5] {
+ for _ in 0..100 {
+ let max_len = rng.gen_range(1..50);
+ let num_strings = rng.gen_range(1..100);
+ let num_distinct_strings = if num_strings > 1 {
+ rng.gen_range(1..num_strings)
+ } else {
+ num_strings
+ };
+ let generator = BatchGenerator {
+ max_len,
+ num_strings,
+ num_distinct_strings,
+ null_pct,
+ rng: StdRng::from_seed(rng.gen()),
+ };
+ join_set.spawn(async move {
run_distinct_count_test(generator).await });
+ }
+ }
+ while let Some(join_handle) = join_set.join_next().await {
+ // propagate errors
+ join_handle.unwrap();
+ }
+}
+
+/// Run COUNT DISTINCT using SQL and compare the result to computing the
+/// distinct count using HashSet<String>
+async fn run_distinct_count_test(mut generator: BatchGenerator) {
+ let input = generator.make_input_batches();
+
+ let schema = input[0].schema();
+ let session_config = SessionConfig::new().with_batch_size(50);
+ let ctx = SessionContext::new_with_config(session_config);
+
+ // split input into two partitions
+ let partition_len = input.len() / 2;
+ let partitions = vec![
+ input[0..partition_len].to_vec(),
+ input[partition_len..].to_vec(),
+ ];
+
+ let provider = MemTable::try_new(schema, partitions).unwrap();
+ ctx.register_table("t", Arc::new(provider)).unwrap();
+ // input has two columns, a and b. The result is the number of distinct
+ // values in each column.
+ //
+ // Note, we need at least two count distinct aggregates to trigger the
+ // count distinct aggregate. Otherwise, the optimizer will rewrite the
+ // `COUNT(DISTINCT a)` to `COUNT(*) from (SELECT DISTINCT a FROM t)`
+ let results = ctx
+ .sql("SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM t")
+ .await
+ .unwrap()
+ .collect()
+ .await
+ .unwrap();
+
+ // get all the strings from the first column of the result (distinct a)
+ let expected_a = extract_distinct_strings::<i32>(&input, 0).len();
+ let result_a = extract_i64(&results, 0);
+ assert_eq!(expected_a, result_a);
+
+ // get all the strings from the second column of the result (distinct b(
+ let expected_b = extract_distinct_strings::<i64>(&input, 1).len();
+ let result_b = extract_i64(&results, 1);
+ assert_eq!(expected_b, result_b);
+}
+
+/// Return all (non null) distinct strings from column col_idx
+fn extract_distinct_strings<O: OffsetSizeTrait>(
+ results: &[RecordBatch],
+ col_idx: usize,
+) -> Vec<String> {
+ results
+ .iter()
+ .flat_map(|batch| {
+ let array = batch.column(col_idx).as_string::<O>();
+ // remove nulls via 'flatten'
+ array.iter().flatten().map(|s| s.to_string())
+ })
+ .collect::<HashSet<_>>()
+ .into_iter()
+ .collect()
+}
+
+// extract the value from the Int64 column in col_idx in batch and return
+// it as a usize
+fn extract_i64(results: &[RecordBatch], col_idx: usize) -> usize {
+ assert_eq!(results.len(), 1);
+ let array = results[0]
+ .column(col_idx)
+ .as_any()
+ .downcast_ref::<arrow::array::Int64Array>()
+ .unwrap();
+ assert_eq!(array.len(), 1);
+ assert!(!array.is_null(0));
+ array.value(0).try_into().unwrap()
+}
+
+struct BatchGenerator {
+ //// The maximum length of the strings
+ max_len: usize,
+ /// the total number of strings in the output
+ num_strings: usize,
+ /// The number of distinct strings in the columns
+ num_distinct_strings: usize,
+ /// The percentage of nulls in the columns
+ null_pct: f64,
+ /// Random number generator
+ rng: StdRng,
+}
+
+impl BatchGenerator {
+ /// Make batches of random strings with a random length columns "a" and
"b":
+ ///
+ /// * "a" is a StringArray
+ /// * "b" is a LargeStringArray
+ fn make_input_batches(&mut self) -> Vec<RecordBatch> {
+ // use a random number generator to pick a random sized output
+
+ let batch = RecordBatch::try_from_iter(vec![
+ ("a", self.gen_data::<i32>()),
+ ("b", self.gen_data::<i64>()),
+ ])
+ .unwrap();
+
+ stagger_batch(batch)
+ }
+
+ /// Creates a StringArray or LargeStringArray with random strings according
+ /// to the parameters of the BatchGenerator
+ fn gen_data<O: OffsetSizeTrait>(&mut self) -> ArrayRef {
+ // table of strings from which to draw
+ let distinct_strings: GenericStringArray<O> =
(0..self.num_distinct_strings)
+ .map(|_| Some(random_string(&mut self.rng, self.max_len)))
+ .collect();
+
+ // pick num_strings randomly from the distinct string table
+ let indicies: UInt32Array = (0..self.num_strings)
+ .map(|_| {
+ if self.rng.gen::<f64>() < self.null_pct {
+ None
+ } else if self.num_distinct_strings > 1 {
+ let range = 1..(self.num_distinct_strings as u32);
+ Some(self.rng.gen_range(range))
+ } else {
+ Some(0)
+ }
+ })
+ .collect();
+
+ let options = None;
+ arrow::compute::take(&distinct_strings, &indicies, options).unwrap()
+ }
+}
+
+/// Return a string of random characters of length 1..=max_len
+fn random_string(rng: &mut StdRng, max_len: usize) -> String {
+ // pick characters at random (not just ascii)
+ match max_len {
+ 0 => "".to_string(),
+ 1 => String::from(rng.gen::<char>()),
+ _ => {
+ let len = rng.gen_range(1..=max_len);
+ rng.sample_iter::<char, _>(rand::distributions::Standard)
+ .take(len)
+ .map(char::from)
+ .collect::<String>()
+ }
+ }
+}
diff --git a/datafusion/core/tests/fuzz_cases/mod.rs
b/datafusion/core/tests/fuzz_cases/mod.rs
index 83ec928ae2..69241571b4 100644
--- a/datafusion/core/tests/fuzz_cases/mod.rs
+++ b/datafusion/core/tests/fuzz_cases/mod.rs
@@ -16,6 +16,7 @@
// under the License.
mod aggregate_fuzz;
+mod distinct_count_string_fuzz;
mod join_fuzz;
mod merge_fuzz;
mod sort_fuzz;
diff --git a/datafusion/physical-expr/Cargo.toml
b/datafusion/physical-expr/Cargo.toml
index d237c68657..61eba042f9 100644
--- a/datafusion/physical-expr/Cargo.toml
+++ b/datafusion/physical-expr/Cargo.toml
@@ -54,6 +54,7 @@ blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
chrono = { workspace = true }
datafusion-common = { workspace = true }
+datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
half = { version = "2.1", default-features = false }
hashbrown = { version = "0.14", features = ["raw"] }
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs
b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
similarity index 98%
rename from datafusion/physical-expr/src/aggregate/count_distinct.rs
rename to datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
index ef1a248d5f..891ef85880 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
@@ -15,34 +15,37 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::datatypes::{DataType, Field, TimeUnit};
-use arrow_array::types::{
- ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
- Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
Int8Type,
- Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
Time64NanosecondType,
- TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType,
- TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
-};
-use arrow_array::PrimitiveArray;
+mod strings;
use std::any::Any;
use std::cmp::Eq;
+use std::collections::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;
use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
-use std::collections::HashSet;
+use arrow::datatypes::{DataType, Field, TimeUnit};
+use arrow_array::types::{
+ ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
+ Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
Int8Type,
+ Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
Time64NanosecondType,
+ TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType,
+ TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+};
+use arrow_array::PrimitiveArray;
-use crate::aggregate::utils::{down_cast_any_ref, Hashable};
-use crate::expressions::format_state_name;
-use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;
+use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator;
+use crate::aggregate::utils::{down_cast_any_ref, Hashable};
+use crate::expressions::format_state_name;
+use crate::{AggregateExpr, PhysicalExpr};
+
type DistinctScalarValues = ScalarValue;
/// Expression for a COUNT(DISTINCT) aggregation.
@@ -61,10 +64,10 @@ impl DistinctCount {
pub fn new(
input_data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
- name: String,
+ name: impl Into<String>,
) -> Self {
Self {
- name,
+ name: name.into(),
state_data_type: input_data_type,
expr,
}
@@ -152,6 +155,9 @@ impl AggregateExpr for DistinctCount {
Float32 => float_distinct_count_accumulator!(Float32Type),
Float64 => float_distinct_count_accumulator!(Float64Type),
+ Utf8 => Ok(Box::new(StringDistinctCountAccumulator::<i32>::new())),
+ LargeUtf8 =>
Ok(Box::new(StringDistinctCountAccumulator::<i64>::new())),
+
_ => Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
@@ -244,7 +250,7 @@ impl Accumulator for DistinctCountAccumulator {
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
for scalars in scalar_vec.into_iter() {
- self.values.extend(scalars)
+ self.values.extend(scalars);
}
Ok(())
}
@@ -440,9 +446,6 @@ where
#[cfg(test)]
mod tests {
- use crate::expressions::NoOp;
-
- use super::*;
use arrow::array::{
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array,
Int32Array,
Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array,
UInt8Array,
@@ -454,10 +457,15 @@ mod tests {
};
use arrow_array::Decimal256Array;
use arrow_buffer::i256;
+
use datafusion_common::cast::{as_boolean_array, as_list_array,
as_primitive_array};
use datafusion_common::internal_err;
use datafusion_common::DataFusionError;
+ use crate::expressions::NoOp;
+
+ use super::*;
+
macro_rules! state_to_vec_primitive {
($LIST:expr, $DATA_TYPE:ident) => {{
let arr = ScalarValue::raw_data($LIST).unwrap();
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs
b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs
new file mode 100644
index 0000000000..d7a9ea5c37
--- /dev/null
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs
@@ -0,0 +1,490 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Specialized implementation of `COUNT DISTINCT` for `StringArray` and
`LargeStringArray`
+
+use ahash::RandomState;
+use arrow_array::cast::AsArray;
+use arrow_array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait};
+use arrow_buffer::{BufferBuilder, OffsetBuffer, ScalarBuffer};
+use datafusion_common::cast::as_list_array;
+use datafusion_common::hash_utils::create_hashes;
+use datafusion_common::utils::array_into_list_array;
+use datafusion_common::ScalarValue;
+use datafusion_execution::memory_pool::proxy::RawTableAllocExt;
+use datafusion_expr::Accumulator;
+use std::fmt::Debug;
+use std::mem;
+use std::ops::Range;
+use std::sync::Arc;
+
+#[derive(Debug)]
+pub(super) struct StringDistinctCountAccumulator<O:
OffsetSizeTrait>(SSOStringHashSet<O>);
+impl<O: OffsetSizeTrait> StringDistinctCountAccumulator<O> {
+ pub(super) fn new() -> Self {
+ Self(SSOStringHashSet::<O>::new())
+ }
+}
+
+impl<O: OffsetSizeTrait> Accumulator for StringDistinctCountAccumulator<O> {
+ fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
+ // take the state out of the string set and replace with default
+ let set = std::mem::take(&mut self.0);
+ let arr = set.into_state();
+ let list = Arc::new(array_into_list_array(arr));
+ Ok(vec![ScalarValue::List(list)])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) ->
datafusion_common::Result<()> {
+ if values.is_empty() {
+ return Ok(());
+ }
+
+ self.0.insert(&values[0]);
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) ->
datafusion_common::Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ }
+ assert_eq!(
+ states.len(),
+ 1,
+ "count_distinct states must be single array"
+ );
+
+ let arr = as_list_array(&states[0])?;
+ arr.iter().try_for_each(|maybe_list| {
+ if let Some(list) = maybe_list {
+ self.0.insert(&list);
+ };
+ Ok(())
+ })
+ }
+
+ fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
+ Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
+ }
+
+ fn size(&self) -> usize {
+ // Size of accumulator
+ // + SSOStringHashSet size
+ std::mem::size_of_val(self) + self.0.size()
+ }
+}
+
+/// Maximum size of a string that can be inlined in the hash table
+const SHORT_STRING_LEN: usize = mem::size_of::<usize>();
+
+/// Entry that is stored in a `SSOStringHashSet` that represents a string
+/// that is either stored inline or in the buffer
+///
+/// This helps the case where there are many short (less than 8 bytes) strings
+/// that are the same (e.g. "MA", "CA", "NY", "TX", etc)
+///
+/// ```text
+///
┌──────────────────┐
+/// │...
│
+///
│TheQuickBrownFox │
+/// ─ ─ ─ ─ ─ ─ ─▶│...
│
+/// │ │
│
+///
└──────────────────┘
+/// │ buffer of
u8
+///
+/// │
+/// ┌────────────────┬───────────────┬───────────────┐
+/// Storing │ │ starting byte │ length, in │
+/// "TheQuickBrownFox" │ hash value │ offset in │ bytes (not │
+/// (long string) │ │ buffer │ characters) │
+/// └────────────────┴───────────────┴───────────────┘
+/// 8 bytes 8 bytes 4 or 8
+///
+///
+/// ┌───────────────┬─┬─┬─┬─┬─┬─┬─┬─┬───────────────┐
+/// Storing "foobar" │ │ │ │ │ │ │ │ │ │ length, in │
+/// (short string) │ hash value │?│?│f│o│o│b│a│r│ bytes (not │
+/// │ │ │ │ │ │ │ │ │ │ characters) │
+/// └───────────────┴─┴─┴─┴─┴─┴─┴─┴─┴───────────────┘
+/// 8 bytes 8 bytes 4 or 8
+/// ```
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
+struct SSOStringHeader {
+ /// hash of the string value (stored to avoid recomputing it in hash table
+ /// check)
+ hash: u64,
+ /// if len =< SHORT_STRING_LEN: the string data inlined
+ /// if len > SHORT_STRING_LEN, the offset of where the data starts
+ offset_or_inline: usize,
+ /// length of the string, in bytes
+ len: usize,
+}
+
+impl SSOStringHeader {
+ /// returns self.offset..self.offset + self.len
+ fn range(&self) -> Range<usize> {
+ self.offset_or_inline..self.offset_or_inline + self.len
+ }
+}
+
+/// HashSet optimized for storing `String` and `LargeString` values
+/// and producing the final set as a GenericStringArray with minimal copies.
+///
+/// Equivalent to `HashSet<String>` but with better performance for arrow data.
+struct SSOStringHashSet<O> {
+ /// Underlying hash set for each distinct string
+ map: hashbrown::raw::RawTable<SSOStringHeader>,
+ /// Total size of the map in bytes
+ map_size: usize,
+ /// In progress arrow `Buffer` containing all string values
+ buffer: BufferBuilder<u8>,
+ /// Offsets into `buffer` for each distinct string value. These offsets
+ /// as used directly to create the final `GenericStringArray`
+ offsets: Vec<O>,
+ /// random state used to generate hashes
+ random_state: RandomState,
+ /// buffer that stores hash values (reused across batches to save
allocations)
+ hashes_buffer: Vec<u64>,
+}
+
+impl<O: OffsetSizeTrait> Default for SSOStringHashSet<O> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<O: OffsetSizeTrait> SSOStringHashSet<O> {
+ fn new() -> Self {
+ Self {
+ map: hashbrown::raw::RawTable::new(),
+ map_size: 0,
+ buffer: BufferBuilder::new(0),
+ offsets: vec![O::default()], // first offset is always 0
+ random_state: RandomState::new(),
+ hashes_buffer: vec![],
+ }
+ }
+
+ fn insert(&mut self, values: &ArrayRef) {
+ // step 1: compute hashes for the strings
+ let batch_hashes = &mut self.hashes_buffer;
+ batch_hashes.clear();
+ batch_hashes.resize(values.len(), 0);
+ create_hashes(&[values.clone()], &self.random_state, batch_hashes)
+ // hash is supported for all string types and create_hashes only
+ // returns errors for unsupported types
+ .unwrap();
+
+ // step 2: insert each string into the set, if not already present
+ let values = values.as_string::<O>();
+
+ // Ensure lengths are equivalent (to guard unsafe values calls below)
+ assert_eq!(values.len(), batch_hashes.len());
+
+ for (value, &hash) in values.iter().zip(batch_hashes.iter()) {
+ // count distinct ignores nulls
+ let Some(value) = value else {
+ continue;
+ };
+
+ // from here on only use bytes (not str/chars) for value
+ let value = value.as_bytes();
+
+ // value is a "small" string
+ if value.len() <= SHORT_STRING_LEN {
+ let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x
as usize);
+
+ // is value is already present in the set?
+ let entry = self.map.get_mut(hash, |header| {
+ // compare value if hashes match
+ if header.len != value.len() {
+ return false;
+ }
+ // value is stored inline so no need to consult buffer
+ // (this is the "small string optimization")
+ inline == header.offset_or_inline
+ });
+
+ // if no existing entry, make a new one
+ if entry.is_none() {
+ // Put the small values into buffer and offsets so it
appears
+ // the output array, but store the actual bytes inline for
+ // comparison
+ self.buffer.append_slice(value);
+
self.offsets.push(O::from_usize(self.buffer.len()).unwrap());
+ let new_header = SSOStringHeader {
+ hash,
+ len: value.len(),
+ offset_or_inline: inline,
+ };
+ self.map.insert_accounted(
+ new_header,
+ |header| header.hash,
+ &mut self.map_size,
+ );
+ }
+ }
+ // value is not a "small" string
+ else {
+ // Check if the value is already present in the set
+ let entry = self.map.get_mut(hash, |header| {
+ // compare value if hashes match
+ if header.len != value.len() {
+ return false;
+ }
+ // Need to compare the bytes in the buffer
+ // SAFETY: buffer is only appended to, and we correctly
inserted values and offsets
+ let existing_value =
+ unsafe {
self.buffer.as_slice().get_unchecked(header.range()) };
+ value == existing_value
+ });
+
+ // if no existing entry, make a new one
+ if entry.is_none() {
+ // Put the small values into buffer and offsets so it
+ // appears the output array, and store that offset
+ // so the bytes can be compared if needed
+ let offset = self.buffer.len(); // offset of start fof data
+ self.buffer.append_slice(value);
+
self.offsets.push(O::from_usize(self.buffer.len()).unwrap());
+
+ let new_header = SSOStringHeader {
+ hash,
+ len: value.len(),
+ offset_or_inline: offset,
+ };
+ self.map.insert_accounted(
+ new_header,
+ |header| header.hash,
+ &mut self.map_size,
+ );
+ }
+ }
+ }
+ }
+
+ /// Converts this set into a `StringArray` or `LargeStringArray` with each
+ /// distinct string value without any copies
+ fn into_state(self) -> ArrayRef {
+ let Self {
+ map: _,
+ map_size: _,
+ offsets,
+ mut buffer,
+ random_state: _,
+ hashes_buffer: _,
+ } = self;
+
+ let offsets: ScalarBuffer<O> = offsets.into();
+ let values = buffer.finish();
+ let nulls = None; // count distinct ignores nulls so intermediate
state never has nulls
+
+ // SAFETY: all the values that went in were valid utf8 so are all the
values that come out
+ let array = unsafe {
+ GenericStringArray::new_unchecked(OffsetBuffer::new(offsets),
values, nulls)
+ };
+ Arc::new(array)
+ }
+
+ fn len(&self) -> usize {
+ self.map.len()
+ }
+
+ /// Return the total size, in bytes, of memory used to store the data in
+ /// this set, not including `self`
+ fn size(&self) -> usize {
+ self.map_size
+ + self.buffer.capacity() * std::mem::size_of::<u8>()
+ + self.offsets.capacity() * std::mem::size_of::<O>()
+ + self.hashes_buffer.capacity() * std::mem::size_of::<u64>()
+ }
+}
+
+impl<O: OffsetSizeTrait> Debug for SSOStringHashSet<O> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("SSOStringHashSet")
+ .field("map", &"<map>")
+ .field("map_size", &self.map_size)
+ .field("buffer", &self.buffer)
+ .field("random_state", &self.random_state)
+ .field("hashes_buffer", &self.hashes_buffer)
+ .finish()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow::array::ArrayRef;
+ use arrow_array::StringArray;
+ #[test]
+ fn string_set_empty() {
+ for values in [StringArray::new_null(0), StringArray::new_null(11)] {
+ let mut set = SSOStringHashSet::<i32>::new();
+ let array: ArrayRef = Arc::new(values);
+ set.insert(&array);
+ assert_set(set, &[]);
+ }
+ }
+
+ #[test]
+ fn string_set_basic_i32() {
+ test_string_set_basic::<i32>();
+ }
+ #[test]
+ fn string_set_basic_i64() {
+ test_string_set_basic::<i64>();
+ }
+ fn test_string_set_basic<O: OffsetSizeTrait>() {
+ // basic test for mixed small and large string values
+ let values = GenericStringArray::<O>::from(vec![
+ Some("a"),
+ Some("b"),
+ Some("CXCCCCCCCC"), // 10 bytes
+ Some(""),
+ Some("cbcxx"), // 5 bytes
+ None,
+ Some("AAAAAAAA"), // 8 bytes
+ Some("BBBBBQBBB"), // 9 bytes
+ Some("a"),
+ Some("cbcxx"),
+ Some("b"),
+ Some("cbcxx"),
+ Some(""),
+ None,
+ Some("BBBBBQBBB"),
+ Some("BBBBBQBBB"),
+ Some("AAAAAAAA"),
+ Some("CXCCCCCCCC"),
+ ]);
+
+ let mut set = SSOStringHashSet::<O>::new();
+ let array: ArrayRef = Arc::new(values);
+ set.insert(&array);
+ assert_set(
+ set,
+ &[
+ Some(""),
+ Some("AAAAAAAA"),
+ Some("BBBBBQBBB"),
+ Some("CXCCCCCCCC"),
+ Some("a"),
+ Some("b"),
+ Some("cbcxx"),
+ ],
+ );
+ }
+
+ #[test]
+ fn string_set_non_utf8_32() {
+ test_string_set_non_utf8::<i32>();
+ }
+ #[test]
+ fn string_set_non_utf8_64() {
+ test_string_set_non_utf8::<i64>();
+ }
+ fn test_string_set_non_utf8<O: OffsetSizeTrait>() {
+ // basic test for mixed small and large string values
+ let values = GenericStringArray::<O>::from(vec![
+ Some("a"),
+ Some("✨🔥"),
+ Some("🔥"),
+ Some("✨✨✨"),
+ Some("foobarbaz"),
+ Some("🔥"),
+ Some("✨🔥"),
+ ]);
+
+ let mut set = SSOStringHashSet::<O>::new();
+ let array: ArrayRef = Arc::new(values);
+ set.insert(&array);
+ assert_set(
+ set,
+ &[
+ Some("a"),
+ Some("foobarbaz"),
+ Some("✨✨✨"),
+ Some("✨🔥"),
+ Some("🔥"),
+ ],
+ );
+ }
+
+ // asserts that the set contains the expected strings
+ fn assert_set<O: OffsetSizeTrait>(
+ set: SSOStringHashSet<O>,
+ expected: &[Option<&str>],
+ ) {
+ let strings = set.into_state();
+ let strings = strings.as_string::<O>();
+ let mut state = strings.into_iter().collect::<Vec<_>>();
+ state.sort();
+ assert_eq!(state, expected);
+ }
+
+ // inserting strings into the set does not increase reported memoyr
+ #[test]
+ fn test_string_set_memory_usage() {
+ let strings1 = GenericStringArray::<i32>::from(vec![
+ Some("a"),
+ Some("b"),
+ Some("CXCCCCCCCC"), // 10 bytes
+ Some("AAAAAAAA"), // 8 bytes
+ Some("BBBBBQBBB"), // 9 bytes
+ ]);
+ let total_strings1_len = strings1
+ .iter()
+ .map(|s| s.map(|s| s.len()).unwrap_or(0))
+ .sum::<usize>();
+ let values1: ArrayRef =
Arc::new(GenericStringArray::<i32>::from(strings1));
+
+ // Much larger strings in strings2
+ let strings2 = GenericStringArray::<i32>::from(vec![
+ "FOO".repeat(1000),
+ "BAR".repeat(2000),
+ "BAZ".repeat(3000),
+ ]);
+ let total_strings2_len = strings2
+ .iter()
+ .map(|s| s.map(|s| s.len()).unwrap_or(0))
+ .sum::<usize>();
+ let values2: ArrayRef =
Arc::new(GenericStringArray::<i32>::from(strings2));
+
+ let mut set = SSOStringHashSet::<i32>::new();
+ let size_empty = set.size();
+
+ set.insert(&values1);
+ let size_after_values1 = set.size();
+ assert!(size_empty < size_after_values1);
+ assert!(
+ size_after_values1 > total_strings1_len,
+ "expect {size_after_values1} to be more than {total_strings1_len}"
+ );
+ assert!(size_after_values1 < total_strings1_len + total_strings2_len);
+
+ // inserting the same strings should not affect the size
+ set.insert(&values1);
+ assert_eq!(set.size(), size_after_values1);
+
+ // inserting the large strings should increase the reported size
+ set.insert(&values2);
+ let size_after_values2 = set.size();
+ assert!(size_after_values2 > size_after_values1);
+ assert!(size_after_values2 > total_strings1_len + total_strings2_len);
+ }
+}
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index 5cd728c434..136fb39c67 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -3069,6 +3069,62 @@ select count(*) from (select count(*) a, count(*) b from
(select 1));
----
1
+# Distinct Count for string
+# (test for the specialized implementation of distinct count for strings)
+
+# UTF8 string matters for string to &[u8] conversion, add it to prevent
regression
+statement ok
+create table distinct_count_string_table as values
+ (1, 'a', 'longstringtest_a', '台灣'),
+ (2, 'b', 'longstringtest_b1', '日本'),
+ (2, 'b', 'longstringtest_b2', '中國'),
+ (3, 'c', 'longstringtest_c1', '美國'),
+ (3, 'c', 'longstringtest_c2', '歐洲'),
+ (3, 'c', 'longstringtest_c3', '韓國')
+;
+
+# run through update_batch
+query IIII
+select count(distinct column1), count(distinct column2), count(distinct
column3), count(distinct column4) from distinct_count_string_table;
+----
+3 3 6 6
+
+# run through merge_batch
+query IIII rowsort
+select count(distinct column1), count(distinct column2), count(distinct
column3), count(distinct column4) from distinct_count_string_table group by
column1;
+----
+1 1 1 1
+1 1 2 2
+1 1 3 3
+
+
+# test with long strings as well
+statement ok
+create table distinct_count_long_string_table as
+SELECT column1,
+ arrow_cast(column2, 'LargeUtf8') as column2,
+ arrow_cast(column3, 'LargeUtf8') as column3,
+ arrow_cast(column4, 'LargeUtf8') as column4
+FROM distinct_count_string_table;
+
+# run through update_batch
+query IIII
+select count(distinct column1), count(distinct column2), count(distinct
column3), count(distinct column4) from distinct_count_long_string_table;
+----
+3 3 6 6
+
+# run through merge_batch
+query IIII rowsort
+select count(distinct column1), count(distinct column2), count(distinct
column3), count(distinct column4) from distinct_count_long_string_table group
by column1;
+----
+1 1 1 1
+1 1 2 2
+1 1 3 3
+
+statement ok
+drop table distinct_count_string_table;
+
+
# rule `aggregate_statistics` should not optimize MIN/MAX to wrong values on
empty relation
statement ok
@@ -3122,3 +3178,4 @@ NULL
statement ok
DROP TABLE t;
+
diff --git a/datafusion/sqllogictest/test_files/clickbench.slt
b/datafusion/sqllogictest/test_files/clickbench.slt
index 21befd7822..b61bee6708 100644
--- a/datafusion/sqllogictest/test_files/clickbench.slt
+++ b/datafusion/sqllogictest/test_files/clickbench.slt
@@ -273,3 +273,6 @@ SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*)
AS PageViews FROM hit
query PI
SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*)
AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >=
'2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND
"DontCountHits" = 0 GROUP BY DATE_TRUNC('minute',
to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10
OFFSET 1000;
----
+
+query
+drop table hits;
\ No newline at end of file