This is an automated email from the ASF dual-hosted git repository.
ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 259e12c6fe Determine causal window frames to produce early results.
(#8842)
259e12c6fe is described below
commit 259e12c6fe3d3db971b4a20aee9b68b4a49ad5c6
Author: Mustafa Akur <[email protected]>
AuthorDate: Mon Jan 15 21:23:07 2024 +0300
Determine causal window frames to produce early results. (#8842)
* add handling for primary key window frame
* Add check for window end
* Add uniqueness check
* Minor changes
* Update signature of WindowFrame::new
* Make is_causal window state
* Minor changes
* Address reviews
* Minor changes
* Add new test
* Minor changes
* Minor changes
* Remove string handling
* Review Part 2
* Improve comments
---------
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
datafusion-examples/examples/advanced_udwf.rs | 2 +-
datafusion-examples/examples/simple_udwf.rs | 2 +-
datafusion/common/src/scalar.rs | 1 -
datafusion/core/src/dataframe/mod.rs | 2 +-
.../core/src/physical_optimizer/test_utils.rs | 2 +-
datafusion/core/tests/dataframe/mod.rs | 10 +-
datafusion/core/tests/fuzz_cases/window_fuzz.rs | 95 +++++++++--
datafusion/expr/src/udwf.rs | 2 +-
datafusion/expr/src/utils.rs | 20 +--
datafusion/expr/src/window_frame.rs | 185 ++++++++++++++++-----
datafusion/expr/src/window_state.rs | 30 ++--
.../optimizer/src/analyzer/count_wildcard_rule.rs | 12 +-
datafusion/optimizer/src/push_down_projection.rs | 4 +-
datafusion/physical-expr/src/window/built_in.rs | 9 +-
datafusion/physical-expr/src/window/window_expr.rs | 5 +-
.../src/windows/bounded_window_agg_exec.rs | 30 ++--
datafusion/physical-plan/src/windows/mod.rs | 2 +-
datafusion/proto/src/logical_plan/from_proto.rs | 6 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 24 +--
.../proto/tests/cases/roundtrip_physical_plan.rs | 22 +--
datafusion/sql/src/expr/function.rs | 23 ++-
datafusion/sqllogictest/test_files/window.slt | 35 ++++
datafusion/substrait/src/logical_plan/consumer.rs | 8 +-
23 files changed, 382 insertions(+), 149 deletions(-)
diff --git a/datafusion-examples/examples/advanced_udwf.rs
b/datafusion-examples/examples/advanced_udwf.rs
index f46031434f..826abc28e1 100644
--- a/datafusion-examples/examples/advanced_udwf.rs
+++ b/datafusion-examples/examples/advanced_udwf.rs
@@ -220,7 +220,7 @@ async fn main() -> Result<()> {
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
- WindowFrame::new(false),
+ WindowFrame::new(None),
);
let df = ctx.table("cars").await?.window(vec![window_expr])?;
diff --git a/datafusion-examples/examples/simple_udwf.rs
b/datafusion-examples/examples/simple_udwf.rs
index 0d04c093e1..a6149d661e 100644
--- a/datafusion-examples/examples/simple_udwf.rs
+++ b/datafusion-examples/examples/simple_udwf.rs
@@ -123,7 +123,7 @@ async fn main() -> Result<()> {
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
- WindowFrame::new(false),
+ WindowFrame::new(None),
);
let df = ctx.table("cars").await?.window(vec![window_expr])?;
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 8820ca9942..cc5b70796e 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -808,7 +808,6 @@ impl ScalarValue {
/// Create a zero value in the given type.
pub fn new_zero(datatype: &DataType) -> Result<ScalarValue> {
- assert!(datatype.is_primitive());
Ok(match datatype {
DataType::Boolean => ScalarValue::Boolean(Some(false)),
DataType::Int8 => ScalarValue::Int8(Some(0)),
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index 9d00dc72f1..285e761e2e 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -1548,7 +1548,7 @@ mod tests {
vec![col("aggregate_test_100.c1")],
vec![col("aggregate_test_100.c2")],
vec![],
- WindowFrame::new(false),
+ WindowFrame::new(None),
));
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();
diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs
b/datafusion/core/src/physical_optimizer/test_utils.rs
index debafefe39..68ac9598a3 100644
--- a/datafusion/core/src/physical_optimizer/test_utils.rs
+++ b/datafusion/core/src/physical_optimizer/test_utils.rs
@@ -239,7 +239,7 @@ pub fn bounded_window_exec(
&[col(col_name, &schema).unwrap()],
&[],
&sort_exprs,
- Arc::new(WindowFrame::new(true)),
+ Arc::new(WindowFrame::new(Some(false))),
schema.as_ref(),
)
.unwrap()],
diff --git a/datafusion/core/tests/dataframe/mod.rs
b/datafusion/core/tests/dataframe/mod.rs
index cca23ac684..588b4647e5 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -174,11 +174,11 @@ async fn test_count_wildcard_on_window() -> Result<()> {
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
- WindowFrame {
- units: WindowFrameUnits::Range,
- start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
- end_bound:
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
- },
+ WindowFrame::new_bounds(
+ WindowFrameUnits::Range,
+ WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
+ WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
+ ),
))])?
.explain(false, false)?
.collect()
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index 3037b4857a..6e5c5f8eb9 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -139,6 +139,89 @@ async fn window_bounded_window_random_comparison() ->
Result<()> {
Ok(())
}
+// This tests whether we can generate bounded window results for each input
+// batch immediately for causal window frames.
+#[tokio::test(flavor = "multi_thread", worker_threads = 16)]
+async fn bounded_window_causal_non_causal() -> Result<()> {
+ let session_config = SessionConfig::new();
+ let ctx = SessionContext::new_with_config(session_config);
+ let mut batches = make_staggered_batches::<true>(1000, 10, 23_u64);
+ // Remove empty batches:
+ batches.retain(|batch| batch.num_rows() > 0);
+ let schema = batches[0].schema();
+ let memory_exec = Arc::new(MemoryExec::try_new(
+ &[batches.clone()],
+ schema.clone(),
+ None,
+ )?);
+ let window_fn =
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count);
+ let fn_name = "COUNT".to_string();
+ let args = vec![col("x", &schema)?];
+ let partitionby_exprs = vec![];
+ let orderby_exprs = vec![];
+ // Window frame starts with "UNBOUNDED PRECEDING":
+ let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None));
+
+ // Simulate cases of the following form:
+ // COUNT(x) OVER (
+ // ROWS BETWEEN UNBOUNDED PRECEDING AND <end_bound> PRECEDING/FOLLOWING
+ // )
+ for is_preceding in [false, true] {
+ for end_bound in [0, 1, 2, 3] {
+ let end_bound = if is_preceding {
+
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(end_bound)))
+ } else {
+
WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound)))
+ };
+ let window_frame = WindowFrame::new_bounds(
+ WindowFrameUnits::Rows,
+ start_bound.clone(),
+ end_bound,
+ );
+ let causal = window_frame.is_causal();
+
+ let window_expr = create_window_expr(
+ &window_fn,
+ fn_name.clone(),
+ &args,
+ &partitionby_exprs,
+ &orderby_exprs,
+ Arc::new(window_frame),
+ schema.as_ref(),
+ )?;
+ let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
+ vec![window_expr],
+ memory_exec.clone(),
+ vec![],
+ InputOrderMode::Linear,
+ )?);
+ let task_ctx = ctx.task_ctx();
+ let mut collected_results = collect(running_window_exec,
task_ctx).await?;
+ collected_results.retain(|batch| batch.num_rows() > 0);
+ let input_batch_sizes = batches
+ .iter()
+ .map(|batch| batch.num_rows())
+ .collect::<Vec<_>>();
+ let result_batch_sizes = collected_results
+ .iter()
+ .map(|batch| batch.num_rows())
+ .collect::<Vec<_>>();
+ if causal {
+ // For causal window frames, we can generate results
immediately
+ // for each input batch. Hence, batch sizes should match.
+ assert_eq!(input_batch_sizes, result_batch_sizes);
+ } else {
+ // For non-causal window frames, we cannot generate results
+ // immediately for each input batch. Hence, batch sizes
shouldn't
+ // match.
+ assert_ne!(input_batch_sizes, result_batch_sizes);
+ }
+ }
+ }
+
+ Ok(())
+}
+
fn get_random_function(
schema: &SchemaRef,
rng: &mut StdRng,
@@ -343,11 +426,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear:
bool) -> WindowFrame {
} else {
WindowFrameBound::Following(ScalarValue::Int32(Some(end_bound.val)))
};
- let mut window_frame = WindowFrame {
- units,
- start_bound,
- end_bound,
- };
+ let mut window_frame = WindowFrame::new_bounds(units, start_bound,
end_bound);
// with 10% use unbounded preceding in tests
if rng.gen_range(0..10) == 0 {
window_frame.start_bound =
@@ -375,11 +454,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear:
bool) -> WindowFrame {
end_bound.val as u64,
)))
};
- let mut window_frame = WindowFrame {
- units,
- start_bound,
- end_bound,
- };
+ let mut window_frame = WindowFrame::new_bounds(units, start_bound,
end_bound);
// with 10% use unbounded preceding in tests
if rng.gen_range(0..10) == 0 {
window_frame.start_bound =
diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs
index 9b8f94f4b0..9534834088 100644
--- a/datafusion/expr/src/udwf.rs
+++ b/datafusion/expr/src/udwf.rs
@@ -222,7 +222,7 @@ where
/// vec![col("speed")], // smooth_it(speed)
/// vec![col("car")], // PARTITION BY car
/// vec![col("time").sort(true, true)], // ORDER BY time ASC
-/// WindowFrame::new(false),
+/// WindowFrame::new(None),
/// );
/// ```
pub trait WindowUDFImpl: Debug + Send + Sync {
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 914b354d29..40c2c47053 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -1252,28 +1252,28 @@ mod tests {
vec![col("name")],
vec![],
vec![],
- WindowFrame::new(false),
+ WindowFrame::new(None),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
- WindowFrame::new(false),
+ WindowFrame::new(None),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![],
- WindowFrame::new(false),
+ WindowFrame::new(None),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![],
- WindowFrame::new(false),
+ WindowFrame::new(None),
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
@@ -1295,28 +1295,28 @@ mod tests {
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
- WindowFrame::new(true),
+ WindowFrame::new(Some(false)),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
- WindowFrame::new(false),
+ WindowFrame::new(None),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
- WindowFrame::new(true),
+ WindowFrame::new(Some(false)),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
- WindowFrame::new(true),
+ WindowFrame::new(Some(false)),
));
// FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
@@ -1350,7 +1350,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("age")), true,
true)),
Expr::Sort(expr::Sort::new(Box::new(col("name")), false,
true)),
],
- WindowFrame::new(true),
+ WindowFrame::new(Some(false)),
)),
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
@@ -1361,7 +1361,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("age")), true,
true)),
Expr::Sort(expr::Sort::new(Box::new(col("created_at")),
false, true)),
],
- WindowFrame::new(true),
+ WindowFrame::new(Some(false)),
)),
];
let expected = vec![
diff --git a/datafusion/expr/src/window_frame.rs
b/datafusion/expr/src/window_frame.rs
index 2701ca1ecf..928cb4fa2b 100644
--- a/datafusion/expr/src/window_frame.rs
+++ b/datafusion/expr/src/window_frame.rs
@@ -23,28 +23,76 @@
//! - An ending frame boundary,
//! - An EXCLUDE clause.
+use std::convert::{From, TryFrom};
+use std::fmt::{self, Formatter};
+use std::hash::Hash;
+
use crate::expr::Sort;
use crate::Expr;
+
use datafusion_common::{plan_err, sql_err, DataFusionError, Result,
ScalarValue};
use sqlparser::ast;
use sqlparser::parser::ParserError::ParserError;
-use std::convert::{From, TryFrom};
-use std::fmt;
-use std::hash::Hash;
-/// The frame-spec determines which output rows are read by an aggregate
window function.
-///
-/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords
that surround the
-/// starting frame boundary are also omitted), in which case the ending frame
boundary defaults to
-/// CURRENT ROW.
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+/// The frame specification determines which output rows are read by an
aggregate
+/// window function. The ending frame boundary can be omitted if the `BETWEEN`
+/// and `AND` keywords that surround the starting frame boundary are also
omitted,
+/// in which case the ending frame boundary defaults to `CURRENT ROW`.
+#[derive(Clone, PartialEq, Eq, Hash)]
pub struct WindowFrame {
- /// A frame type - either ROWS, RANGE or GROUPS
+ /// Frame type - either `ROWS`, `RANGE` or `GROUPS`
pub units: WindowFrameUnits,
- /// A starting frame boundary
+ /// Starting frame boundary
pub start_bound: WindowFrameBound,
- /// An ending frame boundary
+ /// Ending frame boundary
pub end_bound: WindowFrameBound,
+ /// Flag indicating whether the frame is causal (i.e. computing the result
+ /// for the current row doesn't depend on any subsequent rows).
+ ///
+ /// Example causal window frames:
+ /// ```text
+ /// +--------------+
+ /// Future | |
+ /// | | |
+ /// | | |
+ /// Current Row |+------------+| ---
+ /// | | | |
+ /// | | | |
+ /// | | | | Window Frame 1
+ /// Past | | |
+ /// | | |
+ /// | | ---
+ /// +--------------+
+ ///
+ /// +--------------+
+ /// Future | |
+ /// | | |
+ /// | | |
+ /// Current Row |+------------+|
+ /// | | |
+ /// | | | ---
+ /// | | | |
+ /// Past | | | Window Frame 2
+ /// | | |
+ /// | | ---
+ /// +--------------+
+ /// ```
+ /// Example non-causal window frame:
+ /// ```text
+ /// +--------------+
+ /// Future | |
+ /// | | |
+ /// | | | ---
+ /// Current Row |+------------+| |
+ /// | | | | Window Frame 3
+ /// | | | |
+ /// | | | ---
+ /// Past | |
+ /// | |
+ /// | |
+ /// +--------------+
+ /// ```
+ causal: bool,
}
impl fmt::Display for WindowFrame {
@@ -58,6 +106,17 @@ impl fmt::Display for WindowFrame {
}
}
+impl fmt::Debug for WindowFrame {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ write!(
+ f,
+ "WindowFrame {{ units: {:?}, start_bound: {:?}, end_bound: {:?}
}}",
+ self.units, self.start_bound, self.end_bound
+ )?;
+ Ok(())
+ }
+}
+
impl TryFrom<ast::WindowFrame> for WindowFrame {
type Error = DataFusionError;
@@ -81,35 +140,40 @@ impl TryFrom<ast::WindowFrame> for WindowFrame {
)?
}
};
- Ok(Self {
- units: value.units.into(),
- start_bound,
- end_bound,
- })
+ let units = value.units.into();
+ Ok(Self::new_bounds(units, start_bound, end_bound))
}
}
impl WindowFrame {
- /// Creates a new, default window frame (with the meaning of default
depending on whether the
- /// frame contains an `ORDER BY` clause.
- pub fn new(has_order_by: bool) -> Self {
- if has_order_by {
- // This window frame covers the table (or partition if `PARTITION
BY` is used)
- // from beginning to the `CURRENT ROW` (with same rank). It is
used when the `OVER`
- // clause contains an `ORDER BY` clause but no frame.
- WindowFrame {
- units: WindowFrameUnits::Range,
+ /// Creates a new, default window frame (with the meaning of default
+ /// depending on whether the frame contains an `ORDER BY` clause and this
+ /// ordering is strict (i.e. no ties).
+ pub fn new(order_by: Option<bool>) -> Self {
+ if let Some(strict) = order_by {
+ // This window frame covers the table (or partition if `PARTITION
BY`
+ // is used) from beginning to the `CURRENT ROW` (with same rank).
It
+ // is used when the `OVER` clause contains an `ORDER BY` clause but
+ // no frame.
+ Self {
+ units: if strict {
+ WindowFrameUnits::Rows
+ } else {
+ WindowFrameUnits::Range
+ },
start_bound: WindowFrameBound::Preceding(ScalarValue::Null),
end_bound: WindowFrameBound::CurrentRow,
+ causal: strict,
}
} else {
- // This window frame covers the whole table (or partition if
`PARTITION BY` is used).
- // It is used when the `OVER` clause does not contain an `ORDER
BY` clause and there is
- // no frame.
- WindowFrame {
+ // This window frame covers the whole table (or partition if
`PARTITION BY`
+ // is used). It is used when the `OVER` clause does not contain an
+ // `ORDER BY` clause and there is no frame.
+ Self {
units: WindowFrameUnits::Rows,
start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
end_bound:
WindowFrameBound::Following(ScalarValue::UInt64(None)),
+ causal: false,
}
}
}
@@ -119,27 +183,68 @@ impl WindowFrame {
/// `2 ROWS PRECEDING AND 3 ROWS FOLLOWING`
pub fn reverse(&self) -> Self {
let start_bound = match &self.end_bound {
- WindowFrameBound::Preceding(elem) => {
- WindowFrameBound::Following(elem.clone())
+ WindowFrameBound::Preceding(value) => {
+ WindowFrameBound::Following(value.clone())
}
- WindowFrameBound::Following(elem) => {
- WindowFrameBound::Preceding(elem.clone())
+ WindowFrameBound::Following(value) => {
+ WindowFrameBound::Preceding(value.clone())
}
WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow,
};
let end_bound = match &self.start_bound {
- WindowFrameBound::Preceding(elem) => {
- WindowFrameBound::Following(elem.clone())
+ WindowFrameBound::Preceding(value) => {
+ WindowFrameBound::Following(value.clone())
}
- WindowFrameBound::Following(elem) => {
- WindowFrameBound::Preceding(elem.clone())
+ WindowFrameBound::Following(value) => {
+ WindowFrameBound::Preceding(value.clone())
}
WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow,
};
- WindowFrame {
- units: self.units,
+ Self::new_bounds(self.units, start_bound, end_bound)
+ }
+
+ /// Get whether window frame is causal
+ pub fn is_causal(&self) -> bool {
+ self.causal
+ }
+
+ /// Initializes window frame from units (type), start bound and end bound.
+ pub fn new_bounds(
+ units: WindowFrameUnits,
+ start_bound: WindowFrameBound,
+ end_bound: WindowFrameBound,
+ ) -> Self {
+ let causal = match units {
+ WindowFrameUnits::Rows => match &end_bound {
+ WindowFrameBound::Following(value) => {
+ if value.is_null() {
+ // Unbounded following
+ false
+ } else {
+ let zero = ScalarValue::new_zero(&value.data_type());
+ zero.map(|zero| value.eq(&zero)).unwrap_or(false)
+ }
+ }
+ _ => true,
+ },
+ WindowFrameUnits::Range | WindowFrameUnits::Groups => match
&end_bound {
+ WindowFrameBound::Preceding(value) => {
+ if value.is_null() {
+ // Unbounded preceding
+ true
+ } else {
+ let zero = ScalarValue::new_zero(&value.data_type());
+ zero.map(|zero| value.gt(&zero)).unwrap_or(false)
+ }
+ }
+ _ => false,
+ },
+ };
+ Self {
+ units,
start_bound,
end_bound,
+ causal,
}
}
}
diff --git a/datafusion/expr/src/window_state.rs
b/datafusion/expr/src/window_state.rs
index de88396d9b..d6c5a07385 100644
--- a/datafusion/expr/src/window_state.rs
+++ b/datafusion/expr/src/window_state.rs
@@ -682,11 +682,11 @@ mod tests {
#[test]
fn test_window_frame_group_boundaries() -> Result<()> {
- let window_frame = Arc::new(WindowFrame {
- units: WindowFrameUnits::Groups,
- start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
- end_bound:
WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
- });
+ let window_frame = Arc::new(WindowFrame::new_bounds(
+ WindowFrameUnits::Groups,
+ WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
+ WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
+ ));
let expected_results = vec![
(Range { start: 0, end: 2 }, 0),
(Range { start: 0, end: 4 }, 1),
@@ -703,11 +703,11 @@ mod tests {
#[test]
fn test_window_frame_group_boundaries_both_following() -> Result<()> {
- let window_frame = Arc::new(WindowFrame {
- units: WindowFrameUnits::Groups,
- start_bound:
WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
- end_bound:
WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
- });
+ let window_frame = Arc::new(WindowFrame::new_bounds(
+ WindowFrameUnits::Groups,
+ WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
+ WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
+ ));
let expected_results = vec![
(Range::<usize> { start: 1, end: 4 }, 0),
(Range::<usize> { start: 2, end: 5 }, 1),
@@ -724,11 +724,11 @@ mod tests {
#[test]
fn test_window_frame_group_boundaries_both_preceding() -> Result<()> {
- let window_frame = Arc::new(WindowFrame {
- units: WindowFrameUnits::Groups,
- start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
- end_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
- });
+ let window_frame = Arc::new(WindowFrame::new_bounds(
+ WindowFrameUnits::Groups,
+ WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
+ WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
+ ));
let expected_results = vec![
(Range::<usize> { start: 0, end: 0 }, 0),
(Range::<usize> { start: 0, end: 1 }, 1),
diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index 953716713e..35a8597832 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -346,13 +346,11 @@ mod tests {
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
- WindowFrame {
- units: WindowFrameUnits::Range,
- start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(
- 6,
- ))),
- end_bound:
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
- },
+ WindowFrame::new_bounds(
+ WindowFrameUnits::Range,
+ WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
+ WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
+ ),
))])?
.project(vec![count(wildcard())])?
.build()?;
diff --git a/datafusion/optimizer/src/push_down_projection.rs
b/datafusion/optimizer/src/push_down_projection.rs
index 4ee4f7e417..6a003ecb5f 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -586,7 +586,7 @@ mod tests {
vec![col("test.a")],
vec![col("test.b")],
vec![],
- WindowFrame::new(false),
+ WindowFrame::new(None),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
@@ -594,7 +594,7 @@ mod tests {
vec![col("test.b")],
vec![],
vec![],
- WindowFrame::new(false),
+ WindowFrame::new(None),
));
let col1 = col(max1.display_name()?);
let col2 = col(max2.display_name()?);
diff --git a/datafusion/physical-expr/src/window/built_in.rs
b/datafusion/physical-expr/src/window/built_in.rs
index 665ceb70d6..c3c7400026 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -206,6 +206,7 @@ impl WindowExpr for BuiltInWindowExpr {
let record_batch = &partition_batch_state.record_batch;
let num_rows = record_batch.num_rows();
let mut row_wise_results: Vec<ScalarValue> = vec![];
+ let mut is_causal = self.window_frame.is_causal();
for idx in state.last_calculated_index..num_rows {
let frame_range = if evaluator.uses_window_frame() {
state
@@ -224,11 +225,15 @@ impl WindowExpr for BuiltInWindowExpr {
idx,
)
} else {
+ is_causal = false;
evaluator.get_range(idx, num_rows)
}?;
- // Exit if the range extends all the way:
- if frame_range.end == num_rows &&
!partition_batch_state.is_end {
+ // Exit if the range is non-causal and extends all the way:
+ if frame_range.end == num_rows
+ && !is_causal
+ && !partition_batch_state.is_end
+ {
break;
}
// Update last range
diff --git a/datafusion/physical-expr/src/window/window_expr.rs
b/datafusion/physical-expr/src/window/window_expr.rs
index 548fae75bd..e2714dc42b 100644
--- a/datafusion/physical-expr/src/window/window_expr.rs
+++ b/datafusion/physical-expr/src/window/window_expr.rs
@@ -231,12 +231,13 @@ pub trait AggregateWindowExpr: WindowExpr {
// We iterate on each row to perform a running calculation.
let length = values[0].len();
let mut row_wise_results: Vec<ScalarValue> = vec![];
+ let is_causal = self.get_window_frame().is_causal();
while idx < length {
// Start search from the last_range. This squeezes searched range.
let cur_range =
window_frame_ctx.calculate_range(&order_bys, last_range,
length, idx)?;
- // Exit if the range extends all the way:
- if cur_range.end == length && not_end {
+ // Exit if the range is non-causal and extends all the way:
+ if cur_range.end == length && !is_causal && not_end {
break;
}
let value = self.get_aggregate_result_inside_range(
diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
index 0871ec0d7f..9d247d689c 100644
--- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
+++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
@@ -1170,33 +1170,33 @@ mod tests {
last_value_func,
&[],
&[],
- Arc::new(WindowFrame {
- units: WindowFrameUnits::Rows,
- start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
- end_bound: WindowFrameBound::CurrentRow,
- }),
+ Arc::new(WindowFrame::new_bounds(
+ WindowFrameUnits::Rows,
+ WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+ WindowFrameBound::CurrentRow,
+ )),
)) as _,
// NTH_VALUE(a, -1)
Arc::new(BuiltInWindowExpr::new(
nth_value_func1,
&[],
&[],
- Arc::new(WindowFrame {
- units: WindowFrameUnits::Rows,
- start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
- end_bound: WindowFrameBound::CurrentRow,
- }),
+ Arc::new(WindowFrame::new_bounds(
+ WindowFrameUnits::Rows,
+ WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+ WindowFrameBound::CurrentRow,
+ )),
)) as _,
// NTH_VALUE(a, -2)
Arc::new(BuiltInWindowExpr::new(
nth_value_func2,
&[],
&[],
- Arc::new(WindowFrame {
- units: WindowFrameUnits::Rows,
- start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
- end_bound: WindowFrameBound::CurrentRow,
- }),
+ Arc::new(WindowFrame::new_bounds(
+ WindowFrameUnits::Rows,
+ WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+ WindowFrameBound::CurrentRow,
+ )),
)) as _,
];
let physical_plan = BoundedWindowAggExec::try_new(
diff --git a/datafusion/physical-plan/src/windows/mod.rs
b/datafusion/physical-plan/src/windows/mod.rs
index fec168fabf..a85e5cc31c 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -654,7 +654,7 @@ mod tests {
&[col("a", &schema)?],
&[],
&[],
- Arc::new(WindowFrame::new(false)),
+ Arc::new(WindowFrame::new(None)),
schema.as_ref(),
)?],
blocking_exec,
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index d15cf1db92..2d9c7be46b 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -879,11 +879,7 @@ impl TryFrom<protobuf::WindowFrame> for WindowFrame {
})
.transpose()?
.unwrap_or(WindowFrameBound::CurrentRow);
- Ok(Self {
- units,
- start_bound,
- end_bound,
- })
+ Ok(WindowFrame::new_bounds(units, start_bound, end_bound))
}
}
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 03daf535f2..ed21124a9e 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -1671,7 +1671,7 @@ fn roundtrip_window() {
vec![],
vec![col("col1")],
vec![col("col2")],
- WindowFrame::new(true),
+ WindowFrame::new(Some(false)),
));
// 2. with default window_frame
@@ -1682,15 +1682,15 @@ fn roundtrip_window() {
vec![],
vec![col("col1")],
vec![col("col2")],
- WindowFrame::new(true),
+ WindowFrame::new(Some(false)),
));
// 3. with window_frame with row numbers
- let range_number_frame = WindowFrame {
- units: WindowFrameUnits::Range,
- start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
- end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
- };
+ let range_number_frame = WindowFrame::new_bounds(
+ WindowFrameUnits::Range,
+ WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
+ WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
+ );
let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::BuiltInWindowFunction(
@@ -1703,11 +1703,11 @@ fn roundtrip_window() {
));
// 4. test with AggregateFunction
- let row_number_frame = WindowFrame {
- units: WindowFrameUnits::Rows,
- start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
- end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
- };
+ let row_number_frame = WindowFrame::new_bounds(
+ WindowFrameUnits::Rows,
+ WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
+ WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
+ );
let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 9ee8d0d51d..3a13dc887f 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -255,11 +255,11 @@ fn roundtrip_window() -> Result<()> {
let field_b = Field::new("b", DataType::Int64, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
- let window_frame = WindowFrame {
- units: datafusion_expr::WindowFrameUnits::Range,
- start_bound: WindowFrameBound::Preceding(ScalarValue::Int64(None)),
- end_bound: WindowFrameBound::CurrentRow,
- };
+ let window_frame = WindowFrame::new_bounds(
+ datafusion_expr::WindowFrameUnits::Range,
+ WindowFrameBound::Preceding(ScalarValue::Int64(None)),
+ WindowFrameBound::CurrentRow,
+ );
let builtin_window_expr = Arc::new(BuiltInWindowExpr::new(
Arc::new(NthValue::first(
@@ -286,14 +286,14 @@ fn roundtrip_window() -> Result<()> {
)),
&[],
&[],
- Arc::new(WindowFrame::new(false)),
+ Arc::new(WindowFrame::new(None)),
));
- let window_frame = WindowFrame {
- units: datafusion_expr::WindowFrameUnits::Range,
- start_bound: WindowFrameBound::CurrentRow,
- end_bound: WindowFrameBound::Preceding(ScalarValue::Int64(None)),
- };
+ let window_frame = WindowFrame::new_bounds(
+ datafusion_expr::WindowFrameUnits::Range,
+ WindowFrameBound::CurrentRow,
+ WindowFrameBound::Preceding(ScalarValue::Int64(None)),
+ );
let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new(
Arc::new(Sum::new(
diff --git a/datafusion/sql/src/expr/function.rs
b/datafusion/sql/src/expr/function.rs
index 395f10b6f7..30f8605c39 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -17,7 +17,8 @@
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{
- not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError,
Result,
+ not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError,
Dependency,
+ Result,
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::function::suggest_valid_function;
@@ -102,6 +103,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// Numeric literals in window function ORDER BY are treated as
constants
false,
)?;
+
+ let func_deps = schema.functional_dependencies();
+ // Find whether ties are possible in the given ordering:
+ let is_ordering_strict = order_by.iter().any(|orderby_expr| {
+ if let Expr::Sort(sort_expr) = orderby_expr {
+ if let Expr::Column(col) = sort_expr.expr.as_ref() {
+ let idx = schema.index_of_column(col).unwrap();
+ return func_deps.iter().any(|dep| {
+ dep.source_indices == vec![idx]
+ && dep.mode == Dependency::Single
+ });
+ }
+ }
+ false
+ });
+
let window_frame = window
.window_frame
.as_ref()
@@ -115,8 +132,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let window_frame = if let Some(window_frame) = window_frame {
regularize_window_order_by(&window_frame, &mut order_by)?;
window_frame
+ } else if is_ordering_strict {
+ WindowFrame::new(Some(true))
} else {
- WindowFrame::new(!order_by.is_empty())
+ WindowFrame::new((!order_by.is_empty()).then_some(false))
};
if let Ok(fun) = self.find_window_func(&name) {
diff --git a/datafusion/sqllogictest/test_files/window.slt
b/datafusion/sqllogictest/test_files/window.slt
index 100c214383..f8337e21d7 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -3871,3 +3871,38 @@ FROM (SELECT c1, c2, ROW_NUMBER() OVER(PARTITION BY c1)
as rn
LIMIT 5)
GROUP BY rn
ORDER BY rn;
+
+# create a table for testing
+statement ok
+CREATE TABLE table_with_pk (
+ sn INT PRIMARY KEY,
+ ts TIMESTAMP,
+ currency VARCHAR(3),
+ amount FLOAT
+ ) as VALUES
+ (0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0),
+ (1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0),
+ (2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0),
+ (3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0),
+ (4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0)
+
+# An OVER clause of the form `OVER (ORDER BY <expr>)` gets treated as if it was
+# `OVER (ORDER BY <expr> RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)`.
+# However, if we know that <expr> contains a unique column (e.g. a PRIMARY
KEY),
+# it can be treated as `OVER (ORDER BY <expr> ROWS BETWEEN UNBOUNDED PRECEDING
+# AND CURRENT ROW)` where window frame units change from `RANGE` to `ROWS`.
This
+# conversion makes the window frame manifestly causal by eliminating the
possiblity
+# of ties explicitly (see window frame documentation for a discussion of
causality
+# in this context). The Query below should have `ROWS` in its window frame.
+query TT
+EXPLAIN SELECT *, SUM(amount) OVER (ORDER BY sn) as sum1 FROM table_with_pk;
+----
+logical_plan
+Projection: table_with_pk.sn, table_with_pk.ts, table_with_pk.currency,
table_with_pk.amount, SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC
NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1
+--WindowAggr: windowExpr=[[SUM(CAST(table_with_pk.amount AS Float64)) ORDER BY
[table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT
ROW]]
+----TableScan: table_with_pk projection=[sn, ts, currency, amount]
+physical_plan
+ProjectionExec: expr=[sn@0 as sn, ts@1 as ts, currency@2 as currency, amount@3
as amount, SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST]
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum1]
+--BoundedWindowAggExec: wdw=[SUM(table_with_pk.amount) ORDER BY
[table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT
ROW: Ok(Field { name: "SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC
NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type:
Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }),
frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)),
end_bound: CurrentRow }], mode=[Sorted]
+----SortExec: expr=[sn@0 ASC NULLS LAST]
+------MemoryExec: partitions=1, partition_sizes=[1]
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index a4ec3e7722..7687aff2bc 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -973,11 +973,11 @@ pub async fn from_substrait_rex(
)
.await?,
order_by,
- window_frame: datafusion::logical_expr::WindowFrame {
+ window_frame:
datafusion::logical_expr::WindowFrame::new_bounds(
units,
- start_bound: from_substrait_bound(&window.lower_bound,
true)?,
- end_bound: from_substrait_bound(&window.upper_bound,
false)?,
- },
+ from_substrait_bound(&window.lower_bound, true)?,
+ from_substrait_bound(&window.upper_bound, false)?,
+ ),
})))
}
Some(RexType::Subquery(subquery)) => match
&subquery.as_ref().subquery_type {