jcsherin commented on code in PR #12857: URL: https://github.com/apache/datafusion/pull/12857#discussion_r1797116039
########## datafusion/functions-window/src/lead_lag.rs: ########## @@ -15,125 +15,200 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `lead` and `lag` that can evaluated -//! at runtime during query execution -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; +//! `lead` and `lag` window function implementations + +use crate::utils::{get_casted_value, get_scalar_value_from_args, get_signed_integer}; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::PartitionEvaluator; +use datafusion_expr::{ + Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, Volatility, + WindowUDFImpl, +}; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use std::any::Any; use std::cmp::min; use std::collections::VecDeque; use std::ops::{Neg, Range}; use std::sync::Arc; -/// window shift expression +get_or_init_udwf!( + Lag, + lag, + "Returns the row value that precedes the current row by a specified \ + offset within partition. If no such row exists, then returns the \ + default value.", + WindowShift::lag +); +get_or_init_udwf!( + Lead, + lead, + "Returns the value from a row that follows the current row by a \ + specified offset within the partition. If no such row exists, then \ + returns the default value.", + WindowShift::lead +); + +/// Create an expression to represent the `lag` window function +/// +/// returns value evaluated at the row that is offset rows before the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lag( + arg: datafusion_expr::Expr, + shift_offset: Option<i64>, + default_value: Option<ScalarValue>, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lag_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + +/// Create an expression to represent the `lead` window function +/// +/// returns value evaluated at the row that is offset rows after the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lead( + arg: datafusion_expr::Expr, + shift_offset: Option<i64>, + default_value: Option<ScalarValue>, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lead_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + #[derive(Debug)] -pub struct WindowShift { - name: String, - /// Output data type - data_type: DataType, - shift_offset: i64, - expr: Arc<dyn PhysicalExpr>, - default_value: ScalarValue, - ignore_nulls: bool, +enum WindowShiftKind { + Lag, + Lead, } -impl WindowShift { - /// Get shift_offset of window shift expression - pub fn get_shift_offset(&self) -> i64 { - self.shift_offset +impl WindowShiftKind { + fn name(&self) -> &'static str { + match self { + WindowShiftKind::Lag => "lag", + WindowShiftKind::Lead => "lead", + } } - /// Get the default_value for window shift expression. - pub fn get_default_value(&self) -> ScalarValue { - self.default_value.clone() + /// In [`WindowShiftEvaluator`] a positive offset is used to signal + /// computation of `lag()`. So here we negate the input offset + /// value for computing `lead()`. + fn shift_offset(&self, value: Option<i64>) -> i64 { + match self { + WindowShiftKind::Lag => value.unwrap_or(1), + WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1), + } } } -/// lead() window function -pub fn lead( - name: String, - data_type: DataType, - expr: Arc<dyn PhysicalExpr>, - shift_offset: Option<i64>, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1), - expr, - default_value, - ignore_nulls, - } +/// window shift expression +#[derive(Debug)] +pub struct WindowShift { + signature: Signature, + kind: WindowShiftKind, } -/// lag() window function -pub fn lag( - name: String, - data_type: DataType, - expr: Arc<dyn PhysicalExpr>, - shift_offset: Option<i64>, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.unwrap_or(1), - expr, - default_value, - ignore_nulls, +impl WindowShift { + fn new(kind: WindowShiftKind) -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ), + kind, + } + } + + pub fn lag() -> Self { + Self::new(WindowShiftKind::Lag) + } + + pub fn lead() -> Self { + Self::new(WindowShiftKind::Lead) } } -impl BuiltInWindowFunctionExpr for WindowShift { - /// Return a reference to Any that can be used for downcasting +impl WindowUDFImpl for WindowShift { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result<Field> { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + fn name(&self) -> &str { + self.kind.name() } - fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> { - vec![Arc::clone(&self.expr)] + fn signature(&self) -> &Signature { + &self.signature } - fn name(&self) -> &str { - &self.name - } + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result<Box<dyn PartitionEvaluator>> { + let shift_offset = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)? + .map(get_signed_integer) + .map_or(Ok(None), |v| v.map(Some)) + .map(|n| self.kind.shift_offset(n)) + .map(|offset| { + if partition_evaluator_args.is_reversed() { + -offset + } else { + offset + } + })?; + let return_type = partition_evaluator_args + .input_types() + .first() + .unwrap_or(&DataType::Null); + let default_value = get_casted_value( + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 2)?, + return_type, + )?; Review Comment: Parse `shift_offset`, `default_value` arguments remains the same as how it is done in built-in window functions. https://github.com/apache/datafusion/blob/a8d3fae21d12a9caafad77a11b8a22bcfe9555fd/datafusion/physical-plan/src/windows/mod.rs#L254-L285 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org