avantgardnerio commented on code in PR #7192: URL: https://github.com/apache/arrow-datafusion/pull/7192#discussion_r1305932651
########## datafusion/core/src/physical_plan/aggregates/priority_map.rs: ########## @@ -0,0 +1,969 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A memory-conscious aggregation implementation that limits group buckets to a fixed number + +use crate::physical_plan::aggregates::group_values::primitive::HashValue; +use crate::physical_plan::aggregates::{ + aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec, + PhysicalGroupBy, +}; +use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; +use ahash::RandomState; +use arrow::datatypes::i256; +use arrow::util::pretty::print_batches; +use arrow_array::builder::PrimitiveBuilder; +use arrow_array::cast::AsArray; +use arrow_array::{ + downcast_primitive, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, RecordBatch, + StringArray, +}; +use arrow_schema::{DataType, SchemaRef}; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::aggregate::utils::adjust_output_array; +use datafusion_physical_expr::PhysicalExpr; +use futures::stream::{Stream, StreamExt}; +use half::f16; +use hashbrown::raw::RawTable; +use itertools::Itertools; +use log::{trace, Level}; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +pub struct GroupedTopKAggregateStream { + partition: usize, + row_count: usize, + started: bool, + schema: SchemaRef, + input: SendableRecordBatchStream, + aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>, + group_by: PhysicalGroupBy, + map: Box<dyn LimitedHashTable>, + heap: Box<dyn LimitedHeap>, + limit: usize, +} + +unsafe impl Send for GroupedTopKAggregateStream {} + +impl GroupedTopKAggregateStream { + pub fn new( + aggr: &AggregateExec, + context: Arc<TaskContext>, + partition: usize, + limit: usize, + ) -> Result<Self> { + let agg_schema = Arc::clone(&aggr.schema); + let group_by = aggr.group_by.clone(); + let input = aggr.input.execute(partition, Arc::clone(&context))?; + let aggregate_arguments = + aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?; + let (val_field, desc) = aggr + .get_minmax_desc() + .ok_or_else(|| DataFusionError::Internal("Min/max required".to_string()))?; + + let (expr, _) = &aggr.group_expr().expr()[0]; + let kt = expr.data_type(&aggr.input().schema())?; + let map = new_map(limit, kt)?; + + let vt = val_field.data_type().clone(); + let heap = new_heap(limit, desc, vt)?; + + Ok(GroupedTopKAggregateStream { + partition, + started: false, + row_count: 0, + schema: agg_schema, + input, + aggregate_arguments, + group_by, + map, + heap, + limit, + }) + } +} + +/// A trait to allow comparison of floating point numbers +pub trait Comparable { + fn comp(&self, other: &Self) -> Ordering; +} + +impl Comparable for Option<String> { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } +} + +impl HashValue for Option<String> { + fn hash(&self, state: &RandomState) -> u64 { + state.hash_one(self) + } +} + +macro_rules! compare_float { + ($($t:ty),+) => { + $(impl Comparable for Option<$t> { + fn comp(&self, other: &Self) -> Ordering { + match (self, other) { + (Some(me), Some(other)) => me.total_cmp(other), + (Some(_), None) => Ordering::Greater, + (None, Some(_)) => Ordering::Less, + (None, None) => Ordering::Equal, + } + } + })+ + + $(impl Comparable for $t { + fn comp(&self, other: &Self) -> Ordering { + self.total_cmp(other) + } + })+ + + $(impl HashValue for Option<$t> { + fn hash(&self, state: &RandomState) -> u64 { + self.map(|me| me.hash(state)).unwrap_or(0) + } + })+ + }; +} + +macro_rules! compare_integer { + ($($t:ty),+) => { + $(impl Comparable for Option<$t> { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } + })+ + + $(impl Comparable for $t { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } + })+ + + $(impl HashValue for Option<$t> { + fn hash(&self, state: &RandomState) -> u64 { + self.map(|me| me.hash(state)).unwrap_or(0) + } + })+ + }; +} + +compare_integer!(i8, i16, i32, i64, i128, i256); +compare_integer!(u8, u16, u32, u64); +compare_float!(f16, f32, f64); + +pub fn new_map(limit: usize, kt: DataType) -> Result<Box<dyn LimitedHashTable>> { + macro_rules! downcast_helper { + ($kt:ty, $d:ident) => { + return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit))) + }; + } + + downcast_primitive! { + kt => (downcast_helper, kt), + DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit))), + _ => {} // TODO: OwnedRow, etc + } + + Err(DataFusionError::Execution(format!( + "Can't create HashTable for type: {kt:?}" + ))) +} + +pub fn new_heap(limit: usize, desc: bool, vt: DataType) -> Result<Box<dyn LimitedHeap>> { + macro_rules! downcast_helper { + ($vt:ty, $d:ident) => { + return Ok(Box::new(PrimitiveHeap::<$vt>::new(limit, desc, vt))) + }; + } + + downcast_primitive! { + vt => (downcast_helper, vt), + _ => {} // TODO: OwnedRow + } + + Err(DataFusionError::Execution(format!( + "Can't group type: {vt:?}" + ))) +} + +pub trait LimitedHashTable { + fn set_batch(&mut self, ids: ArrayRef); + fn len(&self) -> usize; + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]); + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize; + fn drain(&mut self) -> (ArrayRef, Vec<usize>); + + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + map: &mut Vec<(usize, usize)>, + ) -> (usize, bool); +} + +pub trait LimitedHeap { + fn set_batch(&mut self, vals: ArrayRef); + fn is_worse(&self, idx: usize) -> bool; + fn worst_map_idx(&self) -> usize; + fn renumber(&mut self, heap_to_map: &[(usize, usize)]); + fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>); + fn replace_if_best( + &mut self, + heap_idx: usize, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ); + fn take_all(&mut self, heap_idxs: Vec<usize>) -> ArrayRef; +} + +impl RecordBatchStream for GroupedTopKAggregateStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +pub struct PrimitiveHeap<VAL: ArrowPrimitiveType> +where + <VAL as ArrowPrimitiveType>::Native: Comparable, +{ + owned: ArrayRef, + heap: CustomHeap<VAL::Native>, + desc: bool, + data_type: DataType, +} + +impl<VAL: ArrowPrimitiveType> PrimitiveHeap<VAL> +where + <VAL as ArrowPrimitiveType>::Native: Comparable, +{ + pub fn new(limit: usize, desc: bool, data_type: DataType) -> Self { + let owned: ArrayRef = Arc::new(PrimitiveArray::<VAL>::builder(0).finish()); + Self { + owned, + heap: CustomHeap::new(limit, desc), + desc, + data_type, + } + } +} + +impl<VAL: ArrowPrimitiveType> LimitedHeap for PrimitiveHeap<VAL> +where + <VAL as ArrowPrimitiveType>::Native: Comparable, +{ + fn set_batch(&mut self, vals: ArrayRef) { + self.owned = vals; + } + + fn is_worse(&self, row_idx: usize) -> bool { + if !self.heap.is_full() { + return false; + } + let vals = self.owned.as_primitive::<VAL>(); + let new_val = vals.value(row_idx); + let worst_val = self.heap.worst_val().expect("Missing root"); + (!self.desc && new_val > *worst_val) || (self.desc && new_val < *worst_val) + } + + fn worst_map_idx(&self) -> usize { + self.heap.worst_map_idx() + } + + fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { + self.heap.renumber(heap_to_map); + } + + fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) { + let vals = self.owned.as_primitive::<VAL>(); + let new_val = vals.value(row_idx); + + if self.heap.is_full() { + self.heap.replace_root(new_val, map_idx, map); + } else { + self.heap.append(new_val, map_idx, map); + } + } + + fn replace_if_best( + &mut self, + heap_idx: usize, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ) { + let vals = self.owned.as_primitive::<VAL>(); + let new_val = vals.value(row_idx); + self.heap.replace_if_best(heap_idx, new_val, map); + } + + fn take_all(&mut self, heap_idxs: Vec<usize>) -> ArrayRef { + let vals = self.heap.take_all(heap_idxs); + let vals = Arc::new(PrimitiveArray::<VAL>::from_iter_values(vals)); + adjust_output_array(&self.data_type, vals).expect("Type is incorrect") + } +} + +pub struct StringHashTable { + owned: ArrayRef, + map: CustomMap<Option<String>>, + rnd: RandomState, +} + +impl StringHashTable { + pub fn new(limit: usize) -> Self { + let vals: Vec<&str> = Vec::new(); + let owned = Arc::new(StringArray::from(vals)); + Self { + owned, + map: CustomMap::new(limit, limit * 10), + rnd: ahash::RandomState::default(), + } + } +} + +impl LimitedHashTable for StringHashTable { + fn set_batch(&mut self, ids: ArrayRef) { + self.owned = ids; + } + + fn len(&self) -> usize { + self.map.len() + } + + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + self.map.update_heap_idx(mapper); + } + + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + self.map.heap_idx_at(map_idx) + } + + fn drain(&mut self) -> (ArrayRef, Vec<usize>) { + let (ids, heap_idxs) = self.map.drain(); + let ids = Arc::new(StringArray::from(ids)); + (ids, heap_idxs) + } + + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) -> (usize, bool) { + let ids = self + .owned + .as_any() + .downcast_ref::<StringArray>() + .expect("StringArray required"); + let id = if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + }; + + let hash = self.rnd.hash_one(id); + if let Some(map_idx) = self + .map + .find(hash, |mi| id == mi.as_ref().map(|id| id.as_str())) + { + return (map_idx, false); + } + + // we're full and this is a better value, so remove the worst + let heap_idx = self.map.remove_if_full(replace_idx); + + // add the new group + let id = id.map(|id| id.to_string()); + let map_idx = self.map.insert(hash, id, heap_idx, mapper); + (map_idx, true) + } +} + +pub struct PrimitiveHashTable<VAL: ArrowPrimitiveType> +where + Option<<VAL as ArrowPrimitiveType>::Native>: Comparable, +{ + owned: ArrayRef, + map: CustomMap<Option<VAL::Native>>, + rnd: RandomState, +} + +impl<VAL: ArrowPrimitiveType> PrimitiveHashTable<VAL> +where + Option<<VAL as ArrowPrimitiveType>::Native>: Comparable, + Option<<VAL as ArrowPrimitiveType>::Native>: HashValue, +{ + pub fn new(limit: usize) -> Self { + let owned = Arc::new(PrimitiveArray::<VAL>::builder(0).finish()); + Self { + owned, + map: CustomMap::new(limit, limit * 10), + rnd: ahash::RandomState::default(), + } + } +} + +impl<VAL: ArrowPrimitiveType> LimitedHashTable for PrimitiveHashTable<VAL> +where + Option<<VAL as ArrowPrimitiveType>::Native>: Comparable, + Option<<VAL as ArrowPrimitiveType>::Native>: HashValue, +{ + fn set_batch(&mut self, ids: ArrayRef) { + self.owned = ids; + } + + fn len(&self) -> usize { + self.map.len() + } + + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + self.map.update_heap_idx(mapper); + } + + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + self.map.heap_idx_at(map_idx) + } + + fn drain(&mut self) -> (ArrayRef, Vec<usize>) { + let (ids, heap_idxs) = self.map.drain(); + let mut builder: PrimitiveBuilder<VAL> = PrimitiveArray::builder(ids.len()); + for id in ids.into_iter() { + match id { + None => builder.append_null(), + Some(id) => builder.append_value(id), + } + } + let ids = Arc::new(builder.finish()); + (ids, heap_idxs) + } + + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) -> (usize, bool) { + let ids = self.owned.as_primitive::<VAL>(); + let id: Option<VAL::Native> = if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + }; + + let hash: u64 = id.hash(&self.rnd); + if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) { + return (map_idx, false); + } + + // we're full and this is a better value, so remove the worst + let heap_idx = self.map.remove_if_full(replace_idx); + + // add the new group + let map_idx = self.map.insert(hash, id, heap_idx, mapper); + (map_idx, true) + } +} + +pub trait ValueType: Comparable + Clone + Debug {} + +impl<T> ValueType for T where T: Comparable + Clone + Debug {} + +pub trait KeyType: Clone + Comparable + Debug {} + +impl<T> KeyType for T where T: Clone + Comparable + Debug {} + +impl GroupedTopKAggregateStream { + fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> { + let mut mapper = Vec::with_capacity(self.limit); + let len = ids.len(); + self.map.set_batch(ids); + self.heap.set_batch(vals.clone()); + + let has_nulls = vals.null_count() > 0; + for row_idx in 0..len { + if has_nulls && vals.is_null(row_idx) { + continue; + } + self.insert(row_idx, &mut mapper)?; + } + Ok(()) + } + + pub fn insert( + &mut self, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ) -> Result<()> { + assert!(self.map.len() <= self.limit, "Overflow"); + + // if we're full, and the new val is worse than all our values, just bail + if self.heap.is_worse(row_idx) { + return Ok(()); + } + + // handle new groups we haven't seen yet + map.clear(); + let replace_idx = self.heap.worst_map_idx(); + let (map_idx, did_insert) = + unsafe { self.map.find_or_insert(row_idx, replace_idx, map) }; + if did_insert { + self.heap.renumber(map); + map.clear(); + self.heap.insert(row_idx, map_idx, map); + unsafe { self.map.update_heap_idx(map) }; + return Ok(()); + }; + + // this is a value for an existing group + map.clear(); + let heap_idx = unsafe { self.map.heap_idx_at(map_idx) }; + self.heap.replace_if_best(heap_idx, row_idx, map); + unsafe { self.map.update_heap_idx(map) }; + + Ok(()) + } + + fn emit(&mut self) -> Result<Vec<ArrayRef>> { + let (ids, heap_idxs) = self.map.drain(); + let vals = self.heap.take_all(heap_idxs); + Ok(vec![ids, vals]) + } +} + +impl Stream for GroupedTopKAggregateStream { + type Item = Result<RecordBatch>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Self::Item>> { + while let Poll::Ready(res) = self.input.poll_next_unpin(cx) { + match res { + // got a batch, convert to rows and append to our TreeMap + Some(Ok(batch)) => { + self.started = true; + trace!( + "partition {} has {} rows and got batch with {} rows", + self.partition, + self.row_count, + batch.num_rows() + ); + if log::log_enabled!(Level::Trace) && batch.num_rows() < 20 { + print_batches(&[batch.clone()])?; + } + self.row_count += batch.num_rows(); + let batches = &[batch]; + let group_by_values = + evaluate_group_by(&self.group_by, batches.first().unwrap())?; + assert_eq!( + group_by_values.len(), + 1, + "Exactly 1 group value required" + ); + assert_eq!( + group_by_values[0].len(), + 1, + "Exactly 1 group value required" + ); + let group_by_values = group_by_values[0][0].clone(); + let input_values = evaluate_many( + &self.aggregate_arguments, + batches.first().unwrap(), + )?; + assert_eq!(input_values.len(), 1, "Exactly 1 input required"); + assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); + let input_values = input_values[0][0].clone(); + + // iterate over each column of group_by values + (*self).intern(group_by_values, input_values)?; + } + // inner is done, emit all rows and switch to producing output + None => { + if self.map.len() == 0 { + trace!("partition {} emit None", self.partition); + return Poll::Ready(None); + } + let cols = self.emit()?; + let batch = RecordBatch::try_new(self.schema.clone(), cols)?; + trace!( + "partition {} emit batch with {} rows", + self.partition, + batch.num_rows() + ); + if log::log_enabled!(Level::Trace) { + print_batches(&[batch.clone()])?; + } + return Poll::Ready(Some(Ok(batch))); + } + // inner had error, return to caller + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + } + } + Poll::Pending + } +} + +struct CustomHeap<VAL: ValueType> { + desc: bool, + len: usize, + limit: usize, + heap: Vec<Option<HeapItem<VAL>>>, +} + +impl<VAL: ValueType> CustomHeap<VAL> { + pub fn new(limit: usize, desc: bool) -> Self { + Self { + desc, + limit, + len: 0, + heap: (0..=limit).map(|_| None).collect::<Vec<_>>(), + } + } + + pub fn worst_val(&self) -> Option<&VAL> { + let root = self.heap.first()?; + let hi = match root { + None => return None, + Some(hi) => hi, + }; + Some(&hi.val) + } + + pub fn worst_map_idx(&self) -> usize { + self.heap[0].as_ref().map(|hi| hi.map_idx).unwrap_or(0) + } + + #[allow(dead_code)] + pub fn len(&self) -> usize { + self.len + } + + #[allow(dead_code)] + pub fn is_full(&self) -> bool { + self.len >= self.limit + } + + pub fn append( + &mut self, + new_val: VAL, + map_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) { + let hi = HeapItem::new(new_val, map_idx); + self.heap[self.len] = Some(hi); + self.heapify_up(self.len, mapper); + self.len += 1; + } + + pub fn take_all(&mut self, indexes: Vec<usize>) -> Vec<VAL> { + indexes + .iter() + .map(|i| { + let hi: HeapItem<VAL> = self.heap[*i].take().expect("No heap item"); + hi.val + }) + .collect() + } + + pub fn replace_root( + &mut self, + new_val: VAL, + map_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) { + let hi = self.heap[0].as_mut().expect("No root"); + hi.val = new_val; + hi.map_idx = map_idx; + self.heapify_down(0, mapper); + } + + pub fn replace_if_best( + &mut self, + heap_idx: usize, + new_val: VAL, + mapper: &mut Vec<(usize, usize)>, + ) { + let existing = self.heap[heap_idx].as_mut().expect("Missing heap item"); + if (!self.desc && new_val.comp(&existing.val) != Ordering::Less) + || (self.desc && new_val.comp(&existing.val) != Ordering::Greater) + { + return; + } + existing.val = new_val; + self.heapify_down(heap_idx, mapper); + } + + pub fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { + for (heap_idx, map_idx) in heap_to_map.iter() { + if let Some(Some(hi)) = self.heap.get_mut(*heap_idx) { + hi.map_idx = *map_idx; + } + } + } + + fn heapify_up(&mut self, mut idx: usize, mapper: &mut Vec<(usize, usize)>) { + let desc = self.desc; + while idx != 0 { + let parent_idx = (idx - 1) / 2; + let node = self.heap[idx].as_ref().expect("No heap item"); + let parent = self.heap[parent_idx].as_ref().expect("No heap item"); + if (!desc && node.val.comp(&parent.val) != Ordering::Greater) + || (desc && node.val.comp(&parent.val) != Ordering::Less) + { + return; + } + self.swap(idx, parent_idx, mapper); + idx = parent_idx; + } + } + + fn swap(&mut self, a_idx: usize, b_idx: usize, mapper: &mut Vec<(usize, usize)>) { + let a_hi = self.heap[a_idx].take().expect("Missing heap entry"); + let b_hi = self.heap[b_idx].take().expect("Missing heap entry"); + + mapper.push((a_hi.map_idx, b_idx)); + mapper.push((b_hi.map_idx, a_idx)); + + self.heap[a_idx] = Some(b_hi); + self.heap[b_idx] = Some(a_hi); + } + + fn heapify_down(&mut self, node_idx: usize, mapper: &mut Vec<(usize, usize)>) { + let left_child = node_idx * 2 + 1; + let desc = self.desc; + let entry = self.heap.get(node_idx).expect("Missing node!"); + let entry = entry.as_ref().expect("Missing node!"); + let mut best_idx = node_idx; + let mut best_val = &entry.val; + for child_idx in left_child..=left_child + 1 { + if let Some(Some(child)) = self.heap.get(child_idx) { + if (!desc && child.val.comp(best_val) == Ordering::Greater) + || (desc && child.val.comp(best_val) == Ordering::Less) + { + best_val = &child.val; + best_idx = child_idx; + } + } + } + if best_val.comp(&entry.val) != Ordering::Equal { + self.swap(best_idx, node_idx, mapper); + self.heapify_down(best_idx, mapper); + } + } + + fn _tree_print(&self, idx: usize, builder: &mut ptree::TreeBuilder) -> bool { + let hi = self.heap.get(idx); + let hi = match hi { + None => return true, + Some(hi) => hi, + }; + let mut valid = true; + if let Some(hi) = hi { + let label = format!("val={:?} idx={}, bucket={}", hi.val, idx, hi.map_idx); + builder.begin_child(label); + valid &= self._tree_print(idx * 2 + 1, builder); // left + valid &= self._tree_print(idx * 2 + 2, builder); // right + builder.end_child(); + if idx != 0 { + let parent_idx = (idx - 1) / 2; + let parent = self.heap[parent_idx].as_ref().expect("Missing parent"); + if (!self.desc && hi.val.comp(&parent.val) == Ordering::Greater) + || (self.desc && hi.val.comp(&parent.val) == Ordering::Less) + { + return false; + } + } + } else { + builder.add_empty_child("None".to_string()); + } + valid + } + + #[allow(dead_code)] + pub fn tree_print(&self) { + let mut builder = ptree::TreeBuilder::new("BinaryHeap".to_string()); + let valid = self._tree_print(0, &mut builder); + let mut actual = Vec::new(); + ptree::write_tree(&builder.build(), &mut actual).unwrap(); + println!("{}", String::from_utf8(actual).unwrap()); + if !valid { + panic!("Heap invariant violated"); + } + } +} + +struct CustomMap<ID: KeyType> { + map: RawTable<MapItem<ID>>, + limit: usize, +} + +impl<ID: KeyType> CustomMap<ID> { + pub fn new(limit: usize, capacity: usize) -> Self { + Self { + map: RawTable::with_capacity(capacity), + limit, + } + } + + pub fn find(&self, hash: u64, mut eq: impl FnMut(&ID) -> bool) -> Option<usize> { + let bucket = self.map.find(hash, |mi| eq(&mi.id))?; + let idx = unsafe { self.map.bucket_index(&bucket) }; + Some(idx) + } + + pub unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + let bucket = unsafe { self.map.bucket(map_idx) }; + unsafe { bucket.as_ref().heap_idx } + } + + pub unsafe fn remove_if_full(&mut self, replace_idx: usize) -> usize { + if self.map.len() >= self.limit { + unsafe { self.map.erase(self.map.bucket(replace_idx)) }; + 0 // if full, always replace top node + } else { + self.map.len() // if we're not full, always append to end + } + } + + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + for (m, h) in mapper { + unsafe { self.map.bucket(*m).as_mut().heap_idx = *h } + } + } + + pub fn insert( + &mut self, + hash: u64, + id: ID, + heap_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) -> usize { + let mi = MapItem::new(hash, id, heap_idx); + let bucket = self.map.try_insert_no_grow(hash, mi); + let bucket = match bucket { + Ok(bucket) => bucket, + Err(new_item) => { Review Comment: We `insert_no_grow` so we can catch events that trigger re-indexing of buckets, so we can rebuild are indexes. Since `capacity = limit *10`, we have a low fill factor so this should happen infrequently. Even at `limit * 10` this is still considerably less memory than before, since limit will typically be many orders of magnitude less than the groupby cardinality. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
