This is an automated email from the ASF dual-hosted git repository.

akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8d3504cbe7 Implement IGNORE NULLS for LAST_VALUE (#9801)
8d3504cbe7 is described below

commit 8d3504cbe7823d29116db237f4a4cb0bb6a5989d
Author: Huaxin Gao <[email protected]>
AuthorDate: Tue Mar 26 23:21:08 2024 -0700

    Implement IGNORE NULLS for LAST_VALUE (#9801)
    
    * Implement IGNORE NULLS for LAST_VALUE
    
    * address comments
    
    ---------
    
    Co-authored-by: Huaxin Gao <[email protected]>
---
 datafusion/core/tests/sql/aggregates.rs            | 80 ----------------------
 datafusion/physical-expr/src/aggregate/build_in.rs | 17 +++--
 .../physical-expr/src/aggregate/first_last.rs      | 49 +++++++++++--
 datafusion/sqllogictest/test_files/aggregate.slt   | 52 ++++++++++++++
 4 files changed, 104 insertions(+), 94 deletions(-)

diff --git a/datafusion/core/tests/sql/aggregates.rs 
b/datafusion/core/tests/sql/aggregates.rs
index 14bc7a3d4f..84b791a3de 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -321,83 +321,3 @@ async fn test_accumulator_row_accumulator() -> Result<()> {
 
     Ok(())
 }
-
-#[tokio::test]
-async fn test_first_value() -> Result<()> {
-    let session_ctx = SessionContext::new();
-    session_ctx
-        .sql("CREATE TABLE abc AS VALUES (null,2,3), (4,5,6)")
-        .await?
-        .collect()
-        .await?;
-
-    let results1 = session_ctx
-        .sql("SELECT FIRST_VALUE(column1) ignore nulls FROM abc")
-        .await?
-        .collect()
-        .await?;
-    let expected1 = [
-        "+--------------------------+",
-        "| FIRST_VALUE(abc.column1) |",
-        "+--------------------------+",
-        "| 4                        |",
-        "+--------------------------+",
-    ];
-    assert_batches_eq!(expected1, &results1);
-
-    let results2 = session_ctx
-        .sql("SELECT FIRST_VALUE(column1) respect nulls FROM abc")
-        .await?
-        .collect()
-        .await?;
-    let expected2 = [
-        "+--------------------------+",
-        "| FIRST_VALUE(abc.column1) |",
-        "+--------------------------+",
-        "|                          |",
-        "+--------------------------+",
-    ];
-    assert_batches_eq!(expected2, &results2);
-
-    Ok(())
-}
-
-#[tokio::test]
-async fn test_first_value_with_sort() -> Result<()> {
-    let session_ctx = SessionContext::new();
-    session_ctx
-        .sql("CREATE TABLE abc AS VALUES (null,2,3), (null,1,6), (4, 5, 5), 
(1, 4, 7), (2, 3, 8)")
-        .await?
-        .collect()
-        .await?;
-
-    let results1 = session_ctx
-        .sql("SELECT FIRST_VALUE(column1 ORDER BY column2) ignore nulls FROM 
abc")
-        .await?
-        .collect()
-        .await?;
-    let expected1 = [
-        "+--------------------------+",
-        "| FIRST_VALUE(abc.column1) |",
-        "+--------------------------+",
-        "| 2                        |",
-        "+--------------------------+",
-    ];
-    assert_batches_eq!(expected1, &results1);
-
-    let results2 = session_ctx
-        .sql("SELECT FIRST_VALUE(column1 ORDER BY column2) respect nulls FROM 
abc")
-        .await?
-        .collect()
-        .await?;
-    let expected2 = [
-        "+--------------------------+",
-        "| FIRST_VALUE(abc.column1) |",
-        "+--------------------------+",
-        "|                          |",
-        "+--------------------------+",
-    ];
-    assert_batches_eq!(expected2, &results2);
-
-    Ok(())
-}
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 846431034c..cee6798638 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -370,13 +370,16 @@ pub fn create_aggregate_expr(
             )
             .with_ignore_nulls(ignore_nulls),
         ),
-        (AggregateFunction::LastValue, _) => 
Arc::new(expressions::LastValue::new(
-            input_phy_exprs[0].clone(),
-            name,
-            input_phy_types[0].clone(),
-            ordering_req.to_vec(),
-            ordering_types,
-        )),
+        (AggregateFunction::LastValue, _) => Arc::new(
+            expressions::LastValue::new(
+                input_phy_exprs[0].clone(),
+                name,
+                input_phy_types[0].clone(),
+                ordering_req.to_vec(),
+                ordering_types,
+            )
+            .with_ignore_nulls(ignore_nulls),
+        ),
         (AggregateFunction::NthValue, _) => {
             let expr = &input_phy_exprs[0];
             let Some(n) = input_phy_exprs[1]
diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs 
b/datafusion/physical-expr/src/aggregate/first_last.rs
index 17dd3ef120..6d6e32a149 100644
--- a/datafusion/physical-expr/src/aggregate/first_last.rs
+++ b/datafusion/physical-expr/src/aggregate/first_last.rs
@@ -393,6 +393,7 @@ pub struct LastValue {
     expr: Arc<dyn PhysicalExpr>,
     ordering_req: LexOrdering,
     requirement_satisfied: bool,
+    ignore_nulls: bool,
 }
 
 impl LastValue {
@@ -412,9 +413,15 @@ impl LastValue {
             expr,
             ordering_req,
             requirement_satisfied,
+            ignore_nulls: false,
         }
     }
 
+    pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self {
+        self.ignore_nulls = ignore_nulls;
+        self
+    }
+
     /// Returns the name of the aggregate expression.
     pub fn name(&self) -> &str {
         &self.name
@@ -483,6 +490,7 @@ impl AggregateExpr for LastValue {
             &self.input_data_type,
             &self.order_by_data_types,
             self.ordering_req.clone(),
+            self.ignore_nulls,
         )
         .map(|acc| {
             
Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _
@@ -528,6 +536,7 @@ impl AggregateExpr for LastValue {
             &self.input_data_type,
             &self.order_by_data_types,
             self.ordering_req.clone(),
+            self.ignore_nulls,
         )
         .map(|acc| {
             
Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _
@@ -561,6 +570,8 @@ struct LastValueAccumulator {
     ordering_req: LexOrdering,
     // Stores whether incoming data already satisfies the ordering requirement.
     requirement_satisfied: bool,
+    // Ignore null values.
+    ignore_nulls: bool,
 }
 
 impl LastValueAccumulator {
@@ -569,6 +580,7 @@ impl LastValueAccumulator {
         data_type: &DataType,
         ordering_dtypes: &[DataType],
         ordering_req: LexOrdering,
+        ignore_nulls: bool,
     ) -> Result<Self> {
         let orderings = ordering_dtypes
             .iter()
@@ -581,6 +593,7 @@ impl LastValueAccumulator {
             orderings,
             ordering_req,
             requirement_satisfied,
+            ignore_nulls,
         })
     }
 
@@ -597,7 +610,17 @@ impl LastValueAccumulator {
         };
         if self.requirement_satisfied {
             // Get last entry according to the order of data:
-            return Ok((!value.is_empty()).then_some(value.len() - 1));
+            if self.ignore_nulls {
+                // If ignoring nulls, find the last non-null value.
+                for i in (0..value.len()).rev() {
+                    if !value.is_null(i) {
+                        return Ok(Some(i));
+                    }
+                }
+                return Ok(None);
+            } else {
+                return Ok((!value.is_empty()).then_some(value.len() - 1));
+            }
         }
         let sort_columns = ordering_values
             .iter()
@@ -611,8 +634,20 @@ impl LastValueAccumulator {
                 }
             })
             .collect::<Vec<_>>();
-        let indices = lexsort_to_indices(&sort_columns, Some(1))?;
-        Ok((!indices.is_empty()).then_some(indices.value(0) as _))
+
+        if self.ignore_nulls {
+            let indices = lexsort_to_indices(&sort_columns, None)?;
+            // If ignoring nulls, find the last non-null value.
+            for index in indices.iter().flatten() {
+                if !value.is_null(index as usize) {
+                    return Ok(Some(index as usize));
+                }
+            }
+            Ok(None)
+        } else {
+            let indices = lexsort_to_indices(&sort_columns, Some(1))?;
+            Ok((!indices.is_empty()).then_some(indices.value(0) as _))
+        }
     }
 
     fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> 
Self {
@@ -746,7 +781,7 @@ mod tests {
         let mut first_accumulator =
             FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], 
false)?;
         let mut last_accumulator =
-            LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+            LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], 
false)?;
         // first value in the tuple is start of the range (inclusive),
         // second value in the tuple is end of the range (exclusive)
         let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
@@ -814,13 +849,13 @@ mod tests {
 
         // LastValueAccumulator
         let mut last_accumulator =
-            LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+            LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], 
false)?;
 
         last_accumulator.update_batch(&[arrs[0].clone()])?;
         let state1 = last_accumulator.state()?;
 
         let mut last_accumulator =
-            LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+            LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], 
false)?;
         last_accumulator.update_batch(&[arrs[1].clone()])?;
         let state2 = last_accumulator.state()?;
 
@@ -836,7 +871,7 @@ mod tests {
         }
 
         let mut last_accumulator =
-            LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+            LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], 
false)?;
         last_accumulator.merge_batch(&states)?;
 
         let merged_state = last_accumulator.state()?;
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 19bcf6024b..4929ab485d 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -3376,3 +3376,55 @@ SELECT FIRST_VALUE(column1 ORDER BY column2) IGNORE 
NULLS FROM t;
 
 statement ok
 DROP TABLE t;
+
+# Test for ignore null in LAST_VALUE
+statement ok
+CREATE TABLE t AS VALUES (3), (4), (null::bigint);
+
+query I
+SELECT LAST_VALUE(column1) FROM t;
+----
+NULL
+
+query I
+SELECT LAST_VALUE(column1) RESPECT NULLS FROM t;
+----
+NULL
+
+query I
+SELECT LAST_VALUE(column1) IGNORE NULLS FROM t;
+----
+4
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore null with ORDER BY in LAST_VALUE
+statement ok
+CREATE TABLE t AS VALUES  (3, 3), (4, 4), (null::bigint, 1), (null::bigint, 2);
+
+query I
+SELECT column1 FROM t ORDER BY column2 DESC;
+----
+4
+3
+NULL
+NULL
+
+query I
+SELECT LAST_VALUE(column1 ORDER BY column2 DESC) FROM t;
+----
+NULL
+
+query I
+SELECT LAST_VALUE(column1 ORDER BY column2 DESC) RESPECT NULLS FROM t;
+----
+NULL
+
+query I
+SELECT LAST_VALUE(column1 ORDER BY column2 DESC) IGNORE NULLS FROM t;
+----
+3
+
+statement ok
+DROP TABLE t;

Reply via email to