This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 11ac83615f Add window function support (#5653)
11ac83615f is described below

commit 11ac83615f807a68b31c5bced209c0d42ef40b3c
Author: Nuttiiya Seekhao <[email protected]>
AuthorDate: Wed Mar 22 13:35:22 2023 -0400

    Add window function support (#5653)
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 233 +++++++++++++++++----
 datafusion/substrait/src/logical_plan/producer.rs  | 218 ++++++++++++++++++-
 .../substrait/tests/roundtrip_logical_plan.rs      |   5 +
 3 files changed, 408 insertions(+), 48 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index b7cacf131d..767c4a3937 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -19,10 +19,11 @@ use async_recursion::async_recursion;
 use datafusion::arrow::datatypes::DataType;
 use datafusion::common::{DFField, DFSchema, DFSchemaRef};
 use datafusion::logical_expr::{
-    aggregate_function, BinaryExpr, Case, Expr, LogicalPlan, Operator,
+    aggregate_function, window_function::find_df_window_func, BinaryExpr, 
Case, Expr,
+    LogicalPlan, Operator,
 };
 use datafusion::logical_expr::{build_join_schema, LogicalPlanBuilder};
-use datafusion::logical_expr::{expr, Cast};
+use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
 use datafusion::prelude::JoinType;
 use datafusion::sql::TableReference;
 use datafusion::{
@@ -35,7 +36,10 @@ use substrait::proto::{
     aggregate_function::AggregationInvocation,
     expression::{
         field_reference::ReferenceType::DirectReference, literal::LiteralType,
-        reference_segment::ReferenceType::StructField, MaskExpression, RexType,
+        reference_segment::ReferenceType::StructField,
+        window_function::bound as SubstraitBound,
+        window_function::bound::Kind as BoundKind, window_function::Bound,
+        MaskExpression, RexType,
     },
     extensions::simple_extension_declaration::MappingType,
     function_argument::ArgType,
@@ -45,6 +49,7 @@ use substrait::proto::{
     sort_field::{SortDirection, SortKind::*},
     AggregateFunction, Expression, Plan, Rel, Type,
 };
+use substrait::proto::{FunctionArgument, SortField};
 
 use datafusion::logical_expr::expr::Sort;
 use std::collections::HashMap;
@@ -139,13 +144,25 @@ pub async fn from_substrait_rel(
     match &rel.rel_type {
         Some(RelType::Project(p)) => {
             if let Some(input) = p.input.as_ref() {
-                let input = LogicalPlanBuilder::from(
+                let mut input = LogicalPlanBuilder::from(
                     from_substrait_rel(ctx, input, extensions).await?,
                 );
                 let mut exprs: Vec<Expr> = vec![];
                 for e in &p.expressions {
-                    let x = from_substrait_rex(e, input.schema(), 
extensions).await?;
-                    exprs.push(x.as_ref().clone());
+                    let x =
+                        from_substrait_rex(e, input.clone().schema(), 
extensions).await?;
+                    // if the expression is WindowFunction, wrap in a Window 
relation
+                    //   before returning and do not add to list of this 
Projection's expression list
+                    // otherwise, add expression to the Projection's 
expression list
+                    match &*x {
+                        Expr::WindowFunction(_) => {
+                            input = input.window(vec![x.as_ref().clone()])?;
+                            exprs.push(x.as_ref().clone());
+                        }
+                        _ => {
+                            exprs.push(x.as_ref().clone());
+                        }
+                    }
                 }
                 input.project(exprs)?.build()
             } else {
@@ -193,45 +210,8 @@ pub async fn from_substrait_rel(
                 let input = LogicalPlanBuilder::from(
                     from_substrait_rel(ctx, input, extensions).await?,
                 );
-                let mut sorts: Vec<Expr> = vec![];
-                for s in &sort.sorts {
-                    let expr = from_substrait_rex(
-                        s.expr.as_ref().unwrap(),
-                        input.schema(),
-                        extensions,
-                    )
-                    .await?;
-                    let asc_nullfirst = match &s.sort_kind {
-                        Some(k) => match k {
-                            Direction(d) => {
-                                let direction : SortDirection = unsafe {
-                                    ::std::mem::transmute(*d)
-                                };
-                                match direction {
-                                    SortDirection::AscNullsFirst => Ok((true, 
true)),
-                                    SortDirection::AscNullsLast => Ok((true, 
false)),
-                                    SortDirection::DescNullsFirst => 
Ok((false, true)),
-                                    SortDirection::DescNullsLast => Ok((false, 
false)),
-                                    SortDirection::Clustered =>
-                                        
Err(DataFusionError::NotImplemented("Sort with direction clustered is not yet 
supported".to_string()))
-                                    ,
-                                    SortDirection::Unspecified =>
-                                        
Err(DataFusionError::NotImplemented("Unspecified sort direction is 
invalid".to_string()))
-                                }
-                            }
-                            ComparisonFunctionReference(_) => {
-                                Err(DataFusionError::NotImplemented("Sort 
using comparison function reference is not supported".to_string()))
-                            },
-                        },
-                        None => Err(DataFusionError::NotImplemented("Sort 
without sort kind is invalid".to_string()))
-                    };
-                    let (asc, nulls_first) = asc_nullfirst.unwrap();
-                    sorts.push(Expr::Sort(Sort {
-                        expr: Box::new(expr.as_ref().clone()),
-                        asc,
-                        nulls_first,
-                    }));
-                }
+                let sorts =
+                    from_substrait_sorts(&sort.sorts, input.schema(), 
extensions).await?;
                 input.sort(sorts)?.build()
             } else {
                 Err(DataFusionError::NotImplemented(
@@ -452,6 +432,90 @@ fn from_substrait_jointype(join_type: i32) -> 
Result<JoinType> {
     }
 }
 
+/// Convert Substrait Sorts to DataFusion Exprs
+pub async fn from_substrait_sorts(
+    substrait_sorts: &Vec<SortField>,
+    input_schema: &DFSchema,
+    extensions: &HashMap<u32, &String>,
+) -> Result<Vec<Expr>> {
+    let mut sorts: Vec<Expr> = vec![];
+    for s in substrait_sorts {
+        let expr = from_substrait_rex(s.expr.as_ref().unwrap(), input_schema, 
extensions)
+            .await?;
+        let asc_nullfirst = match &s.sort_kind {
+            Some(k) => match k {
+                Direction(d) => {
+                    let direction: SortDirection = unsafe { 
::std::mem::transmute(*d) };
+                    match direction {
+                        SortDirection::AscNullsFirst => Ok((true, true)),
+                        SortDirection::AscNullsLast => Ok((true, false)),
+                        SortDirection::DescNullsFirst => Ok((false, true)),
+                        SortDirection::DescNullsLast => Ok((false, false)),
+                        SortDirection::Clustered => 
Err(DataFusionError::NotImplemented(
+                            "Sort with direction clustered is not yet 
supported"
+                                .to_string(),
+                        )),
+                        SortDirection::Unspecified => {
+                            Err(DataFusionError::NotImplemented(
+                                "Unspecified sort direction is 
invalid".to_string(),
+                            ))
+                        }
+                    }
+                }
+                ComparisonFunctionReference(_) => 
Err(DataFusionError::NotImplemented(
+                    "Sort using comparison function reference is not supported"
+                        .to_string(),
+                )),
+            },
+            None => Err(DataFusionError::NotImplemented(
+                "Sort without sort kind is invalid".to_string(),
+            )),
+        };
+        let (asc, nulls_first) = asc_nullfirst.unwrap();
+        sorts.push(Expr::Sort(Sort {
+            expr: Box::new(expr.as_ref().clone()),
+            asc,
+            nulls_first,
+        }));
+    }
+    Ok(sorts)
+}
+
+/// Convert Substrait Expressions to DataFusion Exprs
+pub async fn from_substrait_rex_vec(
+    exprs: &Vec<Expression>,
+    input_schema: &DFSchema,
+    extensions: &HashMap<u32, &String>,
+) -> Result<Vec<Expr>> {
+    let mut expressions: Vec<Expr> = vec![];
+    for expr in exprs {
+        let expression = from_substrait_rex(expr, input_schema, 
extensions).await?;
+        expressions.push(expression.as_ref().clone());
+    }
+    Ok(expressions)
+}
+
+/// Convert Substrait FunctionArguments to DataFusion Exprs
+pub async fn from_substriat_func_args(
+    arguments: &Vec<FunctionArgument>,
+    input_schema: &DFSchema,
+    extensions: &HashMap<u32, &String>,
+) -> Result<Vec<Expr>> {
+    let mut args: Vec<Expr> = vec![];
+    for arg in arguments {
+        let arg_expr = match &arg.arg_type {
+            Some(ArgType::Value(e)) => {
+                from_substrait_rex(e, input_schema, extensions).await
+            }
+            _ => Err(DataFusionError::NotImplemented(
+                "Aggregated function argument non-Value type not 
supported".to_string(),
+            )),
+        };
+        args.push(arg_expr?.as_ref().clone());
+    }
+    Ok(args)
+}
+
 /// Convert Substrait AggregateFunction to DataFusion Expr
 pub async fn from_substrait_agg_func(
     f: &AggregateFunction,
@@ -740,6 +804,47 @@ pub async fn from_substrait_rex(
                 "Cast experssion without output type is not 
allowed".to_string(),
             )),
         },
+        Some(RexType::WindowFunction(window)) => {
+            let fun = match extensions.get(&window.function_reference) {
+                Some(function_name) => Ok(find_df_window_func(function_name)),
+                None => Err(DataFusionError::NotImplemented(format!(
+                    "Window function not found: function anchor = {:?}",
+                    &window.function_reference
+                ))),
+            };
+            let order_by =
+                from_substrait_sorts(&window.sorts, input_schema, 
extensions).await?;
+            // Substrait does not encode WindowFrameUnits so we're using a 
simple logic to determine the units
+            // If there is no `ORDER BY`, then by default, the frame counts 
each row from the lower up to upper boundary
+            // If there is `ORDER BY`, then by default, each frame is a range 
starting from unbounded preceding to current row
+            // TODO: Consider the cases where window frame is specified in 
query and is different from default
+            let units = if order_by.is_empty() {
+                WindowFrameUnits::Rows
+            } else {
+                WindowFrameUnits::Range
+            };
+            Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction {
+                fun: fun?.unwrap(),
+                args: from_substriat_func_args(
+                    &window.arguments,
+                    input_schema,
+                    extensions,
+                )
+                .await?,
+                partition_by: from_substrait_rex_vec(
+                    &window.partitions,
+                    input_schema,
+                    extensions,
+                )
+                .await?,
+                order_by,
+                window_frame: datafusion::logical_expr::WindowFrame {
+                    units,
+                    start_bound: from_substrait_bound(&window.lower_bound, 
true)?,
+                    end_bound: from_substrait_bound(&window.upper_bound, 
false)?,
+                },
+            })))
+        }
         _ => Err(DataFusionError::NotImplemented(
             "unsupported rex_type".to_string(),
         )),
@@ -767,6 +872,44 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> 
Result<DataType> {
     }
 }
 
+fn from_substrait_bound(
+    bound: &Option<Bound>,
+    is_lower: bool,
+) -> Result<WindowFrameBound> {
+    match bound {
+        Some(b) => match &b.kind {
+            Some(k) => match k {
+                BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => {
+                    Ok(WindowFrameBound::CurrentRow)
+                }
+                BoundKind::Preceding(SubstraitBound::Preceding { offset }) => 
Ok(
+                    
WindowFrameBound::Preceding(ScalarValue::Int64(Some(*offset))),
+                ),
+                BoundKind::Following(SubstraitBound::Following { offset }) => 
Ok(
+                    
WindowFrameBound::Following(ScalarValue::Int64(Some(*offset))),
+                ),
+                BoundKind::Unbounded(SubstraitBound::Unbounded {}) => {
+                    if is_lower {
+                        Ok(WindowFrameBound::Preceding(ScalarValue::Null))
+                    } else {
+                        Ok(WindowFrameBound::Following(ScalarValue::Null))
+                    }
+                }
+            },
+            None => Err(DataFusionError::Substrait(
+                "WindowFunction missing Substrait Bound kind".to_string(),
+            )),
+        },
+        None => {
+            if is_lower {
+                Ok(WindowFrameBound::Preceding(ScalarValue::Null))
+            } else {
+                Ok(WindowFrameBound::Following(ScalarValue::Null))
+            }
+        }
+    }
+}
+
 fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
     if let Some(kind) = &null_type.kind {
         match kind {
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index 1de26a3433..ecb322edb7 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -20,6 +20,7 @@ use std::{collections::HashMap, mem, sync::Arc};
 use datafusion::{
     arrow::datatypes::DataType,
     error::{DataFusionError, Result},
+    logical_expr::{WindowFrame, WindowFrameBound},
     prelude::JoinType,
     scalar::ScalarValue,
 };
@@ -27,7 +28,7 @@ use datafusion::{
 use datafusion::common::DFSchemaRef;
 #[allow(unused_imports)]
 use datafusion::logical_expr::aggregate_function;
-use datafusion::logical_expr::expr::{BinaryExpr, Case, Cast, Sort};
+use datafusion::logical_expr::expr::{BinaryExpr, Case, Cast, Sort, 
WindowFunction};
 use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, 
Operator};
 use datafusion::prelude::{binary_expr, Expr};
 use substrait::proto::{
@@ -38,8 +39,12 @@ use substrait::proto::{
         if_then::IfClause,
         literal::{Decimal, LiteralType},
         mask_expression::{StructItem, StructSelect},
-        reference_segment, FieldReference, IfThen, Literal, MaskExpression,
-        ReferenceSegment, RexType, ScalarFunction,
+        reference_segment,
+        window_function::bound as SubstraitBound,
+        window_function::bound::Kind as BoundKind,
+        window_function::Bound,
+        FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, 
RexType,
+        ScalarFunction, WindowFunction as SubstraitWindowFunction,
     },
     extensions::{
         self,
@@ -301,6 +306,42 @@ pub fn to_substrait_rel(
             // since there is no corresponding relation type in Substrait
             to_substrait_rel(alias.input.as_ref(), extension_info)
         }
+        LogicalPlan::Window(window) => {
+            let input = to_substrait_rel(window.input.as_ref(), 
extension_info)?;
+            // If the input is a Project relation, we can just append the 
WindowFunction expressions
+            // before returning
+            // Otherwise, wrap the input in a Project relation before 
appending the WindowFunction
+            // expressions
+            let mut project_rel: Box<ProjectRel> = match 
&input.as_ref().rel_type {
+                Some(RelType::Project(p)) => Box::new(*p.clone()),
+                _ => {
+                    // Create Projection with field referencing all output 
fields in the input relation
+                    let expressions = (0..window.input.schema().fields().len())
+                        .map(substrait_field_ref)
+                        .collect::<Result<Vec<_>>>()?;
+                    Box::new(ProjectRel {
+                        common: None,
+                        input: Some(input),
+                        expressions,
+                        advanced_extension: None,
+                    })
+                }
+            };
+            // Parse WindowFunction expression
+            let mut window_exprs = vec![];
+            for expr in &window.window_expr {
+                window_exprs.push(to_substrait_rex(
+                    expr,
+                    window.input.schema(),
+                    extension_info,
+                )?);
+            }
+            // Append parsed WindowFunction expressions
+            project_rel.expressions.extend(window_exprs);
+            Ok(Box::new(Rel {
+                rel_type: Some(RelType::Project(project_rel)),
+            }))
+        }
         _ => Err(DataFusionError::NotImplemented(format!(
             "Unsupported operator: {plan:?}"
         ))),
@@ -636,6 +677,47 @@ pub fn to_substrait_rex(
             })
         }
         Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, 
extension_info),
+        Expr::WindowFunction(WindowFunction {
+            fun,
+            args,
+            partition_by,
+            order_by,
+            window_frame,
+        }) => {
+            // function reference
+            let function_name = fun.to_string().to_lowercase();
+            let function_anchor = _register_function(function_name, 
extension_info);
+            // arguments
+            let mut arguments: Vec<FunctionArgument> = vec![];
+            for arg in args {
+                arguments.push(FunctionArgument {
+                    arg_type: Some(ArgType::Value(to_substrait_rex(
+                        arg,
+                        schema,
+                        extension_info,
+                    )?)),
+                });
+            }
+            // partition by expressions
+            let partition_by = partition_by
+                .iter()
+                .map(|e| to_substrait_rex(e, schema, extension_info))
+                .collect::<Result<Vec<_>>>()?;
+            // order by expressions
+            let order_by = order_by
+                .iter()
+                .map(|e| substrait_sort_field(e, schema, extension_info))
+                .collect::<Result<Vec<_>>>()?;
+            // window frame
+            let bounds = to_substrait_bounds(window_frame)?;
+            Ok(make_substrait_window_function(
+                function_anchor,
+                arguments,
+                partition_by,
+                order_by,
+                bounds,
+            ))
+        }
         _ => Err(DataFusionError::NotImplemented(format!(
             "Unsupported expression: {expr:?}"
         ))),
@@ -693,6 +775,136 @@ fn to_substrait_type(dt: &DataType) -> 
Result<substrait::proto::Type> {
     }
 }
 
+#[allow(deprecated)]
+fn make_substrait_window_function(
+    function_reference: u32,
+    arguments: Vec<FunctionArgument>,
+    partitions: Vec<Expression>,
+    sorts: Vec<SortField>,
+    bounds: (Bound, Bound),
+) -> Expression {
+    Expression {
+        rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction {
+            function_reference,
+            arguments,
+            partitions,
+            sorts,
+            options: vec![],
+            output_type: None,
+            phase: 0,      // default to AGGREGATION_PHASE_UNSPECIFIED
+            invocation: 0, // TODO: fix
+            lower_bound: Some(bounds.0),
+            upper_bound: Some(bounds.1),
+            args: vec![],
+        })),
+    }
+}
+
+fn to_substrait_bound(bound: &WindowFrameBound) -> Bound {
+    match bound {
+        WindowFrameBound::CurrentRow => Bound {
+            kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})),
+        },
+        WindowFrameBound::Preceding(s) => match s {
+            ScalarValue::UInt8(Some(v)) => Bound {
+                kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::UInt16(Some(v)) => Bound {
+                kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::UInt32(Some(v)) => Bound {
+                kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::UInt64(Some(v)) => Bound {
+                kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::Int8(Some(v)) => Bound {
+                kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::Int16(Some(v)) => Bound {
+                kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::Int32(Some(v)) => Bound {
+                kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::Int64(Some(v)) => Bound {
+                kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+                    offset: *v,
+                })),
+            },
+            _ => Bound {
+                kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})),
+            },
+        },
+        WindowFrameBound::Following(s) => match s {
+            ScalarValue::UInt8(Some(v)) => Bound {
+                kind: Some(BoundKind::Following(SubstraitBound::Following {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::UInt16(Some(v)) => Bound {
+                kind: Some(BoundKind::Following(SubstraitBound::Following {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::UInt32(Some(v)) => Bound {
+                kind: Some(BoundKind::Following(SubstraitBound::Following {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::UInt64(Some(v)) => Bound {
+                kind: Some(BoundKind::Following(SubstraitBound::Following {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::Int8(Some(v)) => Bound {
+                kind: Some(BoundKind::Following(SubstraitBound::Following {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::Int16(Some(v)) => Bound {
+                kind: Some(BoundKind::Following(SubstraitBound::Following {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::Int32(Some(v)) => Bound {
+                kind: Some(BoundKind::Following(SubstraitBound::Following {
+                    offset: *v as i64,
+                })),
+            },
+            ScalarValue::Int64(Some(v)) => Bound {
+                kind: Some(BoundKind::Following(SubstraitBound::Following {
+                    offset: *v,
+                })),
+            },
+            _ => Bound {
+                kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})),
+            },
+        },
+    }
+}
+
+fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> {
+    Ok((
+        to_substrait_bound(&window_frame.start_bound),
+        to_substrait_bound(&window_frame.end_bound),
+    ))
+}
+
 fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
     let default_type_ref = 0;
     let default_nullability = r#type::Nullability::Nullable as i32;
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 9aa430bb09..936c4670b3 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -250,6 +250,11 @@ mod tests {
         .await
     }
 
+    #[tokio::test]
+    async fn simple_window_function() -> Result<()> {
+        roundtrip("SELECT RANK() OVER (PARTITION BY a ORDER BY b), d, SUM(b) 
OVER (PARTITION BY a) FROM data;").await
+    }
+
     async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> 
Result<()> {
         let mut ctx = create_context().await?;
         let df = ctx.sql(sql).await?;

Reply via email to