This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new af0e8a95ca Optimize `COUNT( DISTINCT ...)` for strings (up to 9x 
faster) (#8849)
af0e8a95ca is described below

commit af0e8a95ca60a231ee4e7665a14645db55a6b97a
Author: Jay Zhan <[email protected]>
AuthorDate: Mon Jan 29 20:28:20 2024 +0800

    Optimize `COUNT( DISTINCT ...)` for strings (up to 9x faster) (#8849)
    
    * chkp
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * chkp
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * draft
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * iter done
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * short string test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * remove unused
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * to_string directly
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rewrite evaluate
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * return Vec<String>
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add more queries
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add group by query and rewrite evalute with state()
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * move evaluate back
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * upd test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add row sort
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * Update benchmarks/queries/clickbench/README.md
    
    * Rework set to avoid copies
    
    * Simplify offset construction
    
    * fmt
    
    * Improve comments
    
    * Improve comments
    
    * add fuzz test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * Add support for LargeStringArray
    
    * refine fuzz test
    
    * Add tests for size accounting
    
    * Split into new module
    
    * Remove use of Mutex
    
    * revert changes
    
    * Use reference rather than owned ArrayRef
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
    Co-authored-by: Andrew Lamb <[email protected]>
---
 benchmarks/queries/clickbench/README.md            |   3 +-
 datafusion-cli/Cargo.lock                          |   1 +
 .../tests/fuzz_cases/distinct_count_string_fuzz.rs | 211 +++++++++
 datafusion/core/tests/fuzz_cases/mod.rs            |   1 +
 datafusion/physical-expr/Cargo.toml                |   1 +
 .../{count_distinct.rs => count_distinct/mod.rs}   |  46 +-
 .../src/aggregate/count_distinct/strings.rs        | 490 +++++++++++++++++++++
 datafusion/sqllogictest/test_files/aggregate.slt   |  57 +++
 datafusion/sqllogictest/test_files/clickbench.slt  |   3 +
 9 files changed, 792 insertions(+), 21 deletions(-)

diff --git a/benchmarks/queries/clickbench/README.md 
b/benchmarks/queries/clickbench/README.md
index e03b7d519d..ef540ccf9c 100644
--- a/benchmarks/queries/clickbench/README.md
+++ b/benchmarks/queries/clickbench/README.md
@@ -29,7 +29,6 @@ SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT 
"MobilePhone"), COUNT(DIST
 FROM hits;
 ```
 
-
 ### Q1: Data Exploration
 
 **Question**: "How many distinct "hit color", "browser country" and "language" 
are there in the dataset?"
@@ -42,7 +41,7 @@ SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT 
"BrowserCountry"), COUNT(DISTI
 FROM hits;
 ```
 
-### Q2: Top 10 anaylsis
+### Q2: Top 10 analysis
 
 **Question**: "Find the top 10 "browser country" by number of distinct "social 
network"s, 
 including the distinct counts of  "hit color", "browser language",
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index a718f7591a..6b881e3105 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -1255,6 +1255,7 @@ dependencies = [
  "blake3",
  "chrono",
  "datafusion-common",
+ "datafusion-execution",
  "datafusion-expr",
  "half",
  "hashbrown 0.14.3",
diff --git a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs
new file mode 100644
index 0000000000..343a175647
--- /dev/null
+++ b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs
@@ -0,0 +1,211 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Compare DistinctCount for string with naive HashSet and Short String 
Optimized HashSet
+
+use std::sync::Arc;
+
+use arrow::array::ArrayRef;
+use arrow::record_batch::RecordBatch;
+use arrow_array::{Array, GenericStringArray, OffsetSizeTrait, UInt32Array};
+
+use arrow_array::cast::AsArray;
+use datafusion::datasource::MemTable;
+use rand::rngs::StdRng;
+use rand::{thread_rng, Rng, SeedableRng};
+use std::collections::HashSet;
+use tokio::task::JoinSet;
+
+use datafusion::prelude::{SessionConfig, SessionContext};
+use test_utils::stagger_batch;
+
+#[tokio::test(flavor = "multi_thread")]
+async fn distinct_count_string_test() {
+    // max length of generated strings
+    let mut join_set = JoinSet::new();
+    let mut rng = thread_rng();
+    for null_pct in [0.0, 0.01, 0.1, 0.5] {
+        for _ in 0..100 {
+            let max_len = rng.gen_range(1..50);
+            let num_strings = rng.gen_range(1..100);
+            let num_distinct_strings = if num_strings > 1 {
+                rng.gen_range(1..num_strings)
+            } else {
+                num_strings
+            };
+            let generator = BatchGenerator {
+                max_len,
+                num_strings,
+                num_distinct_strings,
+                null_pct,
+                rng: StdRng::from_seed(rng.gen()),
+            };
+            join_set.spawn(async move { 
run_distinct_count_test(generator).await });
+        }
+    }
+    while let Some(join_handle) = join_set.join_next().await {
+        // propagate errors
+        join_handle.unwrap();
+    }
+}
+
+/// Run COUNT DISTINCT using SQL and compare the result to computing the
+/// distinct count using HashSet<String>
+async fn run_distinct_count_test(mut generator: BatchGenerator) {
+    let input = generator.make_input_batches();
+
+    let schema = input[0].schema();
+    let session_config = SessionConfig::new().with_batch_size(50);
+    let ctx = SessionContext::new_with_config(session_config);
+
+    // split input into two partitions
+    let partition_len = input.len() / 2;
+    let partitions = vec![
+        input[0..partition_len].to_vec(),
+        input[partition_len..].to_vec(),
+    ];
+
+    let provider = MemTable::try_new(schema, partitions).unwrap();
+    ctx.register_table("t", Arc::new(provider)).unwrap();
+    // input has two columns, a and b.  The result is the number of distinct
+    // values in each column.
+    //
+    // Note, we need  at least two count distinct aggregates to trigger the
+    // count distinct aggregate. Otherwise, the optimizer will rewrite the
+    // `COUNT(DISTINCT a)` to `COUNT(*) from (SELECT DISTINCT a FROM t)`
+    let results = ctx
+        .sql("SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM t")
+        .await
+        .unwrap()
+        .collect()
+        .await
+        .unwrap();
+
+    // get all the strings from the first column of the result (distinct a)
+    let expected_a = extract_distinct_strings::<i32>(&input, 0).len();
+    let result_a = extract_i64(&results, 0);
+    assert_eq!(expected_a, result_a);
+
+    // get all the strings from the second column of the result (distinct b(
+    let expected_b = extract_distinct_strings::<i64>(&input, 1).len();
+    let result_b = extract_i64(&results, 1);
+    assert_eq!(expected_b, result_b);
+}
+
+/// Return all (non null) distinct strings from  column col_idx
+fn extract_distinct_strings<O: OffsetSizeTrait>(
+    results: &[RecordBatch],
+    col_idx: usize,
+) -> Vec<String> {
+    results
+        .iter()
+        .flat_map(|batch| {
+            let array = batch.column(col_idx).as_string::<O>();
+            // remove nulls via 'flatten'
+            array.iter().flatten().map(|s| s.to_string())
+        })
+        .collect::<HashSet<_>>()
+        .into_iter()
+        .collect()
+}
+
+// extract the value from the Int64 column in col_idx in batch and return
+// it as a usize
+fn extract_i64(results: &[RecordBatch], col_idx: usize) -> usize {
+    assert_eq!(results.len(), 1);
+    let array = results[0]
+        .column(col_idx)
+        .as_any()
+        .downcast_ref::<arrow::array::Int64Array>()
+        .unwrap();
+    assert_eq!(array.len(), 1);
+    assert!(!array.is_null(0));
+    array.value(0).try_into().unwrap()
+}
+
+struct BatchGenerator {
+    //// The maximum length of the strings
+    max_len: usize,
+    /// the total number of strings in the output
+    num_strings: usize,
+    /// The number of distinct strings in the columns
+    num_distinct_strings: usize,
+    /// The percentage of nulls in the columns
+    null_pct: f64,
+    /// Random number generator
+    rng: StdRng,
+}
+
+impl BatchGenerator {
+    /// Make batches of random strings with a random length columns "a" and 
"b":
+    ///
+    /// * "a" is a StringArray
+    /// * "b" is a LargeStringArray
+    fn make_input_batches(&mut self) -> Vec<RecordBatch> {
+        // use a random number generator to pick a random sized output
+
+        let batch = RecordBatch::try_from_iter(vec![
+            ("a", self.gen_data::<i32>()),
+            ("b", self.gen_data::<i64>()),
+        ])
+        .unwrap();
+
+        stagger_batch(batch)
+    }
+
+    /// Creates a StringArray or LargeStringArray with random strings according
+    /// to the parameters of the BatchGenerator
+    fn gen_data<O: OffsetSizeTrait>(&mut self) -> ArrayRef {
+        // table of strings from which to draw
+        let distinct_strings: GenericStringArray<O> = 
(0..self.num_distinct_strings)
+            .map(|_| Some(random_string(&mut self.rng, self.max_len)))
+            .collect();
+
+        // pick num_strings randomly from the distinct string table
+        let indicies: UInt32Array = (0..self.num_strings)
+            .map(|_| {
+                if self.rng.gen::<f64>() < self.null_pct {
+                    None
+                } else if self.num_distinct_strings > 1 {
+                    let range = 1..(self.num_distinct_strings as u32);
+                    Some(self.rng.gen_range(range))
+                } else {
+                    Some(0)
+                }
+            })
+            .collect();
+
+        let options = None;
+        arrow::compute::take(&distinct_strings, &indicies, options).unwrap()
+    }
+}
+
+/// Return a string of random characters of length 1..=max_len
+fn random_string(rng: &mut StdRng, max_len: usize) -> String {
+    // pick characters at random (not just ascii)
+    match max_len {
+        0 => "".to_string(),
+        1 => String::from(rng.gen::<char>()),
+        _ => {
+            let len = rng.gen_range(1..=max_len);
+            rng.sample_iter::<char, _>(rand::distributions::Standard)
+                .take(len)
+                .map(char::from)
+                .collect::<String>()
+        }
+    }
+}
diff --git a/datafusion/core/tests/fuzz_cases/mod.rs 
b/datafusion/core/tests/fuzz_cases/mod.rs
index 83ec928ae2..69241571b4 100644
--- a/datafusion/core/tests/fuzz_cases/mod.rs
+++ b/datafusion/core/tests/fuzz_cases/mod.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 mod aggregate_fuzz;
+mod distinct_count_string_fuzz;
 mod join_fuzz;
 mod merge_fuzz;
 mod sort_fuzz;
diff --git a/datafusion/physical-expr/Cargo.toml 
b/datafusion/physical-expr/Cargo.toml
index d237c68657..61eba042f9 100644
--- a/datafusion/physical-expr/Cargo.toml
+++ b/datafusion/physical-expr/Cargo.toml
@@ -54,6 +54,7 @@ blake2 = { version = "^0.10.2", optional = true }
 blake3 = { version = "1.0", optional = true }
 chrono = { workspace = true }
 datafusion-common = { workspace = true }
+datafusion-execution = { workspace = true }
 datafusion-expr = { workspace = true }
 half = { version = "2.1", default-features = false }
 hashbrown = { version = "0.14", features = ["raw"] }
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs 
b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
similarity index 98%
rename from datafusion/physical-expr/src/aggregate/count_distinct.rs
rename to datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
index ef1a248d5f..891ef85880 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
@@ -15,34 +15,37 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::datatypes::{DataType, Field, TimeUnit};
-use arrow_array::types::{
-    ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
-    Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, 
Int8Type,
-    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, 
Time64NanosecondType,
-    TimestampMicrosecondType, TimestampMillisecondType, 
TimestampNanosecondType,
-    TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
-};
-use arrow_array::PrimitiveArray;
+mod strings;
 
 use std::any::Any;
 use std::cmp::Eq;
+use std::collections::HashSet;
 use std::fmt::Debug;
 use std::hash::Hash;
 use std::sync::Arc;
 
 use ahash::RandomState;
 use arrow::array::{Array, ArrayRef};
-use std::collections::HashSet;
+use arrow::datatypes::{DataType, Field, TimeUnit};
+use arrow_array::types::{
+    ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
+    Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, 
Int8Type,
+    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, 
Time64NanosecondType,
+    TimestampMicrosecondType, TimestampMillisecondType, 
TimestampNanosecondType,
+    TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+};
+use arrow_array::PrimitiveArray;
 
-use crate::aggregate::utils::{down_cast_any_ref, Hashable};
-use crate::expressions::format_state_name;
-use crate::{AggregateExpr, PhysicalExpr};
 use datafusion_common::cast::{as_list_array, as_primitive_array};
 use datafusion_common::utils::array_into_list_array;
 use datafusion_common::{Result, ScalarValue};
 use datafusion_expr::Accumulator;
 
+use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator;
+use crate::aggregate::utils::{down_cast_any_ref, Hashable};
+use crate::expressions::format_state_name;
+use crate::{AggregateExpr, PhysicalExpr};
+
 type DistinctScalarValues = ScalarValue;
 
 /// Expression for a COUNT(DISTINCT) aggregation.
@@ -61,10 +64,10 @@ impl DistinctCount {
     pub fn new(
         input_data_type: DataType,
         expr: Arc<dyn PhysicalExpr>,
-        name: String,
+        name: impl Into<String>,
     ) -> Self {
         Self {
-            name,
+            name: name.into(),
             state_data_type: input_data_type,
             expr,
         }
@@ -152,6 +155,9 @@ impl AggregateExpr for DistinctCount {
             Float32 => float_distinct_count_accumulator!(Float32Type),
             Float64 => float_distinct_count_accumulator!(Float64Type),
 
+            Utf8 => Ok(Box::new(StringDistinctCountAccumulator::<i32>::new())),
+            LargeUtf8 => 
Ok(Box::new(StringDistinctCountAccumulator::<i64>::new())),
+
             _ => Ok(Box::new(DistinctCountAccumulator {
                 values: HashSet::default(),
                 state_data_type: self.state_data_type.clone(),
@@ -244,7 +250,7 @@ impl Accumulator for DistinctCountAccumulator {
         assert_eq!(states.len(), 1, "array_agg states must be singleton!");
         let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
         for scalars in scalar_vec.into_iter() {
-            self.values.extend(scalars)
+            self.values.extend(scalars);
         }
         Ok(())
     }
@@ -440,9 +446,6 @@ where
 
 #[cfg(test)]
 mod tests {
-    use crate::expressions::NoOp;
-
-    use super::*;
     use arrow::array::{
         ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, 
Int32Array,
         Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, 
UInt8Array,
@@ -454,10 +457,15 @@ mod tests {
     };
     use arrow_array::Decimal256Array;
     use arrow_buffer::i256;
+
     use datafusion_common::cast::{as_boolean_array, as_list_array, 
as_primitive_array};
     use datafusion_common::internal_err;
     use datafusion_common::DataFusionError;
 
+    use crate::expressions::NoOp;
+
+    use super::*;
+
     macro_rules! state_to_vec_primitive {
         ($LIST:expr, $DATA_TYPE:ident) => {{
             let arr = ScalarValue::raw_data($LIST).unwrap();
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs 
b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs
new file mode 100644
index 0000000000..d7a9ea5c37
--- /dev/null
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs
@@ -0,0 +1,490 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Specialized implementation of `COUNT DISTINCT` for `StringArray` and 
`LargeStringArray`
+
+use ahash::RandomState;
+use arrow_array::cast::AsArray;
+use arrow_array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait};
+use arrow_buffer::{BufferBuilder, OffsetBuffer, ScalarBuffer};
+use datafusion_common::cast::as_list_array;
+use datafusion_common::hash_utils::create_hashes;
+use datafusion_common::utils::array_into_list_array;
+use datafusion_common::ScalarValue;
+use datafusion_execution::memory_pool::proxy::RawTableAllocExt;
+use datafusion_expr::Accumulator;
+use std::fmt::Debug;
+use std::mem;
+use std::ops::Range;
+use std::sync::Arc;
+
+#[derive(Debug)]
+pub(super) struct StringDistinctCountAccumulator<O: 
OffsetSizeTrait>(SSOStringHashSet<O>);
+impl<O: OffsetSizeTrait> StringDistinctCountAccumulator<O> {
+    pub(super) fn new() -> Self {
+        Self(SSOStringHashSet::<O>::new())
+    }
+}
+
+impl<O: OffsetSizeTrait> Accumulator for StringDistinctCountAccumulator<O> {
+    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
+        // take the state out of the string set and replace with default
+        let set = std::mem::take(&mut self.0);
+        let arr = set.into_state();
+        let list = Arc::new(array_into_list_array(arr));
+        Ok(vec![ScalarValue::List(list)])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> 
datafusion_common::Result<()> {
+        if values.is_empty() {
+            return Ok(());
+        }
+
+        self.0.insert(&values[0]);
+
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> 
datafusion_common::Result<()> {
+        if states.is_empty() {
+            return Ok(());
+        }
+        assert_eq!(
+            states.len(),
+            1,
+            "count_distinct states must be single array"
+        );
+
+        let arr = as_list_array(&states[0])?;
+        arr.iter().try_for_each(|maybe_list| {
+            if let Some(list) = maybe_list {
+                self.0.insert(&list);
+            };
+            Ok(())
+        })
+    }
+
+    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
+        Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
+    }
+
+    fn size(&self) -> usize {
+        // Size of accumulator
+        // + SSOStringHashSet size
+        std::mem::size_of_val(self) + self.0.size()
+    }
+}
+
+/// Maximum size of a string that can be inlined in the hash table
+const SHORT_STRING_LEN: usize = mem::size_of::<usize>();
+
+/// Entry that is stored in a `SSOStringHashSet` that represents a string
+/// that is either stored inline or in the buffer
+///
+/// This helps the case where there are many short (less than 8 bytes) strings
+/// that are the same (e.g. "MA", "CA", "NY", "TX", etc)
+///
+/// ```text
+///                                                                
┌──────────────────┐
+///                                                                │...        
       │
+///                                                                
│TheQuickBrownFox  │
+///                                                  ─ ─ ─ ─ ─ ─ ─▶│...        
       │
+///                                                 │              │           
       │
+///                                                                
└──────────────────┘
+///                                                 │               buffer of 
u8
+///
+///                                                 │
+///                        ┌────────────────┬───────────────┬───────────────┐
+///  Storing               │                │ starting byte │  length, in   │
+///  "TheQuickBrownFox"    │   hash value   │   offset in   │  bytes (not   │
+///  (long string)         │                │    buffer     │  characters)  │
+///                        └────────────────┴───────────────┴───────────────┘
+///                              8 bytes          8 bytes       4 or 8
+///
+///
+///                         ┌───────────────┬─┬─┬─┬─┬─┬─┬─┬─┬───────────────┐
+/// Storing "foobar"        │               │ │ │ │ │ │ │ │ │  length, in   │
+/// (short string)          │  hash value   │?│?│f│o│o│b│a│r│  bytes (not   │
+///                         │               │ │ │ │ │ │ │ │ │  characters)  │
+///                         └───────────────┴─┴─┴─┴─┴─┴─┴─┴─┴───────────────┘
+///                              8 bytes         8 bytes        4 or 8
+/// ```
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
+struct SSOStringHeader {
+    /// hash of the string value (stored to avoid recomputing it in hash table
+    /// check)
+    hash: u64,
+    /// if len =< SHORT_STRING_LEN: the string data inlined
+    /// if len > SHORT_STRING_LEN, the offset of where the data starts
+    offset_or_inline: usize,
+    /// length of the string, in bytes
+    len: usize,
+}
+
+impl SSOStringHeader {
+    /// returns self.offset..self.offset + self.len
+    fn range(&self) -> Range<usize> {
+        self.offset_or_inline..self.offset_or_inline + self.len
+    }
+}
+
+/// HashSet optimized for storing `String` and `LargeString` values
+/// and producing the final set as a GenericStringArray with minimal copies.
+///
+/// Equivalent to `HashSet<String>` but with better performance for arrow data.
+struct SSOStringHashSet<O> {
+    /// Underlying hash set for each distinct string
+    map: hashbrown::raw::RawTable<SSOStringHeader>,
+    /// Total size of the map in bytes
+    map_size: usize,
+    /// In progress arrow `Buffer` containing all string values
+    buffer: BufferBuilder<u8>,
+    /// Offsets into `buffer` for each distinct string value. These offsets
+    /// as used directly to create the final `GenericStringArray`
+    offsets: Vec<O>,
+    /// random state used to generate hashes
+    random_state: RandomState,
+    /// buffer that stores hash values (reused across batches to save 
allocations)
+    hashes_buffer: Vec<u64>,
+}
+
+impl<O: OffsetSizeTrait> Default for SSOStringHashSet<O> {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl<O: OffsetSizeTrait> SSOStringHashSet<O> {
+    fn new() -> Self {
+        Self {
+            map: hashbrown::raw::RawTable::new(),
+            map_size: 0,
+            buffer: BufferBuilder::new(0),
+            offsets: vec![O::default()], // first offset is always 0
+            random_state: RandomState::new(),
+            hashes_buffer: vec![],
+        }
+    }
+
+    fn insert(&mut self, values: &ArrayRef) {
+        // step 1: compute hashes for the strings
+        let batch_hashes = &mut self.hashes_buffer;
+        batch_hashes.clear();
+        batch_hashes.resize(values.len(), 0);
+        create_hashes(&[values.clone()], &self.random_state, batch_hashes)
+            // hash is supported for all string types and create_hashes only
+            // returns errors for unsupported types
+            .unwrap();
+
+        // step 2: insert each string into the set, if not already present
+        let values = values.as_string::<O>();
+
+        // Ensure lengths are equivalent (to guard unsafe values calls below)
+        assert_eq!(values.len(), batch_hashes.len());
+
+        for (value, &hash) in values.iter().zip(batch_hashes.iter()) {
+            // count distinct ignores nulls
+            let Some(value) = value else {
+                continue;
+            };
+
+            // from here on only use bytes (not str/chars) for value
+            let value = value.as_bytes();
+
+            // value is a "small" string
+            if value.len() <= SHORT_STRING_LEN {
+                let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x 
as usize);
+
+                // is value is already present in the set?
+                let entry = self.map.get_mut(hash, |header| {
+                    // compare value if hashes match
+                    if header.len != value.len() {
+                        return false;
+                    }
+                    // value is stored inline so no need to consult buffer
+                    // (this is the "small string optimization")
+                    inline == header.offset_or_inline
+                });
+
+                // if no existing entry, make a new one
+                if entry.is_none() {
+                    // Put the small values into buffer and offsets so it 
appears
+                    // the output array, but store the actual bytes inline for
+                    // comparison
+                    self.buffer.append_slice(value);
+                    
self.offsets.push(O::from_usize(self.buffer.len()).unwrap());
+                    let new_header = SSOStringHeader {
+                        hash,
+                        len: value.len(),
+                        offset_or_inline: inline,
+                    };
+                    self.map.insert_accounted(
+                        new_header,
+                        |header| header.hash,
+                        &mut self.map_size,
+                    );
+                }
+            }
+            // value is not a "small" string
+            else {
+                // Check if the value is already present in the set
+                let entry = self.map.get_mut(hash, |header| {
+                    // compare value if hashes match
+                    if header.len != value.len() {
+                        return false;
+                    }
+                    // Need to compare the bytes in the buffer
+                    // SAFETY: buffer is only appended to, and we correctly 
inserted values and offsets
+                    let existing_value =
+                        unsafe { 
self.buffer.as_slice().get_unchecked(header.range()) };
+                    value == existing_value
+                });
+
+                // if no existing entry, make a new one
+                if entry.is_none() {
+                    // Put the small values into buffer and offsets so it
+                    // appears the output array, and store that offset
+                    // so the bytes can be compared if needed
+                    let offset = self.buffer.len(); // offset of start fof data
+                    self.buffer.append_slice(value);
+                    
self.offsets.push(O::from_usize(self.buffer.len()).unwrap());
+
+                    let new_header = SSOStringHeader {
+                        hash,
+                        len: value.len(),
+                        offset_or_inline: offset,
+                    };
+                    self.map.insert_accounted(
+                        new_header,
+                        |header| header.hash,
+                        &mut self.map_size,
+                    );
+                }
+            }
+        }
+    }
+
+    /// Converts this set into a `StringArray` or `LargeStringArray` with each
+    /// distinct string value without any copies
+    fn into_state(self) -> ArrayRef {
+        let Self {
+            map: _,
+            map_size: _,
+            offsets,
+            mut buffer,
+            random_state: _,
+            hashes_buffer: _,
+        } = self;
+
+        let offsets: ScalarBuffer<O> = offsets.into();
+        let values = buffer.finish();
+        let nulls = None; // count distinct ignores nulls so intermediate 
state never has nulls
+
+        // SAFETY: all the values that went in were valid utf8 so are all the 
values that come out
+        let array = unsafe {
+            GenericStringArray::new_unchecked(OffsetBuffer::new(offsets), 
values, nulls)
+        };
+        Arc::new(array)
+    }
+
+    fn len(&self) -> usize {
+        self.map.len()
+    }
+
+    /// Return the total size, in bytes, of memory used to store the data in
+    /// this set, not including `self`
+    fn size(&self) -> usize {
+        self.map_size
+            + self.buffer.capacity() * std::mem::size_of::<u8>()
+            + self.offsets.capacity() * std::mem::size_of::<O>()
+            + self.hashes_buffer.capacity() * std::mem::size_of::<u64>()
+    }
+}
+
+impl<O: OffsetSizeTrait> Debug for SSOStringHashSet<O> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("SSOStringHashSet")
+            .field("map", &"<map>")
+            .field("map_size", &self.map_size)
+            .field("buffer", &self.buffer)
+            .field("random_state", &self.random_state)
+            .field("hashes_buffer", &self.hashes_buffer)
+            .finish()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow::array::ArrayRef;
+    use arrow_array::StringArray;
+    #[test]
+    fn string_set_empty() {
+        for values in [StringArray::new_null(0), StringArray::new_null(11)] {
+            let mut set = SSOStringHashSet::<i32>::new();
+            let array: ArrayRef = Arc::new(values);
+            set.insert(&array);
+            assert_set(set, &[]);
+        }
+    }
+
+    #[test]
+    fn string_set_basic_i32() {
+        test_string_set_basic::<i32>();
+    }
+    #[test]
+    fn string_set_basic_i64() {
+        test_string_set_basic::<i64>();
+    }
+    fn test_string_set_basic<O: OffsetSizeTrait>() {
+        // basic test for mixed small and large string values
+        let values = GenericStringArray::<O>::from(vec![
+            Some("a"),
+            Some("b"),
+            Some("CXCCCCCCCC"), // 10 bytes
+            Some(""),
+            Some("cbcxx"), // 5 bytes
+            None,
+            Some("AAAAAAAA"),  // 8 bytes
+            Some("BBBBBQBBB"), // 9 bytes
+            Some("a"),
+            Some("cbcxx"),
+            Some("b"),
+            Some("cbcxx"),
+            Some(""),
+            None,
+            Some("BBBBBQBBB"),
+            Some("BBBBBQBBB"),
+            Some("AAAAAAAA"),
+            Some("CXCCCCCCCC"),
+        ]);
+
+        let mut set = SSOStringHashSet::<O>::new();
+        let array: ArrayRef = Arc::new(values);
+        set.insert(&array);
+        assert_set(
+            set,
+            &[
+                Some(""),
+                Some("AAAAAAAA"),
+                Some("BBBBBQBBB"),
+                Some("CXCCCCCCCC"),
+                Some("a"),
+                Some("b"),
+                Some("cbcxx"),
+            ],
+        );
+    }
+
+    #[test]
+    fn string_set_non_utf8_32() {
+        test_string_set_non_utf8::<i32>();
+    }
+    #[test]
+    fn string_set_non_utf8_64() {
+        test_string_set_non_utf8::<i64>();
+    }
+    fn test_string_set_non_utf8<O: OffsetSizeTrait>() {
+        // basic test for mixed small and large string values
+        let values = GenericStringArray::<O>::from(vec![
+            Some("a"),
+            Some("✨🔥"),
+            Some("🔥"),
+            Some("✨✨✨"),
+            Some("foobarbaz"),
+            Some("🔥"),
+            Some("✨🔥"),
+        ]);
+
+        let mut set = SSOStringHashSet::<O>::new();
+        let array: ArrayRef = Arc::new(values);
+        set.insert(&array);
+        assert_set(
+            set,
+            &[
+                Some("a"),
+                Some("foobarbaz"),
+                Some("✨✨✨"),
+                Some("✨🔥"),
+                Some("🔥"),
+            ],
+        );
+    }
+
+    // asserts that the set contains the expected strings
+    fn assert_set<O: OffsetSizeTrait>(
+        set: SSOStringHashSet<O>,
+        expected: &[Option<&str>],
+    ) {
+        let strings = set.into_state();
+        let strings = strings.as_string::<O>();
+        let mut state = strings.into_iter().collect::<Vec<_>>();
+        state.sort();
+        assert_eq!(state, expected);
+    }
+
+    // inserting strings into the set does not increase reported memoyr
+    #[test]
+    fn test_string_set_memory_usage() {
+        let strings1 = GenericStringArray::<i32>::from(vec![
+            Some("a"),
+            Some("b"),
+            Some("CXCCCCCCCC"), // 10 bytes
+            Some("AAAAAAAA"),   // 8 bytes
+            Some("BBBBBQBBB"),  // 9 bytes
+        ]);
+        let total_strings1_len = strings1
+            .iter()
+            .map(|s| s.map(|s| s.len()).unwrap_or(0))
+            .sum::<usize>();
+        let values1: ArrayRef = 
Arc::new(GenericStringArray::<i32>::from(strings1));
+
+        // Much larger strings in strings2
+        let strings2 = GenericStringArray::<i32>::from(vec![
+            "FOO".repeat(1000),
+            "BAR".repeat(2000),
+            "BAZ".repeat(3000),
+        ]);
+        let total_strings2_len = strings2
+            .iter()
+            .map(|s| s.map(|s| s.len()).unwrap_or(0))
+            .sum::<usize>();
+        let values2: ArrayRef = 
Arc::new(GenericStringArray::<i32>::from(strings2));
+
+        let mut set = SSOStringHashSet::<i32>::new();
+        let size_empty = set.size();
+
+        set.insert(&values1);
+        let size_after_values1 = set.size();
+        assert!(size_empty < size_after_values1);
+        assert!(
+            size_after_values1 > total_strings1_len,
+            "expect {size_after_values1} to be more than {total_strings1_len}"
+        );
+        assert!(size_after_values1 < total_strings1_len + total_strings2_len);
+
+        // inserting the same strings should not affect the size
+        set.insert(&values1);
+        assert_eq!(set.size(), size_after_values1);
+
+        // inserting the large strings should increase the reported size
+        set.insert(&values2);
+        let size_after_values2 = set.size();
+        assert!(size_after_values2 > size_after_values1);
+        assert!(size_after_values2 > total_strings1_len + total_strings2_len);
+    }
+}
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 5cd728c434..136fb39c67 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -3069,6 +3069,62 @@ select count(*) from (select count(*) a, count(*) b from 
(select 1));
 ----
 1
 
+# Distinct Count for string
+# (test for the specialized implementation of distinct count for strings)
+
+# UTF8 string matters for string to &[u8] conversion, add it to prevent 
regression
+statement ok
+create table distinct_count_string_table as values 
+    (1, 'a', 'longstringtest_a', '台灣'),
+    (2, 'b', 'longstringtest_b1', '日本'),
+    (2, 'b', 'longstringtest_b2', '中國'),
+    (3, 'c', 'longstringtest_c1', '美國'),
+    (3, 'c', 'longstringtest_c2', '歐洲'),
+    (3, 'c', 'longstringtest_c3', '韓國')
+;
+
+# run through update_batch
+query IIII
+select count(distinct column1), count(distinct column2), count(distinct 
column3), count(distinct column4) from distinct_count_string_table;
+----
+3 3 6 6
+
+# run through merge_batch
+query IIII rowsort
+select count(distinct column1), count(distinct column2), count(distinct 
column3), count(distinct column4) from distinct_count_string_table group by 
column1;
+----
+1 1 1 1
+1 1 2 2
+1 1 3 3
+
+
+# test with long strings as well
+statement ok
+create table distinct_count_long_string_table as
+SELECT column1,
+  arrow_cast(column2, 'LargeUtf8') as column2,
+  arrow_cast(column3, 'LargeUtf8') as column3,
+  arrow_cast(column4, 'LargeUtf8') as column4
+FROM distinct_count_string_table;
+
+# run through update_batch
+query IIII
+select count(distinct column1), count(distinct column2), count(distinct 
column3), count(distinct column4) from distinct_count_long_string_table;
+----
+3 3 6 6
+
+# run through merge_batch
+query IIII rowsort
+select count(distinct column1), count(distinct column2), count(distinct 
column3), count(distinct column4) from distinct_count_long_string_table group 
by column1;
+----
+1 1 1 1
+1 1 2 2
+1 1 3 3
+
+statement ok
+drop table distinct_count_string_table;
+
+
 # rule `aggregate_statistics` should not optimize MIN/MAX to wrong values on 
empty relation
 
 statement ok
@@ -3122,3 +3178,4 @@ NULL
 
 statement ok
 DROP TABLE t;
+
diff --git a/datafusion/sqllogictest/test_files/clickbench.slt 
b/datafusion/sqllogictest/test_files/clickbench.slt
index 21befd7822..b61bee6708 100644
--- a/datafusion/sqllogictest/test_files/clickbench.slt
+++ b/datafusion/sqllogictest/test_files/clickbench.slt
@@ -273,3 +273,6 @@ SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) 
AS PageViews FROM hit
 query PI
 SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) 
AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= 
'2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND 
"DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', 
to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 
OFFSET 1000;
 ----
+
+query
+drop table hits;
\ No newline at end of file


Reply via email to