This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 345117baf0 Support vectorized append and compare for multi group by
(#12996)
345117baf0 is described below
commit 345117baf012318c08bbd0bd33fdd42fc15f735e
Author: kamille <[email protected]>
AuthorDate: Thu Nov 7 00:10:06 2024 +0800
Support vectorized append and compare for multi group by (#12996)
* simple support vectorized append.
* fix tests.
* some logs.
* add `append_n` in `MaybeNullBufferBuilder`.
* impl basic append_batch
* fix equal to.
* define `GroupIndexContext`.
* define the structs useful in vectorizing.
* re-define some structs for vectorized operations.
* impl some vectorized logics.
* impl chekcing hashmap stage.
* fix compile.
* tmp
* define and impl `vectorized_compare`.
* fix compile.
* impl `vectorized_equal_to`.
* impl `vectorized_append`.
* finish the basic vectorized ops logic.
* impl `take_n`.
* fix `renaming clear` and `groups fill`.
* fix death loop due to rehashing.
* fix vectorized append.
* add counter.
* use extend rather than resize.
* remove dbg!.
* remove reserve.
* refactor the codes to make simpler and more performant.
* clear `scalarized_indices` in `intern` to avoid some corner case.
* fix `scalarized_equal_to`.
* fallback to total scalarized `GroupValuesColumn` in streaming aggregation.
* add unit test for `VectorizedGroupValuesColumn`.
* add unit test for emitting first n in `VectorizedGroupValuesColumn`.
* sort out tests codes in for group columns and add vectorized tests for
primitives.
* add vectorized test for byte builder.
* add vectorized test for byte view builder.
* add test for the all nulls or not nulls branches in vectorized.
* fix clippy.
* fix fmt.
* fix compile in rust 1.79.
* improve comments.
* fix doc.
* add more comments to explain the really complex vectorized intern process.
* add comments to explain why we still need origin `GroupValuesColumn`.
* remove some stale comments.
* fix clippy.
* add comments for `vectorized_equal_to` and `vectorized_append`.
* fix clippy.
* use zip to simplify codes.
* use izip to simplify codes.
* Update
datafusion/physical-plan/src/aggregates/group_values/group_column.rs
Co-authored-by: Jay Zhan <[email protected]>
* first_n attempt
Signed-off-by: jayzhan211 <[email protected]>
* add test
Signed-off-by: jayzhan211 <[email protected]>
* improve hashtable modifying in emit first n test.
* add `emit_group_index_list_buffer` to avoid allocating new `Vec` to store
the remaining gourp indices.
* make comments in VectorizedGroupValuesColumn::intern simpler and clearer.
* define `VectorizedOperationBuffers` to hold buffers used in vectorized
operations to make code clearer.
* unify `VectorizedGroupValuesColumn` and `GroupValuesColumn`.
* fix fmt.
* fix comments.
* fix clippy.
---------
Signed-off-by: jayzhan211 <[email protected]>
Co-authored-by: Jay Zhan <[email protected]>
---
datafusion/common/src/utils/memory.rs | 2 +-
.../tests/user_defined/user_defined_aggregates.rs | 1 +
.../src/aggregates/group_values/column.rs | 1498 ++++++++++++++++++--
.../src/aggregates/group_values/group_column.rs | 985 +++++++++++--
.../src/aggregates/group_values/mod.rs | 15 +-
.../src/aggregates/group_values/null_builder.rs | 18 +
.../physical-plan/src/aggregates/row_hash.rs | 2 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 1 +
.../tests/cases/roundtrip_logical_plan.rs | 1 +
9 files changed, 2296 insertions(+), 227 deletions(-)
diff --git a/datafusion/common/src/utils/memory.rs
b/datafusion/common/src/utils/memory.rs
index d5ce59e342..bb68d59eed 100644
--- a/datafusion/common/src/utils/memory.rs
+++ b/datafusion/common/src/utils/memory.rs
@@ -102,7 +102,7 @@ pub fn estimate_memory_size<T>(num_elements: usize,
fixed_size: usize) -> Result
#[cfg(test)]
mod tests {
- use std::collections::HashSet;
+ use std::{collections::HashSet, mem::size_of};
use super::estimate_memory_size;
diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
index 497addd230..99c0061537 100644
--- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
@@ -19,6 +19,7 @@
//! user defined aggregate functions
use std::hash::{DefaultHasher, Hash, Hasher};
+use std::mem::{size_of, size_of_val};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
diff --git a/datafusion/physical-plan/src/aggregates/group_values/column.rs
b/datafusion/physical-plan/src/aggregates/group_values/column.rs
index 958a4b58d8..8100bb876d 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/column.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/column.rs
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+use std::mem::{self, size_of};
+
use crate::aggregates::group_values::group_column::{
ByteGroupValueBuilder, ByteViewGroupValueBuilder, GroupColumn,
PrimitiveGroupValueBuilder,
@@ -35,29 +37,100 @@ use datafusion_common::{not_impl_err, DataFusionError,
Result};
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_expr::EmitTo;
use datafusion_physical_expr::binary_map::OutputType;
+
use hashbrown::raw::RawTable;
-use std::mem::size_of;
-/// A [`GroupValues`] that stores multiple columns of group values.
+const NON_INLINED_FLAG: u64 = 0x8000000000000000;
+const VALUE_MASK: u64 = 0x7FFFFFFFFFFFFFFF;
+
+/// The view of indices pointing to the actual values in `GroupValues`
+///
+/// If only single `group index` represented by view,
+/// value of view is just the `group index`, and we call it a `inlined view`.
+///
+/// If multiple `group indices` represented by view,
+/// value of view is the actually the index pointing to `group indices`,
+/// and we call it `non-inlined view`.
+///
+/// The view(a u64) format is like:
+/// +---------------------+---------------------------------------------+
+/// | inlined flag(1bit) | group index / index to group indices(63bit) |
+/// +---------------------+---------------------------------------------+
///
+/// `inlined flag`: 1 represents `non-inlined`, and 0 represents `inlined`
///
-pub struct GroupValuesColumn {
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+struct GroupIndexView(u64);
+
+impl GroupIndexView {
+ #[inline]
+ pub fn is_non_inlined(&self) -> bool {
+ (self.0 & NON_INLINED_FLAG) > 0
+ }
+
+ #[inline]
+ pub fn new_inlined(group_index: u64) -> Self {
+ Self(group_index)
+ }
+
+ #[inline]
+ pub fn new_non_inlined(list_offset: u64) -> Self {
+ let non_inlined_value = list_offset | NON_INLINED_FLAG;
+ Self(non_inlined_value)
+ }
+
+ #[inline]
+ pub fn value(&self) -> u64 {
+ self.0 & VALUE_MASK
+ }
+}
+
+/// A [`GroupValues`] that stores multiple columns of group values,
+/// and supports vectorized operators for them
+///
+pub struct GroupValuesColumn<const STREAMING: bool> {
/// The output schema
schema: SchemaRef,
/// Logically maps group values to a group_index in
/// [`Self::group_values`] and in each accumulator
///
- /// Uses the raw API of hashbrown to avoid actually storing the
- /// keys (group values) in the table
+ /// It is a `hashtable` based on `hashbrown`.
+ ///
+ /// Key and value in the `hashtable`:
+ /// - The `key` is `hash value(u64)` of the `group value`
+ /// - The `value` is the `group values` with the same `hash value`
///
- /// keys: u64 hashes of the GroupValue
- /// values: (hash, group_index)
- map: RawTable<(u64, usize)>,
+ /// We don't really store the actual `group values` in `hashtable`,
+ /// instead we store the `group indices` pointing to values in
`GroupValues`.
+ /// And we use [`GroupIndexView`] to represent such `group indices` in
table.
+ ///
+ ///
+ map: RawTable<(u64, GroupIndexView)>,
/// The size of `map` in bytes
map_size: usize,
+ /// The lists for group indices with the same hash value
+ ///
+ /// It is possible that hash value collision exists,
+ /// and we will chain the `group indices` with same hash value
+ ///
+ /// The chained indices is like:
+ /// `latest group index -> older group index -> even older group index
-> ...`
+ ///
+ group_index_lists: Vec<Vec<usize>>,
+
+ /// When emitting first n, we need to decrease/erase group indices in
+ /// `map` and `group_index_lists`.
+ ///
+ /// This buffer is used to temporarily store the remaining group indices in
+ /// a specific list in `group_index_lists`.
+ emit_group_index_list_buffer: Vec<usize>,
+
+ /// Buffers for `vectorized_append` and `vectorized_equal_to`
+ vectorized_operation_buffers: VectorizedOperationBuffers,
+
/// The actual group by values, stored column-wise. Compare from
/// the left to right, each column is stored as [`GroupColumn`].
///
@@ -75,13 +148,52 @@ pub struct GroupValuesColumn {
random_state: RandomState,
}
-impl GroupValuesColumn {
+/// Buffers to store intermediate results in `vectorized_append`
+/// and `vectorized_equal_to`, for reducing memory allocation
+#[derive(Default)]
+struct VectorizedOperationBuffers {
+ /// The `vectorized append` row indices buffer
+ append_row_indices: Vec<usize>,
+
+ /// The `vectorized_equal_to` row indices buffer
+ equal_to_row_indices: Vec<usize>,
+
+ /// The `vectorized_equal_to` group indices buffer
+ equal_to_group_indices: Vec<usize>,
+
+ /// The `vectorized_equal_to` result buffer
+ equal_to_results: Vec<bool>,
+
+ /// The buffer for storing row indices found not equal to
+ /// exist groups in `group_values` in `vectorized_equal_to`.
+ /// We will perform `scalarized_intern` for such rows.
+ remaining_row_indices: Vec<usize>,
+}
+
+impl VectorizedOperationBuffers {
+ fn clear(&mut self) {
+ self.append_row_indices.clear();
+ self.equal_to_row_indices.clear();
+ self.equal_to_group_indices.clear();
+ self.equal_to_results.clear();
+ self.remaining_row_indices.clear();
+ }
+}
+
+impl<const STREAMING: bool> GroupValuesColumn<STREAMING> {
+ // ========================================================================
+ // Initialization functions
+ // ========================================================================
+
/// Create a new instance of GroupValuesColumn if supported for the
specified schema
pub fn try_new(schema: SchemaRef) -> Result<Self> {
let map = RawTable::with_capacity(0);
Ok(Self {
schema,
map,
+ group_index_lists: Vec::new(),
+ emit_group_index_list_buffer: Vec::new(),
+ vectorized_operation_buffers:
VectorizedOperationBuffers::default(),
map_size: 0,
group_values: vec![],
hashes_buffer: Default::default(),
@@ -89,41 +201,600 @@ impl GroupValuesColumn {
})
}
- /// Returns true if [`GroupValuesColumn`] supported for the specified
schema
- pub fn supported_schema(schema: &Schema) -> bool {
- schema
- .fields()
+ // ========================================================================
+ // Scalarized intern
+ // ========================================================================
+
+ /// Scalarized intern
+ ///
+ /// This is used only for `streaming aggregation`, because `streaming
aggregation`
+ /// depends on the order between `input rows` and their corresponding
`group indices`.
+ ///
+ /// For example, assuming `input rows` in `cols` with 4 new rows
+ /// (not equal to `exist rows` in `group_values`, and need to create
+ /// new groups for them):
+ ///
+ /// ```text
+ /// row1 (hash collision with the exist rows)
+ /// row2
+ /// row3 (hash collision with the exist rows)
+ /// row4
+ /// ```
+ ///
+ /// # In `scalarized_intern`, their `group indices` will be
+ ///
+ /// ```text
+ /// row1 --> 0
+ /// row2 --> 1
+ /// row3 --> 2
+ /// row4 --> 3
+ /// ```
+ ///
+ /// `Group indices` order agrees with their input order, and the
`streaming aggregation`
+ /// depends on this.
+ ///
+ /// # However In `vectorized_intern`, their `group indices` will be
+ ///
+ /// ```text
+ /// row1 --> 2
+ /// row2 --> 0
+ /// row3 --> 3
+ /// row4 --> 1
+ /// ```
+ ///
+ /// `Group indices` order are against with their input order, and this
will lead to error
+ /// in `streaming aggregation`.
+ ///
+ fn scalarized_intern(
+ &mut self,
+ cols: &[ArrayRef],
+ groups: &mut Vec<usize>,
+ ) -> Result<()> {
+ let n_rows = cols[0].len();
+
+ // tracks to which group each of the input rows belongs
+ groups.clear();
+
+ // 1.1 Calculate the group keys for the group values
+ let batch_hashes = &mut self.hashes_buffer;
+ batch_hashes.clear();
+ batch_hashes.resize(n_rows, 0);
+ create_hashes(cols, &self.random_state, batch_hashes)?;
+
+ for (row, &target_hash) in batch_hashes.iter().enumerate() {
+ let entry = self
+ .map
+ .get_mut(target_hash, |(exist_hash, group_idx_view)| {
+ // It is ensured to be inlined in `scalarized_intern`
+ debug_assert!(!group_idx_view.is_non_inlined());
+
+ // Somewhat surprisingly, this closure can be called even
if the
+ // hash doesn't match, so check the hash first with an
integer
+ // comparison first avoid the more expensive comparison
with
+ // group value.
https://github.com/apache/datafusion/pull/11718
+ if target_hash != *exist_hash {
+ return false;
+ }
+
+ fn check_row_equal(
+ array_row: &dyn GroupColumn,
+ lhs_row: usize,
+ array: &ArrayRef,
+ rhs_row: usize,
+ ) -> bool {
+ array_row.equal_to(lhs_row, array, rhs_row)
+ }
+
+ for (i, group_val) in self.group_values.iter().enumerate()
{
+ if !check_row_equal(
+ group_val.as_ref(),
+ group_idx_view.value() as usize,
+ &cols[i],
+ row,
+ ) {
+ return false;
+ }
+ }
+
+ true
+ });
+
+ let group_idx = match entry {
+ // Existing group_index for this group value
+ Some((_hash, group_idx_view)) => group_idx_view.value() as
usize,
+ // 1.2 Need to create new entry for the group
+ None => {
+ // Add new entry to aggr_state and save newly created index
+ // let group_idx = group_values.num_rows();
+ // group_values.push(group_rows.row(row));
+
+ let mut checklen = 0;
+ let group_idx = self.group_values[0].len();
+ for (i, group_value) in
self.group_values.iter_mut().enumerate() {
+ group_value.append_val(&cols[i], row);
+ let len = group_value.len();
+ if i == 0 {
+ checklen = len;
+ } else {
+ debug_assert_eq!(checklen, len);
+ }
+ }
+
+ // for hasher function, use precomputed hash value
+ self.map.insert_accounted(
+ (target_hash, GroupIndexView::new_inlined(group_idx as
u64)),
+ |(hash, _group_index)| *hash,
+ &mut self.map_size,
+ );
+ group_idx
+ }
+ };
+ groups.push(group_idx);
+ }
+
+ Ok(())
+ }
+
+ // ========================================================================
+ // Vectorized intern
+ // ========================================================================
+
+ /// Vectorized intern
+ ///
+ /// This is used in `non-streaming aggregation` without requiring the
order between
+ /// rows in `cols` and corresponding groups in `group_values`.
+ ///
+ /// The vectorized approach can offer higher performance for avoiding row
by row
+ /// downcast for `cols` and being able to implement even more
optimizations(like simd).
+ ///
+ fn vectorized_intern(
+ &mut self,
+ cols: &[ArrayRef],
+ groups: &mut Vec<usize>,
+ ) -> Result<()> {
+ let n_rows = cols[0].len();
+
+ // tracks to which group each of the input rows belongs
+ groups.clear();
+ groups.resize(n_rows, usize::MAX);
+
+ let mut batch_hashes = mem::take(&mut self.hashes_buffer);
+ batch_hashes.clear();
+ batch_hashes.resize(n_rows, 0);
+ create_hashes(cols, &self.random_state, &mut batch_hashes)?;
+
+ // General steps for one round `vectorized equal_to & append`:
+ // 1. Collect vectorized context by checking hash values of `cols`
in `map`,
+ // mainly fill `vectorized_append_row_indices`,
`vectorized_equal_to_row_indices`
+ // and `vectorized_equal_to_group_indices`
+ //
+ // 2. Perform `vectorized_append` for
`vectorized_append_row_indices`.
+ // `vectorized_append` must be performed before
`vectorized_equal_to`,
+ // because some `group indices` in
`vectorized_equal_to_group_indices`
+ // maybe still point to no actual values in `group_values` before
performing append.
+ //
+ // 3. Perform `vectorized_equal_to` for
`vectorized_equal_to_row_indices`
+ // and `vectorized_equal_to_group_indices`. If found some rows in
input `cols`
+ // not equal to `exist rows` in `group_values`, place them in
`remaining_row_indices`
+ // and perform `scalarized_intern_remaining` for them similar as
`scalarized_intern`
+ // after.
+ //
+ // 4. Perform `scalarized_intern_remaining` for rows mentioned
above, about in what situation
+ // we will process this can see the comments of
`scalarized_intern_remaining`.
+ //
+
+ // 1. Collect vectorized context by checking hash values of `cols` in
`map`
+ self.collect_vectorized_process_context(&batch_hashes, groups);
+
+ // 2. Perform `vectorized_append`
+ self.vectorized_append(cols);
+
+ // 3. Perform `vectorized_equal_to`
+ self.vectorized_equal_to(cols, groups);
+
+ // 4. Perform scalarized inter for remaining rows
+ // (about remaining rows, can see comments for `remaining_row_indices`)
+ self.scalarized_intern_remaining(cols, &batch_hashes, groups);
+
+ self.hashes_buffer = batch_hashes;
+
+ Ok(())
+ }
+
+ /// Collect vectorized context by checking hash values of `cols` in `map`
+ ///
+ /// 1. If bucket not found
+ /// - Build and insert the `new inlined group index view`
+ /// and its hash value to `map`
+ /// - Add row index to `vectorized_append_row_indices`
+ /// - Set group index to row in `groups`
+ ///
+ /// 2. bucket found
+ /// - Add row index to `vectorized_equal_to_row_indices`
+ /// - Check if the `group index view` is `inlined` or `non_inlined`:
+ /// If it is inlined, add to `vectorized_equal_to_group_indices`
directly.
+ /// Otherwise get all group indices from `group_index_lists`, and add
them.
+ ///
+ fn collect_vectorized_process_context(
+ &mut self,
+ batch_hashes: &[u64],
+ groups: &mut [usize],
+ ) {
+ self.vectorized_operation_buffers.append_row_indices.clear();
+ self.vectorized_operation_buffers
+ .equal_to_row_indices
+ .clear();
+ self.vectorized_operation_buffers
+ .equal_to_group_indices
+ .clear();
+
+ let mut group_values_len = self.group_values[0].len();
+ for (row, &target_hash) in batch_hashes.iter().enumerate() {
+ let entry = self
+ .map
+ .get(target_hash, |(exist_hash, _)| target_hash ==
*exist_hash);
+
+ let Some((_, group_index_view)) = entry else {
+ // 1. Bucket not found case
+ // Build `new inlined group index view`
+ let current_group_idx = group_values_len;
+ let group_index_view =
+ GroupIndexView::new_inlined(current_group_idx as u64);
+
+ // Insert the `group index view` and its hash into `map`
+ // for hasher function, use precomputed hash value
+ self.map.insert_accounted(
+ (target_hash, group_index_view),
+ |(hash, _)| *hash,
+ &mut self.map_size,
+ );
+
+ // Add row index to `vectorized_append_row_indices`
+ self.vectorized_operation_buffers
+ .append_row_indices
+ .push(row);
+
+ // Set group index to row in `groups`
+ groups[row] = current_group_idx;
+
+ group_values_len += 1;
+ continue;
+ };
+
+ // 2. bucket found
+ // Check if the `group index view` is `inlined` or `non_inlined`
+ if group_index_view.is_non_inlined() {
+ // Non-inlined case, the value of view is offset in
`group_index_lists`.
+ // We use it to get `group_index_list`, and add related `rows`
and `group_indices`
+ // into `vectorized_equal_to_row_indices` and
`vectorized_equal_to_group_indices`.
+ let list_offset = group_index_view.value() as usize;
+ let group_index_list = &self.group_index_lists[list_offset];
+ for &group_index in group_index_list {
+ self.vectorized_operation_buffers
+ .equal_to_row_indices
+ .push(row);
+ self.vectorized_operation_buffers
+ .equal_to_group_indices
+ .push(group_index);
+ }
+ } else {
+ let group_index = group_index_view.value() as usize;
+ self.vectorized_operation_buffers
+ .equal_to_row_indices
+ .push(row);
+ self.vectorized_operation_buffers
+ .equal_to_group_indices
+ .push(group_index);
+ }
+ }
+ }
+
+ /// Perform `vectorized_append`` for `rows` in
`vectorized_append_row_indices`
+ fn vectorized_append(&mut self, cols: &[ArrayRef]) {
+ if self
+ .vectorized_operation_buffers
+ .append_row_indices
+ .is_empty()
+ {
+ return;
+ }
+
+ let iter = self.group_values.iter_mut().zip(cols.iter());
+ for (group_column, col) in iter {
+ group_column.vectorized_append(
+ col,
+ &self.vectorized_operation_buffers.append_row_indices,
+ );
+ }
+ }
+
+ /// Perform `vectorized_equal_to`
+ ///
+ /// 1. Perform `vectorized_equal_to` for `rows` in
`vectorized_equal_to_group_indices`
+ /// and `group_indices` in `vectorized_equal_to_group_indices`.
+ ///
+ /// 2. Check `equal_to_results`:
+ ///
+ /// If found equal to `rows`, set the `group_indices` to `rows` in
`groups`.
+ ///
+ /// If found not equal to `row`s, just add them to `scalarized_indices`,
+ /// and perform `scalarized_intern` for them after.
+ /// Usually, such `rows` having same hash but different value with
`exists rows`
+ /// are very few.
+ fn vectorized_equal_to(&mut self, cols: &[ArrayRef], groups: &mut [usize])
{
+ assert_eq!(
+ self.vectorized_operation_buffers
+ .equal_to_group_indices
+ .len(),
+ self.vectorized_operation_buffers.equal_to_row_indices.len()
+ );
+
+ self.vectorized_operation_buffers
+ .remaining_row_indices
+ .clear();
+
+ if self
+ .vectorized_operation_buffers
+ .equal_to_group_indices
+ .is_empty()
+ {
+ return;
+ }
+
+ // 1. Perform `vectorized_equal_to` for `rows` in
`vectorized_equal_to_group_indices`
+ // and `group_indices` in `vectorized_equal_to_group_indices`
+ let mut equal_to_results =
+ mem::take(&mut self.vectorized_operation_buffers.equal_to_results);
+ equal_to_results.clear();
+ equal_to_results.resize(
+ self.vectorized_operation_buffers
+ .equal_to_group_indices
+ .len(),
+ true,
+ );
+
+ for (col_idx, group_col) in self.group_values.iter().enumerate() {
+ group_col.vectorized_equal_to(
+ &self.vectorized_operation_buffers.equal_to_group_indices,
+ &cols[col_idx],
+ &self.vectorized_operation_buffers.equal_to_row_indices,
+ &mut equal_to_results,
+ );
+ }
+
+ // 2. Check `equal_to_results`, if found not equal to `row`s, just add
them
+ // to `scalarized_indices`, and perform `scalarized_intern` for
them after.
+ let mut current_row_equal_to_result = false;
+ for (idx, &row) in self
+ .vectorized_operation_buffers
+ .equal_to_row_indices
.iter()
- .map(|f| f.data_type())
- .all(Self::supported_type)
+ .enumerate()
+ {
+ let equal_to_result = equal_to_results[idx];
+
+ // Equal to case, set the `group_indices` to `rows` in `groups`
+ if equal_to_result {
+ groups[row] =
+
self.vectorized_operation_buffers.equal_to_group_indices[idx];
+ }
+ current_row_equal_to_result |= equal_to_result;
+
+ // Look forward next one row to check if have checked all results
+ // of current row
+ let next_row = self
+ .vectorized_operation_buffers
+ .equal_to_row_indices
+ .get(idx + 1)
+ .unwrap_or(&usize::MAX);
+
+ // Have checked all results of current row, check the total result
+ if row != *next_row {
+ // Not equal to case, add `row` to `scalarized_indices`
+ if !current_row_equal_to_result {
+ self.vectorized_operation_buffers
+ .remaining_row_indices
+ .push(row);
+ }
+
+ // Init the total result for checking next row
+ current_row_equal_to_result = false;
+ }
+ }
+
+ self.vectorized_operation_buffers.equal_to_results = equal_to_results;
}
- /// Returns true if the specified data type is supported by
[`GroupValuesColumn`]
- ///
- /// In order to be supported, there must be a specialized implementation of
- /// [`GroupColumn`] for the data type, instantiated in [`Self::intern`]
- fn supported_type(data_type: &DataType) -> bool {
- matches!(
- *data_type,
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64
- | DataType::Utf8
- | DataType::LargeUtf8
- | DataType::Binary
- | DataType::LargeBinary
- | DataType::Date32
- | DataType::Date64
- | DataType::Utf8View
- | DataType::BinaryView
- )
+ /// It is possible that some `input rows` have the same
+ /// hash values with the `exist rows`, but have the different
+ /// actual values the exists.
+ ///
+ /// We can found them in `vectorized_equal_to`, and put them
+ /// into `scalarized_indices`. And for these `input rows`,
+ /// we will perform the `scalarized_intern` similar as what in
+ /// [`GroupValuesColumn`].
+ ///
+ /// This design can make the process simple and still efficient enough:
+ ///
+ /// # About making the process simple
+ ///
+ /// Some corner cases become really easy to solve, like following cases:
+ ///
+ /// ```text
+ /// input row1 (same hash value with exist rows, but value different)
+ /// input row1
+ /// ...
+ /// input row1
+ /// ```
+ ///
+ /// After performing `vectorized_equal_to`, we will found multiple `input
rows`
+ /// not equal to the `exist rows`. However such `input rows` are repeated,
only
+ /// one new group should be create for them.
+ ///
+ /// If we don't fallback to `scalarized_intern`, it is really hard for us
to
+ /// distinguish the such `repeated rows` in `input rows`. And if we just
fallback,
+ /// it is really easy to solve, and the performance is at least not worse
than origin.
+ ///
+ /// # About performance
+ ///
+ /// The hash collision may be not frequent, so the fallback will indeed
hardly happen.
+ /// In most situations, `scalarized_indices` will found to be empty after
finishing to
+ /// preform `vectorized_equal_to`.
+ ///
+ fn scalarized_intern_remaining(
+ &mut self,
+ cols: &[ArrayRef],
+ batch_hashes: &[u64],
+ groups: &mut [usize],
+ ) {
+ if self
+ .vectorized_operation_buffers
+ .remaining_row_indices
+ .is_empty()
+ {
+ return;
+ }
+
+ let mut map = mem::take(&mut self.map);
+
+ for &row in &self.vectorized_operation_buffers.remaining_row_indices {
+ let target_hash = batch_hashes[row];
+ let entry = map.get_mut(target_hash, |(exist_hash, _)| {
+ // Somewhat surprisingly, this closure can be called even if
the
+ // hash doesn't match, so check the hash first with an integer
+ // comparison first avoid the more expensive comparison with
+ // group value. https://github.com/apache/datafusion/pull/11718
+ target_hash == *exist_hash
+ });
+
+ // Only `rows` having the same hash value with `exist rows` but
different value
+ // will be process in `scalarized_intern`.
+ // So related `buckets` in `map` is ensured to be `Some`.
+ let Some((_, group_index_view)) = entry else {
+ unreachable!()
+ };
+
+ // Perform scalarized equal to
+ if self.scalarized_equal_to_remaining(group_index_view, cols, row,
groups) {
+ // Found the row actually exists in group values,
+ // don't need to create new group for it.
+ continue;
+ }
+
+ // Insert the `row` to `group_values` before checking `next row`
+ let group_idx = self.group_values[0].len();
+ let mut checklen = 0;
+ for (i, group_value) in self.group_values.iter_mut().enumerate() {
+ group_value.append_val(&cols[i], row);
+ let len = group_value.len();
+ if i == 0 {
+ checklen = len;
+ } else {
+ debug_assert_eq!(checklen, len);
+ }
+ }
+
+ // Check if the `view` is `inlined` or `non-inlined`
+ if group_index_view.is_non_inlined() {
+ // Non-inlined case, get `group_index_list` from
`group_index_lists`,
+ // then add the new `group` with the same hash values into it.
+ let list_offset = group_index_view.value() as usize;
+ let group_index_list = &mut
self.group_index_lists[list_offset];
+ group_index_list.push(group_idx);
+ } else {
+ // Inlined case
+ let list_offset = self.group_index_lists.len();
+
+ // Create new `group_index_list` including
+ // `exist group index` + `new group index`.
+ // Add new `group_index_list` into ``group_index_lists`.
+ let exist_group_index = group_index_view.value() as usize;
+ let new_group_index_list = vec![exist_group_index, group_idx];
+ self.group_index_lists.push(new_group_index_list);
+
+ // Update the `group_index_view` to non-inlined
+ let new_group_index_view =
+ GroupIndexView::new_non_inlined(list_offset as u64);
+ *group_index_view = new_group_index_view;
+ }
+
+ groups[row] = group_idx;
+ }
+
+ self.map = map;
+ }
+
+ fn scalarized_equal_to_remaining(
+ &self,
+ group_index_view: &GroupIndexView,
+ cols: &[ArrayRef],
+ row: usize,
+ groups: &mut [usize],
+ ) -> bool {
+ // Check if this row exists in `group_values`
+ fn check_row_equal(
+ array_row: &dyn GroupColumn,
+ lhs_row: usize,
+ array: &ArrayRef,
+ rhs_row: usize,
+ ) -> bool {
+ array_row.equal_to(lhs_row, array, rhs_row)
+ }
+
+ if group_index_view.is_non_inlined() {
+ let list_offset = group_index_view.value() as usize;
+ let group_index_list = &self.group_index_lists[list_offset];
+
+ for &group_idx in group_index_list {
+ let mut check_result = true;
+ for (i, group_val) in self.group_values.iter().enumerate() {
+ if !check_row_equal(group_val.as_ref(), group_idx,
&cols[i], row) {
+ check_result = false;
+ break;
+ }
+ }
+
+ if check_result {
+ groups[row] = group_idx;
+ return true;
+ }
+ }
+
+ // All groups unmatched, return false result
+ false
+ } else {
+ let group_idx = group_index_view.value() as usize;
+ for (i, group_val) in self.group_values.iter().enumerate() {
+ if !check_row_equal(group_val.as_ref(), group_idx, &cols[i],
row) {
+ return false;
+ }
+ }
+
+ groups[row] = group_idx;
+ true
+ }
+ }
+
+ /// Return group indices of the hash, also if its `group_index_view` is
non-inlined
+ #[cfg(test)]
+ fn get_indices_by_hash(&self, hash: u64) -> Option<(Vec<usize>,
GroupIndexView)> {
+ let entry = self.map.get(hash, |(exist_hash, _)| hash == *exist_hash);
+
+ match entry {
+ Some((_, group_index_view)) => {
+ if group_index_view.is_non_inlined() {
+ let list_offset = group_index_view.value() as usize;
+ Some((
+ self.group_index_lists[list_offset].clone(),
+ *group_index_view,
+ ))
+ } else {
+ let group_index = group_index_view.value() as usize;
+ Some((vec![group_index], *group_index_view))
+ }
+ }
+ None => None,
+ }
}
}
@@ -146,10 +817,8 @@ macro_rules! instantiate_primitive {
};
}
-impl GroupValues for GroupValuesColumn {
+impl<const STREAMING: bool> GroupValues for GroupValuesColumn<STREAMING> {
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) ->
Result<()> {
- let n_rows = cols[0].len();
-
if self.group_values.is_empty() {
let mut v = Vec::with_capacity(cols.len());
@@ -204,77 +873,11 @@ impl GroupValues for GroupValuesColumn {
self.group_values = v;
}
- // tracks to which group each of the input rows belongs
- groups.clear();
-
- // 1.1 Calculate the group keys for the group values
- let batch_hashes = &mut self.hashes_buffer;
- batch_hashes.clear();
- batch_hashes.resize(n_rows, 0);
- create_hashes(cols, &self.random_state, batch_hashes)?;
-
- for (row, &target_hash) in batch_hashes.iter().enumerate() {
- let entry = self.map.get_mut(target_hash, |(exist_hash,
group_idx)| {
- // Somewhat surprisingly, this closure can be called even if
the
- // hash doesn't match, so check the hash first with an integer
- // comparison first avoid the more expensive comparison with
- // group value. https://github.com/apache/datafusion/pull/11718
- if target_hash != *exist_hash {
- return false;
- }
-
- fn check_row_equal(
- array_row: &dyn GroupColumn,
- lhs_row: usize,
- array: &ArrayRef,
- rhs_row: usize,
- ) -> bool {
- array_row.equal_to(lhs_row, array, rhs_row)
- }
-
- for (i, group_val) in self.group_values.iter().enumerate() {
- if !check_row_equal(group_val.as_ref(), *group_idx,
&cols[i], row) {
- return false;
- }
- }
-
- true
- });
-
- let group_idx = match entry {
- // Existing group_index for this group value
- Some((_hash, group_idx)) => *group_idx,
- // 1.2 Need to create new entry for the group
- None => {
- // Add new entry to aggr_state and save newly created index
- // let group_idx = group_values.num_rows();
- // group_values.push(group_rows.row(row));
-
- let mut checklen = 0;
- let group_idx = self.group_values[0].len();
- for (i, group_value) in
self.group_values.iter_mut().enumerate() {
- group_value.append_val(&cols[i], row);
- let len = group_value.len();
- if i == 0 {
- checklen = len;
- } else {
- debug_assert_eq!(checklen, len);
- }
- }
-
- // for hasher function, use precomputed hash value
- self.map.insert_accounted(
- (target_hash, group_idx),
- |(hash, _group_index)| *hash,
- &mut self.map_size,
- );
- group_idx
- }
- };
- groups.push(group_idx);
+ if !STREAMING {
+ self.vectorized_intern(cols, groups)
+ } else {
+ self.scalarized_intern(cols, groups)
}
-
- Ok(())
}
fn size(&self) -> usize {
@@ -297,7 +900,7 @@ impl GroupValues for GroupValuesColumn {
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let mut output = match emit_to {
EmitTo::All => {
- let group_values = std::mem::take(&mut self.group_values);
+ let group_values = mem::take(&mut self.group_values);
debug_assert!(self.group_values.is_empty());
group_values
@@ -311,20 +914,74 @@ impl GroupValues for GroupValuesColumn {
.iter_mut()
.map(|v| v.take_n(n))
.collect::<Vec<_>>();
+ let mut next_new_list_offset = 0;
// SAFETY: self.map outlives iterator and is not modified
concurrently
unsafe {
for bucket in self.map.iter() {
- // Decrement group index by n
- match bucket.as_ref().1.checked_sub(n) {
+ // In non-streaming case, we need to check if the
`group index view`
+ // is `inlined` or `non-inlined`
+ if !STREAMING && bucket.as_ref().1.is_non_inlined() {
+ // Non-inlined case
+ // We take `group_index_list` from
`old_group_index_lists`
+
+ // list_offset is incrementally
+ self.emit_group_index_list_buffer.clear();
+ let list_offset = bucket.as_ref().1.value() as
usize;
+ for group_index in
self.group_index_lists[list_offset].iter()
+ {
+ if let Some(remaining) =
group_index.checked_sub(n) {
+
self.emit_group_index_list_buffer.push(remaining);
+ }
+ }
+
+ // The possible results:
+ // - `new_group_index_list` is empty, we should
erase this bucket
+ // - only one value in `new_group_index_list`,
switch the `view` to `inlined`
+ // - still multiple values in
`new_group_index_list`, build and set the new `unlined view`
+ if self.emit_group_index_list_buffer.is_empty() {
+ self.map.erase(bucket);
+ } else if self.emit_group_index_list_buffer.len()
== 1 {
+ let group_index =
+
self.emit_group_index_list_buffer.first().unwrap();
+ bucket.as_mut().1 =
+ GroupIndexView::new_inlined(*group_index
as u64);
+ } else {
+ let group_index_list =
+ &mut
self.group_index_lists[next_new_list_offset];
+ group_index_list.clear();
+ group_index_list
+
.extend(self.emit_group_index_list_buffer.iter());
+ bucket.as_mut().1 =
GroupIndexView::new_non_inlined(
+ next_new_list_offset as u64,
+ );
+ next_new_list_offset += 1;
+ }
+
+ continue;
+ }
+
+ // In `streaming case`, the `group index view` is
ensured to be `inlined`
+ debug_assert!(!bucket.as_ref().1.is_non_inlined());
+
+ // Inlined case, we just decrement group index by n)
+ let group_index = bucket.as_ref().1.value() as usize;
+ match group_index.checked_sub(n) {
// Group index was >= n, shift value down
- Some(sub) => bucket.as_mut().1 = sub,
+ Some(sub) => {
+ bucket.as_mut().1 =
+ GroupIndexView::new_inlined(sub as u64)
+ }
// Group index was < n, so remove from table
None => self.map.erase(bucket),
}
}
}
+ if !STREAMING {
+ self.group_index_lists.truncate(next_new_list_offset);
+ }
+
output
}
};
@@ -354,5 +1011,610 @@ impl GroupValues for GroupValuesColumn {
self.map_size = self.map.capacity() * size_of::<(u64, usize)>();
self.hashes_buffer.clear();
self.hashes_buffer.shrink_to(count);
+
+ // Such structures are only used in `non-streaming` case
+ if !STREAMING {
+ self.group_index_lists.clear();
+ self.emit_group_index_list_buffer.clear();
+ self.vectorized_operation_buffers.clear();
+ }
+ }
+}
+
+/// Returns true if [`GroupValuesColumn`] supported for the specified schema
+pub fn supported_schema(schema: &Schema) -> bool {
+ schema
+ .fields()
+ .iter()
+ .map(|f| f.data_type())
+ .all(supported_type)
+}
+
+/// Returns true if the specified data type is supported by
[`GroupValuesColumn`]
+///
+/// In order to be supported, there must be a specialized implementation of
+/// [`GroupColumn`] for the data type, instantiated in
[`GroupValuesColumn::intern`]
+fn supported_type(data_type: &DataType) -> bool {
+ matches!(
+ *data_type,
+ DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ | DataType::Float32
+ | DataType::Float64
+ | DataType::Utf8
+ | DataType::LargeUtf8
+ | DataType::Binary
+ | DataType::LargeBinary
+ | DataType::Date32
+ | DataType::Date64
+ | DataType::Utf8View
+ | DataType::BinaryView
+ )
+}
+
+#[cfg(test)]
+mod tests {
+ use std::{collections::HashMap, sync::Arc};
+
+ use arrow::{compute::concat_batches, util::pretty::pretty_format_batches};
+ use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray,
StringViewArray};
+ use arrow_schema::{DataType, Field, Schema, SchemaRef};
+ use datafusion_common::utils::proxy::RawTableAllocExt;
+ use datafusion_expr::EmitTo;
+
+ use crate::aggregates::group_values::{column::GroupValuesColumn,
GroupValues};
+
+ use super::GroupIndexView;
+
+ #[test]
+ fn test_intern_for_vectorized_group_values() {
+ let data_set = VectorizedTestDataSet::new();
+ let mut group_values =
+ GroupValuesColumn::<false>::try_new(data_set.schema()).unwrap();
+
+ data_set.load_to_group_values(&mut group_values);
+ let actual_batch = group_values.emit(EmitTo::All).unwrap();
+ let actual_batch = RecordBatch::try_new(data_set.schema(),
actual_batch).unwrap();
+
+ check_result(&actual_batch, &data_set.expected_batch);
+ }
+
+ #[test]
+ fn test_emit_first_n_for_vectorized_group_values() {
+ let data_set = VectorizedTestDataSet::new();
+ let mut group_values =
+ GroupValuesColumn::<false>::try_new(data_set.schema()).unwrap();
+
+ // 1~num_rows times to emit the groups
+ let num_rows = data_set.expected_batch.num_rows();
+ let schema = data_set.schema();
+ for times_to_take in 1..=num_rows {
+ // Write data after emitting
+ data_set.load_to_group_values(&mut group_values);
+
+ // Emit `times_to_take` times, collect and concat the sub-results
to total result,
+ // then check it
+ let suggest_num_emit = data_set.expected_batch.num_rows() /
times_to_take;
+ let mut num_remaining_rows = num_rows;
+ let mut actual_sub_batches = Vec::new();
+
+ for nth_time in 0..times_to_take {
+ let num_emit = if nth_time == times_to_take - 1 {
+ num_remaining_rows
+ } else {
+ suggest_num_emit
+ };
+
+ let sub_batch =
group_values.emit(EmitTo::First(num_emit)).unwrap();
+ let sub_batch =
+ RecordBatch::try_new(Arc::clone(&schema),
sub_batch).unwrap();
+ actual_sub_batches.push(sub_batch);
+
+ num_remaining_rows -= num_emit;
+ }
+ assert!(num_remaining_rows == 0);
+
+ let actual_batch = concat_batches(&schema,
&actual_sub_batches).unwrap();
+ check_result(&actual_batch, &data_set.expected_batch);
+ }
+ }
+
+ #[test]
+ fn test_hashtable_modifying_in_emit_first_n() {
+ // Situations should be covered:
+ // 1. Erase inlined group index view
+ // 2. Erase whole non-inlined group index view
+ // 3. Erase + decrease group indices in non-inlined group index view
+ // + view still non-inlined after decreasing
+ // 4. Erase + decrease group indices in non-inlined group index view
+ // + view switch to inlined after decreasing
+ // 5. Only decrease group index in inlined group index view
+ // 6. Only decrease group indices in non-inlined group index view
+ // 7. Erase all things
+
+ let field = Field::new("item", DataType::Int32, true);
+ let schema = Arc::new(Schema::new_with_metadata(vec![field],
HashMap::new()));
+ let mut group_values =
GroupValuesColumn::<false>::try_new(schema).unwrap();
+
+ // Insert group index views and check if success to insert
+ insert_inline_group_index_view(&mut group_values, 0, 0);
+ insert_non_inline_group_index_view(&mut group_values, 1, vec![1, 2]);
+ insert_non_inline_group_index_view(&mut group_values, 2, vec![3, 4,
5]);
+ insert_inline_group_index_view(&mut group_values, 3, 6);
+ insert_non_inline_group_index_view(&mut group_values, 4, vec![7, 8]);
+ insert_non_inline_group_index_view(&mut group_values, 5, vec![9, 10,
11]);
+
+ assert_eq!(
+ group_values.get_indices_by_hash(0).unwrap(),
+ (vec![0], GroupIndexView::new_inlined(0))
+ );
+ assert_eq!(
+ group_values.get_indices_by_hash(1).unwrap(),
+ (vec![1, 2], GroupIndexView::new_non_inlined(0))
+ );
+ assert_eq!(
+ group_values.get_indices_by_hash(2).unwrap(),
+ (vec![3, 4, 5], GroupIndexView::new_non_inlined(1))
+ );
+ assert_eq!(
+ group_values.get_indices_by_hash(3).unwrap(),
+ (vec![6], GroupIndexView::new_inlined(6))
+ );
+ assert_eq!(
+ group_values.get_indices_by_hash(4).unwrap(),
+ (vec![7, 8], GroupIndexView::new_non_inlined(2))
+ );
+ assert_eq!(
+ group_values.get_indices_by_hash(5).unwrap(),
+ (vec![9, 10, 11], GroupIndexView::new_non_inlined(3))
+ );
+ assert_eq!(group_values.map.len(), 6);
+
+ // Emit first 4 to test cases 1~3, 5~6
+ let _ = group_values.emit(EmitTo::First(4)).unwrap();
+ assert!(group_values.get_indices_by_hash(0).is_none());
+ assert!(group_values.get_indices_by_hash(1).is_none());
+ assert_eq!(
+ group_values.get_indices_by_hash(2).unwrap(),
+ (vec![0, 1], GroupIndexView::new_non_inlined(0))
+ );
+ assert_eq!(
+ group_values.get_indices_by_hash(3).unwrap(),
+ (vec![2], GroupIndexView::new_inlined(2))
+ );
+ assert_eq!(
+ group_values.get_indices_by_hash(4).unwrap(),
+ (vec![3, 4], GroupIndexView::new_non_inlined(1))
+ );
+ assert_eq!(
+ group_values.get_indices_by_hash(5).unwrap(),
+ (vec![5, 6, 7], GroupIndexView::new_non_inlined(2))
+ );
+ assert_eq!(group_values.map.len(), 4);
+
+ // Emit first 1 to test case 4, and cases 5~6 again
+ let _ = group_values.emit(EmitTo::First(1)).unwrap();
+ assert_eq!(
+ group_values.get_indices_by_hash(2).unwrap(),
+ (vec![0], GroupIndexView::new_inlined(0))
+ );
+ assert_eq!(
+ group_values.get_indices_by_hash(3).unwrap(),
+ (vec![1], GroupIndexView::new_inlined(1))
+ );
+ assert_eq!(
+ group_values.get_indices_by_hash(4).unwrap(),
+ (vec![2, 3], GroupIndexView::new_non_inlined(0))
+ );
+ assert_eq!(
+ group_values.get_indices_by_hash(5).unwrap(),
+ (vec![4, 5, 6], GroupIndexView::new_non_inlined(1))
+ );
+ assert_eq!(group_values.map.len(), 4);
+
+ // Emit first 5 to test cases 1~3 again
+ let _ = group_values.emit(EmitTo::First(5)).unwrap();
+ assert_eq!(
+ group_values.get_indices_by_hash(5).unwrap(),
+ (vec![0, 1], GroupIndexView::new_non_inlined(0))
+ );
+ assert_eq!(group_values.map.len(), 1);
+
+ // Emit first 1 to test cases 4 again
+ let _ = group_values.emit(EmitTo::First(1)).unwrap();
+ assert_eq!(
+ group_values.get_indices_by_hash(5).unwrap(),
+ (vec![0], GroupIndexView::new_inlined(0))
+ );
+ assert_eq!(group_values.map.len(), 1);
+
+ // Emit first 1 to test cases 7
+ let _ = group_values.emit(EmitTo::First(1)).unwrap();
+ assert!(group_values.map.is_empty());
+ }
+
+ /// Test data set for [`GroupValuesColumn::vectorized_intern`]
+ ///
+ /// Define the test data and support loading them into test
[`GroupValuesColumn::vectorized_intern`]
+ ///
+ /// The covering situations:
+ ///
+ /// Array type:
+ /// - Primitive array
+ /// - String(byte) array
+ /// - String view(byte view) array
+ ///
+ /// Repeation and nullability in single batch:
+ /// - All not null rows
+ /// - Mixed null + not null rows
+ /// - All null rows
+ /// - All not null rows(repeated)
+ /// - Null + not null rows(repeated)
+ /// - All not null rows(repeated)
+ ///
+ /// If group exists in `map`:
+ /// - Group exists in inlined group view
+ /// - Group exists in non-inlined group view
+ /// - Group not exist + bucket not found in `map`
+ /// - Group not exist + not equal to inlined group view(tested in hash
collision)
+ /// - Group not exist + not equal to non-inlined group view(tested in
hash collision)
+ ///
+ struct VectorizedTestDataSet {
+ test_batches: Vec<Vec<ArrayRef>>,
+ expected_batch: RecordBatch,
+ }
+
+ impl VectorizedTestDataSet {
+ fn new() -> Self {
+ // Intern batch 1
+ let col1 = Int64Array::from(vec![
+ // Repeated rows in batch
+ Some(42), // all not nulls + repeated rows + exist in map
case
+ None, // mixed + repeated rows + exist in map case
+ None, // mixed + repeated rows + not exist in map case
+ Some(1142), // mixed + repeated rows + not exist in map case
+ None, // all nulls + repeated rows + exist in map case
+ Some(42),
+ None,
+ None,
+ Some(1142),
+ None,
+ // Unique rows in batch
+ Some(4211), // all not nulls + unique rows + exist in map case
+ None, // mixed + unique rows + exist in map case
+ None, // mixed + unique rows + not exist in map case
+ Some(4212), // mixed + unique rows + not exist in map case
+ ]);
+
+ let col2 = StringArray::from(vec![
+ // Repeated rows in batch
+ Some("string1"), // all not nulls + repeated rows + exist in
map case
+ None, // mixed + repeated rows + exist in map case
+ Some("string2"), // mixed + repeated rows + not exist in map
case
+ None, // mixed + repeated rows + not exist in map
case
+ None, // all nulls + repeated rows + exist in map
case
+ Some("string1"),
+ None,
+ Some("string2"),
+ None,
+ None,
+ // Unique rows in batch
+ Some("string3"), // all not nulls + unique rows + exist in map
case
+ None, // mixed + unique rows + exist in map case
+ Some("string4"), // mixed + unique rows + not exist in map case
+ None, // mixed + unique rows + not exist in map case
+ ]);
+
+ let col3 = StringViewArray::from(vec![
+ // Repeated rows in batch
+ Some("stringview1"), // all not nulls + repeated rows + exist
in map case
+ Some("stringview2"), // mixed + repeated rows + exist in map
case
+ None, // mixed + repeated rows + not exist in
map case
+ None, // mixed + repeated rows + not exist in
map case
+ None, // all nulls + repeated rows + exist in
map case
+ Some("stringview1"),
+ Some("stringview2"),
+ None,
+ None,
+ None,
+ // Unique rows in batch
+ Some("stringview3"), // all not nulls + unique rows + exist in
map case
+ Some("stringview4"), // mixed + unique rows + exist in map case
+ None, // mixed + unique rows + not exist in map
case
+ None, // mixed + unique rows + not exist in map
case
+ ]);
+ let batch1 = vec![
+ Arc::new(col1) as _,
+ Arc::new(col2) as _,
+ Arc::new(col3) as _,
+ ];
+
+ // Intern batch 2
+ let col1 = Int64Array::from(vec![
+ // Repeated rows in batch
+ Some(42), // all not nulls + repeated rows + exist in map
case
+ None, // mixed + repeated rows + exist in map case
+ None, // mixed + repeated rows + not exist in map case
+ Some(21142), // mixed + repeated rows + not exist in map case
+ None, // all nulls + repeated rows + exist in map case
+ Some(42),
+ None,
+ None,
+ Some(21142),
+ None,
+ // Unique rows in batch
+ Some(4211), // all not nulls + unique rows + exist in map case
+ None, // mixed + unique rows + exist in map case
+ None, // mixed + unique rows + not exist in map case
+ Some(24212), // mixed + unique rows + not exist in map case
+ ]);
+
+ let col2 = StringArray::from(vec![
+ // Repeated rows in batch
+ Some("string1"), // all not nulls + repeated rows + exist in
map case
+ None, // mixed + repeated rows + exist in map case
+ Some("2string2"), // mixed + repeated rows + not exist in map
case
+ None, // mixed + repeated rows + not exist in map
case
+ None, // all nulls + repeated rows + exist in map
case
+ Some("string1"),
+ None,
+ Some("2string2"),
+ None,
+ None,
+ // Unique rows in batch
+ Some("string3"), // all not nulls + unique rows + exist in map
case
+ None, // mixed + unique rows + exist in map case
+ Some("2string4"), // mixed + unique rows + not exist in map
case
+ None, // mixed + unique rows + not exist in map case
+ ]);
+
+ let col3 = StringViewArray::from(vec![
+ // Repeated rows in batch
+ Some("stringview1"), // all not nulls + repeated rows + exist
in map case
+ Some("stringview2"), // mixed + repeated rows + exist in map
case
+ None, // mixed + repeated rows + not exist in
map case
+ None, // mixed + repeated rows + not exist in
map case
+ None, // all nulls + repeated rows + exist in
map case
+ Some("stringview1"),
+ Some("stringview2"),
+ None,
+ None,
+ None,
+ // Unique rows in batch
+ Some("stringview3"), // all not nulls + unique rows + exist in
map case
+ Some("stringview4"), // mixed + unique rows + exist in map case
+ None, // mixed + unique rows + not exist in map
case
+ None, // mixed + unique rows + not exist in map
case
+ ]);
+ let batch2 = vec![
+ Arc::new(col1) as _,
+ Arc::new(col2) as _,
+ Arc::new(col3) as _,
+ ];
+
+ // Intern batch 3
+ let col1 = Int64Array::from(vec![
+ // Repeated rows in batch
+ Some(42), // all not nulls + repeated rows + exist in map
case
+ None, // mixed + repeated rows + exist in map case
+ None, // mixed + repeated rows + not exist in map case
+ Some(31142), // mixed + repeated rows + not exist in map case
+ None, // all nulls + repeated rows + exist in map case
+ Some(42),
+ None,
+ None,
+ Some(31142),
+ None,
+ // Unique rows in batch
+ Some(4211), // all not nulls + unique rows + exist in map case
+ None, // mixed + unique rows + exist in map case
+ None, // mixed + unique rows + not exist in map case
+ Some(34212), // mixed + unique rows + not exist in map case
+ ]);
+
+ let col2 = StringArray::from(vec![
+ // Repeated rows in batch
+ Some("string1"), // all not nulls + repeated rows + exist in
map case
+ None, // mixed + repeated rows + exist in map case
+ Some("3string2"), // mixed + repeated rows + not exist in map
case
+ None, // mixed + repeated rows + not exist in map
case
+ None, // all nulls + repeated rows + exist in map
case
+ Some("string1"),
+ None,
+ Some("3string2"),
+ None,
+ None,
+ // Unique rows in batch
+ Some("string3"), // all not nulls + unique rows + exist in map
case
+ None, // mixed + unique rows + exist in map case
+ Some("3string4"), // mixed + unique rows + not exist in map
case
+ None, // mixed + unique rows + not exist in map case
+ ]);
+
+ let col3 = StringViewArray::from(vec![
+ // Repeated rows in batch
+ Some("stringview1"), // all not nulls + repeated rows + exist
in map case
+ Some("stringview2"), // mixed + repeated rows + exist in map
case
+ None, // mixed + repeated rows + not exist in
map case
+ None, // mixed + repeated rows + not exist in
map case
+ None, // all nulls + repeated rows + exist in
map case
+ Some("stringview1"),
+ Some("stringview2"),
+ None,
+ None,
+ None,
+ // Unique rows in batch
+ Some("stringview3"), // all not nulls + unique rows + exist in
map case
+ Some("stringview4"), // mixed + unique rows + exist in map case
+ None, // mixed + unique rows + not exist in map
case
+ None, // mixed + unique rows + not exist in map
case
+ ]);
+ let batch3 = vec![
+ Arc::new(col1) as _,
+ Arc::new(col2) as _,
+ Arc::new(col3) as _,
+ ];
+
+ // Expected batch
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, true),
+ Field::new("b", DataType::Utf8, true),
+ Field::new("c", DataType::Utf8View, true),
+ ]));
+
+ let col1 = Int64Array::from(vec![
+ // Repeated rows in batch
+ Some(42),
+ None,
+ None,
+ Some(1142),
+ None,
+ Some(21142),
+ None,
+ Some(31142),
+ None,
+ // Unique rows in batch
+ Some(4211),
+ None,
+ None,
+ Some(4212),
+ None,
+ Some(24212),
+ None,
+ Some(34212),
+ ]);
+
+ let col2 = StringArray::from(vec![
+ // Repeated rows in batch
+ Some("string1"),
+ None,
+ Some("string2"),
+ None,
+ Some("2string2"),
+ None,
+ Some("3string2"),
+ None,
+ None,
+ // Unique rows in batch
+ Some("string3"),
+ None,
+ Some("string4"),
+ None,
+ Some("2string4"),
+ None,
+ Some("3string4"),
+ None,
+ ]);
+
+ let col3 = StringViewArray::from(vec![
+ // Repeated rows in batch
+ Some("stringview1"),
+ Some("stringview2"),
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ // Unique rows in batch
+ Some("stringview3"),
+ Some("stringview4"),
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ ]);
+ let expected_batch = vec![
+ Arc::new(col1) as _,
+ Arc::new(col2) as _,
+ Arc::new(col3) as _,
+ ];
+ let expected_batch = RecordBatch::try_new(schema,
expected_batch).unwrap();
+
+ Self {
+ test_batches: vec![batch1, batch2, batch3],
+ expected_batch,
+ }
+ }
+
+ fn load_to_group_values(&self, group_values: &mut impl GroupValues) {
+ for batch in self.test_batches.iter() {
+ group_values.intern(batch, &mut vec![]).unwrap();
+ }
+ }
+
+ fn schema(&self) -> SchemaRef {
+ self.expected_batch.schema()
+ }
+ }
+
+ fn check_result(actual_batch: &RecordBatch, expected_batch: &RecordBatch) {
+ let formatted_actual_batch =
pretty_format_batches(&[actual_batch.clone()])
+ .unwrap()
+ .to_string();
+ let mut formatted_actual_batch_sorted: Vec<&str> =
+ formatted_actual_batch.trim().lines().collect();
+ formatted_actual_batch_sorted.sort_unstable();
+
+ let formatted_expected_batch =
pretty_format_batches(&[expected_batch.clone()])
+ .unwrap()
+ .to_string();
+ let mut formatted_expected_batch_sorted: Vec<&str> =
+ formatted_expected_batch.trim().lines().collect();
+ formatted_expected_batch_sorted.sort_unstable();
+
+ for (i, (actual_line, expected_line)) in formatted_actual_batch_sorted
+ .iter()
+ .zip(&formatted_expected_batch_sorted)
+ .enumerate()
+ {
+ assert_eq!(
+ (i, actual_line),
+ (i, expected_line),
+ "Inconsistent result\n\n\
+ Actual batch:\n{}\n\
+ Expected batch:\n{}\n\
+ ",
+ formatted_actual_batch,
+ formatted_expected_batch,
+ );
+ }
+ }
+
+ fn insert_inline_group_index_view(
+ group_values: &mut GroupValuesColumn<false>,
+ hash_key: u64,
+ group_index: u64,
+ ) {
+ let group_index_view = GroupIndexView::new_inlined(group_index);
+ group_values.map.insert_accounted(
+ (hash_key, group_index_view),
+ |(hash, _)| *hash,
+ &mut group_values.map_size,
+ );
+ }
+
+ fn insert_non_inline_group_index_view(
+ group_values: &mut GroupValuesColumn<false>,
+ hash_key: u64,
+ group_indices: Vec<usize>,
+ ) {
+ let list_offset = group_values.group_index_lists.len();
+ let group_index_view = GroupIndexView::new_non_inlined(list_offset as
u64);
+ group_values.group_index_lists.push(group_indices);
+ group_values.map.insert_accounted(
+ (hash_key, group_index_view),
+ |(hash, _)| *hash,
+ &mut group_values.map_size,
+ );
}
}
diff --git
a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs
b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs
index bba59b6d0c..1f59c617d8 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs
@@ -29,13 +29,16 @@ use arrow::datatypes::ByteArrayType;
use arrow::datatypes::ByteViewType;
use arrow::datatypes::DataType;
use arrow::datatypes::GenericBinaryType;
+use arrow_array::GenericByteArray;
use arrow_array::GenericByteViewArray;
use arrow_buffer::Buffer;
use datafusion_common::utils::proxy::VecAllocExt;
+use itertools::izip;
use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder;
use arrow_array::types::GenericStringType;
use datafusion_physical_expr_common::binary_map::{OutputType,
INITIAL_BUFFER_CAPACITY};
+use std::iter;
use std::marker::PhantomData;
use std::mem::{replace, size_of};
use std::sync::Arc;
@@ -56,14 +59,40 @@ pub trait GroupColumn: Send + Sync {
///
/// Note that this comparison returns true if both elements are NULL
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) ->
bool;
+
/// Appends the row at `row` in `array` to this builder
fn append_val(&mut self, array: &ArrayRef, row: usize);
+
+ /// The vectorized version equal to
+ ///
+ /// When found nth row stored in this builder at `lhs_row`
+ /// is equal to the row in `array` at `rhs_row`,
+ /// it will record the `true` result at the corresponding
+ /// position in `equal_to_results`.
+ ///
+ /// And if found nth result in `equal_to_results` is already
+ /// `false`, the check for nth row will be skipped.
+ ///
+ fn vectorized_equal_to(
+ &self,
+ lhs_rows: &[usize],
+ array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut [bool],
+ );
+
+ /// The vectorized version `append_val`
+ fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]);
+
/// Returns the number of rows stored in this builder
fn len(&self) -> usize;
+
/// Returns the number of bytes used by this [`GroupColumn`]
fn size(&self) -> usize;
+
/// Builds a new array from all of the stored rows
fn build(self: Box<Self>) -> ArrayRef;
+
/// Builds a new array from the first `n` stored rows, shifting the
/// remaining rows to the start of the builder
fn take_n(&mut self, n: usize) -> ArrayRef;
@@ -128,6 +157,89 @@ impl<T: ArrowPrimitiveType, const NULLABLE: bool>
GroupColumn
}
}
+ fn vectorized_equal_to(
+ &self,
+ lhs_rows: &[usize],
+ array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut [bool],
+ ) {
+ let array = array.as_primitive::<T>();
+
+ let iter = izip!(
+ lhs_rows.iter(),
+ rhs_rows.iter(),
+ equal_to_results.iter_mut(),
+ );
+
+ for (&lhs_row, &rhs_row, equal_to_result) in iter {
+ // Has found not equal to in previous column, don't need to check
+ if !*equal_to_result {
+ continue;
+ }
+
+ // Perf: skip null check (by short circuit) if input is not
nullable
+ if NULLABLE {
+ let exist_null = self.nulls.is_null(lhs_row);
+ let input_null = array.is_null(rhs_row);
+ if let Some(result) = nulls_equal_to(exist_null, input_null) {
+ *equal_to_result = result;
+ continue;
+ }
+ // Otherwise, we need to check their values
+ }
+
+ *equal_to_result = self.group_values[lhs_row] ==
array.value(rhs_row);
+ }
+ }
+
+ fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) {
+ let arr = array.as_primitive::<T>();
+
+ let null_count = array.null_count();
+ let num_rows = array.len();
+ let all_null_or_non_null = if null_count == 0 {
+ Some(true)
+ } else if null_count == num_rows {
+ Some(false)
+ } else {
+ None
+ };
+
+ match (NULLABLE, all_null_or_non_null) {
+ (true, None) => {
+ for &row in rows {
+ if array.is_null(row) {
+ self.nulls.append(true);
+ self.group_values.push(T::default_value());
+ } else {
+ self.nulls.append(false);
+ self.group_values.push(arr.value(row));
+ }
+ }
+ }
+
+ (true, Some(true)) => {
+ self.nulls.append_n(rows.len(), false);
+ for &row in rows {
+ self.group_values.push(arr.value(row));
+ }
+ }
+
+ (true, Some(false)) => {
+ self.nulls.append_n(rows.len(), true);
+ self.group_values
+ .extend(iter::repeat(T::default_value()).take(rows.len()));
+ }
+
+ (false, _) => {
+ for &row in rows {
+ self.group_values.push(arr.value(row));
+ }
+ }
+ }
+ }
+
fn len(&self) -> usize {
self.group_values.len()
}
@@ -200,6 +312,14 @@ where
}
}
+ fn equal_to_inner<B>(&self, lhs_row: usize, array: &ArrayRef, rhs_row:
usize) -> bool
+ where
+ B: ByteArrayType,
+ {
+ let array = array.as_bytes::<B>();
+ self.do_equal_to_inner(lhs_row, array, rhs_row)
+ }
+
fn append_val_inner<B>(&mut self, array: &ArrayRef, row: usize)
where
B: ByteArrayType,
@@ -212,17 +332,93 @@ where
self.offsets.push(O::usize_as(offset));
} else {
self.nulls.append(false);
- let value: &[u8] = arr.value(row).as_ref();
- self.buffer.append_slice(value);
- self.offsets.push(O::usize_as(self.buffer.len()));
+ self.do_append_val_inner(arr, row);
}
}
- fn equal_to_inner<B>(&self, lhs_row: usize, array: &ArrayRef, rhs_row:
usize) -> bool
- where
+ fn vectorized_equal_to_inner<B>(
+ &self,
+ lhs_rows: &[usize],
+ array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut [bool],
+ ) where
B: ByteArrayType,
{
let array = array.as_bytes::<B>();
+
+ let iter = izip!(
+ lhs_rows.iter(),
+ rhs_rows.iter(),
+ equal_to_results.iter_mut(),
+ );
+
+ for (&lhs_row, &rhs_row, equal_to_result) in iter {
+ // Has found not equal to, don't need to check
+ if !*equal_to_result {
+ continue;
+ }
+
+ *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row);
+ }
+ }
+
+ fn vectorized_append_inner<B>(&mut self, array: &ArrayRef, rows: &[usize])
+ where
+ B: ByteArrayType,
+ {
+ let arr = array.as_bytes::<B>();
+ let null_count = array.null_count();
+ let num_rows = array.len();
+ let all_null_or_non_null = if null_count == 0 {
+ Some(true)
+ } else if null_count == num_rows {
+ Some(false)
+ } else {
+ None
+ };
+
+ match all_null_or_non_null {
+ None => {
+ for &row in rows {
+ if arr.is_null(row) {
+ self.nulls.append(true);
+ // nulls need a zero length in the offset buffer
+ let offset = self.buffer.len();
+ self.offsets.push(O::usize_as(offset));
+ } else {
+ self.nulls.append(false);
+ self.do_append_val_inner(arr, row);
+ }
+ }
+ }
+
+ Some(true) => {
+ self.nulls.append_n(rows.len(), false);
+ for &row in rows {
+ self.do_append_val_inner(arr, row);
+ }
+ }
+
+ Some(false) => {
+ self.nulls.append_n(rows.len(), true);
+
+ let new_len = self.offsets.len() + rows.len();
+ let offset = self.buffer.len();
+ self.offsets.resize(new_len, O::usize_as(offset));
+ }
+ }
+ }
+
+ fn do_equal_to_inner<B>(
+ &self,
+ lhs_row: usize,
+ array: &GenericByteArray<B>,
+ rhs_row: usize,
+ ) -> bool
+ where
+ B: ByteArrayType,
+ {
let exist_null = self.nulls.is_null(lhs_row);
let input_null = array.is_null(rhs_row);
if let Some(result) = nulls_equal_to(exist_null, input_null) {
@@ -232,6 +428,15 @@ where
self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8])
}
+ fn do_append_val_inner<B>(&mut self, array: &GenericByteArray<B>, row:
usize)
+ where
+ B: ByteArrayType,
+ {
+ let value: &[u8] = array.value(row).as_ref();
+ self.buffer.append_slice(value);
+ self.offsets.push(O::usize_as(self.buffer.len()));
+ }
+
/// return the current value of the specified row irrespective of null
pub fn value(&self, row: usize) -> &[u8] {
let l = self.offsets[row].as_usize();
@@ -287,6 +492,63 @@ where
};
}
+ fn vectorized_equal_to(
+ &self,
+ lhs_rows: &[usize],
+ array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut [bool],
+ ) {
+ // Sanity array type
+ match self.output_type {
+ OutputType::Binary => {
+ debug_assert!(matches!(
+ array.data_type(),
+ DataType::Binary | DataType::LargeBinary
+ ));
+ self.vectorized_equal_to_inner::<GenericBinaryType<O>>(
+ lhs_rows,
+ array,
+ rhs_rows,
+ equal_to_results,
+ );
+ }
+ OutputType::Utf8 => {
+ debug_assert!(matches!(
+ array.data_type(),
+ DataType::Utf8 | DataType::LargeUtf8
+ ));
+ self.vectorized_equal_to_inner::<GenericStringType<O>>(
+ lhs_rows,
+ array,
+ rhs_rows,
+ equal_to_results,
+ );
+ }
+ _ => unreachable!("View types should use `ArrowBytesViewMap`"),
+ }
+ }
+
+ fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) {
+ match self.output_type {
+ OutputType::Binary => {
+ debug_assert!(matches!(
+ column.data_type(),
+ DataType::Binary | DataType::LargeBinary
+ ));
+ self.vectorized_append_inner::<GenericBinaryType<O>>(column,
rows)
+ }
+ OutputType::Utf8 => {
+ debug_assert!(matches!(
+ column.data_type(),
+ DataType::Utf8 | DataType::LargeUtf8
+ ));
+ self.vectorized_append_inner::<GenericStringType<O>>(column,
rows)
+ }
+ _ => unreachable!("View types should use `ArrowBytesViewMap`"),
+ };
+ }
+
fn len(&self) -> usize {
self.offsets.len() - 1
}
@@ -446,10 +708,12 @@ impl<B: ByteViewType> ByteViewGroupValueBuilder<B> {
self
}
- fn append_val_inner(&mut self, array: &ArrayRef, row: usize)
- where
- B: ByteViewType,
- {
+ fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize)
-> bool {
+ let array = array.as_byte_view::<B>();
+ self.do_equal_to_inner(lhs_row, array, rhs_row)
+ }
+
+ fn append_val_inner(&mut self, array: &ArrayRef, row: usize) {
let arr = array.as_byte_view::<B>();
// Null row case, set and return
@@ -461,7 +725,80 @@ impl<B: ByteViewType> ByteViewGroupValueBuilder<B> {
// Not null row case
self.nulls.append(false);
- let value: &[u8] = arr.value(row).as_ref();
+ self.do_append_val_inner(arr, row);
+ }
+
+ fn vectorized_equal_to_inner(
+ &self,
+ lhs_rows: &[usize],
+ array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut [bool],
+ ) {
+ let array = array.as_byte_view::<B>();
+
+ let iter = izip!(
+ lhs_rows.iter(),
+ rhs_rows.iter(),
+ equal_to_results.iter_mut(),
+ );
+
+ for (&lhs_row, &rhs_row, equal_to_result) in iter {
+ // Has found not equal to, don't need to check
+ if !*equal_to_result {
+ continue;
+ }
+
+ *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row);
+ }
+ }
+
+ fn vectorized_append_inner(&mut self, array: &ArrayRef, rows: &[usize]) {
+ let arr = array.as_byte_view::<B>();
+ let null_count = array.null_count();
+ let num_rows = array.len();
+ let all_null_or_non_null = if null_count == 0 {
+ Some(true)
+ } else if null_count == num_rows {
+ Some(false)
+ } else {
+ None
+ };
+
+ match all_null_or_non_null {
+ None => {
+ for &row in rows {
+ // Null row case, set and return
+ if arr.is_valid(row) {
+ self.nulls.append(false);
+ self.do_append_val_inner(arr, row);
+ } else {
+ self.nulls.append(true);
+ self.views.push(0);
+ }
+ }
+ }
+
+ Some(true) => {
+ self.nulls.append_n(rows.len(), false);
+ for &row in rows {
+ self.do_append_val_inner(arr, row);
+ }
+ }
+
+ Some(false) => {
+ self.nulls.append_n(rows.len(), true);
+ let new_len = self.views.len() + rows.len();
+ self.views.resize(new_len, 0);
+ }
+ }
+ }
+
+ fn do_append_val_inner(&mut self, array: &GenericByteViewArray<B>, row:
usize)
+ where
+ B: ByteViewType,
+ {
+ let value: &[u8] = array.value(row).as_ref();
let value_len = value.len();
let view = if value_len <= 12 {
@@ -497,9 +834,12 @@ impl<B: ByteViewType> ByteViewGroupValueBuilder<B> {
}
}
- fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize)
-> bool {
- let array = array.as_byte_view::<B>();
-
+ fn do_equal_to_inner(
+ &self,
+ lhs_row: usize,
+ array: &GenericByteViewArray<B>,
+ rhs_row: usize,
+ ) -> bool {
// Check if nulls equal firstly
let exist_null = self.nulls.is_null(lhs_row);
let input_null = array.is_null(rhs_row);
@@ -777,6 +1117,20 @@ impl<B: ByteViewType> GroupColumn for
ByteViewGroupValueBuilder<B> {
self.append_val_inner(array, row)
}
+ fn vectorized_equal_to(
+ &self,
+ group_indices: &[usize],
+ array: &ArrayRef,
+ rows: &[usize],
+ equal_to_results: &mut [bool],
+ ) {
+ self.vectorized_equal_to_inner(group_indices, array, rows,
equal_to_results);
+ }
+
+ fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) {
+ self.vectorized_append_inner(array, rows);
+ }
+
fn len(&self) -> usize {
self.views.len()
}
@@ -826,7 +1180,7 @@ mod tests {
array::AsArray,
datatypes::{Int64Type, StringViewType},
};
- use arrow_array::{ArrayRef, Int64Array, StringArray, StringViewArray};
+ use arrow_array::{Array, ArrayRef, Int64Array, StringArray,
StringViewArray};
use arrow_buffer::{BooleanBufferBuilder, NullBuffer};
use datafusion_physical_expr::binary_map::OutputType;
@@ -836,53 +1190,68 @@ mod tests {
use super::{ByteGroupValueBuilder, GroupColumn};
+ // ========================================================================
+ // Tests for primitive builders
+ // ========================================================================
#[test]
- fn test_take_n() {
- let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
- let array = Arc::new(StringArray::from(vec![Some("a"), None])) as
ArrayRef;
- // a, null, null
- builder.append_val(&array, 0);
- builder.append_val(&array, 1);
- builder.append_val(&array, 1);
-
- // (a, null) remaining: null
- let output = builder.take_n(2);
- assert_eq!(&output, &array);
+ fn test_nullable_primitive_equal_to() {
+ let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type,
true>,
+ builder_array: &ArrayRef,
+ append_rows: &[usize]| {
+ for &index in append_rows {
+ builder.append_val(builder_array, index);
+ }
+ };
- // null, a, null, a
- builder.append_val(&array, 0);
- builder.append_val(&array, 1);
- builder.append_val(&array, 0);
+ let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, true>,
+ lhs_rows: &[usize],
+ input_array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut Vec<bool>| {
+ let iter = lhs_rows.iter().zip(rhs_rows.iter());
+ for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
+ equal_to_results[idx] = builder.equal_to(lhs_row, input_array,
rhs_row);
+ }
+ };
- // (null, a) remaining: (null, a)
- let output = builder.take_n(2);
- let array = Arc::new(StringArray::from(vec![None, Some("a")])) as
ArrayRef;
- assert_eq!(&output, &array);
+ test_nullable_primitive_equal_to_internal(append, equal_to);
+ }
- let array = Arc::new(StringArray::from(vec![
- Some("a"),
- None,
- Some("longstringfortest"),
- ])) as ArrayRef;
+ #[test]
+ fn test_nullable_primitive_vectorized_equal_to() {
+ let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type,
true>,
+ builder_array: &ArrayRef,
+ append_rows: &[usize]| {
+ builder.vectorized_append(builder_array, append_rows);
+ };
- // null, a, longstringfortest, null, null
- builder.append_val(&array, 2);
- builder.append_val(&array, 1);
- builder.append_val(&array, 1);
+ let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, true>,
+ lhs_rows: &[usize],
+ input_array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut Vec<bool>| {
+ builder.vectorized_equal_to(
+ lhs_rows,
+ input_array,
+ rhs_rows,
+ equal_to_results,
+ );
+ };
- // (null, a, longstringfortest, null) remaining: (null)
- let output = builder.take_n(4);
- let array = Arc::new(StringArray::from(vec![
- None,
- Some("a"),
- Some("longstringfortest"),
- None,
- ])) as ArrayRef;
- assert_eq!(&output, &array);
+ test_nullable_primitive_equal_to_internal(append, equal_to);
}
- #[test]
- fn test_nullable_primitive_equal_to() {
+ fn test_nullable_primitive_equal_to_internal<A, E>(mut append: A, mut
equal_to: E)
+ where
+ A: FnMut(&mut PrimitiveGroupValueBuilder<Int64Type, true>, &ArrayRef,
&[usize]),
+ E: FnMut(
+ &PrimitiveGroupValueBuilder<Int64Type, true>,
+ &[usize],
+ &ArrayRef,
+ &[usize],
+ &mut Vec<bool>,
+ ),
+ {
// Will cover such cases:
// - exist null, input not null
// - exist null, input null; values not equal
@@ -901,12 +1270,7 @@ mod tests {
Some(2),
Some(3),
])) as ArrayRef;
- builder.append_val(&builder_array, 0);
- builder.append_val(&builder_array, 1);
- builder.append_val(&builder_array, 2);
- builder.append_val(&builder_array, 3);
- builder.append_val(&builder_array, 4);
- builder.append_val(&builder_array, 5);
+ append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]);
// Define input array
let (_nulls, values, _) =
@@ -925,16 +1289,82 @@ mod tests {
let input_array = Arc::new(Int64Array::new(values, Some(nulls))) as
ArrayRef;
// Check
- assert!(!builder.equal_to(0, &input_array, 0));
- assert!(builder.equal_to(1, &input_array, 1));
- assert!(builder.equal_to(2, &input_array, 2));
- assert!(!builder.equal_to(3, &input_array, 3));
- assert!(!builder.equal_to(4, &input_array, 4));
- assert!(builder.equal_to(5, &input_array, 5));
+ let mut equal_to_results = vec![true; builder.len()];
+ equal_to(
+ &builder,
+ &[0, 1, 2, 3, 4, 5],
+ &input_array,
+ &[0, 1, 2, 3, 4, 5],
+ &mut equal_to_results,
+ );
+
+ assert!(!equal_to_results[0]);
+ assert!(equal_to_results[1]);
+ assert!(equal_to_results[2]);
+ assert!(!equal_to_results[3]);
+ assert!(!equal_to_results[4]);
+ assert!(equal_to_results[5]);
}
#[test]
fn test_not_nullable_primitive_equal_to() {
+ let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type,
false>,
+ builder_array: &ArrayRef,
+ append_rows: &[usize]| {
+ for &index in append_rows {
+ builder.append_val(builder_array, index);
+ }
+ };
+
+ let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, false>,
+ lhs_rows: &[usize],
+ input_array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut Vec<bool>| {
+ let iter = lhs_rows.iter().zip(rhs_rows.iter());
+ for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
+ equal_to_results[idx] = builder.equal_to(lhs_row, input_array,
rhs_row);
+ }
+ };
+
+ test_not_nullable_primitive_equal_to_internal(append, equal_to);
+ }
+
+ #[test]
+ fn test_not_nullable_primitive_vectorized_equal_to() {
+ let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type,
false>,
+ builder_array: &ArrayRef,
+ append_rows: &[usize]| {
+ builder.vectorized_append(builder_array, append_rows);
+ };
+
+ let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, false>,
+ lhs_rows: &[usize],
+ input_array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut Vec<bool>| {
+ builder.vectorized_equal_to(
+ lhs_rows,
+ input_array,
+ rhs_rows,
+ equal_to_results,
+ );
+ };
+
+ test_not_nullable_primitive_equal_to_internal(append, equal_to);
+ }
+
+ fn test_not_nullable_primitive_equal_to_internal<A, E>(mut append: A, mut
equal_to: E)
+ where
+ A: FnMut(&mut PrimitiveGroupValueBuilder<Int64Type, false>, &ArrayRef,
&[usize]),
+ E: FnMut(
+ &PrimitiveGroupValueBuilder<Int64Type, false>,
+ &[usize],
+ &ArrayRef,
+ &[usize],
+ &mut Vec<bool>,
+ ),
+ {
// Will cover such cases:
// - values equal
// - values not equal
@@ -943,19 +1373,244 @@ mod tests {
let mut builder = PrimitiveGroupValueBuilder::<Int64Type,
false>::new();
let builder_array =
Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef;
- builder.append_val(&builder_array, 0);
- builder.append_val(&builder_array, 1);
+ append(&mut builder, &builder_array, &[0, 1]);
// Define input array
let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)]))
as ArrayRef;
// Check
- assert!(builder.equal_to(0, &input_array, 0));
- assert!(!builder.equal_to(1, &input_array, 1));
+ let mut equal_to_results = vec![true; builder.len()];
+ equal_to(
+ &builder,
+ &[0, 1],
+ &input_array,
+ &[0, 1],
+ &mut equal_to_results,
+ );
+
+ assert!(equal_to_results[0]);
+ assert!(!equal_to_results[1]);
}
#[test]
- fn test_byte_array_equal_to() {
+ fn test_nullable_primitive_vectorized_operation_special_case() {
+ // Test the special `all nulls` or `not nulls` input array case
+ // for vectorized append and equal to
+
+ let mut builder = PrimitiveGroupValueBuilder::<Int64Type, true>::new();
+
+ // All nulls input array
+ let all_nulls_input_array = Arc::new(Int64Array::from(vec![
+ Option::<i64>::None,
+ None,
+ None,
+ None,
+ None,
+ ])) as _;
+ builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]);
+
+ let mut equal_to_results = vec![true; all_nulls_input_array.len()];
+ builder.vectorized_equal_to(
+ &[0, 1, 2, 3, 4],
+ &all_nulls_input_array,
+ &[0, 1, 2, 3, 4],
+ &mut equal_to_results,
+ );
+
+ assert!(equal_to_results[0]);
+ assert!(equal_to_results[1]);
+ assert!(equal_to_results[2]);
+ assert!(equal_to_results[3]);
+ assert!(equal_to_results[4]);
+
+ // All not nulls input array
+ let all_not_nulls_input_array = Arc::new(Int64Array::from(vec![
+ Some(1),
+ Some(2),
+ Some(3),
+ Some(4),
+ Some(5),
+ ])) as _;
+ builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3,
4]);
+
+ let mut equal_to_results = vec![true; all_not_nulls_input_array.len()];
+ builder.vectorized_equal_to(
+ &[5, 6, 7, 8, 9],
+ &all_not_nulls_input_array,
+ &[0, 1, 2, 3, 4],
+ &mut equal_to_results,
+ );
+
+ assert!(equal_to_results[0]);
+ assert!(equal_to_results[1]);
+ assert!(equal_to_results[2]);
+ assert!(equal_to_results[3]);
+ assert!(equal_to_results[4]);
+ }
+
+ // ========================================================================
+ // Tests for byte builders
+ // ========================================================================
+ #[test]
+ fn test_byte_take_n() {
+ let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
+ let array = Arc::new(StringArray::from(vec![Some("a"), None])) as
ArrayRef;
+ // a, null, null
+ builder.append_val(&array, 0);
+ builder.append_val(&array, 1);
+ builder.append_val(&array, 1);
+
+ // (a, null) remaining: null
+ let output = builder.take_n(2);
+ assert_eq!(&output, &array);
+
+ // null, a, null, a
+ builder.append_val(&array, 0);
+ builder.append_val(&array, 1);
+ builder.append_val(&array, 0);
+
+ // (null, a) remaining: (null, a)
+ let output = builder.take_n(2);
+ let array = Arc::new(StringArray::from(vec![None, Some("a")])) as
ArrayRef;
+ assert_eq!(&output, &array);
+
+ let array = Arc::new(StringArray::from(vec![
+ Some("a"),
+ None,
+ Some("longstringfortest"),
+ ])) as ArrayRef;
+
+ // null, a, longstringfortest, null, null
+ builder.append_val(&array, 2);
+ builder.append_val(&array, 1);
+ builder.append_val(&array, 1);
+
+ // (null, a, longstringfortest, null) remaining: (null)
+ let output = builder.take_n(4);
+ let array = Arc::new(StringArray::from(vec![
+ None,
+ Some("a"),
+ Some("longstringfortest"),
+ None,
+ ])) as ArrayRef;
+ assert_eq!(&output, &array);
+ }
+
+ #[test]
+ fn test_byte_equal_to() {
+ let append = |builder: &mut ByteGroupValueBuilder<i32>,
+ builder_array: &ArrayRef,
+ append_rows: &[usize]| {
+ for &index in append_rows {
+ builder.append_val(builder_array, index);
+ }
+ };
+
+ let equal_to = |builder: &ByteGroupValueBuilder<i32>,
+ lhs_rows: &[usize],
+ input_array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut Vec<bool>| {
+ let iter = lhs_rows.iter().zip(rhs_rows.iter());
+ for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
+ equal_to_results[idx] = builder.equal_to(lhs_row, input_array,
rhs_row);
+ }
+ };
+
+ test_byte_equal_to_internal(append, equal_to);
+ }
+
+ #[test]
+ fn test_byte_vectorized_equal_to() {
+ let append = |builder: &mut ByteGroupValueBuilder<i32>,
+ builder_array: &ArrayRef,
+ append_rows: &[usize]| {
+ builder.vectorized_append(builder_array, append_rows);
+ };
+
+ let equal_to = |builder: &ByteGroupValueBuilder<i32>,
+ lhs_rows: &[usize],
+ input_array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut Vec<bool>| {
+ builder.vectorized_equal_to(
+ lhs_rows,
+ input_array,
+ rhs_rows,
+ equal_to_results,
+ );
+ };
+
+ test_byte_equal_to_internal(append, equal_to);
+ }
+
+ #[test]
+ fn test_byte_vectorized_operation_special_case() {
+ // Test the special `all nulls` or `not nulls` input array case
+ // for vectorized append and equal to
+
+ let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
+
+ // All nulls input array
+ let all_nulls_input_array = Arc::new(StringArray::from(vec![
+ Option::<&str>::None,
+ None,
+ None,
+ None,
+ None,
+ ])) as _;
+ builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]);
+
+ let mut equal_to_results = vec![true; all_nulls_input_array.len()];
+ builder.vectorized_equal_to(
+ &[0, 1, 2, 3, 4],
+ &all_nulls_input_array,
+ &[0, 1, 2, 3, 4],
+ &mut equal_to_results,
+ );
+
+ assert!(equal_to_results[0]);
+ assert!(equal_to_results[1]);
+ assert!(equal_to_results[2]);
+ assert!(equal_to_results[3]);
+ assert!(equal_to_results[4]);
+
+ // All not nulls input array
+ let all_not_nulls_input_array = Arc::new(StringArray::from(vec![
+ Some("string1"),
+ Some("string2"),
+ Some("string3"),
+ Some("string4"),
+ Some("string5"),
+ ])) as _;
+ builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3,
4]);
+
+ let mut equal_to_results = vec![true; all_not_nulls_input_array.len()];
+ builder.vectorized_equal_to(
+ &[5, 6, 7, 8, 9],
+ &all_not_nulls_input_array,
+ &[0, 1, 2, 3, 4],
+ &mut equal_to_results,
+ );
+
+ assert!(equal_to_results[0]);
+ assert!(equal_to_results[1]);
+ assert!(equal_to_results[2]);
+ assert!(equal_to_results[3]);
+ assert!(equal_to_results[4]);
+ }
+
+ fn test_byte_equal_to_internal<A, E>(mut append: A, mut equal_to: E)
+ where
+ A: FnMut(&mut ByteGroupValueBuilder<i32>, &ArrayRef, &[usize]),
+ E: FnMut(
+ &ByteGroupValueBuilder<i32>,
+ &[usize],
+ &ArrayRef,
+ &[usize],
+ &mut Vec<bool>,
+ ),
+ {
// Will cover such cases:
// - exist null, input not null
// - exist null, input null; values not equal
@@ -974,12 +1629,7 @@ mod tests {
Some("bar"),
Some("baz"),
])) as ArrayRef;
- builder.append_val(&builder_array, 0);
- builder.append_val(&builder_array, 1);
- builder.append_val(&builder_array, 2);
- builder.append_val(&builder_array, 3);
- builder.append_val(&builder_array, 4);
- builder.append_val(&builder_array, 5);
+ append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]);
// Define input array
let (offsets, buffer, _nulls) = StringArray::from(vec![
@@ -1005,14 +1655,26 @@ mod tests {
Arc::new(StringArray::new(offsets, buffer, Some(nulls))) as
ArrayRef;
// Check
- assert!(!builder.equal_to(0, &input_array, 0));
- assert!(builder.equal_to(1, &input_array, 1));
- assert!(builder.equal_to(2, &input_array, 2));
- assert!(!builder.equal_to(3, &input_array, 3));
- assert!(!builder.equal_to(4, &input_array, 4));
- assert!(builder.equal_to(5, &input_array, 5));
+ let mut equal_to_results = vec![true; builder.len()];
+ equal_to(
+ &builder,
+ &[0, 1, 2, 3, 4, 5],
+ &input_array,
+ &[0, 1, 2, 3, 4, 5],
+ &mut equal_to_results,
+ );
+
+ assert!(!equal_to_results[0]);
+ assert!(equal_to_results[1]);
+ assert!(equal_to_results[2]);
+ assert!(!equal_to_results[3]);
+ assert!(!equal_to_results[4]);
+ assert!(equal_to_results[5]);
}
+ // ========================================================================
+ // Tests for byte view builders
+ // ========================================================================
#[test]
fn test_byte_view_append_val() {
let mut builder =
@@ -1033,12 +1695,126 @@ mod tests {
let output = Box::new(builder).build();
// should be 2 output buffers to hold all the data
- assert_eq!(output.as_string_view().data_buffers().len(), 2,);
+ assert_eq!(output.as_string_view().data_buffers().len(), 2);
assert_eq!(&output, &builder_array)
}
#[test]
fn test_byte_view_equal_to() {
+ let append = |builder: &mut ByteViewGroupValueBuilder<StringViewType>,
+ builder_array: &ArrayRef,
+ append_rows: &[usize]| {
+ for &index in append_rows {
+ builder.append_val(builder_array, index);
+ }
+ };
+
+ let equal_to = |builder: &ByteViewGroupValueBuilder<StringViewType>,
+ lhs_rows: &[usize],
+ input_array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut Vec<bool>| {
+ let iter = lhs_rows.iter().zip(rhs_rows.iter());
+ for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
+ equal_to_results[idx] = builder.equal_to(lhs_row, input_array,
rhs_row);
+ }
+ };
+
+ test_byte_view_equal_to_internal(append, equal_to);
+ }
+
+ #[test]
+ fn test_byte_view_vectorized_equal_to() {
+ let append = |builder: &mut ByteViewGroupValueBuilder<StringViewType>,
+ builder_array: &ArrayRef,
+ append_rows: &[usize]| {
+ builder.vectorized_append(builder_array, append_rows);
+ };
+
+ let equal_to = |builder: &ByteViewGroupValueBuilder<StringViewType>,
+ lhs_rows: &[usize],
+ input_array: &ArrayRef,
+ rhs_rows: &[usize],
+ equal_to_results: &mut Vec<bool>| {
+ builder.vectorized_equal_to(
+ lhs_rows,
+ input_array,
+ rhs_rows,
+ equal_to_results,
+ );
+ };
+
+ test_byte_view_equal_to_internal(append, equal_to);
+ }
+
+ #[test]
+ fn test_byte_view_vectorized_operation_special_case() {
+ // Test the special `all nulls` or `not nulls` input array case
+ // for vectorized append and equal to
+
+ let mut builder =
+
ByteViewGroupValueBuilder::<StringViewType>::new().with_max_block_size(60);
+
+ // All nulls input array
+ let all_nulls_input_array = Arc::new(StringViewArray::from(vec![
+ Option::<&str>::None,
+ None,
+ None,
+ None,
+ None,
+ ])) as _;
+ builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]);
+
+ let mut equal_to_results = vec![true; all_nulls_input_array.len()];
+ builder.vectorized_equal_to(
+ &[0, 1, 2, 3, 4],
+ &all_nulls_input_array,
+ &[0, 1, 2, 3, 4],
+ &mut equal_to_results,
+ );
+
+ assert!(equal_to_results[0]);
+ assert!(equal_to_results[1]);
+ assert!(equal_to_results[2]);
+ assert!(equal_to_results[3]);
+ assert!(equal_to_results[4]);
+
+ // All not nulls input array
+ let all_not_nulls_input_array = Arc::new(StringViewArray::from(vec![
+ Some("stringview1"),
+ Some("stringview2"),
+ Some("stringview3"),
+ Some("stringview4"),
+ Some("stringview5"),
+ ])) as _;
+ builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3,
4]);
+
+ let mut equal_to_results = vec![true; all_not_nulls_input_array.len()];
+ builder.vectorized_equal_to(
+ &[5, 6, 7, 8, 9],
+ &all_not_nulls_input_array,
+ &[0, 1, 2, 3, 4],
+ &mut equal_to_results,
+ );
+
+ assert!(equal_to_results[0]);
+ assert!(equal_to_results[1]);
+ assert!(equal_to_results[2]);
+ assert!(equal_to_results[3]);
+ assert!(equal_to_results[4]);
+ }
+
+ fn test_byte_view_equal_to_internal<A, E>(mut append: A, mut equal_to: E)
+ where
+ A: FnMut(&mut ByteViewGroupValueBuilder<StringViewType>, &ArrayRef,
&[usize]),
+ E: FnMut(
+ &ByteViewGroupValueBuilder<StringViewType>,
+ &[usize],
+ &ArrayRef,
+ &[usize],
+ &mut Vec<bool>,
+ ),
+ {
// Will cover such cases:
// - exist null, input not null
// - exist null, input null; values not equal
@@ -1078,15 +1854,7 @@ mod tests {
Some("I am a long string for test eq in completed"),
Some("I am a long string for test eq in progress"),
])) as ArrayRef;
- builder.append_val(&builder_array, 0);
- builder.append_val(&builder_array, 1);
- builder.append_val(&builder_array, 2);
- builder.append_val(&builder_array, 3);
- builder.append_val(&builder_array, 4);
- builder.append_val(&builder_array, 5);
- builder.append_val(&builder_array, 6);
- builder.append_val(&builder_array, 7);
- builder.append_val(&builder_array, 8);
+ append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5, 6, 7, 8]);
// Define input array
let (views, buffer, _nulls) = StringViewArray::from(vec![
@@ -1124,18 +1892,27 @@ mod tests {
Arc::new(StringViewArray::new(views, buffer, Some(nulls))) as
ArrayRef;
// Check
- assert!(!builder.equal_to(0, &input_array, 0));
- assert!(builder.equal_to(1, &input_array, 1));
- assert!(builder.equal_to(2, &input_array, 2));
- assert!(!builder.equal_to(3, &input_array, 3));
- assert!(!builder.equal_to(4, &input_array, 4));
- assert!(!builder.equal_to(5, &input_array, 5));
- assert!(builder.equal_to(6, &input_array, 6));
- assert!(!builder.equal_to(7, &input_array, 7));
- assert!(!builder.equal_to(7, &input_array, 8));
- assert!(builder.equal_to(7, &input_array, 9));
- assert!(!builder.equal_to(8, &input_array, 10));
- assert!(builder.equal_to(8, &input_array, 11));
+ let mut equal_to_results = vec![true; input_array.len()];
+ equal_to(
+ &builder,
+ &[0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 8, 8],
+ &input_array,
+ &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
+ &mut equal_to_results,
+ );
+
+ assert!(!equal_to_results[0]);
+ assert!(equal_to_results[1]);
+ assert!(equal_to_results[2]);
+ assert!(!equal_to_results[3]);
+ assert!(!equal_to_results[4]);
+ assert!(!equal_to_results[5]);
+ assert!(equal_to_results[6]);
+ assert!(!equal_to_results[7]);
+ assert!(!equal_to_results[8]);
+ assert!(equal_to_results[9]);
+ assert!(!equal_to_results[10]);
+ assert!(equal_to_results[11]);
}
#[test]
diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs
b/datafusion/physical-plan/src/aggregates/group_values/mod.rs
index fb7b667750..aefd9c1622 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs
@@ -37,6 +37,8 @@ mod bytes_view;
use bytes::GroupValuesByes;
use datafusion_physical_expr::binary_map::OutputType;
+use crate::aggregates::order::GroupOrdering;
+
mod group_column;
mod null_builder;
@@ -105,7 +107,10 @@ pub trait GroupValues: Send {
}
/// Return a specialized implementation of [`GroupValues`] for the given
schema.
-pub fn new_group_values(schema: SchemaRef) -> Result<Box<dyn GroupValues>> {
+pub fn new_group_values(
+ schema: SchemaRef,
+ group_ordering: &GroupOrdering,
+) -> Result<Box<dyn GroupValues>> {
if schema.fields.len() == 1 {
let d = schema.fields[0].data_type();
@@ -143,8 +148,12 @@ pub fn new_group_values(schema: SchemaRef) ->
Result<Box<dyn GroupValues>> {
}
}
- if GroupValuesColumn::supported_schema(schema.as_ref()) {
- Ok(Box::new(GroupValuesColumn::try_new(schema)?))
+ if column::supported_schema(schema.as_ref()) {
+ if matches!(group_ordering, GroupOrdering::None) {
+ Ok(Box::new(GroupValuesColumn::<false>::try_new(schema)?))
+ } else {
+ Ok(Box::new(GroupValuesColumn::<true>::try_new(schema)?))
+ }
} else {
Ok(Box::new(GroupValuesRows::try_new(schema)?))
}
diff --git
a/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs
b/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs
index 0249390f38..a584cf58e5 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs
@@ -70,6 +70,24 @@ impl MaybeNullBufferBuilder {
}
}
+ pub fn append_n(&mut self, n: usize, is_null: bool) {
+ match self {
+ Self::NoNulls { row_count } if is_null => {
+ // have seen no nulls so far, this is the first null,
+ // need to create the nulls buffer for all currently valid
values
+ // alloc 2x the need given we push a new but immediately
+ let mut nulls = BooleanBufferBuilder::new(*row_count * 2);
+ nulls.append_n(*row_count, true);
+ nulls.append_n(n, false);
+ *self = Self::Nulls(nulls);
+ }
+ Self::NoNulls { row_count } => {
+ *row_count += n;
+ }
+ Self::Nulls(builder) => builder.append_n(n, !is_null),
+ }
+ }
+
/// return the number of heap allocated bytes used by this structure to
store boolean values
pub fn allocated_size(&self) -> usize {
match self {
diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs
b/datafusion/physical-plan/src/aggregates/row_hash.rs
index fe05f7375e..0fa9f206f1 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -514,7 +514,7 @@ impl GroupedHashAggregateStream {
ordering.as_ref(),
)?;
- let group_values = new_group_values(group_schema)?;
+ let group_values = new_group_values(group_schema, &group_ordering)?;
timer.done();
let exec_state = ExecutionState::ReadingInput;
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 5315637174..ecfbaee235 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -29,6 +29,7 @@ use prost::Message;
use std::any::Any;
use std::collections::HashMap;
use std::fmt::{self, Debug, Formatter};
+use std::mem::size_of_val;
use std::sync::Arc;
use std::vec;
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 5687c9af54..d4e2d48885 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -23,6 +23,7 @@ use datafusion_substrait::logical_plan::{
consumer::from_substrait_plan, producer::to_substrait_plan,
};
use std::cmp::Ordering;
+use std::mem::size_of_val;
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema,
TimeUnit};
use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef};
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]