This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 79ecf949a Linear search support for Window Group queries (#5286)
79ecf949a is described below
commit 79ecf949a489485444825cc19af7ec0bea09452e
Author: Mustafa Akur <[email protected]>
AuthorDate: Fri Feb 17 17:38:54 2023 +0300
Linear search support for Window Group queries (#5286)
* add naive linear search
* Add last range to decrease search size
* minor changes
* add low, high arguments
* Go back to old API, improve comments, refactors
* Linear Groups implementation
* Resolve linter errors
* remove old unit tests
* simplifications
* Add unit tests
* Remove sort options from GROUPS calculations, various code simplifications
and comment clarifications
* New TODOs to fix
* Address reviews
* Fix error
* Prehandle range current row and unbounded following case
* Fix error
* Move a check from execution to planning, reduce code duplication
* Incorporate review suggestion (with cargo fmt fix)
---------
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
datafusion/core/src/physical_plan/planner.rs | 2 +-
datafusion/core/tests/sql/window.rs | 61 +-
datafusion/expr/src/window_frame.rs | 30 +
datafusion/physical-expr/src/window/built_in.rs | 32 +-
datafusion/physical-expr/src/window/window_expr.rs | 27 +-
.../physical-expr/src/window/window_frame_state.rs | 972 +++++++--------------
datafusion/proto/src/logical_plan/from_proto.rs | 24 +-
datafusion/sql/src/expr/function.rs | 14 +-
datafusion/sql/tests/integration_test.rs | 21 -
9 files changed, 420 insertions(+), 763 deletions(-)
diff --git a/datafusion/core/src/physical_plan/planner.rs
b/datafusion/core/src/physical_plan/planner.rs
index b6269a560..d0ee38ac9 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -1594,7 +1594,7 @@ pub fn create_window_expr_with_name(
})
.collect::<Result<Vec<_>>>()?;
if !is_window_valid(window_frame) {
- return Err(DataFusionError::Execution(format!(
+ return Err(DataFusionError::Plan(format!(
"Invalid window frame: start bound ({}) cannot be
larger than end bound ({})",
window_frame.start_bound, window_frame.end_bound
)));
diff --git a/datafusion/core/tests/sql/window.rs
b/datafusion/core/tests/sql/window.rs
index 99b9743f0..7ef4af23a 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -982,19 +982,20 @@ async fn window_frame_groups_multiple_order_columns() ->
Result<()> {
async fn window_frame_groups_without_order_by() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
- // execute the query
- let df = ctx
+ // Try executing an erroneous query (the ORDER BY clause is missing in the
+ // window frame):
+ let err = ctx
.sql(
"SELECT
SUM(c4) OVER(PARTITION BY c2 GROUPS BETWEEN 1 PRECEDING AND 1
FOLLOWING)
FROM aggregate_test_100
ORDER BY c9;",
)
- .await?;
- let err = df.collect().await.unwrap_err();
+ .await
+ .unwrap_err();
assert_contains!(
err.to_string(),
- "Execution error: GROUPS mode requires an ORDER BY clause".to_owned()
+ "Error during planning: GROUPS mode requires an ORDER BY
clause".to_owned()
);
Ok(())
}
@@ -1034,7 +1035,7 @@ async fn window_frame_creation() -> Result<()> {
let results = df.collect().await;
assert_eq!(
results.err().unwrap().to_string(),
- "Execution error: Invalid window frame: start bound (1 PRECEDING)
cannot be larger than end bound (2 PRECEDING)"
+ "Error during planning: Invalid window frame: start bound (1
PRECEDING) cannot be larger than end bound (2 PRECEDING)"
);
let df = ctx
@@ -1047,7 +1048,20 @@ async fn window_frame_creation() -> Result<()> {
let results = df.collect().await;
assert_eq!(
results.err().unwrap().to_string(),
- "Execution error: Invalid window frame: start bound (2 FOLLOWING)
cannot be larger than end bound (1 FOLLOWING)"
+ "Error during planning: Invalid window frame: start bound (2
FOLLOWING) cannot be larger than end bound (1 FOLLOWING)"
+ );
+
+ let err = ctx
+ .sql(
+ "SELECT
+ COUNT(c1) OVER(GROUPS BETWEEN CURRENT ROW AND UNBOUNDED
FOLLOWING)
+ FROM aggregate_test_100;",
+ )
+ .await
+ .unwrap_err();
+ assert_contains!(
+ err.to_string(),
+ "Error during planning: GROUPS mode requires an ORDER BY clause"
);
Ok(())
@@ -1123,6 +1137,39 @@ async fn test_window_row_number_aggregate() ->
Result<()> {
Ok(())
}
+#[tokio::test]
+async fn test_window_range_equivalent_frames() -> Result<()> {
+ let config = SessionConfig::new();
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ COUNT(*) OVER(ORDER BY c9, c1 RANGE BETWEEN CURRENT ROW AND CURRENT
ROW) AS cnt1,
+ COUNT(*) OVER(ORDER BY c9, c1 RANGE UNBOUNDED PRECEDING) AS cnt2,
+ COUNT(*) OVER(ORDER BY c9, c1 RANGE CURRENT ROW) AS cnt3,
+ COUNT(*) OVER(RANGE BETWEEN CURRENT ROW AND CURRENT ROW) AS cnt4,
+ COUNT(*) OVER(RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS
cnt5,
+ COUNT(*) OVER(RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS
cnt6
+ FROM aggregate_test_100
+ ORDER BY c9
+ LIMIT 5";
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----------+------+------+------+------+------+------+",
+ "| c9 | cnt1 | cnt2 | cnt3 | cnt4 | cnt5 | cnt6 |",
+ "+-----------+------+------+------+------+------+------+",
+ "| 28774375 | 1 | 1 | 1 | 100 | 100 | 100 |",
+ "| 63044568 | 1 | 2 | 1 | 100 | 100 | 100 |",
+ "| 141047417 | 1 | 3 | 1 | 100 | 100 | 100 |",
+ "| 141680161 | 1 | 4 | 1 | 100 | 100 | 100 |",
+ "| 145294611 | 1 | 5 | 1 | 100 | 100 | 100 |",
+ "+-----------+------+------+------+------+------+------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
#[tokio::test]
async fn test_window_cume_dist() -> Result<()> {
let config = SessionConfig::new();
diff --git a/datafusion/expr/src/window_frame.rs
b/datafusion/expr/src/window_frame.rs
index bf74d02b7..c25d2491e 100644
--- a/datafusion/expr/src/window_frame.rs
+++ b/datafusion/expr/src/window_frame.rs
@@ -144,6 +144,36 @@ impl WindowFrame {
}
}
+/// Construct equivalent explicit window frames for implicit corner cases.
+/// With this processing, we may assume in downstream code that RANGE/GROUPS
+/// frames contain an appropriate ORDER BY clause.
+pub fn regularize(mut frame: WindowFrame, order_bys: usize) ->
Result<WindowFrame> {
+ if frame.units == WindowFrameUnits::Range && order_bys != 1 {
+ // Normally, RANGE frames require an ORDER BY clause with exactly one
+ // column. However, an ORDER BY clause may be absent in two edge cases.
+ if (frame.start_bound.is_unbounded()
+ || frame.start_bound == WindowFrameBound::CurrentRow)
+ && (frame.end_bound == WindowFrameBound::CurrentRow
+ || frame.end_bound.is_unbounded())
+ {
+ if order_bys == 0 {
+ frame.units = WindowFrameUnits::Rows;
+ frame.start_bound =
+ WindowFrameBound::Preceding(ScalarValue::UInt64(None));
+ frame.end_bound =
WindowFrameBound::Following(ScalarValue::UInt64(None));
+ }
+ } else {
+ return Err(DataFusionError::Plan(format!(
+ "With window frame of type RANGE, the ORDER BY expression must
be of length 1, got {}", order_bys)));
+ }
+ } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 {
+ return Err(DataFusionError::Plan(
+ "GROUPS mode requires an ORDER BY clause".to_string(),
+ ));
+ };
+ Ok(frame)
+}
+
/// There are five ways to describe starting and ending frame boundaries:
///
/// 1. UNBOUNDED PRECEDING
diff --git a/datafusion/physical-expr/src/window/built_in.rs
b/datafusion/physical-expr/src/window/built_in.rs
index b53164f66..70ddb2c76 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -104,17 +104,15 @@ impl WindowExpr for BuiltInWindowExpr {
let mut row_wise_results = vec![];
let (values, order_bys) = self.get_values_orderbys(batch)?;
- let mut window_frame_ctx =
WindowFrameContext::new(&self.window_frame);
- let range = Range { start: 0, end: 0 };
+ let mut window_frame_ctx = WindowFrameContext::new(
+ &self.window_frame,
+ sort_options,
+ Range { start: 0, end: 0 },
+ );
// We iterate on each row to calculate window frame range and and
window function result
for idx in 0..num_rows {
- let range = window_frame_ctx.calculate_range(
- &order_bys,
- &sort_options,
- num_rows,
- idx,
- &range,
- )?;
+ let range =
+ window_frame_ctx.calculate_range(&order_bys, num_rows,
idx)?;
let value = evaluator.evaluate_inside_range(&values, &range)?;
row_wise_results.push(value);
}
@@ -168,7 +166,13 @@ impl WindowExpr for BuiltInWindowExpr {
// We iterate on each row to perform a running calculation.
let record_batch = &partition_batch_state.record_batch;
let num_rows = record_batch.num_rows();
- let mut window_frame_ctx =
WindowFrameContext::new(&self.window_frame);
+ let last_range = state.window_frame_range.clone();
+ let mut window_frame_ctx = WindowFrameContext::new(
+ &self.window_frame,
+ sort_options.clone(),
+ // Start search from the last range
+ last_range,
+ );
let sort_partition_points = if evaluator.include_rank() {
let columns = self.sort_columns(record_batch)?;
self.evaluate_partition_points(num_rows, &columns)?
@@ -179,13 +183,7 @@ impl WindowExpr for BuiltInWindowExpr {
let mut last_range = state.window_frame_range.clone();
for idx in state.last_calculated_index..num_rows {
state.window_frame_range = if self.expr.uses_window_frame() {
- window_frame_ctx.calculate_range(
- &order_bys,
- &sort_options,
- num_rows,
- idx,
- &state.window_frame_range,
- )
+ window_frame_ctx.calculate_range(&order_bys, num_rows, idx)
} else {
evaluator.get_range(state, num_rows)
}?;
diff --git a/datafusion/physical-expr/src/window/window_expr.rs
b/datafusion/physical-expr/src/window/window_expr.rs
index 065d26fef..96e22976b 100644
--- a/datafusion/physical-expr/src/window/window_expr.rs
+++ b/datafusion/physical-expr/src/window/window_expr.rs
@@ -162,18 +162,10 @@ pub trait AggregateWindowExpr: WindowExpr {
/// Evaluates the window function against the batch.
fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- let mut window_frame_ctx =
WindowFrameContext::new(self.get_window_frame());
let mut accumulator = self.get_accumulator()?;
let mut last_range = Range { start: 0, end: 0 };
let mut idx = 0;
- self.get_result_column(
- &mut accumulator,
- batch,
- &mut window_frame_ctx,
- &mut last_range,
- &mut idx,
- false,
- )
+ self.get_result_column(&mut accumulator, batch, &mut last_range, &mut
idx, false)
}
/// Statefully evaluates the window function against the batch. Maintains
@@ -207,11 +199,9 @@ pub trait AggregateWindowExpr: WindowExpr {
let mut state = &mut window_state.state;
let record_batch = &partition_batch_state.record_batch;
- let mut window_frame_ctx =
WindowFrameContext::new(self.get_window_frame());
let out_col = self.get_result_column(
accumulator,
record_batch,
- &mut window_frame_ctx,
&mut state.window_frame_range,
&mut state.last_calculated_index,
!partition_batch_state.is_end,
@@ -230,7 +220,6 @@ pub trait AggregateWindowExpr: WindowExpr {
&self,
accumulator: &mut Box<dyn Accumulator>,
record_batch: &RecordBatch,
- window_frame_ctx: &mut WindowFrameContext,
last_range: &mut Range<usize>,
idx: &mut usize,
not_end: bool,
@@ -240,15 +229,15 @@ pub trait AggregateWindowExpr: WindowExpr {
let length = values[0].len();
let sort_options: Vec<SortOptions> =
self.order_by().iter().map(|o| o.options).collect();
+ let mut window_frame_ctx = WindowFrameContext::new(
+ self.get_window_frame(),
+ sort_options,
+ // Start search from the last range
+ last_range.clone(),
+ );
let mut row_wise_results: Vec<ScalarValue> = vec![];
while *idx < length {
- let cur_range = window_frame_ctx.calculate_range(
- &order_bys,
- &sort_options,
- length,
- *idx,
- last_range,
- )?;
+ let cur_range = window_frame_ctx.calculate_range(&order_bys,
length, *idx)?;
// Exit if the range extends all the way:
if cur_range.end == length && not_end {
break;
diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs
b/datafusion/physical-expr/src/window/window_frame_state.rs
index 9cde3cbdf..64abacde4 100644
--- a/datafusion/physical-expr/src/window/window_frame_state.rs
+++ b/datafusion/physical-expr/src/window/window_frame_state.rs
@@ -15,14 +15,12 @@
// specific language governing permissions and limitations
// under the License.
-//! This module provides utilities for window frame index calculations
depending on the window frame mode:
-//! RANGE, ROWS, GROUPS.
+//! This module provides utilities for window frame index calculations
+//! depending on the window frame mode: RANGE, ROWS, GROUPS.
use arrow::array::ArrayRef;
use arrow::compute::kernels::sort::SortOptions;
-use datafusion_common::utils::{
- compare_rows, find_bisect_point, get_row_at_idx, search_in_slice,
-};
+use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use std::cmp::min;
@@ -34,14 +32,18 @@ use std::sync::Arc;
/// This object stores the window frame state for use in incremental
calculations.
#[derive(Debug)]
pub enum WindowFrameContext<'a> {
- // ROWS-frames are inherently stateless:
+ /// ROWS frames are inherently stateless.
Rows(&'a Arc<WindowFrame>),
- // RANGE-frames will soon have a stateful implementation that is more
efficient than a stateless one:
+ /// RANGE frames are stateful, they store indices specifying where the
+ /// previous search left off. This amortizes the overall cost to O(n)
+ /// where n denotes the row count.
Range {
window_frame: &'a Arc<WindowFrame>,
state: WindowFrameStateRange,
},
- // GROUPS-frames have a stateful implementation that is more efficient
than a stateless one:
+ /// GROUPS frames are stateful, they store group boundaries and indices
+ /// specifying where the previous search left off. This amortizes the
+ /// overall cost to O(n) where n denotes the row count.
Groups {
window_frame: &'a Arc<WindowFrame>,
state: WindowFrameStateGroups,
@@ -49,13 +51,17 @@ pub enum WindowFrameContext<'a> {
}
impl<'a> WindowFrameContext<'a> {
- /// Create a new default state for the given window frame.
- pub fn new(window_frame: &'a Arc<WindowFrame>) -> Self {
+ /// Create a new state object for the given window frame.
+ pub fn new(
+ window_frame: &'a Arc<WindowFrame>,
+ sort_options: Vec<SortOptions>,
+ last_range: Range<usize>,
+ ) -> Self {
match window_frame.units {
WindowFrameUnits::Rows => WindowFrameContext::Rows(window_frame),
WindowFrameUnits::Range => WindowFrameContext::Range {
window_frame,
- state: WindowFrameStateRange::default(),
+ state: WindowFrameStateRange::new(sort_options, last_range),
},
WindowFrameUnits::Groups => WindowFrameContext::Groups {
window_frame,
@@ -68,30 +74,23 @@ impl<'a> WindowFrameContext<'a> {
pub fn calculate_range(
&mut self,
range_columns: &[ArrayRef],
- sort_options: &[SortOptions],
length: usize,
idx: usize,
- last_range: &Range<usize>,
) -> Result<Range<usize>> {
match *self {
WindowFrameContext::Rows(window_frame) => {
Self::calculate_range_rows(window_frame, length, idx)
}
- // sort_options is used in RANGE mode calculations because the
ordering and the position of the nulls
- // have impact on the range calculations and comparison of the
rows.
+ // Sort options is used in RANGE mode calculations because the
+ // ordering or position of NULLs impact range calculations and
+ // comparison of rows.
WindowFrameContext::Range {
window_frame,
ref mut state,
- } => state.calculate_range(
- window_frame,
- range_columns,
- sort_options,
- length,
- idx,
- last_range,
- ),
- // sort_options is not used in GROUPS mode calculations as the
inequality of two rows is the indicator
- // of a group change, and the ordering and the position of the
nulls do not have impact on inequality.
+ } => state.calculate_range(window_frame, range_columns, length,
idx),
+ // Sort options is not used in GROUPS mode calculations as the
+ // inequality of two rows indicates a group change, and ordering
+ // or position of NULLs do not impact inequality.
WindowFrameContext::Groups {
window_frame,
ref mut state,
@@ -159,22 +158,37 @@ impl<'a> WindowFrameContext<'a> {
}
}
-/// This structure encapsulates all the state information we require as we
-/// scan ranges of data while processing window frames. Currently we calculate
-/// things from scratch every time, but we will make this incremental in the
future.
+/// This structure encapsulates all the state information we require as we scan
+/// ranges of data while processing RANGE frames. Attribute `last_range` stores
+/// the resulting indices from the previous search. Since the indices only
+/// advance forward, we start from `last_range` subsequently. Thus, the overall
+/// time complexity of linear search amortizes to O(n) where n denotes the
total
+/// row count.
+/// Attribute `sort_options` stores the column ordering specified by the ORDER
+/// BY clause. This information is used to calculate the range.
#[derive(Debug, Default)]
-pub struct WindowFrameStateRange {}
+pub struct WindowFrameStateRange {
+ last_range: Range<usize>,
+ sort_options: Vec<SortOptions>,
+}
impl WindowFrameStateRange {
+ /// Create a new object to store the search state.
+ fn new(sort_options: Vec<SortOptions>, last_range: Range<usize>) -> Self {
+ Self {
+ // Stores the search range we calculate for future use.
+ last_range,
+ sort_options,
+ }
+ }
+
/// This function calculates beginning/ending indices for the frame of the
current row.
fn calculate_range(
&mut self,
window_frame: &Arc<WindowFrame>,
range_columns: &[ArrayRef],
- sort_options: &[SortOptions],
length: usize,
idx: usize,
- last_range: &Range<usize>,
) -> Result<Range<usize>> {
let start = match window_frame.start_bound {
WindowFrameBound::Preceding(ref n) => {
@@ -184,35 +198,23 @@ impl WindowFrameStateRange {
} else {
self.calculate_index_of_row::<true, true>(
range_columns,
- sort_options,
idx,
Some(n),
- last_range,
- length,
- )?
- }
- }
- WindowFrameBound::CurrentRow => {
- if range_columns.is_empty() {
- 0
- } else {
- self.calculate_index_of_row::<true, true>(
- range_columns,
- sort_options,
- idx,
- None,
- last_range,
length,
)?
}
}
+ WindowFrameBound::CurrentRow =>
self.calculate_index_of_row::<true, true>(
+ range_columns,
+ idx,
+ None,
+ length,
+ )?,
WindowFrameBound::Following(ref n) => self
.calculate_index_of_row::<true, false>(
range_columns,
- sort_options,
idx,
Some(n),
- last_range,
length,
)?,
};
@@ -220,26 +222,16 @@ impl WindowFrameStateRange {
WindowFrameBound::Preceding(ref n) => self
.calculate_index_of_row::<false, true>(
range_columns,
- sort_options,
idx,
Some(n),
- last_range,
length,
)?,
- WindowFrameBound::CurrentRow => {
- if range_columns.is_empty() {
- length
- } else {
- self.calculate_index_of_row::<false, false>(
- range_columns,
- sort_options,
- idx,
- None,
- last_range,
- length,
- )?
- }
- }
+ WindowFrameBound::CurrentRow =>
self.calculate_index_of_row::<false, false>(
+ range_columns,
+ idx,
+ None,
+ length,
+ )?,
WindowFrameBound::Following(ref n) => {
if n.is_null() {
// UNBOUNDED FOLLOWING
@@ -247,15 +239,16 @@ impl WindowFrameStateRange {
} else {
self.calculate_index_of_row::<false, false>(
range_columns,
- sort_options,
idx,
Some(n),
- last_range,
length,
)?
}
}
};
+ // Store the resulting range so we can start from here subsequently:
+ self.last_range.start = start;
+ self.last_range.end = end;
Ok(Range { start, end })
}
@@ -265,17 +258,20 @@ impl WindowFrameStateRange {
fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
&mut self,
range_columns: &[ArrayRef],
- sort_options: &[SortOptions],
idx: usize,
delta: Option<&ScalarValue>,
- last_range: &Range<usize>,
length: usize,
) -> Result<usize> {
let current_row_values = get_row_at_idx(range_columns, idx)?;
let end_range = if let Some(delta) = delta {
- let is_descending: bool = sort_options
+ let is_descending: bool = self
+ .sort_options
.first()
- .ok_or_else(|| DataFusionError::Internal("Array is
empty".to_string()))?
+ .ok_or_else(|| {
+ DataFusionError::Internal(
+ "Sort options unexpectedly absent in a window
frame".to_string(),
+ )
+ })?
.descending;
current_row_values
@@ -285,7 +281,7 @@ impl WindowFrameStateRange {
return Ok(value.clone());
}
if SEARCH_SIDE == is_descending {
- // TODO: Handle positive overflows
+ // TODO: Handle positive overflows.
value.add(delta)
} else if value.is_unsigned() && value < delta {
// NOTE: This gets a polymorphic zero without having
long coercion code for ScalarValue.
@@ -293,7 +289,7 @@ impl WindowFrameStateRange {
// change the following statement to use that.
value.sub(value)
} else {
- // TODO: Handle negative overflows
+ // TODO: Handle negative overflows.
value.sub(delta)
}
})
@@ -302,12 +298,12 @@ impl WindowFrameStateRange {
current_row_values
};
let search_start = if SIDE {
- last_range.start
+ self.last_range.start
} else {
- last_range.end
+ self.last_range.end
};
let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
- let cmp = compare_rows(current, target, sort_options)?;
+ let cmp = compare_rows(current, target, &self.sort_options)?;
Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
};
search_in_slice(range_columns, &end_range, compare_fn, search_start,
length)
@@ -340,16 +336,15 @@ impl WindowFrameStateRange {
// scan groups of data while processing window frames.
#[derive(Debug, Default)]
pub struct WindowFrameStateGroups {
- current_group_idx: u64,
+ /// A tuple containing group values and the row index where the group ends.
+ /// Example: [[1, 1], [1, 1], [2, 1], [2, 1], ...] would correspond to
+ /// [([1, 1], 2), ([2, 1], 4), ...].
group_start_indices: VecDeque<(Vec<ScalarValue>, usize)>,
- previous_row_values: Option<Vec<ScalarValue>>,
- reached_end: bool,
- window_frame_end_idx: u64,
- window_frame_start_idx: u64,
+ /// The group index to which the row index belongs.
+ current_group_idx: usize,
}
impl WindowFrameStateGroups {
- /// This function calculates beginning/ending indices for the frame of the
current row.
fn calculate_range(
&mut self,
window_frame: &Arc<WindowFrame>,
@@ -357,662 +352,287 @@ impl WindowFrameStateGroups {
length: usize,
idx: usize,
) -> Result<Range<usize>> {
- if range_columns.is_empty() {
- return Err(DataFusionError::Execution(
- "GROUPS mode requires an ORDER BY clause".to_string(),
- ));
- }
let start = match window_frame.start_bound {
- // UNBOUNDED PRECEDING
- WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0,
- WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => self
- .calculate_index_of_group::<true, true>(range_columns, idx, n,
length)?,
- WindowFrameBound::CurrentRow =>
self.calculate_index_of_group::<true, true>(
+ WindowFrameBound::Preceding(ref n) => {
+ if n.is_null() {
+ // UNBOUNDED PRECEDING
+ 0
+ } else {
+ self.calculate_index_of_row::<true, true>(
+ range_columns,
+ idx,
+ Some(n),
+ length,
+ )?
+ }
+ }
+ WindowFrameBound::CurrentRow =>
self.calculate_index_of_row::<true, true>(
range_columns,
idx,
- 0,
+ None,
length,
)?,
- WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => self
- .calculate_index_of_group::<true, false>(range_columns, idx,
n, length)?,
- // UNBOUNDED FOLLOWING
- WindowFrameBound::Following(ScalarValue::UInt64(None)) => {
- return Err(DataFusionError::Internal(format!(
- "Frame start cannot be UNBOUNDED FOLLOWING
'{window_frame:?}'"
- )))
- }
- // ERRONEOUS FRAMES
- WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) =>
{
- return Err(DataFusionError::Internal(
- "Groups should be Uint".to_string(),
- ))
- }
- };
- let end = match window_frame.end_bound {
- // UNBOUNDED PRECEDING
- WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => {
- return Err(DataFusionError::Internal(format!(
- "Frame end cannot be UNBOUNDED PRECEDING
'{window_frame:?}'"
- )))
- }
- WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => self
- .calculate_index_of_group::<false, true>(range_columns, idx,
n, length)?,
- WindowFrameBound::CurrentRow => self
- .calculate_index_of_group::<false, false>(
+ WindowFrameBound::Following(ref n) => self
+ .calculate_index_of_row::<true, false>(
range_columns,
idx,
- 0,
+ Some(n),
length,
)?,
- WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => self
- .calculate_index_of_group::<false, false>(
+ };
+ let end = match window_frame.end_bound {
+ WindowFrameBound::Preceding(ref n) => self
+ .calculate_index_of_row::<false, true>(
range_columns,
idx,
- n,
+ Some(n),
length,
)?,
- // UNBOUNDED FOLLOWING
- WindowFrameBound::Following(ScalarValue::UInt64(None)) => length,
- // ERRONEOUS FRAMES
- WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) =>
{
- return Err(DataFusionError::Internal(
- "Groups should be Uint".to_string(),
- ))
+ WindowFrameBound::CurrentRow =>
self.calculate_index_of_row::<false, false>(
+ range_columns,
+ idx,
+ None,
+ length,
+ )?,
+ WindowFrameBound::Following(ref n) => {
+ if n.is_null() {
+ // UNBOUNDED FOLLOWING
+ length
+ } else {
+ self.calculate_index_of_row::<false, false>(
+ range_columns,
+ idx,
+ Some(n),
+ length,
+ )?
+ }
}
};
Ok(Range { start, end })
}
- /// This function does the heavy lifting when finding group boundaries. It
is meant to be
- /// called twice, in succession, to get window frame start and end indices
(with `BISECT_SIDE`
- /// supplied as false and true, respectively).
- fn calculate_index_of_group<const BISECT_SIDE: bool, const SEARCH_SIDE:
bool>(
+ /// This function does the heavy lifting when finding range boundaries. It
is meant to be
+ /// called twice, in succession, to get window frame start and end indices
(with `SIDE`
+ /// supplied as true and false, respectively). Generic argument
`SEARCH_SIDE` determines
+ /// the sign of `delta` (where true/false represents negative/positive
respectively).
+ fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
&mut self,
range_columns: &[ArrayRef],
idx: usize,
- delta: u64,
+ delta: Option<&ScalarValue>,
length: usize,
) -> Result<usize> {
- let current_row_values = range_columns
- .iter()
- .map(|col| ScalarValue::try_from_array(col, idx))
- .collect::<Result<Vec<ScalarValue>>>()?;
-
- if BISECT_SIDE {
- // When we call this function to get the window frame start index,
it tries to initialize
- // the internal grouping state if this is not already done before.
This initialization takes
- // place only when the window frame start index is greater than or
equal to zero. In this
- // case, the current row is stored in group_start_indices, with
row values as the group
- // identifier and row index as the start index of the group.
- if !self.initialized() {
- self.initialize::<SEARCH_SIDE>(delta, range_columns)?;
- }
- } else if !self.reached_end {
- // When we call this function to get the window frame end index,
it extends the window
- // frame one by one until the current row's window frame end index
is reached by finding
- // the next group.
-
self.extend_window_frame_if_necessary::<SEARCH_SIDE>(range_columns, delta)?;
- }
- // We keep track of previous row values, so that a group change can be
identified.
- // If there is a group change, the window frame is advanced and
shifted by one group.
- let group_change = match &self.previous_row_values {
- None => false,
- Some(values) => ¤t_row_values != values,
- };
- if self.previous_row_values.is_none() || group_change {
- self.previous_row_values = Some(current_row_values);
- }
- if group_change {
- self.current_group_idx += 1;
- self.advance_one_group::<SEARCH_SIDE>(range_columns)?;
- self.shift_one_group::<SEARCH_SIDE>(delta);
- }
- Ok(if self.group_start_indices.is_empty() {
- if self.reached_end {
- length
+ let delta = if let Some(delta) = delta {
+ if let ScalarValue::UInt64(Some(value)) = delta {
+ *value as usize
} else {
- 0
- }
- } else if BISECT_SIDE {
- match self.group_start_indices.get(0) {
- Some(&(_, idx)) => idx,
- None => 0,
+ return Err(DataFusionError::Internal(
+ "Unexpectedly got a non-UInt64 value in a GROUPS mode
window frame"
+ .to_string(),
+ ));
}
- } else {
- match (self.reached_end, self.group_start_indices.back()) {
- (false, Some(&(_, idx))) => idx,
- _ => length,
- }
- })
- }
-
- fn extend_window_frame_if_necessary<const SEARCH_SIDE: bool>(
- &mut self,
- range_columns: &[ArrayRef],
- delta: u64,
- ) -> Result<()> {
- let current_window_frame_end_idx = if !SEARCH_SIDE {
- self.current_group_idx + delta + 1
- } else if self.current_group_idx >= delta {
- self.current_group_idx - delta + 1
} else {
0
};
- if current_window_frame_end_idx == 0 {
- // the end index of the window frame is still before the first
index
- return Ok(());
+ let mut group_start = 0;
+ let last_group = self.group_start_indices.back();
+ if let Some((_, group_end)) = last_group {
+ // Start searching from the last group boundary:
+ group_start = *group_end;
}
- if self.group_start_indices.is_empty() {
- self.initialize_window_frame_start(range_columns)?;
- }
- while !self.reached_end
- && self.window_frame_end_idx <= current_window_frame_end_idx
- {
- self.advance_one_group::<SEARCH_SIDE>(range_columns)?;
+
+ // Advance groups until `idx` is inside a group:
+ while idx > group_start {
+ let group_row = get_row_at_idx(range_columns, group_start)?;
+ // Find end boundary of the group (search right boundary):
+ let group_end = search_in_slice(
+ range_columns,
+ &group_row,
+ check_equality,
+ group_start,
+ length,
+ )?;
+ self.group_start_indices.push_back((group_row, group_end));
+ group_start = group_end;
}
- Ok(())
- }
- fn initialize<const SEARCH_SIDE: bool>(
- &mut self,
- delta: u64,
- range_columns: &[ArrayRef],
- ) -> Result<()> {
- if !SEARCH_SIDE {
- self.window_frame_start_idx = self.current_group_idx + delta;
- self.initialize_window_frame_start(range_columns)
- } else if self.current_group_idx >= delta {
- self.window_frame_start_idx = self.current_group_idx - delta;
- self.initialize_window_frame_start(range_columns)
- } else {
- Ok(())
+ // Update the group index `idx` belongs to:
+ while self.current_group_idx < self.group_start_indices.len()
+ && idx >= self.group_start_indices[self.current_group_idx].1
+ {
+ self.current_group_idx += 1;
}
- }
- fn initialize_window_frame_start(
- &mut self,
- range_columns: &[ArrayRef],
- ) -> Result<()> {
- let mut group_values = range_columns
- .iter()
- .map(|col| ScalarValue::try_from_array(col, 0))
- .collect::<Result<Vec<ScalarValue>>>()?;
- let mut start_idx: usize = 0;
- for _ in 0..self.window_frame_start_idx {
- let next_group_and_start_index =
- WindowFrameStateGroups::find_next_group_and_start_index(
- range_columns,
- &group_values,
- start_idx,
- )?;
- if let Some(entry) = next_group_and_start_index {
- (group_values, start_idx) = entry;
+ // Find the group index of the frame boundary:
+ let group_idx = if SEARCH_SIDE {
+ if self.current_group_idx > delta {
+ self.current_group_idx - delta
} else {
- // not enough groups to generate a window frame
- self.window_frame_end_idx = self.window_frame_start_idx;
- self.reached_end = true;
- return Ok(());
+ 0
}
- }
- self.group_start_indices
- .push_back((group_values, start_idx));
- self.window_frame_end_idx = self.window_frame_start_idx + 1;
- Ok(())
- }
-
- fn initialized(&self) -> bool {
- self.reached_end || !self.group_start_indices.is_empty()
- }
-
- /// This function advances the window frame by one group.
- fn advance_one_group<const SEARCH_SIDE: bool>(
- &mut self,
- range_columns: &[ArrayRef],
- ) -> Result<()> {
- let last_group_values = self.group_start_indices.back();
- let last_group_values = if let Some(values) = last_group_values {
- values
} else {
- return Ok(());
+ self.current_group_idx + delta
};
- let next_group_and_start_index =
- WindowFrameStateGroups::find_next_group_and_start_index(
+
+ // Extend `group_start_indices` until it includes at least `group_idx`:
+ while self.group_start_indices.len() <= group_idx && group_start <
length {
+ let group_row = get_row_at_idx(range_columns, group_start)?;
+ // Find end boundary of the group (search right boundary):
+ let group_end = search_in_slice(
range_columns,
- &last_group_values.0,
- last_group_values.1,
+ &group_row,
+ check_equality,
+ group_start,
+ length,
)?;
- if let Some(entry) = next_group_and_start_index {
- self.group_start_indices.push_back(entry);
- self.window_frame_end_idx += 1;
- } else {
- // not enough groups to proceed
- self.reached_end = true;
+ self.group_start_indices.push_back((group_row, group_end));
+ group_start = group_end;
}
- Ok(())
- }
- /// This function drops the oldest group from the window frame.
- fn shift_one_group<const SEARCH_SIDE: bool>(&mut self, delta: u64) {
- let current_window_frame_start_idx = if !SEARCH_SIDE {
- self.current_group_idx + delta
- } else if self.current_group_idx >= delta {
- self.current_group_idx - delta
- } else {
- 0
- };
- if current_window_frame_start_idx > self.window_frame_start_idx {
- self.group_start_indices.pop_front();
- self.window_frame_start_idx += 1;
- }
- }
-
- /// This function finds the next group and its start index for a given
group and start index.
- /// It utilizes an exponentially growing step size to find the group
boundary.
- // TODO: For small group sizes, proceeding one-by-one to find the group
change can be more efficient.
- // Statistics about previous group sizes can be used to choose one-by-one
vs. exponentially growing,
- // or even to set the base step_size when exponentially growing. We can
also create a benchmark
- // implementation to get insights about the crossover point.
- fn find_next_group_and_start_index(
- range_columns: &[ArrayRef],
- current_row_values: &[ScalarValue],
- idx: usize,
- ) -> Result<Option<(Vec<ScalarValue>, usize)>> {
- let mut step_size: usize = 1;
- let data_size: usize = range_columns
- .get(0)
- .ok_or_else(|| {
- DataFusionError::Internal("Column array shouldn't be
empty".to_string())
- })?
- .len();
- let mut low = idx;
- let mut high = idx + step_size;
- while high < data_size {
- let val = range_columns
- .iter()
- .map(|arr| ScalarValue::try_from_array(arr, high))
- .collect::<Result<Vec<ScalarValue>>>()?;
- if val == current_row_values {
- low = high;
- step_size *= 2;
- high += step_size;
- } else {
- break;
+ // Calculate index of the group boundary:
+ Ok(match (SIDE, SEARCH_SIDE) {
+ // Window frame start:
+ (true, _) => {
+ let group_idx = min(group_idx, self.group_start_indices.len());
+ if group_idx > 0 {
+ // Normally, start at the boundary of the previous group.
+ self.group_start_indices[group_idx - 1].1
+ } else {
+ // If previous group is out of the table, start at zero.
+ 0
+ }
}
- }
- low = find_bisect_point(
- range_columns,
- current_row_values,
- |current, to_compare| Ok(current == to_compare),
- low,
- min(high, data_size),
- )?;
- if low == data_size {
- return Ok(None);
- }
- let val = range_columns
- .iter()
- .map(|arr| ScalarValue::try_from_array(arr, low))
- .collect::<Result<Vec<ScalarValue>>>()?;
- Ok(Some((val, low)))
+ // Window frame end, PRECEDING n
+ (false, true) => {
+ if self.current_group_idx >= delta {
+ let group_idx = self.current_group_idx - delta;
+ self.group_start_indices[group_idx].1
+ } else {
+ // Group is out of the table, therefore end at zero.
+ 0
+ }
+ }
+ // Window frame end, FOLLOWING n
+ (false, false) => {
+ let group_idx = min(
+ self.current_group_idx + delta,
+ self.group_start_indices.len() - 1,
+ );
+ self.group_start_indices[group_idx].1
+ }
+ })
}
}
+fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) ->
Result<bool> {
+ Ok(current == target)
+}
+
#[cfg(test)]
mod tests {
- use arrow::array::Float64Array;
- use datafusion_common::ScalarValue;
+ use crate::window::window_frame_state::WindowFrameStateGroups;
+ use arrow::array::{ArrayRef, Float64Array};
+ use arrow_schema::SortOptions;
+ use datafusion_common::from_slice::FromSlice;
+ use datafusion_common::{Result, ScalarValue};
+ use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
+ use std::ops::Range;
use std::sync::Arc;
- use crate::from_slice::FromSlice;
-
- use super::*;
-
- struct TestData {
- arrays: Vec<ArrayRef>,
- group_indices: [usize; 6],
- num_groups: usize,
- num_rows: usize,
- next_group_indices: [usize; 5],
- }
-
- fn test_data() -> TestData {
- let num_groups: usize = 5;
- let num_rows: usize = 6;
- let group_indices = [0, 1, 2, 2, 4, 5];
- let next_group_indices = [1, 2, 4, 4, 5];
-
- let arrays: Vec<ArrayRef> = vec![
- Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 8., 9., 10.])),
- Arc::new(Float64Array::from_slice([2.0, 3.0, 3.0, 3., 4.0, 5.0])),
- Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 8., 10., 11.0])),
- Arc::new(Float64Array::from_slice([15.0, 13.0, 8.0, 8., 5., 0.0])),
- ];
- TestData {
- arrays,
- group_indices,
- num_groups,
- num_rows,
- next_group_indices,
- }
- }
-
- #[test]
- fn test_find_next_group_and_start_index() {
- let test_data = test_data();
- for (current_idx, next_idx) in
test_data.next_group_indices.iter().enumerate() {
- let current_row_values = test_data
- .arrays
- .iter()
- .map(|col| ScalarValue::try_from_array(col, current_idx))
- .collect::<Result<Vec<ScalarValue>>>()
- .unwrap();
- let next_row_values = test_data
- .arrays
- .iter()
- .map(|col| ScalarValue::try_from_array(col, *next_idx))
- .collect::<Result<Vec<ScalarValue>>>()
- .unwrap();
- let res = WindowFrameStateGroups::find_next_group_and_start_index(
- &test_data.arrays,
- ¤t_row_values,
- current_idx,
- )
- .unwrap();
- assert_eq!(res, Some((next_row_values, *next_idx)));
- }
- let current_idx = test_data.num_rows - 1;
- let current_row_values = test_data
- .arrays
- .iter()
- .map(|col| ScalarValue::try_from_array(col, current_idx))
- .collect::<Result<Vec<ScalarValue>>>()
- .unwrap();
- let res = WindowFrameStateGroups::find_next_group_and_start_index(
- &test_data.arrays,
- ¤t_row_values,
- current_idx,
- )
- .unwrap();
- assert_eq!(res, None);
- }
-
- #[test]
- fn test_window_frame_groups_preceding_delta_greater_than_partition_size() {
- const START: bool = true;
- const END: bool = false;
- const PRECEDING: bool = true;
- const DELTA: u64 = 10;
-
- let test_data = test_data();
- let mut window_frame_groups = WindowFrameStateGroups::default();
- window_frame_groups
- .initialize::<PRECEDING>(DELTA, &test_data.arrays)
- .unwrap();
- assert_eq!(window_frame_groups.window_frame_start_idx, 0);
- assert_eq!(window_frame_groups.window_frame_end_idx, 0);
- assert!(!window_frame_groups.reached_end);
- assert_eq!(window_frame_groups.group_start_indices.len(), 0);
-
- window_frame_groups
- .extend_window_frame_if_necessary::<PRECEDING>(&test_data.arrays,
DELTA)
- .unwrap();
- assert_eq!(window_frame_groups.window_frame_start_idx, 0);
- assert_eq!(window_frame_groups.window_frame_end_idx, 0);
- assert!(!window_frame_groups.reached_end);
- assert_eq!(window_frame_groups.group_start_indices.len(), 0);
-
- for idx in 0..test_data.num_rows {
- let start = window_frame_groups
- .calculate_index_of_group::<START, PRECEDING>(
- &test_data.arrays,
- idx,
- DELTA,
- test_data.num_rows,
- )
- .unwrap();
- assert_eq!(start, 0);
- let end = window_frame_groups
- .calculate_index_of_group::<END, PRECEDING>(
- &test_data.arrays,
- idx,
- DELTA,
- test_data.num_rows,
- )
- .unwrap();
- assert_eq!(end, 0);
- }
- }
-
- #[test]
- fn test_window_frame_groups_following_delta_greater_than_partition_size() {
- const START: bool = true;
- const END: bool = false;
- const FOLLOWING: bool = false;
- const DELTA: u64 = 10;
-
- let test_data = test_data();
- let mut window_frame_groups = WindowFrameStateGroups::default();
- window_frame_groups
- .initialize::<FOLLOWING>(DELTA, &test_data.arrays)
- .unwrap();
- assert_eq!(window_frame_groups.window_frame_start_idx, DELTA);
- assert_eq!(window_frame_groups.window_frame_end_idx, DELTA);
- assert!(window_frame_groups.reached_end);
- assert_eq!(window_frame_groups.group_start_indices.len(), 0);
-
- window_frame_groups
- .extend_window_frame_if_necessary::<FOLLOWING>(&test_data.arrays,
DELTA)
- .unwrap();
- assert_eq!(window_frame_groups.window_frame_start_idx, DELTA);
- assert_eq!(window_frame_groups.window_frame_end_idx, DELTA);
- assert!(window_frame_groups.reached_end);
- assert_eq!(window_frame_groups.group_start_indices.len(), 0);
+ fn get_test_data() -> (Vec<ArrayRef>, Vec<SortOptions>) {
+ let range_columns: Vec<ArrayRef> =
vec![Arc::new(Float64Array::from_slice([
+ 5.0, 7.0, 8.0, 8.0, 9., 10., 10., 10., 11.,
+ ]))];
+ let sort_options = vec![SortOptions {
+ descending: false,
+ nulls_first: false,
+ }];
- for idx in 0..test_data.num_rows {
- let start = window_frame_groups
- .calculate_index_of_group::<START, FOLLOWING>(
- &test_data.arrays,
- idx,
- DELTA,
- test_data.num_rows,
- )
- .unwrap();
- assert_eq!(start, test_data.num_rows);
- let end = window_frame_groups
- .calculate_index_of_group::<END, FOLLOWING>(
- &test_data.arrays,
- idx,
- DELTA,
- test_data.num_rows,
- )
- .unwrap();
- assert_eq!(end, test_data.num_rows);
- }
+ (range_columns, sort_options)
}
- #[test]
- fn
test_window_frame_groups_preceding_and_following_delta_greater_than_partition_size(
- ) {
- const START: bool = true;
- const END: bool = false;
- const FOLLOWING: bool = false;
- const PRECEDING: bool = true;
- const DELTA: u64 = 10;
-
- let test_data = test_data();
+ fn assert_expected(
+ expected_results: Vec<(Range<usize>, usize)>,
+ window_frame: &Arc<WindowFrame>,
+ ) -> Result<()> {
let mut window_frame_groups = WindowFrameStateGroups::default();
- window_frame_groups
- .initialize::<PRECEDING>(DELTA, &test_data.arrays)
- .unwrap();
- assert_eq!(window_frame_groups.window_frame_start_idx, 0);
- assert_eq!(window_frame_groups.window_frame_end_idx, 0);
- assert!(!window_frame_groups.reached_end);
- assert_eq!(window_frame_groups.group_start_indices.len(), 0);
-
- window_frame_groups
- .extend_window_frame_if_necessary::<FOLLOWING>(&test_data.arrays,
DELTA)
- .unwrap();
- assert_eq!(window_frame_groups.window_frame_start_idx, 0);
- assert_eq!(
- window_frame_groups.window_frame_end_idx,
- test_data.num_groups as u64
- );
- assert!(window_frame_groups.reached_end);
- assert_eq!(
- window_frame_groups.group_start_indices.len(),
- test_data.num_groups
- );
-
- for idx in 0..test_data.num_rows {
- let start = window_frame_groups
- .calculate_index_of_group::<START, PRECEDING>(
- &test_data.arrays,
- idx,
- DELTA,
- test_data.num_rows,
- )
- .unwrap();
- assert_eq!(start, 0);
- let end = window_frame_groups
- .calculate_index_of_group::<END, FOLLOWING>(
- &test_data.arrays,
- idx,
- DELTA,
- test_data.num_rows,
- )
- .unwrap();
- assert_eq!(end, test_data.num_rows);
+ let (range_columns, _) = get_test_data();
+ let n_row = range_columns[0].len();
+ for (idx, (expected_range, expected_group_idx)) in
+ expected_results.into_iter().enumerate()
+ {
+ let range = window_frame_groups.calculate_range(
+ window_frame,
+ &range_columns,
+ n_row,
+ idx,
+ )?;
+ assert_eq!(range, expected_range);
+ assert_eq!(window_frame_groups.current_group_idx,
expected_group_idx);
}
+ Ok(())
}
#[test]
- fn test_window_frame_groups_preceding_and_following_1() {
- const START: bool = true;
- const END: bool = false;
- const FOLLOWING: bool = false;
- const PRECEDING: bool = true;
- const DELTA: u64 = 1;
-
- let test_data = test_data();
- let mut window_frame_groups = WindowFrameStateGroups::default();
- window_frame_groups
- .initialize::<PRECEDING>(DELTA, &test_data.arrays)
- .unwrap();
- assert_eq!(window_frame_groups.window_frame_start_idx, 0);
- assert_eq!(window_frame_groups.window_frame_end_idx, 0);
- assert!(!window_frame_groups.reached_end);
- assert_eq!(window_frame_groups.group_start_indices.len(), 0);
-
- window_frame_groups
- .extend_window_frame_if_necessary::<FOLLOWING>(&test_data.arrays,
DELTA)
- .unwrap();
- assert_eq!(window_frame_groups.window_frame_start_idx, 0);
- assert_eq!(window_frame_groups.window_frame_end_idx, 2 * DELTA + 1);
- assert!(!window_frame_groups.reached_end);
- assert_eq!(
- window_frame_groups.group_start_indices.len(),
- 2 * DELTA as usize + 1
- );
-
- for idx in 0..test_data.num_rows {
- let start_idx = if idx < DELTA as usize {
- 0
- } else {
- test_data.group_indices[idx] - DELTA as usize
- };
- let start = window_frame_groups
- .calculate_index_of_group::<START, PRECEDING>(
- &test_data.arrays,
- idx,
- DELTA,
- test_data.num_rows,
- )
- .unwrap();
- assert_eq!(start, test_data.group_indices[start_idx]);
- let mut end_idx = if idx >= test_data.num_groups {
- test_data.num_rows
- } else {
- test_data.next_group_indices[idx]
- };
- for _ in 0..DELTA {
- end_idx = if end_idx >= test_data.num_groups {
- test_data.num_rows
- } else {
- test_data.next_group_indices[end_idx]
- };
- }
- let end = window_frame_groups
- .calculate_index_of_group::<END, FOLLOWING>(
- &test_data.arrays,
- idx,
- DELTA,
- test_data.num_rows,
- )
- .unwrap();
- assert_eq!(end, end_idx);
- }
+ fn test_window_frame_group_boundaries() -> Result<()> {
+ let window_frame = Arc::new(WindowFrame {
+ units: WindowFrameUnits::Groups,
+ start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
+ end_bound:
WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
+ });
+ let expected_results = vec![
+ (Range { start: 0, end: 2 }, 0),
+ (Range { start: 0, end: 4 }, 1),
+ (Range { start: 1, end: 5 }, 2),
+ (Range { start: 1, end: 5 }, 2),
+ (Range { start: 2, end: 8 }, 3),
+ (Range { start: 4, end: 9 }, 4),
+ (Range { start: 4, end: 9 }, 4),
+ (Range { start: 4, end: 9 }, 4),
+ (Range { start: 5, end: 9 }, 5),
+ ];
+ assert_expected(expected_results, &window_frame)
}
#[test]
- fn test_window_frame_groups_preceding_1_and_unbounded_following() {
- const START: bool = true;
- const PRECEDING: bool = true;
- const DELTA: u64 = 1;
-
- let test_data = test_data();
- let mut window_frame_groups = WindowFrameStateGroups::default();
- window_frame_groups
- .initialize::<PRECEDING>(DELTA, &test_data.arrays)
- .unwrap();
- assert_eq!(window_frame_groups.window_frame_start_idx, 0);
- assert_eq!(window_frame_groups.window_frame_end_idx, 0);
- assert!(!window_frame_groups.reached_end);
- assert_eq!(window_frame_groups.group_start_indices.len(), 0);
-
- for idx in 0..test_data.num_rows {
- let start_idx = if idx < DELTA as usize {
- 0
- } else {
- test_data.group_indices[idx] - DELTA as usize
- };
- let start = window_frame_groups
- .calculate_index_of_group::<START, PRECEDING>(
- &test_data.arrays,
- idx,
- DELTA,
- test_data.num_rows,
- )
- .unwrap();
- assert_eq!(start, test_data.group_indices[start_idx]);
- }
+ fn test_window_frame_group_boundaries_both_following() -> Result<()> {
+ let window_frame = Arc::new(WindowFrame {
+ units: WindowFrameUnits::Groups,
+ start_bound:
WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
+ end_bound:
WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
+ });
+ let expected_results = vec![
+ (Range::<usize> { start: 1, end: 4 }, 0),
+ (Range::<usize> { start: 2, end: 5 }, 1),
+ (Range::<usize> { start: 4, end: 8 }, 2),
+ (Range::<usize> { start: 4, end: 8 }, 2),
+ (Range::<usize> { start: 5, end: 9 }, 3),
+ (Range::<usize> { start: 8, end: 9 }, 4),
+ (Range::<usize> { start: 8, end: 9 }, 4),
+ (Range::<usize> { start: 8, end: 9 }, 4),
+ (Range::<usize> { start: 9, end: 9 }, 5),
+ ];
+ assert_expected(expected_results, &window_frame)
}
#[test]
- fn test_window_frame_groups_current_and_unbounded_following() {
- const START: bool = true;
- const PRECEDING: bool = true;
- const DELTA: u64 = 0;
-
- let test_data = test_data();
- let mut window_frame_groups = WindowFrameStateGroups::default();
- window_frame_groups
- .initialize::<PRECEDING>(DELTA, &test_data.arrays)
- .unwrap();
- assert_eq!(window_frame_groups.window_frame_start_idx, 0);
- assert_eq!(window_frame_groups.window_frame_end_idx, 1);
- assert!(!window_frame_groups.reached_end);
- assert_eq!(window_frame_groups.group_start_indices.len(), 1);
-
- for idx in 0..test_data.num_rows {
- let start = window_frame_groups
- .calculate_index_of_group::<START, PRECEDING>(
- &test_data.arrays,
- idx,
- DELTA,
- test_data.num_rows,
- )
- .unwrap();
- assert_eq!(start, test_data.group_indices[idx]);
- }
+ fn test_window_frame_group_boundaries_both_preceding() -> Result<()> {
+ let window_frame = Arc::new(WindowFrame {
+ units: WindowFrameUnits::Groups,
+ start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
+ end_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
+ });
+ let expected_results = vec![
+ (Range::<usize> { start: 0, end: 0 }, 0),
+ (Range::<usize> { start: 0, end: 1 }, 1),
+ (Range::<usize> { start: 0, end: 2 }, 2),
+ (Range::<usize> { start: 0, end: 2 }, 2),
+ (Range::<usize> { start: 1, end: 4 }, 3),
+ (Range::<usize> { start: 2, end: 5 }, 4),
+ (Range::<usize> { start: 2, end: 5 }, 4),
+ (Range::<usize> { start: 2, end: 5 }, 4),
+ (Range::<usize> { start: 4, end: 8 }, 5),
+ ];
+ assert_expected(expected_results, &window_frame)
}
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index a74874586..498563b2a 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -43,9 +43,10 @@ use datafusion_expr::{
regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim,
sha224, sha256,
sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr,
substring, tan, to_hex, to_timestamp_micros, to_timestamp_millis,
- to_timestamp_seconds, translate, trim, trunc, upper, uuid,
AggregateFunction,
- Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case,
Cast, Expr,
- GetIndexedField, GroupingSet,
+ to_timestamp_seconds, translate, trim, trunc, upper, uuid,
+ window_frame::regularize,
+ AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction,
BuiltinScalarFunction,
+ Case, Cast, Expr, GetIndexedField, GroupingSet,
GroupingSet::GroupingSets,
JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame,
WindowFrameBound,
WindowFrameUnits,
@@ -907,16 +908,15 @@ pub fn parse_expr(
.window_frame
.as_ref()
.map::<Result<WindowFrame, _>, _>(|window_frame| {
- let window_frame: WindowFrame =
window_frame.clone().try_into()?;
- if WindowFrameUnits::Range == window_frame.units
- && order_by.len() != 1
- {
- Err(proto_error("With window frame of type RANGE, the
order by expression must be of length 1"))
- } else {
- Ok(window_frame)
- }
+ let window_frame = window_frame.clone().try_into()?;
+ regularize(window_frame, order_by.len())
})
-
.transpose()?.ok_or_else(||{DataFusionError::Execution("expects
somothing".to_string())})?;
+ .transpose()?
+ .ok_or_else(|| {
+ DataFusionError::Execution(
+ "missing window frame during
deserialization".to_string(),
+ )
+ })?;
match window_function {
window_expr_node::WindowFunction::AggrFunction(i) => {
diff --git a/datafusion/sql/src/expr/function.rs
b/datafusion/sql/src/expr/function.rs
index 1845d5947..c5f23213a 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -19,9 +19,10 @@ use crate::planner::{ContextProvider, PlannerContext,
SqlToRel};
use crate::utils::normalize_ident;
use datafusion_common::{DFSchema, DataFusionError, Result};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
+use datafusion_expr::window_frame::regularize;
use datafusion_expr::{
expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr,
WindowFrame,
- WindowFrameUnits, WindowFunction,
+ WindowFunction,
};
use sqlparser::ast::{
Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr,
@@ -65,15 +66,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.window_frame
.as_ref()
.map(|window_frame| {
- let window_frame: WindowFrame =
window_frame.clone().try_into()?;
- if WindowFrameUnits::Range == window_frame.units
- && order_by.len() != 1
- {
- Err(DataFusionError::Plan(format!(
- "With window frame of type RANGE, the order by
expression must be of length 1, got {}", order_by.len())))
- } else {
- Ok(window_frame)
- }
+ let window_frame = window_frame.clone().try_into()?;
+ regularize(window_frame, order_by.len())
})
.transpose()?;
let window_frame = if let Some(window_frame) = window_frame {
diff --git a/datafusion/sql/tests/integration_test.rs
b/datafusion/sql/tests/integration_test.rs
index c75f93d36..44c0559ef 100644
--- a/datafusion/sql/tests/integration_test.rs
+++ b/datafusion/sql/tests/integration_test.rs
@@ -2101,27 +2101,6 @@ fn over_order_by_with_window_frame_single_end() {
quick_test(sql, expected);
}
-#[test]
-fn over_order_by_with_window_frame_range_order_by_check() {
- let sql = "SELECT order_id, MAX(qty) OVER (RANGE UNBOUNDED PRECEDING) from
orders";
- let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- "Plan(\"With window frame of type RANGE, the order by expression
must be of length 1, got 0\")",
- format!("{err:?}")
- );
-}
-
-#[test]
-fn over_order_by_with_window_frame_range_order_by_check_2() {
- let sql =
- "SELECT order_id, MAX(qty) OVER (ORDER BY order_id, qty RANGE
UNBOUNDED PRECEDING) from orders";
- let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- "Plan(\"With window frame of type RANGE, the order by expression
must be of length 1, got 2\")",
- format!("{err:?}")
- );
-}
-
#[test]
fn over_order_by_with_window_frame_single_end_groups() {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3
PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";