This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 0c0fce36cb LEAD/LAG calculate default value once (#9485)
0c0fce36cb is described below

commit 0c0fce36cb0a85c9c9b73571bdfbef0179791722
Author: comphead <[email protected]>
AuthorDate: Thu Mar 7 08:08:33 2024 -0800

    LEAD/LAG calculate default value once (#9485)
    
    * LEAD/LAG calculate default value once
    
    * refmt
---
 datafusion/physical-expr/src/window/lead_lag.rs | 67 +++++++------------------
 datafusion/physical-plan/src/windows/mod.rs     | 17 ++++++-
 datafusion/proto/src/physical_plan/to_proto.rs  |  8 +--
 3 files changed, 38 insertions(+), 54 deletions(-)

diff --git a/datafusion/physical-expr/src/window/lead_lag.rs 
b/datafusion/physical-expr/src/window/lead_lag.rs
index e496c7343f..da9410cdbe 100644
--- a/datafusion/physical-expr/src/window/lead_lag.rs
+++ b/datafusion/physical-expr/src/window/lead_lag.rs
@@ -21,7 +21,6 @@
 use crate::window::BuiltInWindowFunctionExpr;
 use crate::PhysicalExpr;
 use arrow::array::ArrayRef;
-use arrow::compute::cast;
 use arrow::datatypes::{DataType, Field};
 use arrow_array::Array;
 use datafusion_common::{
@@ -42,7 +41,7 @@ pub struct WindowShift {
     data_type: DataType,
     shift_offset: i64,
     expr: Arc<dyn PhysicalExpr>,
-    default_value: Option<ScalarValue>,
+    default_value: ScalarValue,
     ignore_nulls: bool,
 }
 
@@ -53,7 +52,7 @@ impl WindowShift {
     }
 
     /// Get the default_value for window shift expression.
-    pub fn get_default_value(&self) -> Option<ScalarValue> {
+    pub fn get_default_value(&self) -> ScalarValue {
         self.default_value.clone()
     }
 }
@@ -64,7 +63,7 @@ pub fn lead(
     data_type: DataType,
     expr: Arc<dyn PhysicalExpr>,
     shift_offset: Option<i64>,
-    default_value: Option<ScalarValue>,
+    default_value: ScalarValue,
     ignore_nulls: bool,
 ) -> WindowShift {
     WindowShift {
@@ -83,7 +82,7 @@ pub fn lag(
     data_type: DataType,
     expr: Arc<dyn PhysicalExpr>,
     shift_offset: Option<i64>,
-    default_value: Option<ScalarValue>,
+    default_value: ScalarValue,
     ignore_nulls: bool,
 ) -> WindowShift {
     WindowShift {
@@ -139,7 +138,7 @@ impl BuiltInWindowFunctionExpr for WindowShift {
 #[derive(Debug)]
 pub(crate) struct WindowShiftEvaluator {
     shift_offset: i64,
-    default_value: Option<ScalarValue>,
+    default_value: ScalarValue,
     ignore_nulls: bool,
     // VecDeque contains offset values that between non-null entries
     non_null_offsets: VecDeque<usize>,
@@ -152,29 +151,11 @@ impl WindowShiftEvaluator {
     }
 }
 
-fn create_empty_array(
-    value: Option<&ScalarValue>,
-    data_type: &DataType,
-    size: usize,
-) -> Result<ArrayRef> {
-    use arrow::array::new_null_array;
-    let array = value
-        .as_ref()
-        .map(|scalar| scalar.to_array_of_size(size))
-        .transpose()?
-        .unwrap_or_else(|| new_null_array(data_type, size));
-    if array.data_type() != data_type {
-        cast(&array, data_type).map_err(|e| arrow_datafusion_err!(e))
-    } else {
-        Ok(array)
-    }
-}
-
 // TODO: change the original arrow::compute::kernels::window::shift impl to 
support an optional default value
 fn shift_with_default_value(
     array: &ArrayRef,
     offset: i64,
-    value: Option<&ScalarValue>,
+    default_value: &ScalarValue,
 ) -> Result<ArrayRef> {
     use arrow::compute::concat;
 
@@ -182,7 +163,7 @@ fn shift_with_default_value(
     if offset == 0 {
         Ok(array.clone())
     } else if offset == i64::MIN || offset.abs() >= value_len {
-        create_empty_array(value, array.data_type(), array.len())
+        default_value.to_array_of_size(value_len as usize)
     } else {
         let slice_offset = (-offset).clamp(0, value_len) as usize;
         let length = array.len() - offset.unsigned_abs() as usize;
@@ -190,7 +171,8 @@ fn shift_with_default_value(
 
         // Generate array with remaining `null` items
         let nulls = offset.unsigned_abs() as usize;
-        let default_values = create_empty_array(value, slice.data_type(), 
nulls)?;
+        let default_values = default_value.to_array_of_size(nulls)?;
+
         // Concatenate both arrays, add nulls after if shift > 0 else before
         if offset > 0 {
             concat(&[default_values.as_ref(), slice.as_ref()])
@@ -236,9 +218,7 @@ impl PartitionEvaluator for WindowShiftEvaluator {
         values: &[ArrayRef],
         range: &Range<usize>,
     ) -> Result<ScalarValue> {
-        // TODO: do not recalculate default value every call
         let array = &values[0];
-        let dtype = array.data_type();
         let len = array.len();
 
         // LAG mode
@@ -334,10 +314,10 @@ impl PartitionEvaluator for WindowShiftEvaluator {
         // - ignore nulls mode and current value is null and is within window 
bounds
         // .unwrap() is safe here as there is a none check in front
         #[allow(clippy::unnecessary_unwrap)]
-        if idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap())) 
{
-            get_default_value(self.default_value.as_ref(), dtype)
-        } else {
+        if !(idx.is_none() || (self.ignore_nulls && 
array.is_null(idx.unwrap()))) {
             ScalarValue::try_from_array(array, idx.unwrap())
+        } else {
+            Ok(self.default_value.clone())
         }
     }
 
@@ -353,7 +333,7 @@ impl PartitionEvaluator for WindowShiftEvaluator {
         }
         // LEAD, LAG window functions take single column, values will have 
size 1
         let value = &values[0];
-        shift_with_default_value(value, self.shift_offset, 
self.default_value.as_ref())
+        shift_with_default_value(value, self.shift_offset, &self.default_value)
     }
 
     fn supports_bounded_execution(&self) -> bool {
@@ -361,17 +341,6 @@ impl PartitionEvaluator for WindowShiftEvaluator {
     }
 }
 
-fn get_default_value(
-    default_value: Option<&ScalarValue>,
-    dtype: &DataType,
-) -> Result<ScalarValue> {
-    match default_value {
-        Some(v) if !v.data_type().is_null() => v.cast_to(dtype),
-        // If None or Null datatype
-        _ => ScalarValue::try_from(dtype),
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -400,10 +369,10 @@ mod tests {
         test_i32_result(
             lead(
                 "lead".to_owned(),
-                DataType::Float32,
+                DataType::Int32,
                 Arc::new(Column::new("c3", 0)),
                 None,
-                None,
+                ScalarValue::Null.cast_to(&DataType::Int32)?,
                 false,
             ),
             [
@@ -423,10 +392,10 @@ mod tests {
         test_i32_result(
             lag(
                 "lead".to_owned(),
-                DataType::Float32,
+                DataType::Int32,
                 Arc::new(Column::new("c3", 0)),
                 None,
-                None,
+                ScalarValue::Null.cast_to(&DataType::Int32)?,
                 false,
             ),
             [
@@ -449,7 +418,7 @@ mod tests {
                 DataType::Int32,
                 Arc::new(Column::new("c3", 0)),
                 None,
-                Some(ScalarValue::Int32(Some(100))),
+                ScalarValue::Int32(Some(100)),
                 false,
             ),
             [
diff --git a/datafusion/physical-plan/src/windows/mod.rs 
b/datafusion/physical-plan/src/windows/mod.rs
index 54731f0d81..f91b525d60 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -156,6 +156,17 @@ fn get_scalar_value_from_args(
     })
 }
 
+fn get_casted_value(
+    default_value: Option<ScalarValue>,
+    dtype: &DataType,
+) -> Result<ScalarValue> {
+    match default_value {
+        Some(v) if !v.data_type().is_null() => v.cast_to(dtype),
+        // If None or Null datatype
+        _ => ScalarValue::try_from(dtype),
+    }
+}
+
 fn create_built_in_window_expr(
     fun: &BuiltInWindowFunction,
     args: &[Arc<dyn PhysicalExpr>],
@@ -204,7 +215,8 @@ fn create_built_in_window_expr(
             let shift_offset = get_scalar_value_from_args(args, 1)?
                 .map(|v| v.try_into())
                 .and_then(|v| v.ok());
-            let default_value = get_scalar_value_from_args(args, 2)?;
+            let default_value =
+                get_casted_value(get_scalar_value_from_args(args, 2)?, 
data_type)?;
             Arc::new(lag(
                 name,
                 data_type.clone(),
@@ -219,7 +231,8 @@ fn create_built_in_window_expr(
             let shift_offset = get_scalar_value_from_args(args, 1)?
                 .map(|v| v.try_into())
                 .and_then(|v| v.ok());
-            let default_value = get_scalar_value_from_args(args, 2)?;
+            let default_value =
+                get_casted_value(get_scalar_value_from_args(args, 2)?, 
data_type)?;
             Arc::new(lead(
                 name,
                 data_type.clone(),
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs 
b/datafusion/proto/src/physical_plan/to_proto.rs
index ce3df8183d..c464571893 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -167,9 +167,11 @@ impl TryFrom<Arc<dyn WindowExpr>> for 
protobuf::PhysicalWindowExprNode {
                         window_shift_expr.get_shift_offset(),
                     )))),
                 );
-                if let Some(default_value) = 
window_shift_expr.get_default_value() {
-                    args.insert(2, Arc::new(Literal::new(default_value)));
-                }
+                args.insert(
+                    2,
+                    
Arc::new(Literal::new(window_shift_expr.get_default_value())),
+                );
+
                 if window_shift_expr.get_shift_offset() >= 0 {
                     protobuf::BuiltInWindowFunction::Lag
                 } else {

Reply via email to