This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 61ed374e96 First and Last Accumulators should update with state row
excluding is_set flag (#7565)
61ed374e96 is described below
commit 61ed374e96ac10136c8ddf2b80260d6ea8b6d30b
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Sep 16 03:22:09 2023 -0700
First and Last Accumulators should update with state row excluding is_set
flag (#7565)
* First and Last Accumulators should update with state row excluding is_set
flag
* Add test
* Update datafusion/physical-expr/src/aggregate/first_last.rs
Co-authored-by: Andrew Lamb <[email protected]>
* Update datafusion/physical-expr/src/aggregate/first_last.rs
* Remove
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
.../physical-expr/src/aggregate/first_last.rs | 101 +++++++++++++++++----
1 file changed, 83 insertions(+), 18 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs
b/datafusion/physical-expr/src/aggregate/first_last.rs
index 02bb466d44..6ae7b4895a 100644
--- a/datafusion/physical-expr/src/aggregate/first_last.rs
+++ b/datafusion/physical-expr/src/aggregate/first_last.rs
@@ -165,8 +165,6 @@ struct FirstValueAccumulator {
orderings: Vec<ScalarValue>,
// Stores the applicable ordering requirement.
ordering_req: LexOrdering,
- // Whether merge_batch() is called before
- is_merge_called: bool,
}
impl FirstValueAccumulator {
@@ -185,7 +183,6 @@ impl FirstValueAccumulator {
is_set: false,
orderings,
ordering_req,
- is_merge_called: false,
})
}
@@ -201,9 +198,7 @@ impl Accumulator for FirstValueAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.first.clone()];
result.extend(self.orderings.iter().cloned());
- if !self.is_merge_called {
- result.push(ScalarValue::Boolean(Some(self.is_set)));
- }
+ result.push(ScalarValue::Boolean(Some(self.is_set)));
Ok(result)
}
@@ -218,7 +213,6 @@ impl Accumulator for FirstValueAccumulator {
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
- self.is_merge_called = true;
// FIRST_VALUE(first1, first2, first3, ...)
// last index contains is_set flag.
let is_set_idx = states.len() - 1;
@@ -237,13 +231,17 @@ impl Accumulator for FirstValueAccumulator {
};
if !ordered_states[0].is_empty() {
let first_row = get_row_at_idx(&ordered_states, 0)?;
- let first_ordering = &first_row[1..];
+ // When collecting orderings, we exclude the is_set flag from the
state.
+ let first_ordering = &first_row[1..is_set_idx];
let sort_options = get_sort_options(&self.ordering_req);
// Either there is no existing value, or there is an earlier
version in new data.
if !self.is_set
|| compare_rows(first_ordering, &self.orderings,
&sort_options)?.is_lt()
{
- self.update_with_new_row(&first_row);
+ // Update with first value in the state. Note that we should
exclude the
+ // is_set flag from the state. Otherwise, we will end up with
a state
+ // containing two is_set flags.
+ self.update_with_new_row(&first_row[0..is_set_idx]);
}
}
Ok(())
@@ -390,8 +388,6 @@ struct LastValueAccumulator {
orderings: Vec<ScalarValue>,
// Stores the applicable ordering requirement.
ordering_req: LexOrdering,
- // Whether merge_batch() is called before
- is_merge_called: bool,
}
impl LastValueAccumulator {
@@ -410,7 +406,6 @@ impl LastValueAccumulator {
is_set: false,
orderings,
ordering_req,
- is_merge_called: false,
})
}
@@ -426,9 +421,7 @@ impl Accumulator for LastValueAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.last.clone()];
result.extend(self.orderings.clone());
- if !self.is_merge_called {
- result.push(ScalarValue::Boolean(Some(self.is_set)));
- }
+ result.push(ScalarValue::Boolean(Some(self.is_set)));
Ok(result)
}
@@ -442,7 +435,6 @@ impl Accumulator for LastValueAccumulator {
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
- self.is_merge_called = true;
// LAST_VALUE(last1, last2, last3, ...)
// last index contains is_set flag.
let is_set_idx = states.len() - 1;
@@ -463,14 +455,18 @@ impl Accumulator for LastValueAccumulator {
if !ordered_states[0].is_empty() {
let last_idx = ordered_states[0].len() - 1;
let last_row = get_row_at_idx(&ordered_states, last_idx)?;
- let last_ordering = &last_row[1..];
+ // When collecting orderings, we exclude the is_set flag from the
state.
+ let last_ordering = &last_row[1..is_set_idx];
let sort_options = get_sort_options(&self.ordering_req);
// Either there is no existing value, or there is a newer (latest)
// version in the new data:
if !self.is_set
|| compare_rows(last_ordering, &self.orderings,
&sort_options)?.is_gt()
{
- self.update_with_new_row(&last_row);
+ // Update with last value in the state. Note that we should
exclude the
+ // is_set flag from the state. Otherwise, we will end up with
a state
+ // containing two is_set flags.
+ self.update_with_new_row(&last_row[0..is_set_idx]);
}
}
Ok(())
@@ -531,6 +527,7 @@ mod tests {
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;
+ use arrow::compute::concat;
use std::sync::Arc;
#[test]
@@ -562,4 +559,72 @@ mod tests {
assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
Ok(())
}
+
+ #[test]
+ fn test_first_last_state_after_merge() -> Result<()> {
+ let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
+ // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10,
2 to 12
+ let arrs = ranges
+ .into_iter()
+ .map(|(start, end)| {
+ Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
+ })
+ .collect::<Vec<_>>();
+
+ // FirstValueAccumulator
+ let mut first_accumulator =
+ FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+
+ first_accumulator.update_batch(&[arrs[0].clone()])?;
+ let state1 = first_accumulator.state()?;
+
+ let mut first_accumulator =
+ FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+ first_accumulator.update_batch(&[arrs[1].clone()])?;
+ let state2 = first_accumulator.state()?;
+
+ assert_eq!(state1.len(), state2.len());
+
+ let mut states = vec![];
+
+ for idx in 0..state1.len() {
+ states.push(concat(&[&state1[idx].to_array(),
&state2[idx].to_array()])?);
+ }
+
+ let mut first_accumulator =
+ FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+ first_accumulator.merge_batch(&states)?;
+
+ let merged_state = first_accumulator.state()?;
+ assert_eq!(merged_state.len(), state1.len());
+
+ // LastValueAccumulator
+ let mut last_accumulator =
+ LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+
+ last_accumulator.update_batch(&[arrs[0].clone()])?;
+ let state1 = last_accumulator.state()?;
+
+ let mut last_accumulator =
+ LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+ last_accumulator.update_batch(&[arrs[1].clone()])?;
+ let state2 = last_accumulator.state()?;
+
+ assert_eq!(state1.len(), state2.len());
+
+ let mut states = vec![];
+
+ for idx in 0..state1.len() {
+ states.push(concat(&[&state1[idx].to_array(),
&state2[idx].to_array()])?);
+ }
+
+ let mut last_accumulator =
+ LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+ last_accumulator.merge_batch(&states)?;
+
+ let merged_state = last_accumulator.state()?;
+ assert_eq!(merged_state.len(), state1.len());
+
+ Ok(())
+ }
}