This is an automated email from the ASF dual-hosted git repository.

ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 259e12c6fe Determine causal window frames to produce early results. 
(#8842)
259e12c6fe is described below

commit 259e12c6fe3d3db971b4a20aee9b68b4a49ad5c6
Author: Mustafa Akur <[email protected]>
AuthorDate: Mon Jan 15 21:23:07 2024 +0300

    Determine causal window frames to produce early results. (#8842)
    
    * add handling for primary key window frame
    
    * Add check for window end
    
    * Add uniqueness check
    
    * Minor changes
    
    * Update signature of WindowFrame::new
    
    * Make is_causal window state
    
    * Minor changes
    
    * Address reviews
    
    * Minor changes
    
    * Add new test
    
    * Minor changes
    
    * Minor changes
    
    * Remove string handling
    
    * Review Part 2
    
    * Improve comments
    
    ---------
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
 datafusion-examples/examples/advanced_udwf.rs      |   2 +-
 datafusion-examples/examples/simple_udwf.rs        |   2 +-
 datafusion/common/src/scalar.rs                    |   1 -
 datafusion/core/src/dataframe/mod.rs               |   2 +-
 .../core/src/physical_optimizer/test_utils.rs      |   2 +-
 datafusion/core/tests/dataframe/mod.rs             |  10 +-
 datafusion/core/tests/fuzz_cases/window_fuzz.rs    |  95 +++++++++--
 datafusion/expr/src/udwf.rs                        |   2 +-
 datafusion/expr/src/utils.rs                       |  20 +--
 datafusion/expr/src/window_frame.rs                | 185 ++++++++++++++++-----
 datafusion/expr/src/window_state.rs                |  30 ++--
 .../optimizer/src/analyzer/count_wildcard_rule.rs  |  12 +-
 datafusion/optimizer/src/push_down_projection.rs   |   4 +-
 datafusion/physical-expr/src/window/built_in.rs    |   9 +-
 datafusion/physical-expr/src/window/window_expr.rs |   5 +-
 .../src/windows/bounded_window_agg_exec.rs         |  30 ++--
 datafusion/physical-plan/src/windows/mod.rs        |   2 +-
 datafusion/proto/src/logical_plan/from_proto.rs    |   6 +-
 .../proto/tests/cases/roundtrip_logical_plan.rs    |  24 +--
 .../proto/tests/cases/roundtrip_physical_plan.rs   |  22 +--
 datafusion/sql/src/expr/function.rs                |  23 ++-
 datafusion/sqllogictest/test_files/window.slt      |  35 ++++
 datafusion/substrait/src/logical_plan/consumer.rs  |   8 +-
 23 files changed, 382 insertions(+), 149 deletions(-)

diff --git a/datafusion-examples/examples/advanced_udwf.rs 
b/datafusion-examples/examples/advanced_udwf.rs
index f46031434f..826abc28e1 100644
--- a/datafusion-examples/examples/advanced_udwf.rs
+++ b/datafusion-examples/examples/advanced_udwf.rs
@@ -220,7 +220,7 @@ async fn main() -> Result<()> {
         vec![col("speed")],                 // smooth_it(speed)
         vec![col("car")],                   // PARTITION BY car
         vec![col("time").sort(true, true)], // ORDER BY time ASC
-        WindowFrame::new(false),
+        WindowFrame::new(None),
     );
     let df = ctx.table("cars").await?.window(vec![window_expr])?;
 
diff --git a/datafusion-examples/examples/simple_udwf.rs 
b/datafusion-examples/examples/simple_udwf.rs
index 0d04c093e1..a6149d661e 100644
--- a/datafusion-examples/examples/simple_udwf.rs
+++ b/datafusion-examples/examples/simple_udwf.rs
@@ -123,7 +123,7 @@ async fn main() -> Result<()> {
         vec![col("speed")],                 // smooth_it(speed)
         vec![col("car")],                   // PARTITION BY car
         vec![col("time").sort(true, true)], // ORDER BY time ASC
-        WindowFrame::new(false),
+        WindowFrame::new(None),
     );
     let df = ctx.table("cars").await?.window(vec![window_expr])?;
 
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 8820ca9942..cc5b70796e 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -808,7 +808,6 @@ impl ScalarValue {
 
     /// Create a zero value in the given type.
     pub fn new_zero(datatype: &DataType) -> Result<ScalarValue> {
-        assert!(datatype.is_primitive());
         Ok(match datatype {
             DataType::Boolean => ScalarValue::Boolean(Some(false)),
             DataType::Int8 => ScalarValue::Int8(Some(0)),
diff --git a/datafusion/core/src/dataframe/mod.rs 
b/datafusion/core/src/dataframe/mod.rs
index 9d00dc72f1..285e761e2e 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -1548,7 +1548,7 @@ mod tests {
             vec![col("aggregate_test_100.c1")],
             vec![col("aggregate_test_100.c2")],
             vec![],
-            WindowFrame::new(false),
+            WindowFrame::new(None),
         ));
         let t2 = t.select(vec![col("c1"), first_row])?;
         let plan = t2.plan.clone();
diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs 
b/datafusion/core/src/physical_optimizer/test_utils.rs
index debafefe39..68ac9598a3 100644
--- a/datafusion/core/src/physical_optimizer/test_utils.rs
+++ b/datafusion/core/src/physical_optimizer/test_utils.rs
@@ -239,7 +239,7 @@ pub fn bounded_window_exec(
                 &[col(col_name, &schema).unwrap()],
                 &[],
                 &sort_exprs,
-                Arc::new(WindowFrame::new(true)),
+                Arc::new(WindowFrame::new(Some(false))),
                 schema.as_ref(),
             )
             .unwrap()],
diff --git a/datafusion/core/tests/dataframe/mod.rs 
b/datafusion/core/tests/dataframe/mod.rs
index cca23ac684..588b4647e5 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -174,11 +174,11 @@ async fn test_count_wildcard_on_window() -> Result<()> {
             vec![wildcard()],
             vec![],
             vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
-            WindowFrame {
-                units: WindowFrameUnits::Range,
-                start_bound: 
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
-                end_bound: 
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
-            },
+            WindowFrame::new_bounds(
+                WindowFrameUnits::Range,
+                WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
+                WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
+            ),
         ))])?
         .explain(false, false)?
         .collect()
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index 3037b4857a..6e5c5f8eb9 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -139,6 +139,89 @@ async fn window_bounded_window_random_comparison() -> 
Result<()> {
     Ok(())
 }
 
+// This tests whether we can generate bounded window results for each input
+// batch immediately for causal window frames.
+#[tokio::test(flavor = "multi_thread", worker_threads = 16)]
+async fn bounded_window_causal_non_causal() -> Result<()> {
+    let session_config = SessionConfig::new();
+    let ctx = SessionContext::new_with_config(session_config);
+    let mut batches = make_staggered_batches::<true>(1000, 10, 23_u64);
+    // Remove empty batches:
+    batches.retain(|batch| batch.num_rows() > 0);
+    let schema = batches[0].schema();
+    let memory_exec = Arc::new(MemoryExec::try_new(
+        &[batches.clone()],
+        schema.clone(),
+        None,
+    )?);
+    let window_fn = 
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count);
+    let fn_name = "COUNT".to_string();
+    let args = vec![col("x", &schema)?];
+    let partitionby_exprs = vec![];
+    let orderby_exprs = vec![];
+    // Window frame starts with "UNBOUNDED PRECEDING":
+    let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None));
+
+    // Simulate cases of the following form:
+    // COUNT(x) OVER (
+    //     ROWS BETWEEN UNBOUNDED PRECEDING AND <end_bound> PRECEDING/FOLLOWING
+    // )
+    for is_preceding in [false, true] {
+        for end_bound in [0, 1, 2, 3] {
+            let end_bound = if is_preceding {
+                
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(end_bound)))
+            } else {
+                
WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound)))
+            };
+            let window_frame = WindowFrame::new_bounds(
+                WindowFrameUnits::Rows,
+                start_bound.clone(),
+                end_bound,
+            );
+            let causal = window_frame.is_causal();
+
+            let window_expr = create_window_expr(
+                &window_fn,
+                fn_name.clone(),
+                &args,
+                &partitionby_exprs,
+                &orderby_exprs,
+                Arc::new(window_frame),
+                schema.as_ref(),
+            )?;
+            let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
+                vec![window_expr],
+                memory_exec.clone(),
+                vec![],
+                InputOrderMode::Linear,
+            )?);
+            let task_ctx = ctx.task_ctx();
+            let mut collected_results = collect(running_window_exec, 
task_ctx).await?;
+            collected_results.retain(|batch| batch.num_rows() > 0);
+            let input_batch_sizes = batches
+                .iter()
+                .map(|batch| batch.num_rows())
+                .collect::<Vec<_>>();
+            let result_batch_sizes = collected_results
+                .iter()
+                .map(|batch| batch.num_rows())
+                .collect::<Vec<_>>();
+            if causal {
+                // For causal window frames, we can generate results 
immediately
+                // for each input batch. Hence, batch sizes should match.
+                assert_eq!(input_batch_sizes, result_batch_sizes);
+            } else {
+                // For non-causal window frames, we cannot generate results
+                // immediately for each input batch. Hence, batch sizes 
shouldn't
+                // match.
+                assert_ne!(input_batch_sizes, result_batch_sizes);
+            }
+        }
+    }
+
+    Ok(())
+}
+
 fn get_random_function(
     schema: &SchemaRef,
     rng: &mut StdRng,
@@ -343,11 +426,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: 
bool) -> WindowFrame {
             } else {
                 
WindowFrameBound::Following(ScalarValue::Int32(Some(end_bound.val)))
             };
-            let mut window_frame = WindowFrame {
-                units,
-                start_bound,
-                end_bound,
-            };
+            let mut window_frame = WindowFrame::new_bounds(units, start_bound, 
end_bound);
             // with 10% use unbounded preceding in tests
             if rng.gen_range(0..10) == 0 {
                 window_frame.start_bound =
@@ -375,11 +454,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: 
bool) -> WindowFrame {
                     end_bound.val as u64,
                 )))
             };
-            let mut window_frame = WindowFrame {
-                units,
-                start_bound,
-                end_bound,
-            };
+            let mut window_frame = WindowFrame::new_bounds(units, start_bound, 
end_bound);
             // with 10% use unbounded preceding in tests
             if rng.gen_range(0..10) == 0 {
                 window_frame.start_bound =
diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs
index 9b8f94f4b0..9534834088 100644
--- a/datafusion/expr/src/udwf.rs
+++ b/datafusion/expr/src/udwf.rs
@@ -222,7 +222,7 @@ where
 ///     vec![col("speed")],                 // smooth_it(speed)
 ///     vec![col("car")],                   // PARTITION BY car
 ///     vec![col("time").sort(true, true)], // ORDER BY time ASC
-///     WindowFrame::new(false),
+///     WindowFrame::new(None),
 /// );
 /// ```
 pub trait WindowUDFImpl: Debug + Send + Sync {
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 914b354d29..40c2c47053 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -1252,28 +1252,28 @@ mod tests {
             vec![col("name")],
             vec![],
             vec![],
-            WindowFrame::new(false),
+            WindowFrame::new(None),
         ));
         let max2 = Expr::WindowFunction(expr::WindowFunction::new(
             
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
             vec![col("name")],
             vec![],
             vec![],
-            WindowFrame::new(false),
+            WindowFrame::new(None),
         ));
         let min3 = Expr::WindowFunction(expr::WindowFunction::new(
             
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
             vec![col("name")],
             vec![],
             vec![],
-            WindowFrame::new(false),
+            WindowFrame::new(None),
         ));
         let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
             
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
             vec![col("age")],
             vec![],
             vec![],
-            WindowFrame::new(false),
+            WindowFrame::new(None),
         ));
         let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
         let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
@@ -1295,28 +1295,28 @@ mod tests {
             vec![col("name")],
             vec![],
             vec![age_asc.clone(), name_desc.clone()],
-            WindowFrame::new(true),
+            WindowFrame::new(Some(false)),
         ));
         let max2 = Expr::WindowFunction(expr::WindowFunction::new(
             
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
             vec![col("name")],
             vec![],
             vec![],
-            WindowFrame::new(false),
+            WindowFrame::new(None),
         ));
         let min3 = Expr::WindowFunction(expr::WindowFunction::new(
             
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
             vec![col("name")],
             vec![],
             vec![age_asc.clone(), name_desc.clone()],
-            WindowFrame::new(true),
+            WindowFrame::new(Some(false)),
         ));
         let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
             
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
             vec![col("age")],
             vec![],
             vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
-            WindowFrame::new(true),
+            WindowFrame::new(Some(false)),
         ));
         // FIXME use as_ref
         let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
@@ -1350,7 +1350,7 @@ mod tests {
                     Expr::Sort(expr::Sort::new(Box::new(col("age")), true, 
true)),
                     Expr::Sort(expr::Sort::new(Box::new(col("name")), false, 
true)),
                 ],
-                WindowFrame::new(true),
+                WindowFrame::new(Some(false)),
             )),
             Expr::WindowFunction(expr::WindowFunction::new(
                 
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
@@ -1361,7 +1361,7 @@ mod tests {
                     Expr::Sort(expr::Sort::new(Box::new(col("age")), true, 
true)),
                     Expr::Sort(expr::Sort::new(Box::new(col("created_at")), 
false, true)),
                 ],
-                WindowFrame::new(true),
+                WindowFrame::new(Some(false)),
             )),
         ];
         let expected = vec![
diff --git a/datafusion/expr/src/window_frame.rs 
b/datafusion/expr/src/window_frame.rs
index 2701ca1ecf..928cb4fa2b 100644
--- a/datafusion/expr/src/window_frame.rs
+++ b/datafusion/expr/src/window_frame.rs
@@ -23,28 +23,76 @@
 //! - An ending frame boundary,
 //! - An EXCLUDE clause.
 
+use std::convert::{From, TryFrom};
+use std::fmt::{self, Formatter};
+use std::hash::Hash;
+
 use crate::expr::Sort;
 use crate::Expr;
+
 use datafusion_common::{plan_err, sql_err, DataFusionError, Result, 
ScalarValue};
 use sqlparser::ast;
 use sqlparser::parser::ParserError::ParserError;
-use std::convert::{From, TryFrom};
-use std::fmt;
-use std::hash::Hash;
 
-/// The frame-spec determines which output rows are read by an aggregate 
window function.
-///
-/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords 
that surround the
-/// starting frame boundary are also omitted), in which case the ending frame 
boundary defaults to
-/// CURRENT ROW.
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+/// The frame specification determines which output rows are read by an 
aggregate
+/// window function. The ending frame boundary can be omitted if the `BETWEEN`
+/// and `AND` keywords that surround the starting frame boundary are also 
omitted,
+/// in which case the ending frame boundary defaults to `CURRENT ROW`.
+#[derive(Clone, PartialEq, Eq, Hash)]
 pub struct WindowFrame {
-    /// A frame type - either ROWS, RANGE or GROUPS
+    /// Frame type - either `ROWS`, `RANGE` or `GROUPS`
     pub units: WindowFrameUnits,
-    /// A starting frame boundary
+    /// Starting frame boundary
     pub start_bound: WindowFrameBound,
-    /// An ending frame boundary
+    /// Ending frame boundary
     pub end_bound: WindowFrameBound,
+    /// Flag indicating whether the frame is causal (i.e. computing the result
+    /// for the current row doesn't depend on any subsequent rows).
+    ///
+    /// Example causal window frames:
+    /// ```text
+    ///                +--------------+
+    ///      Future    |              |
+    ///         |      |              |
+    ///         |      |              |
+    ///    Current Row |+------------+|  ---
+    ///         |      |              |   |
+    ///         |      |              |   |
+    ///         |      |              |   |  Window Frame 1
+    ///       Past     |              |   |
+    ///                |              |   |
+    ///                |              |  ---
+    ///                +--------------+
+    ///
+    ///                +--------------+
+    ///      Future    |              |
+    ///         |      |              |
+    ///         |      |              |
+    ///    Current Row |+------------+|
+    ///         |      |              |
+    ///         |      |              | ---
+    ///         |      |              |  |
+    ///       Past     |              |  |  Window Frame 2
+    ///                |              |  |
+    ///                |              | ---
+    ///                +--------------+
+    /// ```
+    /// Example non-causal window frame:
+    /// ```text
+    ///                +--------------+
+    ///      Future    |              |
+    ///         |      |              |
+    ///         |      |              | ---
+    ///    Current Row |+------------+|  |
+    ///         |      |              |  |  Window Frame 3
+    ///         |      |              |  |
+    ///         |      |              | ---
+    ///       Past     |              |
+    ///                |              |
+    ///                |              |
+    ///                +--------------+
+    /// ```
+    causal: bool,
 }
 
 impl fmt::Display for WindowFrame {
@@ -58,6 +106,17 @@ impl fmt::Display for WindowFrame {
     }
 }
 
+impl fmt::Debug for WindowFrame {
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+        write!(
+            f,
+            "WindowFrame {{ units: {:?}, start_bound: {:?}, end_bound: {:?} 
}}",
+            self.units, self.start_bound, self.end_bound
+        )?;
+        Ok(())
+    }
+}
+
 impl TryFrom<ast::WindowFrame> for WindowFrame {
     type Error = DataFusionError;
 
@@ -81,35 +140,40 @@ impl TryFrom<ast::WindowFrame> for WindowFrame {
                 )?
             }
         };
-        Ok(Self {
-            units: value.units.into(),
-            start_bound,
-            end_bound,
-        })
+        let units = value.units.into();
+        Ok(Self::new_bounds(units, start_bound, end_bound))
     }
 }
 
 impl WindowFrame {
-    /// Creates a new, default window frame (with the meaning of default 
depending on whether the
-    /// frame contains an `ORDER BY` clause.
-    pub fn new(has_order_by: bool) -> Self {
-        if has_order_by {
-            // This window frame covers the table (or partition if `PARTITION 
BY` is used)
-            // from beginning to the `CURRENT ROW` (with same rank). It is 
used when the `OVER`
-            // clause contains an `ORDER BY` clause but no frame.
-            WindowFrame {
-                units: WindowFrameUnits::Range,
+    /// Creates a new, default window frame (with the meaning of default
+    /// depending on whether the frame contains an `ORDER BY` clause and this
+    /// ordering is strict (i.e. no ties).
+    pub fn new(order_by: Option<bool>) -> Self {
+        if let Some(strict) = order_by {
+            // This window frame covers the table (or partition if `PARTITION 
BY`
+            // is used) from beginning to the `CURRENT ROW` (with same rank). 
It
+            // is used when the `OVER` clause contains an `ORDER BY` clause but
+            // no frame.
+            Self {
+                units: if strict {
+                    WindowFrameUnits::Rows
+                } else {
+                    WindowFrameUnits::Range
+                },
                 start_bound: WindowFrameBound::Preceding(ScalarValue::Null),
                 end_bound: WindowFrameBound::CurrentRow,
+                causal: strict,
             }
         } else {
-            // This window frame covers the whole table (or partition if 
`PARTITION BY` is used).
-            // It is used when the `OVER` clause does not contain an `ORDER 
BY` clause and there is
-            // no frame.
-            WindowFrame {
+            // This window frame covers the whole table (or partition if 
`PARTITION BY`
+            // is used). It is used when the `OVER` clause does not contain an
+            // `ORDER BY` clause and there is no frame.
+            Self {
                 units: WindowFrameUnits::Rows,
                 start_bound: 
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
                 end_bound: 
WindowFrameBound::Following(ScalarValue::UInt64(None)),
+                causal: false,
             }
         }
     }
@@ -119,27 +183,68 @@ impl WindowFrame {
     /// `2 ROWS PRECEDING AND 3 ROWS FOLLOWING`
     pub fn reverse(&self) -> Self {
         let start_bound = match &self.end_bound {
-            WindowFrameBound::Preceding(elem) => {
-                WindowFrameBound::Following(elem.clone())
+            WindowFrameBound::Preceding(value) => {
+                WindowFrameBound::Following(value.clone())
             }
-            WindowFrameBound::Following(elem) => {
-                WindowFrameBound::Preceding(elem.clone())
+            WindowFrameBound::Following(value) => {
+                WindowFrameBound::Preceding(value.clone())
             }
             WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow,
         };
         let end_bound = match &self.start_bound {
-            WindowFrameBound::Preceding(elem) => {
-                WindowFrameBound::Following(elem.clone())
+            WindowFrameBound::Preceding(value) => {
+                WindowFrameBound::Following(value.clone())
             }
-            WindowFrameBound::Following(elem) => {
-                WindowFrameBound::Preceding(elem.clone())
+            WindowFrameBound::Following(value) => {
+                WindowFrameBound::Preceding(value.clone())
             }
             WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow,
         };
-        WindowFrame {
-            units: self.units,
+        Self::new_bounds(self.units, start_bound, end_bound)
+    }
+
+    /// Get whether window frame is causal
+    pub fn is_causal(&self) -> bool {
+        self.causal
+    }
+
+    /// Initializes window frame from units (type), start bound and end bound.
+    pub fn new_bounds(
+        units: WindowFrameUnits,
+        start_bound: WindowFrameBound,
+        end_bound: WindowFrameBound,
+    ) -> Self {
+        let causal = match units {
+            WindowFrameUnits::Rows => match &end_bound {
+                WindowFrameBound::Following(value) => {
+                    if value.is_null() {
+                        // Unbounded following
+                        false
+                    } else {
+                        let zero = ScalarValue::new_zero(&value.data_type());
+                        zero.map(|zero| value.eq(&zero)).unwrap_or(false)
+                    }
+                }
+                _ => true,
+            },
+            WindowFrameUnits::Range | WindowFrameUnits::Groups => match 
&end_bound {
+                WindowFrameBound::Preceding(value) => {
+                    if value.is_null() {
+                        // Unbounded preceding
+                        true
+                    } else {
+                        let zero = ScalarValue::new_zero(&value.data_type());
+                        zero.map(|zero| value.gt(&zero)).unwrap_or(false)
+                    }
+                }
+                _ => false,
+            },
+        };
+        Self {
+            units,
             start_bound,
             end_bound,
+            causal,
         }
     }
 }
diff --git a/datafusion/expr/src/window_state.rs 
b/datafusion/expr/src/window_state.rs
index de88396d9b..d6c5a07385 100644
--- a/datafusion/expr/src/window_state.rs
+++ b/datafusion/expr/src/window_state.rs
@@ -682,11 +682,11 @@ mod tests {
 
     #[test]
     fn test_window_frame_group_boundaries() -> Result<()> {
-        let window_frame = Arc::new(WindowFrame {
-            units: WindowFrameUnits::Groups,
-            start_bound: 
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
-            end_bound: 
WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
-        });
+        let window_frame = Arc::new(WindowFrame::new_bounds(
+            WindowFrameUnits::Groups,
+            WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
+            WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
+        ));
         let expected_results = vec![
             (Range { start: 0, end: 2 }, 0),
             (Range { start: 0, end: 4 }, 1),
@@ -703,11 +703,11 @@ mod tests {
 
     #[test]
     fn test_window_frame_group_boundaries_both_following() -> Result<()> {
-        let window_frame = Arc::new(WindowFrame {
-            units: WindowFrameUnits::Groups,
-            start_bound: 
WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
-            end_bound: 
WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
-        });
+        let window_frame = Arc::new(WindowFrame::new_bounds(
+            WindowFrameUnits::Groups,
+            WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
+            WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
+        ));
         let expected_results = vec![
             (Range::<usize> { start: 1, end: 4 }, 0),
             (Range::<usize> { start: 2, end: 5 }, 1),
@@ -724,11 +724,11 @@ mod tests {
 
     #[test]
     fn test_window_frame_group_boundaries_both_preceding() -> Result<()> {
-        let window_frame = Arc::new(WindowFrame {
-            units: WindowFrameUnits::Groups,
-            start_bound: 
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
-            end_bound: 
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
-        });
+        let window_frame = Arc::new(WindowFrame::new_bounds(
+            WindowFrameUnits::Groups,
+            WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
+            WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
+        ));
         let expected_results = vec![
             (Range::<usize> { start: 0, end: 0 }, 0),
             (Range::<usize> { start: 0, end: 1 }, 1),
diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs 
b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index 953716713e..35a8597832 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -346,13 +346,11 @@ mod tests {
                 vec![wildcard()],
                 vec![],
                 vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
-                WindowFrame {
-                    units: WindowFrameUnits::Range,
-                    start_bound: 
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(
-                        6,
-                    ))),
-                    end_bound: 
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
-                },
+                WindowFrame::new_bounds(
+                    WindowFrameUnits::Range,
+                    WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
+                    WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
+                ),
             ))])?
             .project(vec![count(wildcard())])?
             .build()?;
diff --git a/datafusion/optimizer/src/push_down_projection.rs 
b/datafusion/optimizer/src/push_down_projection.rs
index 4ee4f7e417..6a003ecb5f 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -586,7 +586,7 @@ mod tests {
             vec![col("test.a")],
             vec![col("test.b")],
             vec![],
-            WindowFrame::new(false),
+            WindowFrame::new(None),
         ));
 
         let max2 = Expr::WindowFunction(expr::WindowFunction::new(
@@ -594,7 +594,7 @@ mod tests {
             vec![col("test.b")],
             vec![],
             vec![],
-            WindowFrame::new(false),
+            WindowFrame::new(None),
         ));
         let col1 = col(max1.display_name()?);
         let col2 = col(max2.display_name()?);
diff --git a/datafusion/physical-expr/src/window/built_in.rs 
b/datafusion/physical-expr/src/window/built_in.rs
index 665ceb70d6..c3c7400026 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -206,6 +206,7 @@ impl WindowExpr for BuiltInWindowExpr {
             let record_batch = &partition_batch_state.record_batch;
             let num_rows = record_batch.num_rows();
             let mut row_wise_results: Vec<ScalarValue> = vec![];
+            let mut is_causal = self.window_frame.is_causal();
             for idx in state.last_calculated_index..num_rows {
                 let frame_range = if evaluator.uses_window_frame() {
                     state
@@ -224,11 +225,15 @@ impl WindowExpr for BuiltInWindowExpr {
                             idx,
                         )
                 } else {
+                    is_causal = false;
                     evaluator.get_range(idx, num_rows)
                 }?;
 
-                // Exit if the range extends all the way:
-                if frame_range.end == num_rows && 
!partition_batch_state.is_end {
+                // Exit if the range is non-causal and extends all the way:
+                if frame_range.end == num_rows
+                    && !is_causal
+                    && !partition_batch_state.is_end
+                {
                     break;
                 }
                 // Update last range
diff --git a/datafusion/physical-expr/src/window/window_expr.rs 
b/datafusion/physical-expr/src/window/window_expr.rs
index 548fae75bd..e2714dc42b 100644
--- a/datafusion/physical-expr/src/window/window_expr.rs
+++ b/datafusion/physical-expr/src/window/window_expr.rs
@@ -231,12 +231,13 @@ pub trait AggregateWindowExpr: WindowExpr {
         // We iterate on each row to perform a running calculation.
         let length = values[0].len();
         let mut row_wise_results: Vec<ScalarValue> = vec![];
+        let is_causal = self.get_window_frame().is_causal();
         while idx < length {
             // Start search from the last_range. This squeezes searched range.
             let cur_range =
                 window_frame_ctx.calculate_range(&order_bys, last_range, 
length, idx)?;
-            // Exit if the range extends all the way:
-            if cur_range.end == length && not_end {
+            // Exit if the range is non-causal and extends all the way:
+            if cur_range.end == length && !is_causal && not_end {
                 break;
             }
             let value = self.get_aggregate_result_inside_range(
diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs 
b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
index 0871ec0d7f..9d247d689c 100644
--- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
+++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
@@ -1170,33 +1170,33 @@ mod tests {
                 last_value_func,
                 &[],
                 &[],
-                Arc::new(WindowFrame {
-                    units: WindowFrameUnits::Rows,
-                    start_bound: 
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
-                    end_bound: WindowFrameBound::CurrentRow,
-                }),
+                Arc::new(WindowFrame::new_bounds(
+                    WindowFrameUnits::Rows,
+                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+                    WindowFrameBound::CurrentRow,
+                )),
             )) as _,
             // NTH_VALUE(a, -1)
             Arc::new(BuiltInWindowExpr::new(
                 nth_value_func1,
                 &[],
                 &[],
-                Arc::new(WindowFrame {
-                    units: WindowFrameUnits::Rows,
-                    start_bound: 
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
-                    end_bound: WindowFrameBound::CurrentRow,
-                }),
+                Arc::new(WindowFrame::new_bounds(
+                    WindowFrameUnits::Rows,
+                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+                    WindowFrameBound::CurrentRow,
+                )),
             )) as _,
             // NTH_VALUE(a, -2)
             Arc::new(BuiltInWindowExpr::new(
                 nth_value_func2,
                 &[],
                 &[],
-                Arc::new(WindowFrame {
-                    units: WindowFrameUnits::Rows,
-                    start_bound: 
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
-                    end_bound: WindowFrameBound::CurrentRow,
-                }),
+                Arc::new(WindowFrame::new_bounds(
+                    WindowFrameUnits::Rows,
+                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+                    WindowFrameBound::CurrentRow,
+                )),
             )) as _,
         ];
         let physical_plan = BoundedWindowAggExec::try_new(
diff --git a/datafusion/physical-plan/src/windows/mod.rs 
b/datafusion/physical-plan/src/windows/mod.rs
index fec168fabf..a85e5cc31c 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -654,7 +654,7 @@ mod tests {
                 &[col("a", &schema)?],
                 &[],
                 &[],
-                Arc::new(WindowFrame::new(false)),
+                Arc::new(WindowFrame::new(None)),
                 schema.as_ref(),
             )?],
             blocking_exec,
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index d15cf1db92..2d9c7be46b 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -879,11 +879,7 @@ impl TryFrom<protobuf::WindowFrame> for WindowFrame {
             })
             .transpose()?
             .unwrap_or(WindowFrameBound::CurrentRow);
-        Ok(Self {
-            units,
-            start_bound,
-            end_bound,
-        })
+        Ok(WindowFrame::new_bounds(units, start_bound, end_bound))
     }
 }
 
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 03daf535f2..ed21124a9e 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -1671,7 +1671,7 @@ fn roundtrip_window() {
         vec![],
         vec![col("col1")],
         vec![col("col2")],
-        WindowFrame::new(true),
+        WindowFrame::new(Some(false)),
     ));
 
     // 2. with default window_frame
@@ -1682,15 +1682,15 @@ fn roundtrip_window() {
         vec![],
         vec![col("col1")],
         vec![col("col2")],
-        WindowFrame::new(true),
+        WindowFrame::new(Some(false)),
     ));
 
     // 3. with window_frame with row numbers
-    let range_number_frame = WindowFrame {
-        units: WindowFrameUnits::Range,
-        start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
-        end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
-    };
+    let range_number_frame = WindowFrame::new_bounds(
+        WindowFrameUnits::Range,
+        WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
+        WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
+    );
 
     let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new(
         WindowFunctionDefinition::BuiltInWindowFunction(
@@ -1703,11 +1703,11 @@ fn roundtrip_window() {
     ));
 
     // 4. test with AggregateFunction
-    let row_number_frame = WindowFrame {
-        units: WindowFrameUnits::Rows,
-        start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
-        end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
-    };
+    let row_number_frame = WindowFrame::new_bounds(
+        WindowFrameUnits::Rows,
+        WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
+        WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
+    );
 
     let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new(
         WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 9ee8d0d51d..3a13dc887f 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -255,11 +255,11 @@ fn roundtrip_window() -> Result<()> {
     let field_b = Field::new("b", DataType::Int64, false);
     let schema = Arc::new(Schema::new(vec![field_a, field_b]));
 
-    let window_frame = WindowFrame {
-        units: datafusion_expr::WindowFrameUnits::Range,
-        start_bound: WindowFrameBound::Preceding(ScalarValue::Int64(None)),
-        end_bound: WindowFrameBound::CurrentRow,
-    };
+    let window_frame = WindowFrame::new_bounds(
+        datafusion_expr::WindowFrameUnits::Range,
+        WindowFrameBound::Preceding(ScalarValue::Int64(None)),
+        WindowFrameBound::CurrentRow,
+    );
 
     let builtin_window_expr = Arc::new(BuiltInWindowExpr::new(
         Arc::new(NthValue::first(
@@ -286,14 +286,14 @@ fn roundtrip_window() -> Result<()> {
         )),
         &[],
         &[],
-        Arc::new(WindowFrame::new(false)),
+        Arc::new(WindowFrame::new(None)),
     ));
 
-    let window_frame = WindowFrame {
-        units: datafusion_expr::WindowFrameUnits::Range,
-        start_bound: WindowFrameBound::CurrentRow,
-        end_bound: WindowFrameBound::Preceding(ScalarValue::Int64(None)),
-    };
+    let window_frame = WindowFrame::new_bounds(
+        datafusion_expr::WindowFrameUnits::Range,
+        WindowFrameBound::CurrentRow,
+        WindowFrameBound::Preceding(ScalarValue::Int64(None)),
+    );
 
     let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new(
         Arc::new(Sum::new(
diff --git a/datafusion/sql/src/expr/function.rs 
b/datafusion/sql/src/expr/function.rs
index 395f10b6f7..30f8605c39 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -17,7 +17,8 @@
 
 use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
 use datafusion_common::{
-    not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, 
Result,
+    not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, 
Dependency,
+    Result,
 };
 use datafusion_expr::expr::ScalarFunction;
 use datafusion_expr::function::suggest_valid_function;
@@ -102,6 +103,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 // Numeric literals in window function ORDER BY are treated as 
constants
                 false,
             )?;
+
+            let func_deps = schema.functional_dependencies();
+            // Find whether ties are possible in the given ordering:
+            let is_ordering_strict = order_by.iter().any(|orderby_expr| {
+                if let Expr::Sort(sort_expr) = orderby_expr {
+                    if let Expr::Column(col) = sort_expr.expr.as_ref() {
+                        let idx = schema.index_of_column(col).unwrap();
+                        return func_deps.iter().any(|dep| {
+                            dep.source_indices == vec![idx]
+                                && dep.mode == Dependency::Single
+                        });
+                    }
+                }
+                false
+            });
+
             let window_frame = window
                 .window_frame
                 .as_ref()
@@ -115,8 +132,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             let window_frame = if let Some(window_frame) = window_frame {
                 regularize_window_order_by(&window_frame, &mut order_by)?;
                 window_frame
+            } else if is_ordering_strict {
+                WindowFrame::new(Some(true))
             } else {
-                WindowFrame::new(!order_by.is_empty())
+                WindowFrame::new((!order_by.is_empty()).then_some(false))
             };
 
             if let Ok(fun) = self.find_window_func(&name) {
diff --git a/datafusion/sqllogictest/test_files/window.slt 
b/datafusion/sqllogictest/test_files/window.slt
index 100c214383..f8337e21d7 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -3871,3 +3871,38 @@ FROM (SELECT c1, c2, ROW_NUMBER() OVER(PARTITION BY c1) 
as rn
     LIMIT 5)
 GROUP BY rn
 ORDER BY rn;
+
+# create a table for testing
+statement ok
+CREATE TABLE table_with_pk (
+          sn INT PRIMARY KEY,
+          ts TIMESTAMP,
+          currency VARCHAR(3),
+          amount FLOAT
+        ) as VALUES
+          (0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0),
+          (1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0),
+          (2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0),
+          (3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0),
+          (4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0)
+
+# An OVER clause of the form `OVER (ORDER BY <expr>)` gets treated as if it was
+# `OVER (ORDER BY <expr> RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)`.
+# However, if we know that <expr> contains a unique column (e.g. a PRIMARY 
KEY),
+# it can be treated as `OVER (ORDER BY <expr> ROWS BETWEEN UNBOUNDED PRECEDING
+# AND CURRENT ROW)` where window frame units change from `RANGE` to `ROWS`. 
This
+# conversion makes the window frame manifestly causal by eliminating the 
possiblity
+# of ties explicitly (see window frame documentation for a discussion of 
causality
+# in this context). The Query below should have `ROWS` in its window frame.
+query TT
+EXPLAIN SELECT *, SUM(amount) OVER (ORDER BY sn) as sum1 FROM table_with_pk;
+----
+logical_plan
+Projection: table_with_pk.sn, table_with_pk.ts, table_with_pk.currency, 
table_with_pk.amount, SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC 
NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1
+--WindowAggr: windowExpr=[[SUM(CAST(table_with_pk.amount AS Float64)) ORDER BY 
[table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT 
ROW]]
+----TableScan: table_with_pk projection=[sn, ts, currency, amount]
+physical_plan
+ProjectionExec: expr=[sn@0 as sn, ts@1 as ts, currency@2 as currency, amount@3 
as amount, SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] 
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum1]
+--BoundedWindowAggExec: wdw=[SUM(table_with_pk.amount) ORDER BY 
[table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT 
ROW: Ok(Field { name: "SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC 
NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: 
Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), 
frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), 
end_bound: CurrentRow }], mode=[Sorted]
+----SortExec: expr=[sn@0 ASC NULLS LAST]
+------MemoryExec: partitions=1, partition_sizes=[1]
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index a4ec3e7722..7687aff2bc 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -973,11 +973,11 @@ pub async fn from_substrait_rex(
                 )
                 .await?,
                 order_by,
-                window_frame: datafusion::logical_expr::WindowFrame {
+                window_frame: 
datafusion::logical_expr::WindowFrame::new_bounds(
                     units,
-                    start_bound: from_substrait_bound(&window.lower_bound, 
true)?,
-                    end_bound: from_substrait_bound(&window.upper_bound, 
false)?,
-                },
+                    from_substrait_bound(&window.lower_bound, true)?,
+                    from_substrait_bound(&window.upper_bound, false)?,
+                ),
             })))
         }
         Some(RexType::Subquery(subquery)) => match 
&subquery.as_ref().subquery_type {


Reply via email to