This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 92d046d7c4 Support IGNORE NULLS for FIRST/LAST window function (#9470)
92d046d7c4 is described below
commit 92d046d7c42a9274476f661d33843a8524ae824e
Author: Huaxin Gao <[email protected]>
AuthorDate: Tue Mar 12 02:26:49 2024 -0700
Support IGNORE NULLS for FIRST/LAST window function (#9470)
* Support IGNORE NULLS for FIRST/LAST window function
* fix error
* fix style error
* fix clippy error
* add tests for all NULL values
* address comments
* fix format
* address comments
* fix format
* Fix commented test case
* resolve conflicts
---------
Co-authored-by: Huaxin Gao <[email protected]>
Co-authored-by: Mustafa Akur <[email protected]>
---
datafusion/physical-expr/src/window/nth_value.rs | 58 ++++-
.../src/windows/bounded_window_agg_exec.rs | 12 +-
datafusion/physical-plan/src/windows/mod.rs | 12 +-
.../proto/tests/cases/roundtrip_physical_plan.rs | 1 +
datafusion/sqllogictest/test_files/window.slt | 272 +++++++++++++++++++++
5 files changed, 344 insertions(+), 11 deletions(-)
diff --git a/datafusion/physical-expr/src/window/nth_value.rs
b/datafusion/physical-expr/src/window/nth_value.rs
index a7bb31b6e1..5c7c891f92 100644
--- a/datafusion/physical-expr/src/window/nth_value.rs
+++ b/datafusion/physical-expr/src/window/nth_value.rs
@@ -42,6 +42,7 @@ pub struct NthValue {
/// Output data type
data_type: DataType,
kind: NthValueKind,
+ ignore_nulls: bool,
}
impl NthValue {
@@ -50,12 +51,14 @@ impl NthValue {
name: impl Into<String>,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
+ ignore_nulls: bool,
) -> Self {
Self {
name: name.into(),
expr,
data_type,
kind: NthValueKind::First,
+ ignore_nulls,
}
}
@@ -64,12 +67,14 @@ impl NthValue {
name: impl Into<String>,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
+ ignore_nulls: bool,
) -> Self {
Self {
name: name.into(),
expr,
data_type,
kind: NthValueKind::Last,
+ ignore_nulls,
}
}
@@ -79,7 +84,11 @@ impl NthValue {
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
n: u32,
+ ignore_nulls: bool,
) -> Result<Self> {
+ if ignore_nulls {
+ return exec_err!("NTH_VALUE ignore_nulls is not supported yet");
+ }
match n {
0 => exec_err!("NTH_VALUE expects n to be non-zero"),
_ => Ok(Self {
@@ -87,6 +96,7 @@ impl NthValue {
expr,
data_type,
kind: NthValueKind::Nth(n as i64),
+ ignore_nulls,
}),
}
}
@@ -122,7 +132,10 @@ impl BuiltInWindowFunctionExpr for NthValue {
finalized_result: None,
kind: self.kind,
};
- Ok(Box::new(NthValueEvaluator { state }))
+ Ok(Box::new(NthValueEvaluator {
+ state,
+ ignore_nulls: self.ignore_nulls,
+ }))
}
fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
@@ -136,6 +149,7 @@ impl BuiltInWindowFunctionExpr for NthValue {
expr: self.expr.clone(),
data_type: self.data_type.clone(),
kind: reversed_kind,
+ ignore_nulls: self.ignore_nulls,
}))
}
}
@@ -144,6 +158,7 @@ impl BuiltInWindowFunctionExpr for NthValue {
#[derive(Debug)]
pub(crate) struct NthValueEvaluator {
state: NthValueState,
+ ignore_nulls: bool,
}
impl PartitionEvaluator for NthValueEvaluator {
@@ -184,7 +199,8 @@ impl PartitionEvaluator for NthValueEvaluator {
}
}
};
- if is_prunable {
+ // Do not memoize results when nulls are ignored.
+ if is_prunable && !self.ignore_nulls {
if self.state.finalized_result.is_none() && !is_reverse_direction {
let result = ScalarValue::try_from_array(out, size - 1)?;
self.state.finalized_result = Some(result);
@@ -210,9 +226,39 @@ impl PartitionEvaluator for NthValueEvaluator {
// We produce None if the window is empty.
return ScalarValue::try_from(arr.data_type());
}
+
+ // Extract valid indices if ignoring nulls.
+ let (slice, valid_indices) = if self.ignore_nulls {
+ let slice = arr.slice(range.start, n_range);
+ let valid_indices =
+ slice.nulls().unwrap().valid_indices().collect::<Vec<_>>();
+ if valid_indices.is_empty() {
+ return ScalarValue::try_from(arr.data_type());
+ }
+ (Some(slice), Some(valid_indices))
+ } else {
+ (None, None)
+ };
match self.state.kind {
- NthValueKind::First => ScalarValue::try_from_array(arr,
range.start),
- NthValueKind::Last => ScalarValue::try_from_array(arr,
range.end - 1),
+ NthValueKind::First => {
+ if let Some(slice) = &slice {
+ let valid_indices = valid_indices.unwrap();
+ ScalarValue::try_from_array(slice, valid_indices[0])
+ } else {
+ ScalarValue::try_from_array(arr, range.start)
+ }
+ }
+ NthValueKind::Last => {
+ if let Some(slice) = &slice {
+ let valid_indices = valid_indices.unwrap();
+ ScalarValue::try_from_array(
+ slice,
+ valid_indices[valid_indices.len() - 1],
+ )
+ } else {
+ ScalarValue::try_from_array(arr, range.end - 1)
+ }
+ }
NthValueKind::Nth(n) => {
match n.cmp(&0) {
Ordering::Greater => {
@@ -295,6 +341,7 @@ mod tests {
"first_value".to_owned(),
Arc::new(Column::new("arr", 0)),
DataType::Int32,
+ false,
);
test_i32_result(first_value, Int32Array::from(vec![1; 8]))?;
Ok(())
@@ -306,6 +353,7 @@ mod tests {
"last_value".to_owned(),
Arc::new(Column::new("arr", 0)),
DataType::Int32,
+ false,
);
test_i32_result(
last_value,
@@ -330,6 +378,7 @@ mod tests {
Arc::new(Column::new("arr", 0)),
DataType::Int32,
1,
+ false,
)?;
test_i32_result(nth_value, Int32Array::from(vec![1; 8]))?;
Ok(())
@@ -342,6 +391,7 @@ mod tests {
Arc::new(Column::new("arr", 0)),
DataType::Int32,
2,
+ false,
)?;
test_i32_result(
nth_value,
diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
index 4cba571054..0349f8f1ee 100644
--- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
+++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
@@ -1179,15 +1179,19 @@ mod tests {
.map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;
let col_a = col("a", &schema)?;
let nth_value_func1 =
- NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1)?
+ NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1,
false)?
.reverse_expr()
.unwrap();
let nth_value_func2 =
- NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2)?
+ NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2,
false)?
.reverse_expr()
.unwrap();
- let last_value_func =
- Arc::new(NthValue::last("last", col_a.clone(), DataType::Int32))
as _;
+ let last_value_func = Arc::new(NthValue::last(
+ "last",
+ col_a.clone(),
+ DataType::Int32,
+ false,
+ )) as _;
let window_exprs = vec![
// LAST_VALUE(a)
Arc::new(BuiltInWindowExpr::new(
diff --git a/datafusion/physical-plan/src/windows/mod.rs
b/datafusion/physical-plan/src/windows/mod.rs
index f91b525d60..6712bc855f 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -250,15 +250,21 @@ fn create_built_in_window_expr(
.try_into()
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
let n: u32 = n as u32;
- Arc::new(NthValue::nth(name, arg, data_type.clone(), n)?)
+ Arc::new(NthValue::nth(
+ name,
+ arg,
+ data_type.clone(),
+ n,
+ ignore_nulls,
+ )?)
}
BuiltInWindowFunction::FirstValue => {
let arg = args[0].clone();
- Arc::new(NthValue::first(name, arg, data_type.clone()))
+ Arc::new(NthValue::first(name, arg, data_type.clone(),
ignore_nulls))
}
BuiltInWindowFunction::LastValue => {
let arg = args[0].clone();
- Arc::new(NthValue::last(name, arg, data_type.clone()))
+ Arc::new(NthValue::last(name, arg, data_type.clone(),
ignore_nulls))
}
})
}
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index a3c0b3eccd..004261eff5 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -271,6 +271,7 @@ fn roundtrip_window() -> Result<()> {
"FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE
BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW",
col("a", &schema)?,
DataType::Int64,
+ false,
)),
&[col("b", &schema)?],
&[PhysicalSortExpr {
diff --git a/datafusion/sqllogictest/test_files/window.slt
b/datafusion/sqllogictest/test_files/window.slt
index cce67d898d..39c105a4dc 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -4307,3 +4307,275 @@ select lag(a) over (order by a ASC NULLS FIRST) as x1
NULL
NULL
NULL
+
+# Test for ignore nulls in FIRST_VALUE
+statement ok
+CREATE TABLE t AS VALUES (null::bigint), (3), (4);
+
+query I
+SELECT FIRST_VALUE(column1) OVER() FROM t;
+----
+NULL
+NULL
+NULL
+
+query I
+SELECT FIRST_VALUE(column1) RESPECT NULLS OVER() FROM t;
+----
+NULL
+NULL
+NULL
+
+query I
+SELECT FIRST_VALUE(column1) IGNORE NULLS OVER() FROM t;
+----
+3
+3
+3
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore nulls with ORDER BY in FIRST_VALUE
+statement ok
+CREATE TABLE t AS VALUES (3, 4), (4, 3), (null::bigint, 1), (null::bigint,
2), (5, 5), (6, 6);
+
+query II
+SELECT column1, column2 FROM t ORDER BY column2;
+----
+NULL 1
+NULL 2
+4 3
+3 4
+5 5
+6 6
+
+query II
+SELECT FIRST_VALUE(column1) OVER(ORDER BY column2), column2 FROM t;
+----
+NULL 1
+NULL 2
+NULL 3
+NULL 4
+NULL 5
+NULL 6
+
+query II
+SELECT FIRST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2), column2 FROM
t;
+----
+NULL 1
+NULL 2
+NULL 3
+NULL 4
+NULL 5
+NULL 6
+
+query II
+SELECT FIRST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2), column2 FROM
t;
+----
+NULL 1
+NULL 2
+4 3
+4 4
+4 5
+4 6
+
+query II
+SELECT FIRST_VALUE(column1)OVER(ORDER BY column2 RANGE BETWEEN 1 PRECEDING AND
1 FOLLOWING), column2 FROM t;
+----
+NULL 1
+NULL 2
+NULL 3
+4 4
+3 5
+5 6
+
+query II
+SELECT FIRST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 RANGE BETWEEN 1
PRECEDING AND 1 FOLLOWING), column2 FROM t;
+----
+NULL 1
+4 2
+4 3
+4 4
+3 5
+5 6
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore nulls with ORDER BY in FIRST_VALUE with all NULL values
+statement ok
+CREATE TABLE t AS VALUES (null::bigint, 4), (null::bigint, 3), (null::bigint,
1), (null::bigint, 2);
+
+query II
+SELECT FIRST_VALUE(column1) OVER(ORDER BY column2), column2 FROM t;
+----
+NULL 1
+NULL 2
+NULL 3
+NULL 4
+
+query II
+SELECT FIRST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2), column2 FROM
t;
+----
+NULL 1
+NULL 2
+NULL 3
+NULL 4
+
+query II
+SELECT FIRST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2), column2 FROM
t;
+----
+NULL 1
+NULL 2
+NULL 3
+NULL 4
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore nulls in LAST_VALUE
+statement ok
+CREATE TABLE t AS VALUES (1), (3), (null::bigint);
+
+query I
+SELECT LAST_VALUE(column1) OVER() FROM t;
+----
+NULL
+NULL
+NULL
+
+query I
+SELECT LAST_VALUE(column1) RESPECT NULLS OVER() FROM t;
+----
+NULL
+NULL
+NULL
+
+query I
+SELECT LAST_VALUE(column1) IGNORE NULLS OVER() FROM t;
+----
+3
+3
+3
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore nulls with ORDER BY in LAST_VALUE
+statement ok
+CREATE TABLE t AS VALUES (3, 4), (4, 3), (null::bigint, 1), (null::bigint,
2), (5, 5), (6, 6);
+
+query II
+SELECT column1, column2 FROM t ORDER BY column2 DESC NULLS LAST;
+----
+6 6
+5 5
+3 4
+4 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST), column2
FROM t;
+----
+6 6
+5 5
+3 4
+4 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS
LAST), column2 FROM t;
+----
+6 6
+5 5
+3 4
+4 3
+4 2
+4 1
+
+query II
+SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST ROWS BETWEEN
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+NULL 6
+NULL 5
+NULL 4
+NULL 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2 DESC NULLS LAST
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+NULL 6
+NULL 5
+NULL 4
+NULL 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS LAST
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+4 6
+4 5
+4 4
+4 3
+4 2
+4 1
+
+query II
+SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST RANGE BETWEEN
1 PRECEDING AND 1 FOLLOWING), column2 FROM t;
+----
+5 6
+3 5
+4 4
+NULL 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS LAST
RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING), column2 FROM t;
+----
+5 6
+3 5
+4 4
+4 3
+4 2
+NULL 1
+
+statement ok
+DROP TABLE t;
+
+# Test for ignore nulls with ORDER BY in LAST_VALUE with all NULLs
+statement ok
+CREATE TABLE t AS VALUES (null::bigint, 4), (null::bigint, 3), (null::bigint,
1), (null::bigint, 2);
+
+query II
+SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST ROWS BETWEEN
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+NULL 4
+NULL 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2 DESC NULLS LAST
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+NULL 4
+NULL 3
+NULL 2
+NULL 1
+
+query II
+SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS LAST
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t;
+----
+NULL 4
+NULL 3
+NULL 2
+NULL 1
+
+statement ok
+DROP TABLE t;