This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 61ed374e96 First and Last Accumulators should update with state row 
excluding is_set flag (#7565)
61ed374e96 is described below

commit 61ed374e96ac10136c8ddf2b80260d6ea8b6d30b
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Sep 16 03:22:09 2023 -0700

    First and Last Accumulators should update with state row excluding is_set 
flag (#7565)
    
    * First and Last Accumulators should update with state row excluding is_set 
flag
    
    * Add test
    
    * Update datafusion/physical-expr/src/aggregate/first_last.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Update datafusion/physical-expr/src/aggregate/first_last.rs
    
    * Remove
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../physical-expr/src/aggregate/first_last.rs      | 101 +++++++++++++++++----
 1 file changed, 83 insertions(+), 18 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs 
b/datafusion/physical-expr/src/aggregate/first_last.rs
index 02bb466d44..6ae7b4895a 100644
--- a/datafusion/physical-expr/src/aggregate/first_last.rs
+++ b/datafusion/physical-expr/src/aggregate/first_last.rs
@@ -165,8 +165,6 @@ struct FirstValueAccumulator {
     orderings: Vec<ScalarValue>,
     // Stores the applicable ordering requirement.
     ordering_req: LexOrdering,
-    // Whether merge_batch() is called before
-    is_merge_called: bool,
 }
 
 impl FirstValueAccumulator {
@@ -185,7 +183,6 @@ impl FirstValueAccumulator {
             is_set: false,
             orderings,
             ordering_req,
-            is_merge_called: false,
         })
     }
 
@@ -201,9 +198,7 @@ impl Accumulator for FirstValueAccumulator {
     fn state(&self) -> Result<Vec<ScalarValue>> {
         let mut result = vec![self.first.clone()];
         result.extend(self.orderings.iter().cloned());
-        if !self.is_merge_called {
-            result.push(ScalarValue::Boolean(Some(self.is_set)));
-        }
+        result.push(ScalarValue::Boolean(Some(self.is_set)));
         Ok(result)
     }
 
@@ -218,7 +213,6 @@ impl Accumulator for FirstValueAccumulator {
     }
 
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        self.is_merge_called = true;
         // FIRST_VALUE(first1, first2, first3, ...)
         // last index contains is_set flag.
         let is_set_idx = states.len() - 1;
@@ -237,13 +231,17 @@ impl Accumulator for FirstValueAccumulator {
         };
         if !ordered_states[0].is_empty() {
             let first_row = get_row_at_idx(&ordered_states, 0)?;
-            let first_ordering = &first_row[1..];
+            // When collecting orderings, we exclude the is_set flag from the 
state.
+            let first_ordering = &first_row[1..is_set_idx];
             let sort_options = get_sort_options(&self.ordering_req);
             // Either there is no existing value, or there is an earlier 
version in new data.
             if !self.is_set
                 || compare_rows(first_ordering, &self.orderings, 
&sort_options)?.is_lt()
             {
-                self.update_with_new_row(&first_row);
+                // Update with first value in the state. Note that we should 
exclude the
+                // is_set flag from the state. Otherwise, we will end up with 
a state
+                // containing two is_set flags.
+                self.update_with_new_row(&first_row[0..is_set_idx]);
             }
         }
         Ok(())
@@ -390,8 +388,6 @@ struct LastValueAccumulator {
     orderings: Vec<ScalarValue>,
     // Stores the applicable ordering requirement.
     ordering_req: LexOrdering,
-    // Whether merge_batch() is called before
-    is_merge_called: bool,
 }
 
 impl LastValueAccumulator {
@@ -410,7 +406,6 @@ impl LastValueAccumulator {
             is_set: false,
             orderings,
             ordering_req,
-            is_merge_called: false,
         })
     }
 
@@ -426,9 +421,7 @@ impl Accumulator for LastValueAccumulator {
     fn state(&self) -> Result<Vec<ScalarValue>> {
         let mut result = vec![self.last.clone()];
         result.extend(self.orderings.clone());
-        if !self.is_merge_called {
-            result.push(ScalarValue::Boolean(Some(self.is_set)));
-        }
+        result.push(ScalarValue::Boolean(Some(self.is_set)));
         Ok(result)
     }
 
@@ -442,7 +435,6 @@ impl Accumulator for LastValueAccumulator {
     }
 
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        self.is_merge_called = true;
         // LAST_VALUE(last1, last2, last3, ...)
         // last index contains is_set flag.
         let is_set_idx = states.len() - 1;
@@ -463,14 +455,18 @@ impl Accumulator for LastValueAccumulator {
         if !ordered_states[0].is_empty() {
             let last_idx = ordered_states[0].len() - 1;
             let last_row = get_row_at_idx(&ordered_states, last_idx)?;
-            let last_ordering = &last_row[1..];
+            // When collecting orderings, we exclude the is_set flag from the 
state.
+            let last_ordering = &last_row[1..is_set_idx];
             let sort_options = get_sort_options(&self.ordering_req);
             // Either there is no existing value, or there is a newer (latest)
             // version in the new data:
             if !self.is_set
                 || compare_rows(last_ordering, &self.orderings, 
&sort_options)?.is_gt()
             {
-                self.update_with_new_row(&last_row);
+                // Update with last value in the state. Note that we should 
exclude the
+                // is_set flag from the state. Otherwise, we will end up with 
a state
+                // containing two is_set flags.
+                self.update_with_new_row(&last_row[0..is_set_idx]);
             }
         }
         Ok(())
@@ -531,6 +527,7 @@ mod tests {
     use datafusion_common::{Result, ScalarValue};
     use datafusion_expr::Accumulator;
 
+    use arrow::compute::concat;
     use std::sync::Arc;
 
     #[test]
@@ -562,4 +559,72 @@ mod tests {
         assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
         Ok(())
     }
+
+    #[test]
+    fn test_first_last_state_after_merge() -> Result<()> {
+        let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
+        // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 
2 to 12
+        let arrs = ranges
+            .into_iter()
+            .map(|(start, end)| {
+                Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
+            })
+            .collect::<Vec<_>>();
+
+        // FirstValueAccumulator
+        let mut first_accumulator =
+            FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+
+        first_accumulator.update_batch(&[arrs[0].clone()])?;
+        let state1 = first_accumulator.state()?;
+
+        let mut first_accumulator =
+            FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+        first_accumulator.update_batch(&[arrs[1].clone()])?;
+        let state2 = first_accumulator.state()?;
+
+        assert_eq!(state1.len(), state2.len());
+
+        let mut states = vec![];
+
+        for idx in 0..state1.len() {
+            states.push(concat(&[&state1[idx].to_array(), 
&state2[idx].to_array()])?);
+        }
+
+        let mut first_accumulator =
+            FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+        first_accumulator.merge_batch(&states)?;
+
+        let merged_state = first_accumulator.state()?;
+        assert_eq!(merged_state.len(), state1.len());
+
+        // LastValueAccumulator
+        let mut last_accumulator =
+            LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+
+        last_accumulator.update_batch(&[arrs[0].clone()])?;
+        let state1 = last_accumulator.state()?;
+
+        let mut last_accumulator =
+            LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+        last_accumulator.update_batch(&[arrs[1].clone()])?;
+        let state2 = last_accumulator.state()?;
+
+        assert_eq!(state1.len(), state2.len());
+
+        let mut states = vec![];
+
+        for idx in 0..state1.len() {
+            states.push(concat(&[&state1[idx].to_array(), 
&state2[idx].to_array()])?);
+        }
+
+        let mut last_accumulator =
+            LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+        last_accumulator.merge_batch(&states)?;
+
+        let merged_state = last_accumulator.state()?;
+        assert_eq!(merged_state.len(), state1.len());
+
+        Ok(())
+    }
 }

Reply via email to