This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 92d046d7c4 Support IGNORE NULLS for FIRST/LAST window function (#9470)
92d046d7c4 is described below

commit 92d046d7c42a9274476f661d33843a8524ae824e
Author: Huaxin Gao <[email protected]>
AuthorDate: Tue Mar 12 02:26:49 2024 -0700

    Support IGNORE NULLS for FIRST/LAST window function (#9470)
    
    * Support IGNORE NULLS for FIRST/LAST window function
    
    * fix error
    
    * fix style error
    
    * fix clippy error
    
    * add tests for all NULL values
    
    * address comments
    
    * fix format
    
    * address comments
    
    * fix format
    
    * Fix commented test case
    
    * resolve conflicts
    
    ---------
    
    Co-authored-by: Huaxin Gao <[email protected]>
    Co-authored-by: Mustafa Akur <[email protected]>
---
 datafusion/physical-expr/src/window/nth_value.rs   |  58 ++++-
 .../src/windows/bounded_window_agg_exec.rs         |  12 +-
 datafusion/physical-plan/src/windows/mod.rs        |  12 +-
 .../proto/tests/cases/roundtrip_physical_plan.rs   |   1 +
 datafusion/sqllogictest/test_files/window.slt      | 272 +++++++++++++++++++++
 5 files changed, 344 insertions(+), 11 deletions(-)

diff --git a/datafusion/physical-expr/src/window/nth_value.rs 
b/datafusion/physical-expr/src/window/nth_value.rs
index a7bb31b6e1..5c7c891f92 100644
--- a/datafusion/physical-expr/src/window/nth_value.rs
+++ b/datafusion/physical-expr/src/window/nth_value.rs
@@ -42,6 +42,7 @@ pub struct NthValue {
     /// Output data type
     data_type: DataType,
     kind: NthValueKind,
+    ignore_nulls: bool,
 }
 
 impl NthValue {
@@ -50,12 +51,14 @@ impl NthValue {
         name: impl Into<String>,
         expr: Arc<dyn PhysicalExpr>,
         data_type: DataType,
+        ignore_nulls: bool,
     ) -> Self {
         Self {
             name: name.into(),
             expr,
             data_type,
             kind: NthValueKind::First,
+            ignore_nulls,
         }
     }
 
@@ -64,12 +67,14 @@ impl NthValue {
         name: impl Into<String>,
         expr: Arc<dyn PhysicalExpr>,
         data_type: DataType,
+        ignore_nulls: bool,
     ) -> Self {
         Self {
             name: name.into(),
             expr,
             data_type,
             kind: NthValueKind::Last,
+            ignore_nulls,
         }
     }
 
@@ -79,7 +84,11 @@ impl NthValue {
         expr: Arc<dyn PhysicalExpr>,
         data_type: DataType,
         n: u32,
+        ignore_nulls: bool,
     ) -> Result<Self> {
+        if ignore_nulls {
+            return exec_err!("NTH_VALUE ignore_nulls is not supported yet");
+        }
         match n {
             0 => exec_err!("NTH_VALUE expects n to be non-zero"),
             _ => Ok(Self {
@@ -87,6 +96,7 @@ impl NthValue {
                 expr,
                 data_type,
                 kind: NthValueKind::Nth(n as i64),
+                ignore_nulls,
             }),
         }
     }
@@ -122,7 +132,10 @@ impl BuiltInWindowFunctionExpr for NthValue {
             finalized_result: None,
             kind: self.kind,
         };
-        Ok(Box::new(NthValueEvaluator { state }))
+        Ok(Box::new(NthValueEvaluator {
+            state,
+            ignore_nulls: self.ignore_nulls,
+        }))
     }
 
     fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
@@ -136,6 +149,7 @@ impl BuiltInWindowFunctionExpr for NthValue {
             expr: self.expr.clone(),
             data_type: self.data_type.clone(),
             kind: reversed_kind,
+            ignore_nulls: self.ignore_nulls,
         }))
     }
 }
@@ -144,6 +158,7 @@ impl BuiltInWindowFunctionExpr for NthValue {
 #[derive(Debug)]
 pub(crate) struct NthValueEvaluator {
     state: NthValueState,
+    ignore_nulls: bool,
 }
 
 impl PartitionEvaluator for NthValueEvaluator {
@@ -184,7 +199,8 @@ impl PartitionEvaluator for NthValueEvaluator {
                 }
             }
         };
-        if is_prunable {
+        // Do not memoize results when nulls are ignored.
+        if is_prunable && !self.ignore_nulls {
             if self.state.finalized_result.is_none() && !is_reverse_direction {
                 let result = ScalarValue::try_from_array(out, size - 1)?;
                 self.state.finalized_result = Some(result);
@@ -210,9 +226,39 @@ impl PartitionEvaluator for NthValueEvaluator {
                 // We produce None if the window is empty.
                 return ScalarValue::try_from(arr.data_type());
             }
+
+            // Extract valid indices if ignoring nulls.
+            let (slice, valid_indices) = if self.ignore_nulls {
+                let slice = arr.slice(range.start, n_range);
+                let valid_indices =
+                    slice.nulls().unwrap().valid_indices().collect::<Vec<_>>();
+                if valid_indices.is_empty() {
+                    return ScalarValue::try_from(arr.data_type());
+                }
+                (Some(slice), Some(valid_indices))
+            } else {
+                (None, None)
+            };
             match self.state.kind {
-                NthValueKind::First => ScalarValue::try_from_array(arr, 
range.start),
-                NthValueKind::Last => ScalarValue::try_from_array(arr, 
range.end - 1),
+                NthValueKind::First => {
+                    if let Some(slice) = &slice {
+                        let valid_indices = valid_indices.unwrap();
+                        ScalarValue::try_from_array(slice, valid_indices[0])
+                    } else {
+                        ScalarValue::try_from_array(arr, range.start)
+                    }
+                }
+                NthValueKind::Last => {
+                    if let Some(slice) = &slice {
+                        let valid_indices = valid_indices.unwrap();
+                        ScalarValue::try_from_array(
+                            slice,
+                            valid_indices[valid_indices.len() - 1],
+                        )
+                    } else {
+                        ScalarValue::try_from_array(arr, range.end - 1)
+                    }
+                }
                 NthValueKind::Nth(n) => {
                     match n.cmp(&0) {
                         Ordering::Greater => {
@@ -295,6 +341,7 @@ mod tests {
             "first_value".to_owned(),
             Arc::new(Column::new("arr", 0)),
             DataType::Int32,
+            false,
         );
         test_i32_result(first_value, Int32Array::from(vec![1; 8]))?;
         Ok(())
@@ -306,6 +353,7 @@ mod tests {
             "last_value".to_owned(),
             Arc::new(Column::new("arr", 0)),
             DataType::Int32,
+            false,
         );
         test_i32_result(
             last_value,
@@ -330,6 +378,7 @@ mod tests {
             Arc::new(Column::new("arr", 0)),
             DataType::Int32,
             1,
+            false,
         )?;
         test_i32_result(nth_value, Int32Array::from(vec![1; 8]))?;
         Ok(())
@@ -342,6 +391,7 @@ mod tests {
             Arc::new(Column::new("arr", 0)),
             DataType::Int32,
             2,
+            false,
         )?;
         test_i32_result(
             nth_value,
diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs 
b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
index 4cba571054..0349f8f1ee 100644
--- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
+++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
@@ -1179,15 +1179,19 @@ mod tests {
         .map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;
         let col_a = col("a", &schema)?;
         let nth_value_func1 =
-            NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1)?
+            NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1, 
false)?
                 .reverse_expr()
                 .unwrap();
         let nth_value_func2 =
-            NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2)?
+            NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2, 
false)?
                 .reverse_expr()
                 .unwrap();
-        let last_value_func =
-            Arc::new(NthValue::last("last", col_a.clone(), DataType::Int32)) 
as _;
+        let last_value_func = Arc::new(NthValue::last(
+            "last",
+            col_a.clone(),
+            DataType::Int32,
+            false,
+        )) as _;
         let window_exprs = vec![
             // LAST_VALUE(a)
             Arc::new(BuiltInWindowExpr::new(
diff --git a/datafusion/physical-plan/src/windows/mod.rs 
b/datafusion/physical-plan/src/windows/mod.rs
index f91b525d60..6712bc855f 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -250,15 +250,21 @@ fn create_built_in_window_expr(
                 .try_into()
                 .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
             let n: u32 = n as u32;
-            Arc::new(NthValue::nth(name, arg, data_type.clone(), n)?)
+            Arc::new(NthValue::nth(
+                name,
+                arg,
+                data_type.clone(),
+                n,
+                ignore_nulls,
+            )?)
         }
         BuiltInWindowFunction::FirstValue => {
             let arg = args[0].clone();
-            Arc::new(NthValue::first(name, arg, data_type.clone()))
+            Arc::new(NthValue::first(name, arg, data_type.clone(), 
ignore_nulls))
         }
         BuiltInWindowFunction::LastValue => {
             let arg = args[0].clone();
-            Arc::new(NthValue::last(name, arg, data_type.clone()))
+            Arc::new(NthValue::last(name, arg, data_type.clone(), 
ignore_nulls))
         }
     })
 }
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index a3c0b3eccd..004261eff5 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -271,6 +271,7 @@ fn roundtrip_window() -> Result<()> {
             "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE 
BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW",
             col("a", &schema)?,
             DataType::Int64,
+            false,
         )),
         &[col("b", &schema)?],
         &[PhysicalSortExpr {
diff --git a/datafusion/sqllogictest/test_files/window.slt 
b/datafusion/sqllogictest/test_files/window.slt
index cce67d898d..39c105a4dc 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -4307,3 +4307,275 @@ select lag(a) over (order by a ASC NULLS FIRST) as x1
 NULL
 NULL
 NULL
+
+# Test for ignore nulls in FIRST_VALUE
+statement ok
+CREATE TABLE t AS VALUES (null::bigint), (3), (4);
+
+query I
+SELECT FIRST_VALUE(column1) OVER() FROM t;
+----
+NULL
+NULL
+NULL
+
+query I
+SELECT FIRST_VALUE(column1) RESPECT NULLS OVER() FROM t;
+----
+NULL
+NULL
+NULL
+
+query I
+SELECT FIRST_VALUE(column1) IGNORE NULLS OVER() FROM t;
+----
+3
+3
+3
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore nulls with ORDER BY in FIRST_VALUE
+statement ok
+CREATE TABLE t AS VALUES  (3, 4), (4, 3), (null::bigint, 1), (null::bigint, 
2), (5, 5), (6, 6);
+
+query II
+SELECT column1, column2 FROM t ORDER BY column2;
+----
+NULL 1
+NULL 2
+4 3
+3 4
+5 5
+6 6
+
+query II
+SELECT FIRST_VALUE(column1) OVER(ORDER BY column2), column2 FROM t;
+----
+NULL 1
+NULL 2
+NULL 3
+NULL 4
+NULL 5
+NULL 6
+
+query II
+SELECT FIRST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2), column2 FROM 
t;
+----
+NULL 1
+NULL 2
+NULL 3
+NULL 4
+NULL 5
+NULL 6
+
+query II
+SELECT FIRST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2), column2 FROM 
t;
+----
+NULL 1
+NULL 2
+4 3
+4 4
+4 5
+4 6
+
+query II
+SELECT FIRST_VALUE(column1)OVER(ORDER BY column2 RANGE BETWEEN 1 PRECEDING AND 
1 FOLLOWING), column2 FROM t;
+----
+NULL 1
+NULL 2
+NULL 3
+4 4
+3 5
+5 6
+
+query II
+SELECT FIRST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 RANGE BETWEEN 1 
PRECEDING AND 1 FOLLOWING), column2 FROM t;
+----
+NULL 1
+4 2
+4 3
+4 4
+3 5
+5 6
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore nulls with ORDER BY in FIRST_VALUE with all NULL values
+statement ok
+CREATE TABLE t AS VALUES  (null::bigint, 4), (null::bigint, 3), (null::bigint, 
1), (null::bigint, 2);
+
+query II
+SELECT FIRST_VALUE(column1) OVER(ORDER BY column2), column2 FROM t;
+----
+NULL 1
+NULL 2
+NULL 3
+NULL 4
+
+query II
+SELECT FIRST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2), column2 FROM 
t;
+----
+NULL 1
+NULL 2
+NULL 3
+NULL 4
+
+query II
+SELECT FIRST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2), column2 FROM 
t;
+----
+NULL 1
+NULL 2
+NULL 3
+NULL 4
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore nulls in LAST_VALUE
+statement ok
+CREATE TABLE t AS VALUES (1), (3), (null::bigint);
+
+query I
+SELECT LAST_VALUE(column1) OVER() FROM t;
+----
+NULL
+NULL
+NULL
+
+query I
+SELECT LAST_VALUE(column1) RESPECT NULLS OVER() FROM t;
+----
+NULL
+NULL
+NULL
+
+query I
+SELECT LAST_VALUE(column1) IGNORE NULLS OVER() FROM t;
+----
+3
+3
+3
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore nulls with ORDER BY in LAST_VALUE
+statement ok
+CREATE TABLE t AS VALUES  (3, 4), (4, 3), (null::bigint, 1), (null::bigint, 
2), (5, 5), (6, 6);
+
+query II
+SELECT column1, column2 FROM t ORDER BY column2 DESC NULLS LAST;
+----
+6 6
+5 5
+3 4
+4 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST), column2 
FROM t;
+----
+6 6
+5 5
+3 4
+4 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS 
LAST), column2 FROM t;
+----
+6 6
+5 5
+3 4
+4 3
+4 2
+4 1
+
+query II
+SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+NULL 6
+NULL 5
+NULL 4
+NULL 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2 DESC NULLS LAST 
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+NULL 6
+NULL 5
+NULL 4
+NULL 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS LAST 
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+4 6
+4 5
+4 4
+4 3
+4 2
+4 1
+
+query II
+SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST RANGE BETWEEN 
1 PRECEDING AND 1 FOLLOWING), column2 FROM t;
+----
+5 6
+3 5
+4 4
+NULL 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS LAST 
RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING), column2 FROM t;
+----
+5 6
+3 5
+4 4
+4 3
+4 2
+NULL 1
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore nulls with ORDER BY in LAST_VALUE with all NULLs
+statement ok
+CREATE TABLE t AS VALUES  (null::bigint, 4), (null::bigint, 3), (null::bigint, 
1), (null::bigint, 2);
+
+query II
+SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+NULL 4
+NULL 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2 DESC NULLS LAST 
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+NULL 4
+NULL 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS LAST 
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+NULL 4
+NULL 3
+NULL 2
+NULL 1
+
+statement ok
+DROP TABLE t;

Reply via email to