Rich-T-kid commented on code in PR #22983: URL: https://github.com/apache/datafusion/pull/22983#discussion_r3429146634
########## datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs: ########## @@ -0,0 +1,894 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, DictionaryArray, Int8Array, Int16Array, Int32Array, + Int64Array, ListBuilder, NullArray, StringBuilder, UInt8Array, UInt16Array, + UInt32Array, UInt64Array, +}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, SchemaRef, + UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; +use arrow::downcast_dictionary_array; +use datafusion_common::hash_utils::{RandomState, combine_hashes, create_hashes}; +use datafusion_common::{Result, internal_datafusion_err}; +use datafusion_execution::memory_pool::proxy::HashTableAllocExt; +use datafusion_expr::EmitTo; +use hashbrown::hash_table::HashTable; + +use crate::aggregates::group_values::GroupValues; + +/// Caches the hashes for one dictionary column's values array. +/// Rebuilt only when the `Arc` pointer changes (i.e. a new values array arrives). +struct ColumnCache { + /// Keeps the values `Arc` alive and is compared with `Arc::ptr_eq` to detect staleness. + values: ArrayRef, + /// `value_hashes[k]` = hash of the value at dictionary index `k`. + value_hashes: Vec<u64>, +} + +impl ColumnCache { + fn empty() -> Self { + Self { + values: Arc::new(NullArray::new(0)), + value_hashes: vec![], + } + } + + fn update(&mut self, new_values: ArrayRef, random_state: &RandomState) -> Result<()> { + if Arc::ptr_eq(&new_values, &self.values) { + return Ok(()); + } + let num_values = new_values.len(); + // Reuse the allocation; only grows capacity when a larger values array arrives. + self.value_hashes.clear(); + self.value_hashes.resize(num_values, 0u64); + create_hashes(&[new_values.clone()], random_state, &mut self.value_hashes)?; + self.values = new_values; + Ok(()) + } + + fn size(&self) -> usize { + self.value_hashes.len() * size_of::<u64>() + } + + fn clear_shrink(&mut self, shrink_to: usize) { + self.values = Arc::new(NullArray::new(0)); + self.value_hashes.clear(); + self.value_hashes.shrink_to(shrink_to); + } +} + +/// [`GroupValues`] for GROUP BY over **two or more** dictionary-typed columns. +pub struct GroupDictionaryColumn { + schema: SchemaRef, + col_caches: Vec<ColumnCache>, + /// `(row_hash, group_id)`. Multiple entries may share the same hash value; + /// byte-level comparison is used to resolve collisions. + map: HashTable<(u64, usize)>, + /// Tracked allocation size of `map` in bytes, updated on every insert and shrink. + map_size: usize, + /// All group rows packed back-to-back into a single contiguous buffer. + /// + /// CSR-style layout: `row_offsets[g]` is the start of group `g` and + /// `row_offsets[g+1]` is its end. The last group has no `g+1` entry; its + /// end is `row_buffer.len()`. + row_buffer: Vec<u8>, + /// `row_offsets[g]` = start byte of group `g` inside `row_buffer`. + row_offsets: Vec<usize>, + /// Reused scratch buffer for encoding the current row. + row_scratch: Vec<u8>, + row_decoder: RowSetDecoder, + random_state: RandomState, +} + +/// Returns `true` when every field in `schema` is `DataType::Dictionary`. +pub fn all_dictionary_schema(schema: &arrow::datatypes::Schema) -> bool { + schema + .fields() + .iter() + .all(|field| matches!(field.data_type(), DataType::Dictionary(_, _))) +} + +fn is_supported_value_type(data_type: &DataType) -> bool { + matches!(data_type, DataType::Utf8) + || matches!(data_type, DataType::List(f) if f.data_type() == &DataType::Utf8) +} + +impl GroupDictionaryColumn { + pub fn new(schema: SchemaRef) -> Result<Self> { + if schema.fields().len() < 2 { + return Err(internal_datafusion_err!( + "GroupDictionaryColumn requires at least 2 columns, got {}", + schema.fields().len() + )); + } + for field in schema.fields() { + match field.data_type() { + DataType::Dictionary(_, value_type) => { + if !is_supported_value_type(value_type) { + return Err(internal_datafusion_err!( + "GroupDictionaryColumn: unsupported dictionary value type \ + '{}' in column '{}'", + value_type, + field.name() + )); + } + } + _ => { + return Err(internal_datafusion_err!( + "GroupDictionaryColumn requires all columns to be Dictionary, \ + but '{}' has type {}", + field.name(), + field.data_type() + )); + } + } + } + let n_cols = schema.fields().len(); + let row_decoder = RowSetDecoder::new(Arc::clone(&schema)); + Ok(Self { + schema, + col_caches: (0..n_cols).map(|_| ColumnCache::empty()).collect(), + map: HashTable::with_capacity(128), + map_size: 0, + row_buffer: Vec::new(), + row_offsets: Vec::new(), + row_scratch: Vec::new(), + row_decoder, + random_state: crate::aggregates::AGGREGATION_HASH_SEED, + }) + } +} + +fn dict_values_array(col: &dyn Array) -> ArrayRef { + downcast_dictionary_array!( + col => col.values().clone(), + _ => unreachable!("schema validated in GroupDictionaryColumn::new") + ) +} + +// Box is required: different key widths (Int8/Int16/Int32/Int64) produce different concrete iterator types. +fn fill_keys(col: &dyn Array) -> Box<dyn Iterator<Item = Option<usize>> + '_> { + downcast_dictionary_array!( + col => { + let keys = col.keys(); + Box::new((0..keys.len()).map(move |row_idx| { + if keys.is_valid(row_idx) { + Some(keys.value(row_idx).as_usize()) + } else { + None + } + })) + }, + _ => unreachable!("schema validated in GroupDictionaryColumn::new") + ) +} + +impl GroupValues for GroupDictionaryColumn { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> { + debug_assert_eq!(cols.len(), self.schema.fields().len()); + groups.clear(); + + if cols.is_empty() || cols[0].is_empty() { + return Ok(()); + } + let n_rows = cols[0].len(); + + for (col_idx, col) in cols.iter().enumerate() { + self.col_caches[col_idx] + .update(dict_values_array(col.as_ref()), &self.random_state)?; + } + + // Downcast once per column; advance with .next() per row to avoid per-row downcast. + let mut key_iters: Vec<_> = + cols.iter().map(|col| fill_keys(col.as_ref())).collect(); + + groups.reserve(n_rows); + + for _row in 0..n_rows { + let mut hash = 0u64; + self.row_scratch.clear(); + + for (col_idx, key_iter) in key_iters.iter_mut().enumerate() { + let key = key_iter.next().unwrap(); + let cache = &self.col_caches[col_idx]; + let value_hash = key.map_or(0, |key_idx| cache.value_hashes[key_idx]); + hash = combine_hashes(hash, value_hash); + encode_value(key, cache.values.as_ref(), &mut self.row_scratch); + } + + let combined_hash = hash; + let found = { + let row_scratch = self.row_scratch.as_slice(); + let row_buffer = self.row_buffer.as_slice(); + let row_offsets = self.row_offsets.as_slice(); + self.map + .find(combined_hash, |&(stored_hash, group_id)| { + stored_hash == combined_hash && { + let end = row_offsets + .get(group_id + 1) + .copied() + .unwrap_or(row_buffer.len()); // last group has no g+1 entry + row_buffer[row_offsets[group_id]..end] == *row_scratch + } + }) + .map(|&(_, group_id)| group_id) + }; + + let group_id = match found { + Some(existing_id) => existing_id, + None => { + let new_id = self.row_offsets.len(); + self.row_offsets.push(self.row_buffer.len()); + self.row_buffer.extend_from_slice(&self.row_scratch); + self.map.insert_accounted( + (combined_hash, new_id), + |(stored_hash, _)| *stored_hash, + &mut self.map_size, + ); + new_id + } + }; + + groups.push(group_id); + } + + Ok(()) + } + + fn size(&self) -> usize { + let cache_bytes: usize = self.col_caches.iter().map(|c| c.size()).sum(); + self.map_size + + self.row_buffer.len() + + self.row_offsets.len() * size_of::<usize>() + + self.row_scratch.capacity() + + cache_bytes + } + + fn is_empty(&self) -> bool { + self.row_offsets.is_empty() + } + + fn len(&self) -> usize { + self.row_offsets.len() + } + + fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> { + let n_total = self.row_offsets.len(); + if n_total == 0 { + return Ok(self.row_decoder.finish()); + } + let n_emit = match emit_to { + EmitTo::All => n_total, + EmitTo::First(n) => n.min(n_total), + }; + + for row_idx in 0..n_emit { + let start = self.row_offsets[row_idx]; + let end = self + .row_offsets + .get(row_idx + 1) + .copied() + .unwrap_or(self.row_buffer.len()); + self.row_decoder.decode(&self.row_buffer[start..end]); + } + let inner = self.row_decoder.finish(); + let arrays: Vec<ArrayRef> = inner + .into_iter() + .zip(self.schema.fields()) + .map(|(values, field)| match field.data_type() { + DataType::Dictionary(key_type, _) => wrap_as_dictionary( + values, + make_sequential_keys(n_emit, key_type), + key_type, + ), + _ => unreachable!("schema validated in GroupDictionaryColumn::new"), + }) + .collect(); + + if n_emit == n_total { + self.row_buffer.clear(); + self.row_offsets.clear(); + self.map.clear(); + self.map_size = 0; + } else { + let retain_start = self.row_offsets[n_emit]; + self.row_offsets.drain(0..n_emit); + for offset in &mut self.row_offsets { + *offset -= retain_start; + } + self.row_buffer.drain(0..retain_start); + // avoiding this somehow would be nice. worse case this runs once + // VecDeque? + // Shift remaining group ids in-place; retain gives &mut access so no rehashing occurs. + self.map.retain(|(_, gid)| { + if *gid < n_emit { + return false; + } + *gid -= n_emit; + true + }); + } + + Ok(arrays) + } + + fn clear_shrink(&mut self, num_rows: usize) { + self.map.clear(); + self.map.shrink_to(num_rows, |_| 0); + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); + self.row_buffer.clear(); + self.row_offsets.clear(); + self.row_offsets.shrink_to(num_rows); + for cache in &mut self.col_caches { + cache.clear_shrink(num_rows); + } + } +} + +// ── encoding / decoding ─────────────────────────────────────────────────────── + +/// Wire format per column: +/// null: `[0x00]` +/// non-null scalar: `[0x01][len: u32 LE][utf8_bytes…]` +/// non-null list: `[0x01][content_len: u32 LE][n: u32 LE][elem…]` +/// where each elem is `[0x00]` (null) or `[0x01][len: u32 LE][utf8_bytes…]` +fn encode_value(key: Option<usize>, values: &dyn Array, buf: &mut Vec<u8>) { + let key_idx = match key { + None => { + buf.push(0); + return; + } + Some(k) => k, + }; + if values.is_null(key_idx) { + buf.push(0); + return; + } + buf.push(1); + match values.data_type() { + DataType::Utf8 => { + let bytes = values.as_string::<i32>().value(key_idx).as_bytes(); Review Comment: This cast isnt needed. `.value().as_bytes()` should be fine similar to #/21765 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
