alamb commented on code in PR #6904:
URL: https://github.com/apache/arrow-datafusion/pull/6904#discussion_r1259632608
##########
datafusion/physical-expr/src/aggregate/count.rs:
##########
@@ -76,6 +85,109 @@ impl Count {
}
}
+/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
+/// Stores values as native types, and does overflow checking
+///
+/// Unlike most other accumulators, COUNT never produces NULLs. If no
+/// non-null values are seen in any group the output is 0. Thus, this
+/// accumulator has no additional null or seen filter tracking.
+#[derive(Debug)]
+struct CountGroupsAccumulator {
+ /// Count per group (use i64 to make Int64Array)
+ counts: Vec<i64>,
+}
+
+impl CountGroupsAccumulator {
+ pub fn new() -> Self {
+ Self { counts: vec![] }
+ }
+}
+
+impl GroupsAccumulator for CountGroupsAccumulator {
+ fn update_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&arrow_array::BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ assert_eq!(values.len(), 1, "single argument to update_batch");
+ let values = values.get(0).unwrap();
+
+ // Add one to each group's counter for each non null, non
+ // filtered value
+ self.counts.resize(total_num_groups, 0);
+ accumulate_indices(
+ group_indices,
+ values.nulls(), // ignore values
+ opt_filter,
+ |group_index| {
Review Comment:
Ah, I was confused -- I used `add_wrapping` (part of `std::ops`) rather than
`wrapping_add` (part of Arrow NativeType) 🤦
I played around with it -- it turns out anything with an actual rust type
(like i64) I can use `+=` but for templated code on `ArrowNativeType` I can't
use `+=` I have to use `wrapping_add`
This is how the code would look like if I use `add_wrapping`, which I think
looks a bit more tortured (though it is consistent). Do you think I should make
the change?
```diff
diff --git a/datafusion/physical-expr/src/aggregate/count.rs
b/datafusion/physical-expr/src/aggregate/count.rs
index 287970de29..cfb1713d8e 100644
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ b/datafusion/physical-expr/src/aggregate/count.rs
@@ -156,9 +156,10 @@ impl GroupsAccumulator for CountGroupsAccumulator {
.iter()
.zip(group_indices.iter())
.zip(partial_counts.iter())
- .for_each(|((filter_value, &group_index), partial_count)| {
+ .for_each(|((filter_value, &group_index), &partial_count)| {
if let Some(true) = filter_value {
- self.counts[group_index] += partial_count;
+ let count = &mut self.counts[group_index];
+ *count = count.add_wrapping(partial_count);
}
}),
None =>
group_indices.iter().zip(partial_counts.iter()).for_each(
```
##########
datafusion/physical-expr/src/aggregate/count.rs:
##########
@@ -76,6 +85,109 @@ impl Count {
}
}
+/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
+/// Stores values as native types, and does overflow checking
+///
+/// Unlike most other accumulators, COUNT never produces NULLs. If no
+/// non-null values are seen in any group the output is 0. Thus, this
+/// accumulator has no additional null or seen filter tracking.
+#[derive(Debug)]
+struct CountGroupsAccumulator {
+ /// Count per group (use i64 to make Int64Array)
+ counts: Vec<i64>,
+}
+
+impl CountGroupsAccumulator {
+ pub fn new() -> Self {
+ Self { counts: vec![] }
+ }
+}
+
+impl GroupsAccumulator for CountGroupsAccumulator {
+ fn update_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&arrow_array::BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ assert_eq!(values.len(), 1, "single argument to update_batch");
+ let values = values.get(0).unwrap();
+
+ // Add one to each group's counter for each non null, non
+ // filtered value
+ self.counts.resize(total_num_groups, 0);
+ accumulate_indices(
+ group_indices,
+ values.nulls(), // ignore values
+ opt_filter,
+ |group_index| {
Review Comment:
Ah, I was confused -- I used `add_wrapping` (part of `std::ops`) rather than
`wrapping_add` (part of Arrow NativeType) 🤦
I played around with it -- it turns out anything with an actual rust type
(like i64) I can use `+=` but for templated code on `ArrowNativeType` I can't
use `+=` I have to use `wrapping_add`
This is how the code would look like if I use `add_wrapping`, which I think
looks a bit more tortured (though it is consistent). Do you think I should make
the change?
```diff
diff --git a/datafusion/physical-expr/src/aggregate/count.rs
b/datafusion/physical-expr/src/aggregate/count.rs
index 287970de29..cfb1713d8e 100644
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ b/datafusion/physical-expr/src/aggregate/count.rs
@@ -156,9 +156,10 @@ impl GroupsAccumulator for CountGroupsAccumulator {
.iter()
.zip(group_indices.iter())
.zip(partial_counts.iter())
- .for_each(|((filter_value, &group_index), partial_count)| {
+ .for_each(|((filter_value, &group_index), &partial_count)| {
if let Some(true) = filter_value {
- self.counts[group_index] += partial_count;
+ let count = &mut self.counts[group_index];
+ *count = count.add_wrapping(partial_count);
}
}),
None =>
group_indices.iter().zip(partial_counts.iter()).for_each(
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]