avantgardnerio commented on code in PR #7192: URL: https://github.com/apache/arrow-datafusion/pull/7192#discussion_r1305926379
########## datafusion/core/src/physical_plan/aggregates/priority_map.rs: ########## @@ -0,0 +1,969 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A memory-conscious aggregation implementation that limits group buckets to a fixed number + +use crate::physical_plan::aggregates::group_values::primitive::HashValue; +use crate::physical_plan::aggregates::{ + aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec, + PhysicalGroupBy, +}; +use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; +use ahash::RandomState; +use arrow::datatypes::i256; +use arrow::util::pretty::print_batches; +use arrow_array::builder::PrimitiveBuilder; +use arrow_array::cast::AsArray; +use arrow_array::{ + downcast_primitive, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, RecordBatch, + StringArray, +}; +use arrow_schema::{DataType, SchemaRef}; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::aggregate::utils::adjust_output_array; +use datafusion_physical_expr::PhysicalExpr; +use futures::stream::{Stream, StreamExt}; +use half::f16; +use hashbrown::raw::RawTable; +use itertools::Itertools; +use log::{trace, Level}; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +pub struct GroupedTopKAggregateStream { + partition: usize, + row_count: usize, + started: bool, + schema: SchemaRef, + input: SendableRecordBatchStream, + aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>, + group_by: PhysicalGroupBy, + map: Box<dyn LimitedHashTable>, + heap: Box<dyn LimitedHeap>, + limit: usize, +} + +unsafe impl Send for GroupedTopKAggregateStream {} + +impl GroupedTopKAggregateStream { + pub fn new( + aggr: &AggregateExec, + context: Arc<TaskContext>, + partition: usize, + limit: usize, + ) -> Result<Self> { + let agg_schema = Arc::clone(&aggr.schema); + let group_by = aggr.group_by.clone(); + let input = aggr.input.execute(partition, Arc::clone(&context))?; + let aggregate_arguments = + aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?; + let (val_field, desc) = aggr + .get_minmax_desc() + .ok_or_else(|| DataFusionError::Internal("Min/max required".to_string()))?; + + let (expr, _) = &aggr.group_expr().expr()[0]; + let kt = expr.data_type(&aggr.input().schema())?; + let map = new_map(limit, kt)?; + + let vt = val_field.data_type().clone(); + let heap = new_heap(limit, desc, vt)?; + + Ok(GroupedTopKAggregateStream { + partition, + started: false, + row_count: 0, + schema: agg_schema, + input, + aggregate_arguments, + group_by, + map, + heap, + limit, + }) + } +} + +/// A trait to allow comparison of floating point numbers +pub trait Comparable { + fn comp(&self, other: &Self) -> Ordering; +} + +impl Comparable for Option<String> { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } +} + +impl HashValue for Option<String> { + fn hash(&self, state: &RandomState) -> u64 { + state.hash_one(self) + } +} + +macro_rules! compare_float { + ($($t:ty),+) => { + $(impl Comparable for Option<$t> { + fn comp(&self, other: &Self) -> Ordering { + match (self, other) { + (Some(me), Some(other)) => me.total_cmp(other), + (Some(_), None) => Ordering::Greater, + (None, Some(_)) => Ordering::Less, + (None, None) => Ordering::Equal, + } + } + })+ + + $(impl Comparable for $t { + fn comp(&self, other: &Self) -> Ordering { + self.total_cmp(other) + } + })+ + + $(impl HashValue for Option<$t> { + fn hash(&self, state: &RandomState) -> u64 { + self.map(|me| me.hash(state)).unwrap_or(0) + } + })+ + }; +} + +macro_rules! compare_integer { + ($($t:ty),+) => { + $(impl Comparable for Option<$t> { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } + })+ + + $(impl Comparable for $t { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } + })+ + + $(impl HashValue for Option<$t> { + fn hash(&self, state: &RandomState) -> u64 { + self.map(|me| me.hash(state)).unwrap_or(0) + } + })+ + }; +} + +compare_integer!(i8, i16, i32, i64, i128, i256); +compare_integer!(u8, u16, u32, u64); +compare_float!(f16, f32, f64); + +pub fn new_map(limit: usize, kt: DataType) -> Result<Box<dyn LimitedHashTable>> { + macro_rules! downcast_helper { + ($kt:ty, $d:ident) => { + return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit))) + }; + } + + downcast_primitive! { + kt => (downcast_helper, kt), + DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit))), + _ => {} // TODO: OwnedRow, etc + } + + Err(DataFusionError::Execution(format!( + "Can't create HashTable for type: {kt:?}" + ))) +} + +pub fn new_heap(limit: usize, desc: bool, vt: DataType) -> Result<Box<dyn LimitedHeap>> { + macro_rules! downcast_helper { + ($vt:ty, $d:ident) => { + return Ok(Box::new(PrimitiveHeap::<$vt>::new(limit, desc, vt))) + }; + } + + downcast_primitive! { + vt => (downcast_helper, vt), + _ => {} // TODO: OwnedRow + } + + Err(DataFusionError::Execution(format!( + "Can't group type: {vt:?}" + ))) +} + +pub trait LimitedHashTable { + fn set_batch(&mut self, ids: ArrayRef); + fn len(&self) -> usize; + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]); + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize; + fn drain(&mut self) -> (ArrayRef, Vec<usize>); + + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + map: &mut Vec<(usize, usize)>, + ) -> (usize, bool); +} + +pub trait LimitedHeap { + fn set_batch(&mut self, vals: ArrayRef); + fn is_worse(&self, idx: usize) -> bool; + fn worst_map_idx(&self) -> usize; + fn renumber(&mut self, heap_to_map: &[(usize, usize)]); + fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>); + fn replace_if_best( + &mut self, + heap_idx: usize, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ); + fn take_all(&mut self, heap_idxs: Vec<usize>) -> ArrayRef; +} + +impl RecordBatchStream for GroupedTopKAggregateStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +pub struct PrimitiveHeap<VAL: ArrowPrimitiveType> +where + <VAL as ArrowPrimitiveType>::Native: Comparable, +{ + owned: ArrayRef, + heap: CustomHeap<VAL::Native>, + desc: bool, + data_type: DataType, +} + +impl<VAL: ArrowPrimitiveType> PrimitiveHeap<VAL> +where + <VAL as ArrowPrimitiveType>::Native: Comparable, +{ + pub fn new(limit: usize, desc: bool, data_type: DataType) -> Self { + let owned: ArrayRef = Arc::new(PrimitiveArray::<VAL>::builder(0).finish()); + Self { + owned, + heap: CustomHeap::new(limit, desc), + desc, + data_type, + } + } +} + +impl<VAL: ArrowPrimitiveType> LimitedHeap for PrimitiveHeap<VAL> +where + <VAL as ArrowPrimitiveType>::Native: Comparable, +{ + fn set_batch(&mut self, vals: ArrayRef) { + self.owned = vals; + } + + fn is_worse(&self, row_idx: usize) -> bool { + if !self.heap.is_full() { + return false; + } + let vals = self.owned.as_primitive::<VAL>(); + let new_val = vals.value(row_idx); + let worst_val = self.heap.worst_val().expect("Missing root"); + (!self.desc && new_val > *worst_val) || (self.desc && new_val < *worst_val) + } + + fn worst_map_idx(&self) -> usize { + self.heap.worst_map_idx() + } + + fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { + self.heap.renumber(heap_to_map); + } + + fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) { + let vals = self.owned.as_primitive::<VAL>(); + let new_val = vals.value(row_idx); + + if self.heap.is_full() { + self.heap.replace_root(new_val, map_idx, map); + } else { + self.heap.append(new_val, map_idx, map); + } + } + + fn replace_if_best( + &mut self, + heap_idx: usize, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ) { + let vals = self.owned.as_primitive::<VAL>(); + let new_val = vals.value(row_idx); + self.heap.replace_if_best(heap_idx, new_val, map); + } + + fn take_all(&mut self, heap_idxs: Vec<usize>) -> ArrayRef { + let vals = self.heap.take_all(heap_idxs); + let vals = Arc::new(PrimitiveArray::<VAL>::from_iter_values(vals)); + adjust_output_array(&self.data_type, vals).expect("Type is incorrect") + } +} + +pub struct StringHashTable { + owned: ArrayRef, + map: CustomMap<Option<String>>, + rnd: RandomState, +} + +impl StringHashTable { + pub fn new(limit: usize) -> Self { + let vals: Vec<&str> = Vec::new(); + let owned = Arc::new(StringArray::from(vals)); + Self { + owned, + map: CustomMap::new(limit, limit * 10), + rnd: ahash::RandomState::default(), + } + } +} + +impl LimitedHashTable for StringHashTable { + fn set_batch(&mut self, ids: ArrayRef) { + self.owned = ids; + } + + fn len(&self) -> usize { + self.map.len() + } + + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + self.map.update_heap_idx(mapper); + } + + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + self.map.heap_idx_at(map_idx) + } + + fn drain(&mut self) -> (ArrayRef, Vec<usize>) { + let (ids, heap_idxs) = self.map.drain(); + let ids = Arc::new(StringArray::from(ids)); + (ids, heap_idxs) + } + + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) -> (usize, bool) { + let ids = self + .owned + .as_any() + .downcast_ref::<StringArray>() + .expect("StringArray required"); + let id = if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + }; + + let hash = self.rnd.hash_one(id); + if let Some(map_idx) = self + .map + .find(hash, |mi| id == mi.as_ref().map(|id| id.as_str())) + { + return (map_idx, false); + } + + // we're full and this is a better value, so remove the worst + let heap_idx = self.map.remove_if_full(replace_idx); + + // add the new group + let id = id.map(|id| id.to_string()); + let map_idx = self.map.insert(hash, id, heap_idx, mapper); + (map_idx, true) + } +} + +pub struct PrimitiveHashTable<VAL: ArrowPrimitiveType> +where + Option<<VAL as ArrowPrimitiveType>::Native>: Comparable, +{ + owned: ArrayRef, + map: CustomMap<Option<VAL::Native>>, + rnd: RandomState, +} + +impl<VAL: ArrowPrimitiveType> PrimitiveHashTable<VAL> +where + Option<<VAL as ArrowPrimitiveType>::Native>: Comparable, + Option<<VAL as ArrowPrimitiveType>::Native>: HashValue, +{ + pub fn new(limit: usize) -> Self { + let owned = Arc::new(PrimitiveArray::<VAL>::builder(0).finish()); + Self { + owned, + map: CustomMap::new(limit, limit * 10), + rnd: ahash::RandomState::default(), + } + } +} + +impl<VAL: ArrowPrimitiveType> LimitedHashTable for PrimitiveHashTable<VAL> +where + Option<<VAL as ArrowPrimitiveType>::Native>: Comparable, + Option<<VAL as ArrowPrimitiveType>::Native>: HashValue, +{ + fn set_batch(&mut self, ids: ArrayRef) { + self.owned = ids; + } + + fn len(&self) -> usize { + self.map.len() + } + + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + self.map.update_heap_idx(mapper); + } + + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + self.map.heap_idx_at(map_idx) + } + + fn drain(&mut self) -> (ArrayRef, Vec<usize>) { + let (ids, heap_idxs) = self.map.drain(); + let mut builder: PrimitiveBuilder<VAL> = PrimitiveArray::builder(ids.len()); + for id in ids.into_iter() { + match id { + None => builder.append_null(), + Some(id) => builder.append_value(id), + } + } + let ids = Arc::new(builder.finish()); + (ids, heap_idxs) + } + + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) -> (usize, bool) { + let ids = self.owned.as_primitive::<VAL>(); + let id: Option<VAL::Native> = if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + }; + + let hash: u64 = id.hash(&self.rnd); + if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) { + return (map_idx, false); + } + + // we're full and this is a better value, so remove the worst + let heap_idx = self.map.remove_if_full(replace_idx); + + // add the new group + let map_idx = self.map.insert(hash, id, heap_idx, mapper); + (map_idx, true) + } +} + +pub trait ValueType: Comparable + Clone + Debug {} + +impl<T> ValueType for T where T: Comparable + Clone + Debug {} + +pub trait KeyType: Clone + Comparable + Debug {} + +impl<T> KeyType for T where T: Clone + Comparable + Debug {} + +impl GroupedTopKAggregateStream { + fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> { + let mut mapper = Vec::with_capacity(self.limit); + let len = ids.len(); + self.map.set_batch(ids); + self.heap.set_batch(vals.clone()); + + let has_nulls = vals.null_count() > 0; + for row_idx in 0..len { + if has_nulls && vals.is_null(row_idx) { + continue; + } + self.insert(row_idx, &mut mapper)?; + } + Ok(()) + } + + pub fn insert( + &mut self, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ) -> Result<()> { + assert!(self.map.len() <= self.limit, "Overflow"); + + // if we're full, and the new val is worse than all our values, just bail + if self.heap.is_worse(row_idx) { + return Ok(()); + } + + // handle new groups we haven't seen yet + map.clear(); + let replace_idx = self.heap.worst_map_idx(); + let (map_idx, did_insert) = + unsafe { self.map.find_or_insert(row_idx, replace_idx, map) }; + if did_insert { + self.heap.renumber(map); Review Comment: As `HeapItem`s get `swap()`ed or `HashTable`s grow, we need to track the mapping of old indexes to new and update the other struct. This updating is relatively safe since it is indexes, but slow since it is happening continually with the index-based heap. We can probably speed this up significantly if we switch to a `RawPointer` based heap so the pointers never change (unlike these indexes). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org