This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a3db19136f Don't store hashes in GroupOrdering (#7029)
a3db19136f is described below

commit a3db19136f0e49ac63f4f64ba57583b28fed96af
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Wed Jul 19 15:40:55 2023 -0400

    Don't store hashes in GroupOrdering (#7029)
    
    * Don't store hashes in GroupOrdering
    
    * Update group IDs
    
    * Review feedback
---
 .../src/physical_plan/aggregates/order/full.rs     | 29 ++++-----------------
 .../core/src/physical_plan/aggregates/order/mod.rs |  9 +++----
 .../src/physical_plan/aggregates/order/partial.rs  | 30 +++++-----------------
 .../core/src/physical_plan/aggregates/row_hash.rs  | 22 ++++++++--------
 4 files changed, 26 insertions(+), 64 deletions(-)

diff --git a/datafusion/core/src/physical_plan/aggregates/order/full.rs 
b/datafusion/core/src/physical_plan/aggregates/order/full.rs
index d95433a998..69b308da7c 100644
--- a/datafusion/core/src/physical_plan/aggregates/order/full.rs
+++ b/datafusion/core/src/physical_plan/aggregates/order/full.rs
@@ -15,8 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use datafusion_execution::memory_pool::proxy::VecAllocExt;
-
 use crate::physical_expr::EmitTo;
 
 /// Tracks grouping state when the data is ordered entirely by its
@@ -58,8 +56,6 @@ use crate::physical_expr::EmitTo;
 #[derive(Debug)]
 pub(crate) struct GroupOrderingFull {
     state: State,
-    /// Hash values for groups in 0..current
-    hashes: Vec<u64>,
 }
 
 #[derive(Debug)]
@@ -79,7 +75,6 @@ impl GroupOrderingFull {
     pub fn new() -> Self {
         Self {
             state: State::Start,
-            hashes: vec![],
         }
     }
 
@@ -101,19 +96,17 @@ impl GroupOrderingFull {
     }
 
     /// remove the first n groups from the internal state, shifting
-    /// all existing indexes down by `n`. Returns stored hash values
-    pub fn remove_groups(&mut self, n: usize) -> &[u64] {
+    /// all existing indexes down by `n`
+    pub fn remove_groups(&mut self, n: usize) {
         match &mut self.state {
             State::Start => panic!("invalid state: start"),
             State::InProgress { current } => {
                 // shift down by n
                 assert!(*current >= n);
                 *current -= n;
-                self.hashes.drain(0..n);
             }
             State::Complete { .. } => panic!("invalid state: complete"),
-        };
-        &self.hashes
+        }
     }
 
     /// Note that the input is complete so any outstanding groups are done as 
well
@@ -123,20 +116,8 @@ impl GroupOrderingFull {
 
     /// Called when new groups are added in a batch. See documentation
     /// on [`super::GroupOrdering::new_groups`]
-    pub fn new_groups(
-        &mut self,
-        group_indices: &[usize],
-        batch_hashes: &[u64],
-        total_num_groups: usize,
-    ) {
+    pub fn new_groups(&mut self, total_num_groups: usize) {
         assert_ne!(total_num_groups, 0);
-        assert_eq!(group_indices.len(), batch_hashes.len());
-
-        // copy any hash values
-        self.hashes.resize(total_num_groups, 0);
-        for (&group_index, &hash) in 
group_indices.iter().zip(batch_hashes.iter()) {
-            self.hashes[group_index] = hash;
-        }
 
         // Update state
         let max_group_index = total_num_groups - 1;
@@ -158,6 +139,6 @@ impl GroupOrderingFull {
     }
 
     pub(crate) fn size(&self) -> usize {
-        std::mem::size_of::<Self>() + self.hashes.allocated_size()
+        std::mem::size_of::<Self>()
     }
 }
diff --git a/datafusion/core/src/physical_plan/aggregates/order/mod.rs 
b/datafusion/core/src/physical_plan/aggregates/order/mod.rs
index 4e1da35319..81bf38aac3 100644
--- a/datafusion/core/src/physical_plan/aggregates/order/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/order/mod.rs
@@ -84,9 +84,9 @@ impl GroupOrdering {
 
     /// remove the first n groups from the internal state, shifting
     /// all existing indexes down by `n`. Returns stored hash values
-    pub fn remove_groups(&mut self, n: usize) -> &[u64] {
+    pub fn remove_groups(&mut self, n: usize) {
         match self {
-            GroupOrdering::None => &[],
+            GroupOrdering::None => {}
             GroupOrdering::Partial(partial) => partial.remove_groups(n),
             GroupOrdering::Full(full) => full.remove_groups(n),
         }
@@ -106,7 +106,6 @@ impl GroupOrdering {
         &mut self,
         batch_group_values: &[ArrayRef],
         group_indices: &[usize],
-        batch_hashes: &[u64],
         total_num_groups: usize,
     ) -> Result<()> {
         match self {
@@ -115,13 +114,11 @@ impl GroupOrdering {
                 partial.new_groups(
                     batch_group_values,
                     group_indices,
-                    batch_hashes,
                     total_num_groups,
                 )?;
             }
-
             GroupOrdering::Full(full) => {
-                full.new_groups(group_indices, batch_hashes, total_num_groups);
+                full.new_groups(total_num_groups);
             }
         };
         Ok(())
diff --git a/datafusion/core/src/physical_plan/aggregates/order/partial.rs 
b/datafusion/core/src/physical_plan/aggregates/order/partial.rs
index be8cd59671..ac32c69fd5 100644
--- a/datafusion/core/src/physical_plan/aggregates/order/partial.rs
+++ b/datafusion/core/src/physical_plan/aggregates/order/partial.rs
@@ -71,9 +71,6 @@ pub(crate) struct GroupOrderingPartial {
     /// Converter for the sort key (used on the group columns
     /// specified in `order_indexes`)
     row_converter: RowConverter,
-
-    /// Hash values for groups in 0..completed
-    hashes: Vec<u64>,
 }
 
 #[derive(Debug, Default)]
@@ -127,7 +124,6 @@ impl GroupOrderingPartial {
             state: State::Start,
             order_indices: order_indices.to_vec(),
             row_converter: RowConverter::new(fields)?,
-            hashes: vec![],
         })
     }
 
@@ -167,8 +163,8 @@ impl GroupOrderingPartial {
     }
 
     /// remove the first n groups from the internal state, shifting
-    /// all existing indexes down by `n`. Returns stored hash values
-    pub fn remove_groups(&mut self, n: usize) -> &[u64] {
+    /// all existing indexes down by `n`
+    pub fn remove_groups(&mut self, n: usize) {
         match &mut self.state {
             State::Taken => unreachable!("State previously taken"),
             State::Start => panic!("invalid state: start"),
@@ -182,12 +178,9 @@ impl GroupOrderingPartial {
                 *current -= n;
                 assert!(*current_sort >= n);
                 *current_sort -= n;
-                // Note sort_key stays the same, we are just translating group 
indexes
-                self.hashes.drain(0..n);
             }
             State::Complete { .. } => panic!("invalid state: complete"),
-        };
-        &self.hashes
+        }
     }
 
     /// Note that the input is complete so any outstanding groups are done as 
well
@@ -204,18 +197,15 @@ impl GroupOrderingPartial {
         &mut self,
         batch_group_values: &[ArrayRef],
         group_indices: &[usize],
-        batch_hashes: &[u64],
         total_num_groups: usize,
     ) -> Result<()> {
         assert!(total_num_groups > 0);
         assert!(!batch_group_values.is_empty());
-        assert_eq!(group_indices.len(), batch_hashes.len());
 
         let max_group_index = total_num_groups - 1;
 
         // compute the sort key values for each group
         let sort_keys = self.compute_sort_keys(batch_group_values)?;
-        assert_eq!(sort_keys.num_rows(), batch_hashes.len());
 
         let old_state = std::mem::take(&mut self.state);
         let (mut current_sort, mut sort_key) = match &old_state {
@@ -231,16 +221,9 @@ impl GroupOrderingPartial {
             }
         };
 
-        // copy any hash values, and find latest sort key
-        self.hashes.resize(total_num_groups, 0);
-        let iter = group_indices
-            .iter()
-            .zip(batch_hashes.iter())
-            .zip(sort_keys.iter());
-
-        for ((&group_index, &hash), group_sort_key) in iter {
-            self.hashes[group_index] = hash;
-
+        // Find latest sort key
+        let iter = group_indices.iter().zip(sort_keys.iter());
+        for (&group_index, group_sort_key) in iter {
             // Does this group have seen a new sort_key?
             if sort_key != group_sort_key {
                 current_sort = group_index;
@@ -262,6 +245,5 @@ impl GroupOrderingPartial {
         std::mem::size_of::<Self>()
             + self.order_indices.allocated_size()
             + self.row_converter.size()
-            + self.hashes.allocated_size()
     }
 }
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs 
b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index b48e8f38e9..59ffbe5cf1 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -485,7 +485,6 @@ impl GroupedHashAggregateStream {
             self.group_ordering.new_groups(
                 group_values,
                 group_indices,
-                batch_hashes,
                 total_num_groups,
             )?;
         }
@@ -624,15 +623,18 @@ impl GroupedHashAggregateStream {
                 }
                 std::mem::swap(&mut new_group_values, &mut self.group_values);
 
-                // rebuild hash table (maybe we should remove the
-                // entries for each group that was emitted rather than
-                // rebuilding the whole thing
-
-                let hashes = self.group_ordering.remove_groups(n);
-                assert_eq!(hashes.len(), self.group_values.num_rows());
-                self.map.clear();
-                for (idx, &hash) in hashes.iter().enumerate() {
-                    self.map.insert(hash, (hash, idx), |(hash, _)| *hash);
+                self.group_ordering.remove_groups(n);
+                // SAFETY: self.map outlives iterator and is not modified 
concurrently
+                unsafe {
+                    for bucket in self.map.iter() {
+                        // Decrement group index by n
+                        match bucket.as_ref().1.checked_sub(n) {
+                            // Group index was >= n, shift value down
+                            Some(sub) => bucket.as_mut().1 = sub,
+                            // Group index was < n, so remove from table
+                            None => self.map.erase(bucket),
+                        }
+                    }
                 }
             }
         };

Reply via email to