Dandandan commented on code in PR #16136: URL: https://github.com/apache/datafusion/pull/16136#discussion_r2100975676
########## datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs: ########## @@ -74,21 +77,21 @@ macro_rules! hash_float { hash_float!(f16, f32, f64); -/// A [`GroupValues`] storing a single column of primitive values +/// A [`GroupValues`] storing a single column of normal primitive values (bits <= 64) /// /// This specialization is significantly faster than using the more general /// purpose `Row`s format pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> { /// The data type of the output array data_type: DataType, - /// Stores the `(group_index, hash)` based on the hash of its value + /// Stores the `(group_index, group_value)` /// - /// We also store `hash` is for reducing cost of rehashing. Such cost - /// is obvious in high cardinality group by situation. + /// We directly store copy of `group_value` for not only efficient + /// rehashing, but also efficient probing. /// More details can see: - /// <https://github.com/apache/datafusion/issues/15961> + /// <https://github.com/apache/datafusion/issues/16136> /// - map: HashTable<(usize, u64)>, + map: HashTable<(usize, T::Native)>, Review Comment: Something like: ```diff diff --git i/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs w/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs index 693cc997fa..600287486d 100644 --- i/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs +++ w/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs @@ -25,7 +25,6 @@ use arrow::array::{ use arrow::datatypes::{i256, DataType}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; -use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use half::f16; use hashbrown::hash_table::HashTable; @@ -33,7 +32,6 @@ use std::mem::size_of; use std::sync::Arc; mod large_primitive; -pub use large_primitive::GroupValuesLargePrimitive; /// A trait to allow hashing of floating point numbers pub(crate) trait HashValue { @@ -94,8 +92,6 @@ pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> { map: HashTable<(usize, T::Native)>, /// The group index of the null value if any null_group: Option<usize>, - /// The values for each group index - values: Vec<T::Native>, /// The random state used to generate hashes random_state: RandomState, } @@ -106,7 +102,6 @@ impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> { Self { data_type, map: HashTable::with_capacity(128), - values: Vec::with_capacity(128), null_group: None, random_state: Default::default(), } @@ -124,13 +119,14 @@ where for v in cols[0].as_primitive::<T>() { let group_id = match v { None => *self.null_group.get_or_insert_with(|| { - let group_id = self.values.len(); - self.values.push(Default::default()); + let group_id = self.map.len(); group_id }), Some(key) => { let state = &self.random_state; let hash = key.hash(state); + let group_id = self.map.len(); + let insert = self.map.entry( hash, |&(_, v)| v.is_eq(key), @@ -140,10 +136,8 @@ where match insert { hashbrown::hash_table::Entry::Occupied(o) => o.get().0, hashbrown::hash_table::Entry::Vacant(v) => { - let g = self.values.len(); - v.insert((g, key)); - self.values.push(key); - g + v.insert((group_id, key)); + group_id } } } @@ -155,21 +149,19 @@ where fn size(&self) -> usize { self.map.capacity() * size_of::<(usize, T::Native)>() - + self.values.allocated_size() } fn is_empty(&self) -> bool { - self.values.is_empty() + self.map.is_empty() } fn len(&self) -> usize { - self.values.len() + self.map.len() } fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> { - emit_internal::<T, T::Native>( + emit_internal::<T>( emit_to, - &mut self.values, &mut self.null_group, &mut self.map, self.data_type.clone(), @@ -178,22 +170,19 @@ where fn clear_shrink(&mut self, batch: &RecordBatch) { let count = batch.num_rows(); - self.values.clear(); - self.values.shrink_to(count); self.map.clear(); self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared } } -pub(crate) fn emit_internal<T: ArrowPrimitiveType, K>( +pub(crate) fn emit_internal<T: ArrowPrimitiveType>( emit_to: EmitTo, - values: &mut Vec<T::Native>, null_group: &mut Option<usize>, - map: &mut HashTable<(usize, K)>, + map: &mut HashTable<(usize, T::Native)>, data_type: DataType, ) -> Result<Vec<ArrayRef>> { fn build_primitive<T: ArrowPrimitiveType>( - values: Vec<T::Native>, + values: HashTable<(usize, T::Native)>, null_idx: Option<usize>, ) -> PrimitiveArray<T> { let nulls = null_idx.map(|null_idx| { @@ -204,39 +193,39 @@ pub(crate) fn emit_internal<T: ArrowPrimitiveType, K>( // NOTE: The inner builder must be constructed as there is at least one null buffer.finish().unwrap() }); - PrimitiveArray::<T>::new(values.into(), nulls) + PrimitiveArray::<T>::new(values.iter().map(|x|x.1).collect::<Vec<_>>().into(), nulls) } let array: PrimitiveArray<T> = match emit_to { EmitTo::All => { - map.clear(); - build_primitive(std::mem::take(values), null_group.take()) + build_primitive(std::mem::take(map), null_group.take()) } EmitTo::First(n) => { - map.retain(|entry| { - // Decrement group index by n - let group_idx = entry.0; - match group_idx.checked_sub(n) { - // Group index was >= n, shift value down - Some(sub) => { - entry.0 = sub; - true - } - // Group index was < n, so remove from table - None => false, - } - }); - let null_group = match null_group { - Some(v) if *v >= n => { - *v -= n; - None - } - Some(_) => null_group.take(), - None => None, - }; - let mut split = values.split_off(n); - std::mem::swap(values, &mut split); - build_primitive(split, null_group) + todo!(""); + // map.retain(|entry| { + // // Decrement group index by n + // let group_idx = entry.0; + // match group_idx.checked_sub(n) { + // // Group index was >= n, shift value down + // Some(sub) => { + // entry.0 = sub; + // true + // } + // // Group index was < n, so remove from table + // None => false, + // } + // }); + // let null_group = match null_group { + // Some(v) if *v >= n => { + // *v -= n; + // None + // } + // Some(_) => null_group.take(), + // None => None, + // }; + // let mut split = values.split_off(n); + // std::mem::swap(values, &mut split); + // build_primitive(split, null_group) } }; ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org