This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 44127ec96d Fix: handle NULL input in lead/lag window function (#12811)
44127ec96d is described below

commit 44127ec96d7f2be2200af03ad65a2af1b8340b0a
Author: HuSen <[email protected]>
AuthorDate: Thu Oct 17 00:19:25 2024 +0800

    Fix: handle NULL input in lead/lag window function (#12811)
---
 datafusion/physical-plan/src/windows/mod.rs   | 51 +++++++++++++++++++++++----
 datafusion/sqllogictest/test_files/window.slt | 50 +++++++++++++++++++++++++-
 2 files changed, 94 insertions(+), 7 deletions(-)

diff --git a/datafusion/physical-plan/src/windows/mod.rs 
b/datafusion/physical-plan/src/windows/mod.rs
index 6f7d95bf95..e6a773f6b1 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -217,6 +217,41 @@ fn get_casted_value(
     }
 }
 
+/// Rewrites the NULL expression (1st argument) with an expression
+/// which is the same data type as the default value (3rd argument).
+/// Also rewrites the return type with the same data type as the
+/// default value.
+///
+/// If a default value is not provided, or it is NULL the original
+/// expression (1st argument) and return type is returned without
+/// any modifications.
+fn rewrite_null_expr_and_data_type(
+    args: &[Arc<dyn PhysicalExpr>],
+    expr_type: &DataType,
+) -> Result<(Arc<dyn PhysicalExpr>, DataType)> {
+    assert!(!args.is_empty());
+    let expr = Arc::clone(&args[0]);
+
+    // The input expression and the return is type is unchanged
+    // when the input expression is not NULL.
+    if !expr_type.is_null() {
+        return Ok((expr, expr_type.clone()));
+    }
+
+    get_scalar_value_from_args(args, 2)?
+        .and_then(|value| {
+            ScalarValue::try_from(value.data_type().clone())
+                .map(|sv| {
+                    Ok((
+                        Arc::new(Literal::new(sv)) as Arc<dyn PhysicalExpr>,
+                        value.data_type().clone(),
+                    ))
+                })
+                .ok()
+        })
+        .unwrap_or(Ok((expr, expr_type.clone())))
+}
+
 fn create_built_in_window_expr(
     fun: &BuiltInWindowFunction,
     args: &[Arc<dyn PhysicalExpr>],
@@ -252,15 +287,17 @@ fn create_built_in_window_expr(
             }
         }
         BuiltInWindowFunction::Lag => {
-            let arg = Arc::clone(&args[0]);
+            // rewrite NULL expression and the return datatype
+            let (arg, out_data_type) =
+                rewrite_null_expr_and_data_type(args, out_data_type)?;
             let shift_offset = get_scalar_value_from_args(args, 1)?
                 .map(get_signed_integer)
                 .map_or(Ok(None), |v| v.map(Some))?;
             let default_value =
-                get_casted_value(get_scalar_value_from_args(args, 2)?, 
out_data_type)?;
+                get_casted_value(get_scalar_value_from_args(args, 2)?, 
&out_data_type)?;
             Arc::new(lag(
                 name,
-                out_data_type.clone(),
+                default_value.data_type().clone(),
                 arg,
                 shift_offset,
                 default_value,
@@ -268,15 +305,17 @@ fn create_built_in_window_expr(
             ))
         }
         BuiltInWindowFunction::Lead => {
-            let arg = Arc::clone(&args[0]);
+            // rewrite NULL expression and the return datatype
+            let (arg, out_data_type) =
+                rewrite_null_expr_and_data_type(args, out_data_type)?;
             let shift_offset = get_scalar_value_from_args(args, 1)?
                 .map(get_signed_integer)
                 .map_or(Ok(None), |v| v.map(Some))?;
             let default_value =
-                get_casted_value(get_scalar_value_from_args(args, 2)?, 
out_data_type)?;
+                get_casted_value(get_scalar_value_from_args(args, 2)?, 
&out_data_type)?;
             Arc::new(lead(
                 name,
-                out_data_type.clone(),
+                default_value.data_type().clone(),
                 arg,
                 shift_offset,
                 default_value,
diff --git a/datafusion/sqllogictest/test_files/window.slt 
b/datafusion/sqllogictest/test_files/window.slt
index 79cb91e183..1b612f9212 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -4941,4 +4941,52 @@ NULL
 statement ok
 DROP TABLE t;
 
-## end test handle NULL and 0 of NTH_VALUE
\ No newline at end of file
+## end test handle NULL and 0 of NTH_VALUE
+
+## test handle NULL of lead
+
+statement ok
+create table t1(v1 int);
+
+statement ok
+insert into t1 values (1);
+
+query B
+SELECT LEAD(NULL, 0, false) OVER () FROM t1;
+----
+NULL
+
+query B
+SELECT LAG(NULL, 0, false) OVER () FROM t1;
+----
+NULL
+
+query B
+SELECT LEAD(NULL, 1, false) OVER () FROM t1;
+----
+false
+
+query B
+SELECT LAG(NULL, 1, false) OVER () FROM t1;
+----
+false
+
+statement ok
+insert into t1 values (2);
+
+query B
+SELECT LEAD(NULL, 1, false) OVER () FROM t1;
+----
+NULL
+false
+
+query B
+SELECT LAG(NULL, 1, false) OVER () FROM t1;
+----
+false
+NULL
+
+statement ok
+DROP TABLE t1;
+
+## end test handle NULL of lead
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to