This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new afc9c9daec [MINOR] Remove update state api from PartitionEvaluator 
(#6966)
afc9c9daec is described below

commit afc9c9daecb3277caa62f11d60c2ac23535bbf02
Author: Mustafa Akur <[email protected]>
AuthorDate: Sat Jul 15 13:27:50 2023 +0300

    [MINOR] Remove update state api from PartitionEvaluator (#6966)
    
    * remove update_state api from partition_evaluator
    
    * Resolve linter errors
    
    * Simplifications
    
    * remove row_idx argument from evaluate
    
    * Simplifications
    
    * Update datafusion/expr/src/partition_evaluator.rs
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
    
    * Update comment
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
    
    * Update document
    
    * Use boolean operator instead of bitwise
    
    ---------
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
 datafusion/core/tests/fuzz_cases/window_fuzz.rs    |  7 ++++
 .../user_defined/user_defined_window_functions.rs  | 29 +-------------
 datafusion/expr/src/partition_evaluator.rs         | 43 +++------------------
 datafusion/physical-expr/src/window/built_in.rs    | 36 ++++++++++--------
 datafusion/physical-expr/src/window/lead_lag.rs    | 26 +++++--------
 datafusion/physical-expr/src/window/nth_value.rs   | 12 ------
 datafusion/physical-expr/src/window/rank.rs        | 44 ++++++++--------------
 datafusion/physical-expr/src/window/window_expr.rs | 33 ++++------------
 8 files changed, 67 insertions(+), 163 deletions(-)

diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index 77b6e0a5d1..870baf948b 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -208,6 +208,13 @@ fn get_random_function(
                 vec![],
             ),
         );
+        window_fn_map.insert(
+            "dense_rank",
+            (
+                
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::DenseRank),
+                vec![],
+            ),
+        );
         window_fn_map.insert(
             "lead",
             (
diff --git 
a/datafusion/core/tests/user_defined/user_defined_window_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_window_functions.rs
index 1331347fac..5f99391572 100644
--- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs
@@ -32,8 +32,8 @@ use arrow_schema::DataType;
 use datafusion::{assert_batches_eq, prelude::SessionContext};
 use datafusion_common::{Result, ScalarValue};
 use datafusion_expr::{
-    function::PartitionEvaluatorFactory, window_state::WindowAggState,
-    PartitionEvaluator, ReturnTypeFunction, Signature, Volatility, WindowUDF,
+    function::PartitionEvaluatorFactory, PartitionEvaluator, 
ReturnTypeFunction,
+    Signature, Volatility, WindowUDF,
 };
 
 /// A query with a window function evaluated over the entire partition
@@ -195,7 +195,6 @@ async fn test_stateful_udwf() {
         &execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap()
     );
     assert_eq!(test_state.evaluate_called(), 10);
-    assert_eq!(test_state.update_state_called(), 10);
     assert_eq!(test_state.evaluate_all_called(), 0);
 }
 
@@ -229,7 +228,6 @@ async fn test_stateful_udwf_bounded_window() {
     );
     // Evaluate and update_state is called for each input row
     assert_eq!(test_state.evaluate_called(), 10);
-    assert_eq!(test_state.update_state_called(), 10);
     assert_eq!(test_state.evaluate_all_called(), 0);
 }
 
@@ -388,8 +386,6 @@ struct TestState {
     evaluate_all_called: AtomicUsize,
     /// How many times was `evaluate` called?
     evaluate_called: AtomicUsize,
-    /// How many times was `update_state` called?
-    update_state_called: AtomicUsize,
     /// How many times was `evaluate_all_with_rank` called?
     evaluate_all_with_rank_called: AtomicUsize,
     /// should the functions say they use the window frame?
@@ -451,16 +447,6 @@ impl TestState {
         self.evaluate_called.fetch_add(1, Ordering::SeqCst);
     }
 
-    /// return the update_state_called counter
-    fn update_state_called(&self) -> usize {
-        self.update_state_called.load(Ordering::SeqCst)
-    }
-
-    /// update the update_state_called counter
-    fn inc_update_state_called(&self) {
-        self.update_state_called.fetch_add(1, Ordering::SeqCst);
-    }
-
     /// return the evaluate_all_with_rank_called counter
     fn evaluate_all_with_rank_called(&self) -> usize {
         self.evaluate_all_with_rank_called.load(Ordering::SeqCst)
@@ -555,17 +541,6 @@ impl PartitionEvaluator for OddCounter {
         Ok(Arc::new(array))
     }
 
-    fn update_state(
-        &mut self,
-        _state: &WindowAggState,
-        _idx: usize,
-        _range_columns: &[ArrayRef],
-        _sort_partition_points: &[Range<usize>],
-    ) -> Result<()> {
-        self.test_state.inc_update_state_called();
-        Ok(())
-    }
-
     fn supports_bounded_execution(&self) -> bool {
         self.test_state.supports_bounded_execution
     }
diff --git a/datafusion/expr/src/partition_evaluator.rs 
b/datafusion/expr/src/partition_evaluator.rs
index f0c425ca59..f3fb67c9e2 100644
--- a/datafusion/expr/src/partition_evaluator.rs
+++ b/datafusion/expr/src/partition_evaluator.rs
@@ -69,27 +69,10 @@ use crate::window_state::WindowAggState;
 /// capabilities described by [`supports_bounded_execution`],
 /// [`uses_window_frame`], and [`include_rank`],
 ///
-/// # Stateless `PartitionEvaluator`s
-///
-/// In this case, `PartitionEvaluator` holds no state, and either
-/// [`evaluate_all`] or [`evaluate_all_with_rank`] is called with
-/// values for the entire partition.
-///
-/// # Stateful `PartitionEvaluator`s
-///
-/// In this case, [`Self::evaluate`] is called to calculate the window
-/// function incrementally for each new batch.
-///
-/// For example, when computing `ROW_NUMBER` incrementally,
-/// [`Self::evaluate`] will be called multiple times with
-/// different batches. For all batches after the first, the output
-/// `row_number` must start from last `row_number` produced for the
-/// previous batch. The previous row number is saved and restored as
-/// the state.
-///
 /// When implementing a new `PartitionEvaluator`, implement
 /// corresponding evaluator according to table below.
 ///
+/// # Implementation Table
 ///
 /// 
|[`uses_window_frame`]|[`supports_bounded_execution`]|[`include_rank`]|function_to_implement|
 /// |---|---|----|----|
@@ -105,25 +88,6 @@ use crate::window_state::WindowAggState;
 /// [`include_rank`]: Self::include_rank
 /// [`supports_bounded_execution`]: Self::supports_bounded_execution
 pub trait PartitionEvaluator: Debug + Send {
-    /// Updates the internal state for window function
-    ///
-    /// Only used for stateful evaluation
-    ///
-    /// `state`: is useful to update internal state for window function.
-    /// `idx`: is the index of last row for which result is calculated.
-    /// `range_columns`: is the result of order by column values. It is used 
to calculate rank boundaries
-    /// `sort_partition_points`: is the boundaries of each rank in the 
range_column. It is used to update rank.
-    fn update_state(
-        &mut self,
-        _state: &WindowAggState,
-        _idx: usize,
-        _range_columns: &[ArrayRef],
-        _sort_partition_points: &[Range<usize>],
-    ) -> Result<()> {
-        // If we do not use state, update_state does nothing
-        Ok(())
-    }
-
     /// When the window frame has a fixed beginning (e.g UNBOUNDED
     /// PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and
     /// NTH_VALUE do not need the (unbounded) input once they have
@@ -220,7 +184,10 @@ pub trait PartitionEvaluator: Debug + Send {
     /// trait.
     ///
     /// Returns a [`ScalarValue`] that is the value of the window
-    /// function within `range` for the entire partition
+    /// function within `range` for the entire partition. Argument
+    /// `values` contains the evaluation result of function arguments
+    /// and evaluation results of ORDER BY expressions. If function has a
+    /// single argument, `values[1..]` will contain ORDER BY expression 
results.
     fn evaluate(
         &mut self,
         _values: &[ArrayRef],
diff --git a/datafusion/physical-expr/src/window/built_in.rs 
b/datafusion/physical-expr/src/window/built_in.rs
index a528676c26..e81ffe59b8 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -23,7 +23,7 @@ use std::sync::Arc;
 
 use super::BuiltInWindowFunctionExpr;
 use super::WindowExpr;
-use crate::window::window_expr::WindowFn;
+use crate::window::window_expr::{get_orderby_values, WindowFn};
 use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState};
 use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr};
 use arrow::array::{new_empty_array, ArrayRef};
@@ -101,14 +101,19 @@ impl WindowExpr for BuiltInWindowExpr {
                 self.order_by.iter().map(|o| o.options).collect();
             let mut row_wise_results = vec![];
 
-            let (values, order_bys) = self.get_values_orderbys(batch)?;
+            let mut values = self.evaluate_args(batch)?;
+            let order_bys = get_orderby_values(self.order_by_columns(batch)?);
+            let n_args = values.len();
+            values.extend(order_bys);
+            let order_bys_ref = &values[n_args..];
+
             let mut window_frame_ctx =
                 WindowFrameContext::new(self.window_frame.clone(), 
sort_options);
             let mut last_range = Range { start: 0, end: 0 };
             // We iterate on each row to calculate window frame range and and 
window function result
             for idx in 0..num_rows {
                 let range = window_frame_ctx.calculate_range(
-                    &order_bys,
+                    order_bys_ref,
                     &last_range,
                     num_rows,
                     idx,
@@ -119,11 +124,11 @@ impl WindowExpr for BuiltInWindowExpr {
             }
             ScalarValue::iter_to_array(row_wise_results.into_iter())
         } else if evaluator.include_rank() {
-            let columns = self.sort_columns(batch)?;
+            let columns = self.order_by_columns(batch)?;
             let sort_partition_points = evaluate_partition_ranges(num_rows, 
&columns)?;
             evaluator.evaluate_all_with_rank(num_rows, &sort_partition_points)
         } else {
-            let (values, _) = self.get_values_orderbys(batch)?;
+            let values = self.evaluate_args(batch)?;
             evaluator.evaluate_all(&values, num_rows)
         }
     }
@@ -157,18 +162,20 @@ impl WindowExpr for BuiltInWindowExpr {
             };
             let state = &mut window_state.state;
 
-            let (values, order_bys) =
-                self.get_values_orderbys(&partition_batch_state.record_batch)?;
+            let batch_ref = &partition_batch_state.record_batch;
+            let mut values = self.evaluate_args(batch_ref)?;
+            let order_bys = if evaluator.uses_window_frame() || 
evaluator.include_rank() {
+                get_orderby_values(self.order_by_columns(batch_ref)?)
+            } else {
+                vec![]
+            };
+            let n_args = values.len();
+            values.extend(order_bys);
+            let order_bys_ref = &values[n_args..];
 
             // We iterate on each row to perform a running calculation.
             let record_batch = &partition_batch_state.record_batch;
             let num_rows = record_batch.num_rows();
-            let sort_partition_points = if evaluator.include_rank() {
-                let columns = self.sort_columns(record_batch)?;
-                evaluate_partition_ranges(num_rows, &columns)?
-            } else {
-                vec![]
-            };
             let mut row_wise_results: Vec<ScalarValue> = vec![];
             for idx in state.last_calculated_index..num_rows {
                 let frame_range = if evaluator.uses_window_frame() {
@@ -181,7 +188,7 @@ impl WindowExpr for BuiltInWindowExpr {
                             )
                         })
                         .calculate_range(
-                            &order_bys,
+                            order_bys_ref,
                             // Start search from the last range
                             &state.window_frame_range,
                             num_rows,
@@ -197,7 +204,6 @@ impl WindowExpr for BuiltInWindowExpr {
                 }
                 // Update last range
                 state.window_frame_range = frame_range;
-                evaluator.update_state(state, idx, &order_bys, 
&sort_partition_points)?;
                 row_wise_results
                     .push(evaluator.evaluate(&values, 
&state.window_frame_range)?);
             }
diff --git a/datafusion/physical-expr/src/window/lead_lag.rs 
b/datafusion/physical-expr/src/window/lead_lag.rs
index 637297b4cf..862648993a 100644
--- a/datafusion/physical-expr/src/window/lead_lag.rs
+++ b/datafusion/physical-expr/src/window/lead_lag.rs
@@ -18,7 +18,6 @@
 //! Defines physical expression for `lead` and `lag` that can evaluated
 //! at runtime during query execution
 
-use crate::window::window_expr::LeadLagState;
 use crate::window::BuiltInWindowFunctionExpr;
 use crate::PhysicalExpr;
 use arrow::array::ArrayRef;
@@ -26,7 +25,6 @@ use arrow::compute::cast;
 use arrow::datatypes::{DataType, Field};
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::window_state::WindowAggState;
 use datafusion_expr::PartitionEvaluator;
 use std::any::Any;
 use std::cmp::min;
@@ -105,7 +103,6 @@ impl BuiltInWindowFunctionExpr for WindowShift {
 
     fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
         Ok(Box::new(WindowShiftEvaluator {
-            state: LeadLagState { idx: 0 },
             shift_offset: self.shift_offset,
             default_value: self.default_value.clone(),
         }))
@@ -124,7 +121,6 @@ impl BuiltInWindowFunctionExpr for WindowShift {
 
 #[derive(Debug)]
 pub(crate) struct WindowShiftEvaluator {
-    state: LeadLagState,
     shift_offset: i64,
     default_value: Option<ScalarValue>,
 }
@@ -179,17 +175,6 @@ fn shift_with_default_value(
 }
 
 impl PartitionEvaluator for WindowShiftEvaluator {
-    fn update_state(
-        &mut self,
-        _state: &WindowAggState,
-        idx: usize,
-        _range_columns: &[ArrayRef],
-        _sort_partition_points: &[Range<usize>],
-    ) -> Result<()> {
-        self.state.idx = idx;
-        Ok(())
-    }
-
     fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
         if self.shift_offset > 0 {
             let offset = self.shift_offset as usize;
@@ -206,11 +191,18 @@ impl PartitionEvaluator for WindowShiftEvaluator {
     fn evaluate(
         &mut self,
         values: &[ArrayRef],
-        _range: &Range<usize>,
+        range: &Range<usize>,
     ) -> Result<ScalarValue> {
         let array = &values[0];
         let dtype = array.data_type();
-        let idx = self.state.idx as i64 - self.shift_offset;
+        // LAG mode
+        let idx = if self.shift_offset > 0 {
+            range.end as i64 - self.shift_offset - 1
+        } else {
+            // LEAD mode
+            range.start as i64 - self.shift_offset
+        };
+
         if idx < 0 || idx as usize >= array.len() {
             get_default_value(self.default_value.as_ref(), dtype)
         } else {
diff --git a/datafusion/physical-expr/src/window/nth_value.rs 
b/datafusion/physical-expr/src/window/nth_value.rs
index 2d592bbb6f..0da04274e2 100644
--- a/datafusion/physical-expr/src/window/nth_value.rs
+++ b/datafusion/physical-expr/src/window/nth_value.rs
@@ -145,18 +145,6 @@ pub(crate) struct NthValueEvaluator {
 }
 
 impl PartitionEvaluator for NthValueEvaluator {
-    fn update_state(
-        &mut self,
-        state: &WindowAggState,
-        _idx: usize,
-        _range_columns: &[ArrayRef],
-        _sort_partition_points: &[Range<usize>],
-    ) -> Result<()> {
-        // If we do not use state, update_state does nothing
-        self.state.range.clone_from(&state.window_frame_range);
-        Ok(())
-    }
-
     /// When the window frame has a fixed beginning (e.g UNBOUNDED
     /// PRECEDING), for some functions such as FIRST_VALUE, LAST_VALUE and
     /// NTH_VALUE we can memoize result.  Once result is calculated it
diff --git a/datafusion/physical-expr/src/window/rank.rs 
b/datafusion/physical-expr/src/window/rank.rs
index 527eaab611..48ff872078 100644
--- a/datafusion/physical-expr/src/window/rank.rs
+++ b/datafusion/physical-expr/src/window/rank.rs
@@ -26,7 +26,6 @@ use arrow::array::{Float64Array, UInt64Array};
 use arrow::datatypes::{DataType, Field};
 use datafusion_common::utils::get_row_at_idx;
 use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::window_state::WindowAggState;
 use datafusion_expr::PartitionEvaluator;
 use std::any::Any;
 use std::iter;
@@ -116,39 +115,26 @@ pub(crate) struct RankEvaluator {
 }
 
 impl PartitionEvaluator for RankEvaluator {
-    fn update_state(
+    /// Evaluates the window function inside the given range.
+    fn evaluate(
         &mut self,
-        state: &WindowAggState,
-        idx: usize,
-        range_columns: &[ArrayRef],
-        sort_partition_points: &[Range<usize>],
-    ) -> Result<()> {
-        // find range inside `sort_partition_points` containing `idx`
-        let chunk_idx = sort_partition_points
-            .iter()
-            .position(|elem| elem.start <= idx && idx < elem.end)
-            .ok_or_else(|| {
-                DataFusionError::Execution(
-                    "Expects sort_partition_points to contain idx".to_string(),
-                )
-            })?;
-        let chunk = &sort_partition_points[chunk_idx];
-        let last_rank_data = get_row_at_idx(range_columns, chunk.end - 1)?;
+        values: &[ArrayRef],
+        range: &Range<usize>,
+    ) -> Result<ScalarValue> {
+        let row_idx = range.start;
+        // There is no argument, values are order by column values (where rank 
is calculated)
+        let range_columns = values;
+        let last_rank_data = get_row_at_idx(range_columns, row_idx)?;
         let empty = self.state.last_rank_data.is_empty();
         if empty || self.state.last_rank_data != last_rank_data {
             self.state.last_rank_data = last_rank_data;
-            self.state.last_rank_boundary = state.offset_pruned_rows + 
chunk.start;
-            self.state.n_rank = 1 + if empty { chunk_idx } else { 
self.state.n_rank };
+            self.state.last_rank_boundary += self.state.current_group_count;
+            self.state.current_group_count = 1;
+            self.state.n_rank += 1;
+        } else {
+            // data is still in the same rank
+            self.state.current_group_count += 1;
         }
-        Ok(())
-    }
-
-    /// evaluate window function result inside given range
-    fn evaluate(
-        &mut self,
-        _values: &[ArrayRef],
-        _range: &Range<usize>,
-    ) -> Result<ScalarValue> {
         match self.rank_type {
             RankType::Basic => Ok(ScalarValue::UInt64(Some(
                 self.state.last_rank_boundary as u64 + 1,
diff --git a/datafusion/physical-expr/src/window/window_expr.rs 
b/datafusion/physical-expr/src/window/window_expr.rs
index 9175d97525..3cc022f88f 100644
--- a/datafusion/physical-expr/src/window/window_expr.rs
+++ b/datafusion/physical-expr/src/window/window_expr.rs
@@ -117,25 +117,6 @@ pub trait WindowExpr: Send + Sync + Debug {
             .collect::<Result<Vec<SortColumn>>>()
     }
 
-    /// Get sort columns that can be used for peer evaluation, empty if absent
-    fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
-        let order_by_columns = self.order_by_columns(batch)?;
-        Ok(order_by_columns)
-    }
-
-    /// Get values columns (argument of Window Function)
-    /// and order by columns (columns of the ORDER BY expression) used in 
evaluators
-    fn get_values_orderbys(
-        &self,
-        record_batch: &RecordBatch,
-    ) -> Result<(Vec<ArrayRef>, Vec<ArrayRef>)> {
-        let values = self.evaluate_args(record_batch)?;
-        let order_by_columns = self.order_by_columns(record_batch)?;
-        let order_bys: Vec<ArrayRef> =
-            order_by_columns.iter().map(|s| s.values.clone()).collect();
-        Ok((values, order_bys))
-    }
-
     /// Get the window frame of this [WindowExpr].
     fn get_window_frame(&self) -> &Arc<WindowFrame>;
 
@@ -244,7 +225,8 @@ pub trait AggregateWindowExpr: WindowExpr {
         mut idx: usize,
         not_end: bool,
     ) -> Result<ArrayRef> {
-        let (values, order_bys) = self.get_values_orderbys(record_batch)?;
+        let values = self.evaluate_args(record_batch)?;
+        let order_bys = 
get_orderby_values(self.order_by_columns(record_batch)?);
         // We iterate on each row to perform a running calculation.
         let length = values[0].len();
         let mut row_wise_results: Vec<ScalarValue> = vec![];
@@ -276,6 +258,10 @@ pub trait AggregateWindowExpr: WindowExpr {
         }
     }
 }
+/// Get order by expression results inside `order_by_columns`.
+pub(crate) fn get_orderby_values(order_by_columns: Vec<SortColumn>) -> 
Vec<ArrayRef> {
+    order_by_columns.into_iter().map(|s| s.values).collect()
+}
 
 #[derive(Debug)]
 pub enum WindowFn {
@@ -290,6 +276,8 @@ pub struct RankState {
     pub last_rank_data: Vec<ScalarValue>,
     /// The index where last_rank_boundary is started
     pub last_rank_boundary: usize,
+    /// Keep the number of entries in current rank
+    pub current_group_count: usize,
     /// Rank number kept from the start
     pub n_rank: usize,
 }
@@ -323,11 +311,6 @@ pub struct NthValueState {
     pub kind: NthValueKind,
 }
 
-#[derive(Debug, Clone, Default)]
-pub struct LeadLagState {
-    pub idx: usize,
-}
-
 /// Key for IndexMap for each unique partition
 ///
 /// For instance, if window frame is `OVER(PARTITION BY a,b)`,

Reply via email to