This is an automated email from the ASF dual-hosted git repository.
akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 8d3504cbe7 Implement IGNORE NULLS for LAST_VALUE (#9801)
8d3504cbe7 is described below
commit 8d3504cbe7823d29116db237f4a4cb0bb6a5989d
Author: Huaxin Gao <[email protected]>
AuthorDate: Tue Mar 26 23:21:08 2024 -0700
Implement IGNORE NULLS for LAST_VALUE (#9801)
* Implement IGNORE NULLS for LAST_VALUE
* address comments
---------
Co-authored-by: Huaxin Gao <[email protected]>
---
datafusion/core/tests/sql/aggregates.rs | 80 ----------------------
datafusion/physical-expr/src/aggregate/build_in.rs | 17 +++--
.../physical-expr/src/aggregate/first_last.rs | 49 +++++++++++--
datafusion/sqllogictest/test_files/aggregate.slt | 52 ++++++++++++++
4 files changed, 104 insertions(+), 94 deletions(-)
diff --git a/datafusion/core/tests/sql/aggregates.rs
b/datafusion/core/tests/sql/aggregates.rs
index 14bc7a3d4f..84b791a3de 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -321,83 +321,3 @@ async fn test_accumulator_row_accumulator() -> Result<()> {
Ok(())
}
-
-#[tokio::test]
-async fn test_first_value() -> Result<()> {
- let session_ctx = SessionContext::new();
- session_ctx
- .sql("CREATE TABLE abc AS VALUES (null,2,3), (4,5,6)")
- .await?
- .collect()
- .await?;
-
- let results1 = session_ctx
- .sql("SELECT FIRST_VALUE(column1) ignore nulls FROM abc")
- .await?
- .collect()
- .await?;
- let expected1 = [
- "+--------------------------+",
- "| FIRST_VALUE(abc.column1) |",
- "+--------------------------+",
- "| 4 |",
- "+--------------------------+",
- ];
- assert_batches_eq!(expected1, &results1);
-
- let results2 = session_ctx
- .sql("SELECT FIRST_VALUE(column1) respect nulls FROM abc")
- .await?
- .collect()
- .await?;
- let expected2 = [
- "+--------------------------+",
- "| FIRST_VALUE(abc.column1) |",
- "+--------------------------+",
- "| |",
- "+--------------------------+",
- ];
- assert_batches_eq!(expected2, &results2);
-
- Ok(())
-}
-
-#[tokio::test]
-async fn test_first_value_with_sort() -> Result<()> {
- let session_ctx = SessionContext::new();
- session_ctx
- .sql("CREATE TABLE abc AS VALUES (null,2,3), (null,1,6), (4, 5, 5),
(1, 4, 7), (2, 3, 8)")
- .await?
- .collect()
- .await?;
-
- let results1 = session_ctx
- .sql("SELECT FIRST_VALUE(column1 ORDER BY column2) ignore nulls FROM
abc")
- .await?
- .collect()
- .await?;
- let expected1 = [
- "+--------------------------+",
- "| FIRST_VALUE(abc.column1) |",
- "+--------------------------+",
- "| 2 |",
- "+--------------------------+",
- ];
- assert_batches_eq!(expected1, &results1);
-
- let results2 = session_ctx
- .sql("SELECT FIRST_VALUE(column1 ORDER BY column2) respect nulls FROM
abc")
- .await?
- .collect()
- .await?;
- let expected2 = [
- "+--------------------------+",
- "| FIRST_VALUE(abc.column1) |",
- "+--------------------------+",
- "| |",
- "+--------------------------+",
- ];
- assert_batches_eq!(expected2, &results2);
-
- Ok(())
-}
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 846431034c..cee6798638 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -370,13 +370,16 @@ pub fn create_aggregate_expr(
)
.with_ignore_nulls(ignore_nulls),
),
- (AggregateFunction::LastValue, _) =>
Arc::new(expressions::LastValue::new(
- input_phy_exprs[0].clone(),
- name,
- input_phy_types[0].clone(),
- ordering_req.to_vec(),
- ordering_types,
- )),
+ (AggregateFunction::LastValue, _) => Arc::new(
+ expressions::LastValue::new(
+ input_phy_exprs[0].clone(),
+ name,
+ input_phy_types[0].clone(),
+ ordering_req.to_vec(),
+ ordering_types,
+ )
+ .with_ignore_nulls(ignore_nulls),
+ ),
(AggregateFunction::NthValue, _) => {
let expr = &input_phy_exprs[0];
let Some(n) = input_phy_exprs[1]
diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs
b/datafusion/physical-expr/src/aggregate/first_last.rs
index 17dd3ef120..6d6e32a149 100644
--- a/datafusion/physical-expr/src/aggregate/first_last.rs
+++ b/datafusion/physical-expr/src/aggregate/first_last.rs
@@ -393,6 +393,7 @@ pub struct LastValue {
expr: Arc<dyn PhysicalExpr>,
ordering_req: LexOrdering,
requirement_satisfied: bool,
+ ignore_nulls: bool,
}
impl LastValue {
@@ -412,9 +413,15 @@ impl LastValue {
expr,
ordering_req,
requirement_satisfied,
+ ignore_nulls: false,
}
}
+ pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self {
+ self.ignore_nulls = ignore_nulls;
+ self
+ }
+
/// Returns the name of the aggregate expression.
pub fn name(&self) -> &str {
&self.name
@@ -483,6 +490,7 @@ impl AggregateExpr for LastValue {
&self.input_data_type,
&self.order_by_data_types,
self.ordering_req.clone(),
+ self.ignore_nulls,
)
.map(|acc| {
Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _
@@ -528,6 +536,7 @@ impl AggregateExpr for LastValue {
&self.input_data_type,
&self.order_by_data_types,
self.ordering_req.clone(),
+ self.ignore_nulls,
)
.map(|acc| {
Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _
@@ -561,6 +570,8 @@ struct LastValueAccumulator {
ordering_req: LexOrdering,
// Stores whether incoming data already satisfies the ordering requirement.
requirement_satisfied: bool,
+ // Ignore null values.
+ ignore_nulls: bool,
}
impl LastValueAccumulator {
@@ -569,6 +580,7 @@ impl LastValueAccumulator {
data_type: &DataType,
ordering_dtypes: &[DataType],
ordering_req: LexOrdering,
+ ignore_nulls: bool,
) -> Result<Self> {
let orderings = ordering_dtypes
.iter()
@@ -581,6 +593,7 @@ impl LastValueAccumulator {
orderings,
ordering_req,
requirement_satisfied,
+ ignore_nulls,
})
}
@@ -597,7 +610,17 @@ impl LastValueAccumulator {
};
if self.requirement_satisfied {
// Get last entry according to the order of data:
- return Ok((!value.is_empty()).then_some(value.len() - 1));
+ if self.ignore_nulls {
+ // If ignoring nulls, find the last non-null value.
+ for i in (0..value.len()).rev() {
+ if !value.is_null(i) {
+ return Ok(Some(i));
+ }
+ }
+ return Ok(None);
+ } else {
+ return Ok((!value.is_empty()).then_some(value.len() - 1));
+ }
}
let sort_columns = ordering_values
.iter()
@@ -611,8 +634,20 @@ impl LastValueAccumulator {
}
})
.collect::<Vec<_>>();
- let indices = lexsort_to_indices(&sort_columns, Some(1))?;
- Ok((!indices.is_empty()).then_some(indices.value(0) as _))
+
+ if self.ignore_nulls {
+ let indices = lexsort_to_indices(&sort_columns, None)?;
+ // If ignoring nulls, find the last non-null value.
+ for index in indices.iter().flatten() {
+ if !value.is_null(index as usize) {
+ return Ok(Some(index as usize));
+ }
+ }
+ Ok(None)
+ } else {
+ let indices = lexsort_to_indices(&sort_columns, Some(1))?;
+ Ok((!indices.is_empty()).then_some(indices.value(0) as _))
+ }
}
fn with_requirement_satisfied(mut self, requirement_satisfied: bool) ->
Self {
@@ -746,7 +781,7 @@ mod tests {
let mut first_accumulator =
FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![],
false)?;
let mut last_accumulator =
- LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+ LastValueAccumulator::try_new(&DataType::Int64, &[], vec![],
false)?;
// first value in the tuple is start of the range (inclusive),
// second value in the tuple is end of the range (exclusive)
let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
@@ -814,13 +849,13 @@ mod tests {
// LastValueAccumulator
let mut last_accumulator =
- LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+ LastValueAccumulator::try_new(&DataType::Int64, &[], vec![],
false)?;
last_accumulator.update_batch(&[arrs[0].clone()])?;
let state1 = last_accumulator.state()?;
let mut last_accumulator =
- LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+ LastValueAccumulator::try_new(&DataType::Int64, &[], vec![],
false)?;
last_accumulator.update_batch(&[arrs[1].clone()])?;
let state2 = last_accumulator.state()?;
@@ -836,7 +871,7 @@ mod tests {
}
let mut last_accumulator =
- LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
+ LastValueAccumulator::try_new(&DataType::Int64, &[], vec![],
false)?;
last_accumulator.merge_batch(&states)?;
let merged_state = last_accumulator.state()?;
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index 19bcf6024b..4929ab485d 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -3376,3 +3376,55 @@ SELECT FIRST_VALUE(column1 ORDER BY column2) IGNORE
NULLS FROM t;
statement ok
DROP TABLE t;
+
+# Test for ignore null in LAST_VALUE
+statement ok
+CREATE TABLE t AS VALUES (3), (4), (null::bigint);
+
+query I
+SELECT LAST_VALUE(column1) FROM t;
+----
+NULL
+
+query I
+SELECT LAST_VALUE(column1) RESPECT NULLS FROM t;
+----
+NULL
+
+query I
+SELECT LAST_VALUE(column1) IGNORE NULLS FROM t;
+----
+4
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore null with ORDER BY in LAST_VALUE
+statement ok
+CREATE TABLE t AS VALUES (3, 3), (4, 4), (null::bigint, 1), (null::bigint, 2);
+
+query I
+SELECT column1 FROM t ORDER BY column2 DESC;
+----
+4
+3
+NULL
+NULL
+
+query I
+SELECT LAST_VALUE(column1 ORDER BY column2 DESC) FROM t;
+----
+NULL
+
+query I
+SELECT LAST_VALUE(column1 ORDER BY column2 DESC) RESPECT NULLS FROM t;
+----
+NULL
+
+query I
+SELECT LAST_VALUE(column1 ORDER BY column2 DESC) IGNORE NULLS FROM t;
+----
+3
+
+statement ok
+DROP TABLE t;