This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new c07c26cd1e Fix incorrect results in `BitAnd` GroupsAccumulator (#6957)
c07c26cd1e is described below

commit c07c26cd1e237cda4f8db332f6b7acec3ab4055c
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Jul 13 13:19:34 2023 -0400

    Fix incorrect results in `BitAnd` GroupsAccumulator (#6957)
    
    Fix accumulator
---
 .../tests/sqllogictests/test_files/aggregate.slt   | 184 +++++++++++++--------
 .../physical-expr/src/aggregate/bit_and_or_xor.rs  |  83 ++++------
 .../src/aggregate/groups_accumulator/prim_op.rs    |  12 +-
 3 files changed, 160 insertions(+), 119 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt 
b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index 95cf51d571..72b9e8400b 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -1420,65 +1420,95 @@ select var(sq.column1), var_pop(sq.column1), 
stddev(sq.column1), stddev_pop(sq.c
 2 1 1.414213562373 1
 
 
-# sum / count for all nulls
-statement ok
-create table the_nulls as values (null::bigint, 1), (null::bigint, 1), 
(null::bigint, 2);
 
-# counts should be zeros (even for nulls)
-query II
-SELECT count(column1), column2 from the_nulls group by column2 order by 
column2;
-----
-0 1
-0 2
-
-# sums should be null
-query II
-SELECT sum(column1), column2 from the_nulls group by column2 order by column2;
+# aggregates on empty tables
+statement ok
+CREATE TABLE empty (column1 bigint, column2 int);
+
+# no group by column
+query IIRIIIII
+SELECT
+  count(column1), -- counts should be zero, even for nulls
+  sum(column1),   -- other aggregates should be null
+  avg(column1),
+  min(column1),
+  max(column1),
+  bit_and(column1),
+  bit_or(column1),
+  bit_xor(column1)
+FROM empty
+----
+0 NULL NULL NULL NULL NULL NULL NULL
+
+# Same query but with grouping (no groups, so no output)
+query IIRIIIIII
+SELECT
+  count(column1),
+  sum(column1),
+  avg(column1),
+  min(column1),
+  max(column1),
+  bit_and(column1),
+  bit_or(column1),
+  bit_xor(column1),
+  column2
+FROM empty
+GROUP BY column2
+ORDER BY column2;
 ----
-NULL 1
-NULL 2
 
-# avg should be null
-query RI
-SELECT avg(column1), column2 from the_nulls group by column2 order by column2;
-----
-NULL 1
-NULL 2
 
-# bit_and should be null
-query II
-SELECT bit_and(column1), column2 from the_nulls group by column2 order by 
column2;
-----
-NULL 1
-NULL 2
+statement ok
+drop table empty
 
-# bit_or should be null
-query II
-SELECT bit_or(column1), column2 from the_nulls group by column2 order by 
column2;
-----
-NULL 1
-NULL 2
+# aggregates on all nulls
+statement ok
+CREATE TABLE the_nulls
+AS VALUES
+  (null::bigint, 1),
+  (null::bigint, 1),
+  (null::bigint, 2);
 
-# bit_xor should be null
 query II
-SELECT bit_xor(column1), column2 from the_nulls group by column2 order by 
column2;
+select * from the_nulls
 ----
 NULL 1
-NULL 2
-
-# min should be null
-query II
-SELECT min(column1), column2 from the_nulls group by column2 order by column2;
-----
 NULL 1
 NULL 2
 
-# max should be null
-query II
-SELECT max(column1), column2 from the_nulls group by column2 order by column2;
-----
-NULL 1
-NULL 2
+# no group by column
+query IIRIIIII
+SELECT
+  count(column1), -- counts should be zero, even for nulls
+  sum(column1),   -- other aggregates should be null
+  avg(column1),
+  min(column1),
+  max(column1),
+  bit_and(column1),
+  bit_or(column1),
+  bit_xor(column1)
+FROM the_nulls
+----
+0 NULL NULL NULL NULL NULL NULL NULL
+
+# Same query but with grouping
+query IIRIIIIII
+SELECT
+  count(column1), -- counts should be zero, even for nulls
+  sum(column1),   -- other aggregates should be null
+  avg(column1),
+  min(column1),
+  max(column1),
+  bit_and(column1),
+  bit_or(column1),
+  bit_xor(column1),
+  column2
+FROM the_nulls
+GROUP BY column2
+ORDER BY column2;
+----
+0 NULL NULL NULL NULL NULL NULL NULL 1
+0 NULL NULL NULL NULL NULL NULL NULL 2
 
 
 statement ok
@@ -1489,29 +1519,49 @@ create table bit_aggregate_functions (
   c1 SMALLINT NOT NULL,
   c2 SMALLINT NOT NULL,
   c3 SMALLINT,
+  tag varchar
 )
 as values
-  (5, 10, 11),
-  (33, 11, null),
-  (9, 12, null);
-
-# query_bit_and
-query III
-SELECT bit_and(c1), bit_and(c2), bit_and(c3) FROM bit_aggregate_functions
-----
-1 8 11
-
-# query_bit_or
-query III
-SELECT bit_or(c1), bit_or(c2), bit_or(c3) FROM bit_aggregate_functions
-----
-45 15 11
+  (5,  10, 11,   'A'),
+  (33, 11, null, 'B'),
+  (9,  12, null, 'A');
+
+# query_bit_and, query_bit_or, query_bit_xor
+query IIIIIIIII
+SELECT
+  bit_and(c1),
+  bit_and(c2),
+  bit_and(c3),
+  bit_or(c1),
+  bit_or(c2),
+  bit_or(c3),
+  bit_xor(c1),
+  bit_xor(c2),
+  bit_xor(c3)
+FROM bit_aggregate_functions
+----
+1 8 11 45 15 11 45 13 11
+
+# query_bit_and, query_bit_or, query_bit_xor, with group
+query IIIIIIIIIT
+SELECT
+  bit_and(c1),
+  bit_and(c2),
+  bit_and(c3),
+  bit_or(c1),
+  bit_or(c2),
+  bit_or(c3),
+  bit_xor(c1),
+  bit_xor(c2),
+  bit_xor(c3),
+  tag
+FROM bit_aggregate_functions
+GROUP BY tag
+ORDER BY tag
+----
+1 8 11 13 14 11 12 6 11 A
+33 11 NULL 33 11 NULL 33 11 NULL B
 
-# query_bit_xor
-query III
-SELECT bit_xor(c1), bit_xor(c2), bit_xor(c3) FROM bit_aggregate_functions
-----
-45 13 11
 
 statement ok
 create table bool_aggregate_functions (
diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs 
b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
index ab37e5891e..6a2d509389 100644
--- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
+++ b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
@@ -49,15 +49,16 @@ use arrow::compute::{bit_and, bit_or, bit_xor};
 use datafusion_row::accessor::RowAccessor;
 
 /// Creates a [`PrimitiveGroupsAccumulator`] with the specified
-/// [`ArrowPrimitiveType`] which applies `$FN` to each element
+/// [`ArrowPrimitiveType`] that initailizes each accumulator to $START
+/// and applies `$FN` to each element
 ///
 /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
-macro_rules! instantiate_primitive_accumulator {
-    ($SELF:expr, $PRIMTYPE:ident, $FN:expr) => {{
-        Ok(Box::new(PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(
-            &$SELF.data_type,
-            $FN,
-        )))
+macro_rules! instantiate_accumulator {
+    ($SELF:expr, $START:expr, $PRIMTYPE:ident, $FN:expr) => {{
+        Ok(Box::new(
+            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$SELF.data_type, 
$FN)
+                .with_starting_value($START),
+        ))
     }};
 }
 
@@ -279,35 +280,31 @@ impl AggregateExpr for BitAnd {
         use std::ops::BitAndAssign;
         match self.data_type {
             DataType::Int8 => {
-                instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
-                    .bitand_assign(y))
+                instantiate_accumulator!(self, -1, Int8Type, |x, y| 
x.bitand_assign(y))
             }
             DataType::Int16 => {
-                instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
-                    .bitand_assign(y))
+                instantiate_accumulator!(self, -1, Int16Type, |x, y| 
x.bitand_assign(y))
             }
             DataType::Int32 => {
-                instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
-                    .bitand_assign(y))
+                instantiate_accumulator!(self, -1, Int32Type, |x, y| 
x.bitand_assign(y))
             }
             DataType::Int64 => {
-                instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
-                    .bitand_assign(y))
+                instantiate_accumulator!(self, -1, Int64Type, |x, y| 
x.bitand_assign(y))
             }
             DataType::UInt8 => {
-                instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
+                instantiate_accumulator!(self, u8::MAX, UInt8Type, |x, y| x
                     .bitand_assign(y))
             }
             DataType::UInt16 => {
-                instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
+                instantiate_accumulator!(self, u16::MAX, UInt16Type, |x, y| x
                     .bitand_assign(y))
             }
             DataType::UInt32 => {
-                instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
+                instantiate_accumulator!(self, u32::MAX, UInt32Type, |x, y| x
                     .bitand_assign(y))
             }
             DataType::UInt64 => {
-                instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
+                instantiate_accumulator!(self, u64::MAX, UInt64Type, |x, y| x
                     .bitand_assign(y))
             }
 
@@ -517,36 +514,28 @@ impl AggregateExpr for BitOr {
         use std::ops::BitOrAssign;
         match self.data_type {
             DataType::Int8 => {
-                instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
-                    .bitor_assign(y))
+                instantiate_accumulator!(self, 0, Int8Type, |x, y| 
x.bitor_assign(y))
             }
             DataType::Int16 => {
-                instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
-                    .bitor_assign(y))
+                instantiate_accumulator!(self, 0, Int16Type, |x, y| 
x.bitor_assign(y))
             }
             DataType::Int32 => {
-                instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
-                    .bitor_assign(y))
+                instantiate_accumulator!(self, 0, Int32Type, |x, y| 
x.bitor_assign(y))
             }
             DataType::Int64 => {
-                instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
-                    .bitor_assign(y))
+                instantiate_accumulator!(self, 0, Int64Type, |x, y| 
x.bitor_assign(y))
             }
             DataType::UInt8 => {
-                instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
-                    .bitor_assign(y))
+                instantiate_accumulator!(self, 0, UInt8Type, |x, y| 
x.bitor_assign(y))
             }
             DataType::UInt16 => {
-                instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
-                    .bitor_assign(y))
+                instantiate_accumulator!(self, 0, UInt16Type, |x, y| 
x.bitor_assign(y))
             }
             DataType::UInt32 => {
-                instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
-                    .bitor_assign(y))
+                instantiate_accumulator!(self, 0, UInt32Type, |x, y| 
x.bitor_assign(y))
             }
             DataType::UInt64 => {
-                instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
-                    .bitor_assign(y))
+                instantiate_accumulator!(self, 0, UInt64Type, |x, y| 
x.bitor_assign(y))
             }
 
             _ => Err(DataFusionError::NotImplemented(format!(
@@ -756,36 +745,28 @@ impl AggregateExpr for BitXor {
         use std::ops::BitXorAssign;
         match self.data_type {
             DataType::Int8 => {
-                instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
-                    .bitxor_assign(y))
+                instantiate_accumulator!(self, 0, Int8Type, |x, y| 
x.bitxor_assign(y))
             }
             DataType::Int16 => {
-                instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
-                    .bitxor_assign(y))
+                instantiate_accumulator!(self, 0, Int16Type, |x, y| 
x.bitxor_assign(y))
             }
             DataType::Int32 => {
-                instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
-                    .bitxor_assign(y))
+                instantiate_accumulator!(self, 0, Int32Type, |x, y| 
x.bitxor_assign(y))
             }
             DataType::Int64 => {
-                instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
-                    .bitxor_assign(y))
+                instantiate_accumulator!(self, 0, Int64Type, |x, y| 
x.bitxor_assign(y))
             }
             DataType::UInt8 => {
-                instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
-                    .bitxor_assign(y))
+                instantiate_accumulator!(self, 0, UInt8Type, |x, y| 
x.bitxor_assign(y))
             }
             DataType::UInt16 => {
-                instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
-                    .bitxor_assign(y))
+                instantiate_accumulator!(self, 0, UInt16Type, |x, y| 
x.bitxor_assign(y))
             }
             DataType::UInt32 => {
-                instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
-                    .bitxor_assign(y))
+                instantiate_accumulator!(self, 0, UInt32Type, |x, y| 
x.bitxor_assign(y))
             }
             DataType::UInt64 => {
-                instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
-                    .bitxor_assign(y))
+                instantiate_accumulator!(self, 0, UInt64Type, |x, y| 
x.bitxor_assign(y))
             }
 
             _ => Err(DataFusionError::NotImplemented(format!(
diff --git 
a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs 
b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs
index 8603010789..a49651a5e3 100644
--- a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs
+++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs
@@ -47,6 +47,9 @@ where
     /// The output type (needed for Decimal precision and scale)
     data_type: DataType,
 
+    /// The starting value for new groups
+    starting_value: T::Native,
+
     /// Track nulls in the input / filters
     null_state: NullState,
 
@@ -64,9 +67,16 @@ where
             values: vec![],
             data_type: data_type.clone(),
             null_state: NullState::new(),
+            starting_value: T::default_value(),
             prim_fn,
         }
     }
+
+    /// Set the starting values for new groups
+    pub fn with_starting_value(mut self, starting_value: T::Native) -> Self {
+        self.starting_value = starting_value;
+        self
+    }
 }
 
 impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F>
@@ -85,7 +95,7 @@ where
         let values = values[0].as_primitive::<T>();
 
         // update values
-        self.values.resize(total_num_groups, T::default_value());
+        self.values.resize(total_num_groups, self.starting_value);
 
         // NullState dispatches / handles tracking nulls and groups that saw 
no values
         self.null_state.accumulate(

Reply via email to