This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 0c0fce36cb LEAD/LAG calculate default value once (#9485)
0c0fce36cb is described below
commit 0c0fce36cb0a85c9c9b73571bdfbef0179791722
Author: comphead <[email protected]>
AuthorDate: Thu Mar 7 08:08:33 2024 -0800
LEAD/LAG calculate default value once (#9485)
* LEAD/LAG calculate default value once
* refmt
---
datafusion/physical-expr/src/window/lead_lag.rs | 67 +++++++------------------
datafusion/physical-plan/src/windows/mod.rs | 17 ++++++-
datafusion/proto/src/physical_plan/to_proto.rs | 8 +--
3 files changed, 38 insertions(+), 54 deletions(-)
diff --git a/datafusion/physical-expr/src/window/lead_lag.rs
b/datafusion/physical-expr/src/window/lead_lag.rs
index e496c7343f..da9410cdbe 100644
--- a/datafusion/physical-expr/src/window/lead_lag.rs
+++ b/datafusion/physical-expr/src/window/lead_lag.rs
@@ -21,7 +21,6 @@
use crate::window::BuiltInWindowFunctionExpr;
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
-use arrow::compute::cast;
use arrow::datatypes::{DataType, Field};
use arrow_array::Array;
use datafusion_common::{
@@ -42,7 +41,7 @@ pub struct WindowShift {
data_type: DataType,
shift_offset: i64,
expr: Arc<dyn PhysicalExpr>,
- default_value: Option<ScalarValue>,
+ default_value: ScalarValue,
ignore_nulls: bool,
}
@@ -53,7 +52,7 @@ impl WindowShift {
}
/// Get the default_value for window shift expression.
- pub fn get_default_value(&self) -> Option<ScalarValue> {
+ pub fn get_default_value(&self) -> ScalarValue {
self.default_value.clone()
}
}
@@ -64,7 +63,7 @@ pub fn lead(
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
shift_offset: Option<i64>,
- default_value: Option<ScalarValue>,
+ default_value: ScalarValue,
ignore_nulls: bool,
) -> WindowShift {
WindowShift {
@@ -83,7 +82,7 @@ pub fn lag(
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
shift_offset: Option<i64>,
- default_value: Option<ScalarValue>,
+ default_value: ScalarValue,
ignore_nulls: bool,
) -> WindowShift {
WindowShift {
@@ -139,7 +138,7 @@ impl BuiltInWindowFunctionExpr for WindowShift {
#[derive(Debug)]
pub(crate) struct WindowShiftEvaluator {
shift_offset: i64,
- default_value: Option<ScalarValue>,
+ default_value: ScalarValue,
ignore_nulls: bool,
// VecDeque contains offset values that between non-null entries
non_null_offsets: VecDeque<usize>,
@@ -152,29 +151,11 @@ impl WindowShiftEvaluator {
}
}
-fn create_empty_array(
- value: Option<&ScalarValue>,
- data_type: &DataType,
- size: usize,
-) -> Result<ArrayRef> {
- use arrow::array::new_null_array;
- let array = value
- .as_ref()
- .map(|scalar| scalar.to_array_of_size(size))
- .transpose()?
- .unwrap_or_else(|| new_null_array(data_type, size));
- if array.data_type() != data_type {
- cast(&array, data_type).map_err(|e| arrow_datafusion_err!(e))
- } else {
- Ok(array)
- }
-}
-
// TODO: change the original arrow::compute::kernels::window::shift impl to
support an optional default value
fn shift_with_default_value(
array: &ArrayRef,
offset: i64,
- value: Option<&ScalarValue>,
+ default_value: &ScalarValue,
) -> Result<ArrayRef> {
use arrow::compute::concat;
@@ -182,7 +163,7 @@ fn shift_with_default_value(
if offset == 0 {
Ok(array.clone())
} else if offset == i64::MIN || offset.abs() >= value_len {
- create_empty_array(value, array.data_type(), array.len())
+ default_value.to_array_of_size(value_len as usize)
} else {
let slice_offset = (-offset).clamp(0, value_len) as usize;
let length = array.len() - offset.unsigned_abs() as usize;
@@ -190,7 +171,8 @@ fn shift_with_default_value(
// Generate array with remaining `null` items
let nulls = offset.unsigned_abs() as usize;
- let default_values = create_empty_array(value, slice.data_type(),
nulls)?;
+ let default_values = default_value.to_array_of_size(nulls)?;
+
// Concatenate both arrays, add nulls after if shift > 0 else before
if offset > 0 {
concat(&[default_values.as_ref(), slice.as_ref()])
@@ -236,9 +218,7 @@ impl PartitionEvaluator for WindowShiftEvaluator {
values: &[ArrayRef],
range: &Range<usize>,
) -> Result<ScalarValue> {
- // TODO: do not recalculate default value every call
let array = &values[0];
- let dtype = array.data_type();
let len = array.len();
// LAG mode
@@ -334,10 +314,10 @@ impl PartitionEvaluator for WindowShiftEvaluator {
// - ignore nulls mode and current value is null and is within window
bounds
// .unwrap() is safe here as there is a none check in front
#[allow(clippy::unnecessary_unwrap)]
- if idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))
{
- get_default_value(self.default_value.as_ref(), dtype)
- } else {
+ if !(idx.is_none() || (self.ignore_nulls &&
array.is_null(idx.unwrap()))) {
ScalarValue::try_from_array(array, idx.unwrap())
+ } else {
+ Ok(self.default_value.clone())
}
}
@@ -353,7 +333,7 @@ impl PartitionEvaluator for WindowShiftEvaluator {
}
// LEAD, LAG window functions take single column, values will have
size 1
let value = &values[0];
- shift_with_default_value(value, self.shift_offset,
self.default_value.as_ref())
+ shift_with_default_value(value, self.shift_offset, &self.default_value)
}
fn supports_bounded_execution(&self) -> bool {
@@ -361,17 +341,6 @@ impl PartitionEvaluator for WindowShiftEvaluator {
}
}
-fn get_default_value(
- default_value: Option<&ScalarValue>,
- dtype: &DataType,
-) -> Result<ScalarValue> {
- match default_value {
- Some(v) if !v.data_type().is_null() => v.cast_to(dtype),
- // If None or Null datatype
- _ => ScalarValue::try_from(dtype),
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
@@ -400,10 +369,10 @@ mod tests {
test_i32_result(
lead(
"lead".to_owned(),
- DataType::Float32,
+ DataType::Int32,
Arc::new(Column::new("c3", 0)),
None,
- None,
+ ScalarValue::Null.cast_to(&DataType::Int32)?,
false,
),
[
@@ -423,10 +392,10 @@ mod tests {
test_i32_result(
lag(
"lead".to_owned(),
- DataType::Float32,
+ DataType::Int32,
Arc::new(Column::new("c3", 0)),
None,
- None,
+ ScalarValue::Null.cast_to(&DataType::Int32)?,
false,
),
[
@@ -449,7 +418,7 @@ mod tests {
DataType::Int32,
Arc::new(Column::new("c3", 0)),
None,
- Some(ScalarValue::Int32(Some(100))),
+ ScalarValue::Int32(Some(100)),
false,
),
[
diff --git a/datafusion/physical-plan/src/windows/mod.rs
b/datafusion/physical-plan/src/windows/mod.rs
index 54731f0d81..f91b525d60 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -156,6 +156,17 @@ fn get_scalar_value_from_args(
})
}
+fn get_casted_value(
+ default_value: Option<ScalarValue>,
+ dtype: &DataType,
+) -> Result<ScalarValue> {
+ match default_value {
+ Some(v) if !v.data_type().is_null() => v.cast_to(dtype),
+ // If None or Null datatype
+ _ => ScalarValue::try_from(dtype),
+ }
+}
+
fn create_built_in_window_expr(
fun: &BuiltInWindowFunction,
args: &[Arc<dyn PhysicalExpr>],
@@ -204,7 +215,8 @@ fn create_built_in_window_expr(
let shift_offset = get_scalar_value_from_args(args, 1)?
.map(|v| v.try_into())
.and_then(|v| v.ok());
- let default_value = get_scalar_value_from_args(args, 2)?;
+ let default_value =
+ get_casted_value(get_scalar_value_from_args(args, 2)?,
data_type)?;
Arc::new(lag(
name,
data_type.clone(),
@@ -219,7 +231,8 @@ fn create_built_in_window_expr(
let shift_offset = get_scalar_value_from_args(args, 1)?
.map(|v| v.try_into())
.and_then(|v| v.ok());
- let default_value = get_scalar_value_from_args(args, 2)?;
+ let default_value =
+ get_casted_value(get_scalar_value_from_args(args, 2)?,
data_type)?;
Arc::new(lead(
name,
data_type.clone(),
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index ce3df8183d..c464571893 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -167,9 +167,11 @@ impl TryFrom<Arc<dyn WindowExpr>> for
protobuf::PhysicalWindowExprNode {
window_shift_expr.get_shift_offset(),
)))),
);
- if let Some(default_value) =
window_shift_expr.get_default_value() {
- args.insert(2, Arc::new(Literal::new(default_value)));
- }
+ args.insert(
+ 2,
+
Arc::new(Literal::new(window_shift_expr.get_default_value())),
+ );
+
if window_shift_expr.get_shift_offset() >= 0 {
protobuf::BuiltInWindowFunction::Lag
} else {