This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 11ac83615f Add window function support (#5653)
11ac83615f is described below
commit 11ac83615f807a68b31c5bced209c0d42ef40b3c
Author: Nuttiiya Seekhao <[email protected]>
AuthorDate: Wed Mar 22 13:35:22 2023 -0400
Add window function support (#5653)
---
datafusion/substrait/src/logical_plan/consumer.rs | 233 +++++++++++++++++----
datafusion/substrait/src/logical_plan/producer.rs | 218 ++++++++++++++++++-
.../substrait/tests/roundtrip_logical_plan.rs | 5 +
3 files changed, 408 insertions(+), 48 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index b7cacf131d..767c4a3937 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -19,10 +19,11 @@ use async_recursion::async_recursion;
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{DFField, DFSchema, DFSchemaRef};
use datafusion::logical_expr::{
- aggregate_function, BinaryExpr, Case, Expr, LogicalPlan, Operator,
+ aggregate_function, window_function::find_df_window_func, BinaryExpr,
Case, Expr,
+ LogicalPlan, Operator,
};
use datafusion::logical_expr::{build_join_schema, LogicalPlanBuilder};
-use datafusion::logical_expr::{expr, Cast};
+use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
use datafusion::prelude::JoinType;
use datafusion::sql::TableReference;
use datafusion::{
@@ -35,7 +36,10 @@ use substrait::proto::{
aggregate_function::AggregationInvocation,
expression::{
field_reference::ReferenceType::DirectReference, literal::LiteralType,
- reference_segment::ReferenceType::StructField, MaskExpression, RexType,
+ reference_segment::ReferenceType::StructField,
+ window_function::bound as SubstraitBound,
+ window_function::bound::Kind as BoundKind, window_function::Bound,
+ MaskExpression, RexType,
},
extensions::simple_extension_declaration::MappingType,
function_argument::ArgType,
@@ -45,6 +49,7 @@ use substrait::proto::{
sort_field::{SortDirection, SortKind::*},
AggregateFunction, Expression, Plan, Rel, Type,
};
+use substrait::proto::{FunctionArgument, SortField};
use datafusion::logical_expr::expr::Sort;
use std::collections::HashMap;
@@ -139,13 +144,25 @@ pub async fn from_substrait_rel(
match &rel.rel_type {
Some(RelType::Project(p)) => {
if let Some(input) = p.input.as_ref() {
- let input = LogicalPlanBuilder::from(
+ let mut input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
let mut exprs: Vec<Expr> = vec![];
for e in &p.expressions {
- let x = from_substrait_rex(e, input.schema(),
extensions).await?;
- exprs.push(x.as_ref().clone());
+ let x =
+ from_substrait_rex(e, input.clone().schema(),
extensions).await?;
+ // if the expression is WindowFunction, wrap in a Window
relation
+ // before returning and do not add to list of this
Projection's expression list
+ // otherwise, add expression to the Projection's
expression list
+ match &*x {
+ Expr::WindowFunction(_) => {
+ input = input.window(vec![x.as_ref().clone()])?;
+ exprs.push(x.as_ref().clone());
+ }
+ _ => {
+ exprs.push(x.as_ref().clone());
+ }
+ }
}
input.project(exprs)?.build()
} else {
@@ -193,45 +210,8 @@ pub async fn from_substrait_rel(
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
- let mut sorts: Vec<Expr> = vec![];
- for s in &sort.sorts {
- let expr = from_substrait_rex(
- s.expr.as_ref().unwrap(),
- input.schema(),
- extensions,
- )
- .await?;
- let asc_nullfirst = match &s.sort_kind {
- Some(k) => match k {
- Direction(d) => {
- let direction : SortDirection = unsafe {
- ::std::mem::transmute(*d)
- };
- match direction {
- SortDirection::AscNullsFirst => Ok((true,
true)),
- SortDirection::AscNullsLast => Ok((true,
false)),
- SortDirection::DescNullsFirst =>
Ok((false, true)),
- SortDirection::DescNullsLast => Ok((false,
false)),
- SortDirection::Clustered =>
-
Err(DataFusionError::NotImplemented("Sort with direction clustered is not yet
supported".to_string()))
- ,
- SortDirection::Unspecified =>
-
Err(DataFusionError::NotImplemented("Unspecified sort direction is
invalid".to_string()))
- }
- }
- ComparisonFunctionReference(_) => {
- Err(DataFusionError::NotImplemented("Sort
using comparison function reference is not supported".to_string()))
- },
- },
- None => Err(DataFusionError::NotImplemented("Sort
without sort kind is invalid".to_string()))
- };
- let (asc, nulls_first) = asc_nullfirst.unwrap();
- sorts.push(Expr::Sort(Sort {
- expr: Box::new(expr.as_ref().clone()),
- asc,
- nulls_first,
- }));
- }
+ let sorts =
+ from_substrait_sorts(&sort.sorts, input.schema(),
extensions).await?;
input.sort(sorts)?.build()
} else {
Err(DataFusionError::NotImplemented(
@@ -452,6 +432,90 @@ fn from_substrait_jointype(join_type: i32) ->
Result<JoinType> {
}
}
+/// Convert Substrait Sorts to DataFusion Exprs
+pub async fn from_substrait_sorts(
+ substrait_sorts: &Vec<SortField>,
+ input_schema: &DFSchema,
+ extensions: &HashMap<u32, &String>,
+) -> Result<Vec<Expr>> {
+ let mut sorts: Vec<Expr> = vec![];
+ for s in substrait_sorts {
+ let expr = from_substrait_rex(s.expr.as_ref().unwrap(), input_schema,
extensions)
+ .await?;
+ let asc_nullfirst = match &s.sort_kind {
+ Some(k) => match k {
+ Direction(d) => {
+ let direction: SortDirection = unsafe {
::std::mem::transmute(*d) };
+ match direction {
+ SortDirection::AscNullsFirst => Ok((true, true)),
+ SortDirection::AscNullsLast => Ok((true, false)),
+ SortDirection::DescNullsFirst => Ok((false, true)),
+ SortDirection::DescNullsLast => Ok((false, false)),
+ SortDirection::Clustered =>
Err(DataFusionError::NotImplemented(
+ "Sort with direction clustered is not yet
supported"
+ .to_string(),
+ )),
+ SortDirection::Unspecified => {
+ Err(DataFusionError::NotImplemented(
+ "Unspecified sort direction is
invalid".to_string(),
+ ))
+ }
+ }
+ }
+ ComparisonFunctionReference(_) =>
Err(DataFusionError::NotImplemented(
+ "Sort using comparison function reference is not supported"
+ .to_string(),
+ )),
+ },
+ None => Err(DataFusionError::NotImplemented(
+ "Sort without sort kind is invalid".to_string(),
+ )),
+ };
+ let (asc, nulls_first) = asc_nullfirst.unwrap();
+ sorts.push(Expr::Sort(Sort {
+ expr: Box::new(expr.as_ref().clone()),
+ asc,
+ nulls_first,
+ }));
+ }
+ Ok(sorts)
+}
+
+/// Convert Substrait Expressions to DataFusion Exprs
+pub async fn from_substrait_rex_vec(
+ exprs: &Vec<Expression>,
+ input_schema: &DFSchema,
+ extensions: &HashMap<u32, &String>,
+) -> Result<Vec<Expr>> {
+ let mut expressions: Vec<Expr> = vec![];
+ for expr in exprs {
+ let expression = from_substrait_rex(expr, input_schema,
extensions).await?;
+ expressions.push(expression.as_ref().clone());
+ }
+ Ok(expressions)
+}
+
+/// Convert Substrait FunctionArguments to DataFusion Exprs
+pub async fn from_substriat_func_args(
+ arguments: &Vec<FunctionArgument>,
+ input_schema: &DFSchema,
+ extensions: &HashMap<u32, &String>,
+) -> Result<Vec<Expr>> {
+ let mut args: Vec<Expr> = vec![];
+ for arg in arguments {
+ let arg_expr = match &arg.arg_type {
+ Some(ArgType::Value(e)) => {
+ from_substrait_rex(e, input_schema, extensions).await
+ }
+ _ => Err(DataFusionError::NotImplemented(
+ "Aggregated function argument non-Value type not
supported".to_string(),
+ )),
+ };
+ args.push(arg_expr?.as_ref().clone());
+ }
+ Ok(args)
+}
+
/// Convert Substrait AggregateFunction to DataFusion Expr
pub async fn from_substrait_agg_func(
f: &AggregateFunction,
@@ -740,6 +804,47 @@ pub async fn from_substrait_rex(
"Cast experssion without output type is not
allowed".to_string(),
)),
},
+ Some(RexType::WindowFunction(window)) => {
+ let fun = match extensions.get(&window.function_reference) {
+ Some(function_name) => Ok(find_df_window_func(function_name)),
+ None => Err(DataFusionError::NotImplemented(format!(
+ "Window function not found: function anchor = {:?}",
+ &window.function_reference
+ ))),
+ };
+ let order_by =
+ from_substrait_sorts(&window.sorts, input_schema,
extensions).await?;
+ // Substrait does not encode WindowFrameUnits so we're using a
simple logic to determine the units
+ // If there is no `ORDER BY`, then by default, the frame counts
each row from the lower up to upper boundary
+ // If there is `ORDER BY`, then by default, each frame is a range
starting from unbounded preceding to current row
+ // TODO: Consider the cases where window frame is specified in
query and is different from default
+ let units = if order_by.is_empty() {
+ WindowFrameUnits::Rows
+ } else {
+ WindowFrameUnits::Range
+ };
+ Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction {
+ fun: fun?.unwrap(),
+ args: from_substriat_func_args(
+ &window.arguments,
+ input_schema,
+ extensions,
+ )
+ .await?,
+ partition_by: from_substrait_rex_vec(
+ &window.partitions,
+ input_schema,
+ extensions,
+ )
+ .await?,
+ order_by,
+ window_frame: datafusion::logical_expr::WindowFrame {
+ units,
+ start_bound: from_substrait_bound(&window.lower_bound,
true)?,
+ end_bound: from_substrait_bound(&window.upper_bound,
false)?,
+ },
+ })))
+ }
_ => Err(DataFusionError::NotImplemented(
"unsupported rex_type".to_string(),
)),
@@ -767,6 +872,44 @@ fn from_substrait_type(dt: &substrait::proto::Type) ->
Result<DataType> {
}
}
+fn from_substrait_bound(
+ bound: &Option<Bound>,
+ is_lower: bool,
+) -> Result<WindowFrameBound> {
+ match bound {
+ Some(b) => match &b.kind {
+ Some(k) => match k {
+ BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => {
+ Ok(WindowFrameBound::CurrentRow)
+ }
+ BoundKind::Preceding(SubstraitBound::Preceding { offset }) =>
Ok(
+
WindowFrameBound::Preceding(ScalarValue::Int64(Some(*offset))),
+ ),
+ BoundKind::Following(SubstraitBound::Following { offset }) =>
Ok(
+
WindowFrameBound::Following(ScalarValue::Int64(Some(*offset))),
+ ),
+ BoundKind::Unbounded(SubstraitBound::Unbounded {}) => {
+ if is_lower {
+ Ok(WindowFrameBound::Preceding(ScalarValue::Null))
+ } else {
+ Ok(WindowFrameBound::Following(ScalarValue::Null))
+ }
+ }
+ },
+ None => Err(DataFusionError::Substrait(
+ "WindowFunction missing Substrait Bound kind".to_string(),
+ )),
+ },
+ None => {
+ if is_lower {
+ Ok(WindowFrameBound::Preceding(ScalarValue::Null))
+ } else {
+ Ok(WindowFrameBound::Following(ScalarValue::Null))
+ }
+ }
+ }
+}
+
fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
if let Some(kind) = &null_type.kind {
match kind {
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 1de26a3433..ecb322edb7 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -20,6 +20,7 @@ use std::{collections::HashMap, mem, sync::Arc};
use datafusion::{
arrow::datatypes::DataType,
error::{DataFusionError, Result},
+ logical_expr::{WindowFrame, WindowFrameBound},
prelude::JoinType,
scalar::ScalarValue,
};
@@ -27,7 +28,7 @@ use datafusion::{
use datafusion::common::DFSchemaRef;
#[allow(unused_imports)]
use datafusion::logical_expr::aggregate_function;
-use datafusion::logical_expr::expr::{BinaryExpr, Case, Cast, Sort};
+use datafusion::logical_expr::expr::{BinaryExpr, Case, Cast, Sort,
WindowFunction};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan,
Operator};
use datafusion::prelude::{binary_expr, Expr};
use substrait::proto::{
@@ -38,8 +39,12 @@ use substrait::proto::{
if_then::IfClause,
literal::{Decimal, LiteralType},
mask_expression::{StructItem, StructSelect},
- reference_segment, FieldReference, IfThen, Literal, MaskExpression,
- ReferenceSegment, RexType, ScalarFunction,
+ reference_segment,
+ window_function::bound as SubstraitBound,
+ window_function::bound::Kind as BoundKind,
+ window_function::Bound,
+ FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment,
RexType,
+ ScalarFunction, WindowFunction as SubstraitWindowFunction,
},
extensions::{
self,
@@ -301,6 +306,42 @@ pub fn to_substrait_rel(
// since there is no corresponding relation type in Substrait
to_substrait_rel(alias.input.as_ref(), extension_info)
}
+ LogicalPlan::Window(window) => {
+ let input = to_substrait_rel(window.input.as_ref(),
extension_info)?;
+ // If the input is a Project relation, we can just append the
WindowFunction expressions
+ // before returning
+ // Otherwise, wrap the input in a Project relation before
appending the WindowFunction
+ // expressions
+ let mut project_rel: Box<ProjectRel> = match
&input.as_ref().rel_type {
+ Some(RelType::Project(p)) => Box::new(*p.clone()),
+ _ => {
+ // Create Projection with field referencing all output
fields in the input relation
+ let expressions = (0..window.input.schema().fields().len())
+ .map(substrait_field_ref)
+ .collect::<Result<Vec<_>>>()?;
+ Box::new(ProjectRel {
+ common: None,
+ input: Some(input),
+ expressions,
+ advanced_extension: None,
+ })
+ }
+ };
+ // Parse WindowFunction expression
+ let mut window_exprs = vec![];
+ for expr in &window.window_expr {
+ window_exprs.push(to_substrait_rex(
+ expr,
+ window.input.schema(),
+ extension_info,
+ )?);
+ }
+ // Append parsed WindowFunction expressions
+ project_rel.expressions.extend(window_exprs);
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Project(project_rel)),
+ }))
+ }
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported operator: {plan:?}"
))),
@@ -636,6 +677,47 @@ pub fn to_substrait_rex(
})
}
Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema,
extension_info),
+ Expr::WindowFunction(WindowFunction {
+ fun,
+ args,
+ partition_by,
+ order_by,
+ window_frame,
+ }) => {
+ // function reference
+ let function_name = fun.to_string().to_lowercase();
+ let function_anchor = _register_function(function_name,
extension_info);
+ // arguments
+ let mut arguments: Vec<FunctionArgument> = vec![];
+ for arg in args {
+ arguments.push(FunctionArgument {
+ arg_type: Some(ArgType::Value(to_substrait_rex(
+ arg,
+ schema,
+ extension_info,
+ )?)),
+ });
+ }
+ // partition by expressions
+ let partition_by = partition_by
+ .iter()
+ .map(|e| to_substrait_rex(e, schema, extension_info))
+ .collect::<Result<Vec<_>>>()?;
+ // order by expressions
+ let order_by = order_by
+ .iter()
+ .map(|e| substrait_sort_field(e, schema, extension_info))
+ .collect::<Result<Vec<_>>>()?;
+ // window frame
+ let bounds = to_substrait_bounds(window_frame)?;
+ Ok(make_substrait_window_function(
+ function_anchor,
+ arguments,
+ partition_by,
+ order_by,
+ bounds,
+ ))
+ }
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported expression: {expr:?}"
))),
@@ -693,6 +775,136 @@ fn to_substrait_type(dt: &DataType) ->
Result<substrait::proto::Type> {
}
}
+#[allow(deprecated)]
+fn make_substrait_window_function(
+ function_reference: u32,
+ arguments: Vec<FunctionArgument>,
+ partitions: Vec<Expression>,
+ sorts: Vec<SortField>,
+ bounds: (Bound, Bound),
+) -> Expression {
+ Expression {
+ rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction {
+ function_reference,
+ arguments,
+ partitions,
+ sorts,
+ options: vec![],
+ output_type: None,
+ phase: 0, // default to AGGREGATION_PHASE_UNSPECIFIED
+ invocation: 0, // TODO: fix
+ lower_bound: Some(bounds.0),
+ upper_bound: Some(bounds.1),
+ args: vec![],
+ })),
+ }
+}
+
+fn to_substrait_bound(bound: &WindowFrameBound) -> Bound {
+ match bound {
+ WindowFrameBound::CurrentRow => Bound {
+ kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})),
+ },
+ WindowFrameBound::Preceding(s) => match s {
+ ScalarValue::UInt8(Some(v)) => Bound {
+ kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::UInt16(Some(v)) => Bound {
+ kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::UInt32(Some(v)) => Bound {
+ kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::UInt64(Some(v)) => Bound {
+ kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::Int8(Some(v)) => Bound {
+ kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::Int16(Some(v)) => Bound {
+ kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::Int32(Some(v)) => Bound {
+ kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::Int64(Some(v)) => Bound {
+ kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
+ offset: *v,
+ })),
+ },
+ _ => Bound {
+ kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})),
+ },
+ },
+ WindowFrameBound::Following(s) => match s {
+ ScalarValue::UInt8(Some(v)) => Bound {
+ kind: Some(BoundKind::Following(SubstraitBound::Following {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::UInt16(Some(v)) => Bound {
+ kind: Some(BoundKind::Following(SubstraitBound::Following {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::UInt32(Some(v)) => Bound {
+ kind: Some(BoundKind::Following(SubstraitBound::Following {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::UInt64(Some(v)) => Bound {
+ kind: Some(BoundKind::Following(SubstraitBound::Following {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::Int8(Some(v)) => Bound {
+ kind: Some(BoundKind::Following(SubstraitBound::Following {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::Int16(Some(v)) => Bound {
+ kind: Some(BoundKind::Following(SubstraitBound::Following {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::Int32(Some(v)) => Bound {
+ kind: Some(BoundKind::Following(SubstraitBound::Following {
+ offset: *v as i64,
+ })),
+ },
+ ScalarValue::Int64(Some(v)) => Bound {
+ kind: Some(BoundKind::Following(SubstraitBound::Following {
+ offset: *v,
+ })),
+ },
+ _ => Bound {
+ kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})),
+ },
+ },
+ }
+}
+
+fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> {
+ Ok((
+ to_substrait_bound(&window_frame.start_bound),
+ to_substrait_bound(&window_frame.end_bound),
+ ))
+}
+
fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
let default_type_ref = 0;
let default_nullability = r#type::Nullability::Nullable as i32;
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 9aa430bb09..936c4670b3 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -250,6 +250,11 @@ mod tests {
.await
}
+ #[tokio::test]
+ async fn simple_window_function() -> Result<()> {
+ roundtrip("SELECT RANK() OVER (PARTITION BY a ORDER BY b), d, SUM(b)
OVER (PARTITION BY a) FROM data;").await
+ }
+
async fn assert_expected_plan(sql: &str, expected_plan_str: &str) ->
Result<()> {
let mut ctx = create_context().await?;
let df = ctx.sql(sql).await?;