This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 44127ec96d Fix: handle NULL input in lead/lag window function (#12811)
44127ec96d is described below
commit 44127ec96d7f2be2200af03ad65a2af1b8340b0a
Author: HuSen <[email protected]>
AuthorDate: Thu Oct 17 00:19:25 2024 +0800
Fix: handle NULL input in lead/lag window function (#12811)
---
datafusion/physical-plan/src/windows/mod.rs | 51 +++++++++++++++++++++++----
datafusion/sqllogictest/test_files/window.slt | 50 +++++++++++++++++++++++++-
2 files changed, 94 insertions(+), 7 deletions(-)
diff --git a/datafusion/physical-plan/src/windows/mod.rs
b/datafusion/physical-plan/src/windows/mod.rs
index 6f7d95bf95..e6a773f6b1 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -217,6 +217,41 @@ fn get_casted_value(
}
}
+/// Rewrites the NULL expression (1st argument) with an expression
+/// which is the same data type as the default value (3rd argument).
+/// Also rewrites the return type with the same data type as the
+/// default value.
+///
+/// If a default value is not provided, or it is NULL the original
+/// expression (1st argument) and return type is returned without
+/// any modifications.
+fn rewrite_null_expr_and_data_type(
+ args: &[Arc<dyn PhysicalExpr>],
+ expr_type: &DataType,
+) -> Result<(Arc<dyn PhysicalExpr>, DataType)> {
+ assert!(!args.is_empty());
+ let expr = Arc::clone(&args[0]);
+
+ // The input expression and the return is type is unchanged
+ // when the input expression is not NULL.
+ if !expr_type.is_null() {
+ return Ok((expr, expr_type.clone()));
+ }
+
+ get_scalar_value_from_args(args, 2)?
+ .and_then(|value| {
+ ScalarValue::try_from(value.data_type().clone())
+ .map(|sv| {
+ Ok((
+ Arc::new(Literal::new(sv)) as Arc<dyn PhysicalExpr>,
+ value.data_type().clone(),
+ ))
+ })
+ .ok()
+ })
+ .unwrap_or(Ok((expr, expr_type.clone())))
+}
+
fn create_built_in_window_expr(
fun: &BuiltInWindowFunction,
args: &[Arc<dyn PhysicalExpr>],
@@ -252,15 +287,17 @@ fn create_built_in_window_expr(
}
}
BuiltInWindowFunction::Lag => {
- let arg = Arc::clone(&args[0]);
+ // rewrite NULL expression and the return datatype
+ let (arg, out_data_type) =
+ rewrite_null_expr_and_data_type(args, out_data_type)?;
let shift_offset = get_scalar_value_from_args(args, 1)?
.map(get_signed_integer)
.map_or(Ok(None), |v| v.map(Some))?;
let default_value =
- get_casted_value(get_scalar_value_from_args(args, 2)?,
out_data_type)?;
+ get_casted_value(get_scalar_value_from_args(args, 2)?,
&out_data_type)?;
Arc::new(lag(
name,
- out_data_type.clone(),
+ default_value.data_type().clone(),
arg,
shift_offset,
default_value,
@@ -268,15 +305,17 @@ fn create_built_in_window_expr(
))
}
BuiltInWindowFunction::Lead => {
- let arg = Arc::clone(&args[0]);
+ // rewrite NULL expression and the return datatype
+ let (arg, out_data_type) =
+ rewrite_null_expr_and_data_type(args, out_data_type)?;
let shift_offset = get_scalar_value_from_args(args, 1)?
.map(get_signed_integer)
.map_or(Ok(None), |v| v.map(Some))?;
let default_value =
- get_casted_value(get_scalar_value_from_args(args, 2)?,
out_data_type)?;
+ get_casted_value(get_scalar_value_from_args(args, 2)?,
&out_data_type)?;
Arc::new(lead(
name,
- out_data_type.clone(),
+ default_value.data_type().clone(),
arg,
shift_offset,
default_value,
diff --git a/datafusion/sqllogictest/test_files/window.slt
b/datafusion/sqllogictest/test_files/window.slt
index 79cb91e183..1b612f9212 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -4941,4 +4941,52 @@ NULL
statement ok
DROP TABLE t;
-## end test handle NULL and 0 of NTH_VALUE
\ No newline at end of file
+## end test handle NULL and 0 of NTH_VALUE
+
+## test handle NULL of lead
+
+statement ok
+create table t1(v1 int);
+
+statement ok
+insert into t1 values (1);
+
+query B
+SELECT LEAD(NULL, 0, false) OVER () FROM t1;
+----
+NULL
+
+query B
+SELECT LAG(NULL, 0, false) OVER () FROM t1;
+----
+NULL
+
+query B
+SELECT LEAD(NULL, 1, false) OVER () FROM t1;
+----
+false
+
+query B
+SELECT LAG(NULL, 1, false) OVER () FROM t1;
+----
+false
+
+statement ok
+insert into t1 values (2);
+
+query B
+SELECT LEAD(NULL, 1, false) OVER () FROM t1;
+----
+NULL
+false
+
+query B
+SELECT LAG(NULL, 1, false) OVER () FROM t1;
+----
+false
+NULL
+
+statement ok
+DROP TABLE t1;
+
+## end test handle NULL of lead
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]