This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 81aff944bd feat: support UDWFs in Substrait (#11489)
81aff944bd is described below
commit 81aff944bd76b674a22371f7deaa12560d2f629d
Author: Arttu <[email protected]>
AuthorDate: Tue Jul 16 22:54:50 2024 +0200
feat: support UDWFs in Substrait (#11489)
* feat: support UDWFs in Substrait
Previously Substrait consumer would, for window functions, look at:
1. UDAFs
2. built-in window functions
3. built-in aggregate functions
That makes it tough to override the built-in
window function behavior, as it could
only be overridden with a UDAF but some
window functions don't fit nicely into aggregates.
This change adds UDWFs at the top, so the consumer will look at:
1. UDWFs
2. UDAFs
3. built-in window functions
4. built-in aggregate functions
This also paves the way for moving DF's built-in window funcs into UDWFs.
* check udwf first, then udaf
---
datafusion/substrait/src/logical_plan/consumer.rs | 27 ++++++++--------
.../tests/cases/roundtrip_logical_plan.rs | 36 +++++++++++++++++++++-
2 files changed, 50 insertions(+), 13 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 991aa61fbf..1365630d50 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -23,8 +23,8 @@ use datafusion::arrow::datatypes::{
};
use datafusion::common::plan_err;
use datafusion::common::{
- not_impl_datafusion_err, not_impl_err, plan_datafusion_err,
substrait_datafusion_err,
- substrait_err, DFSchema, DFSchemaRef,
+ not_impl_err, plan_datafusion_err, substrait_datafusion_err,
substrait_err, DFSchema,
+ DFSchemaRef,
};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};
@@ -1182,16 +1182,19 @@ pub async fn from_substrait_rex(
};
let fn_name = substrait_fun_name(fn_name);
- // check udaf first, then built-in functions
- let fun = match ctx.udaf(fn_name) {
- Ok(udaf) => Ok(WindowFunctionDefinition::AggregateUDF(udaf)),
- Err(_) => find_df_window_func(fn_name).ok_or_else(|| {
- not_impl_datafusion_err!(
- "Window function {} is not supported: function anchor
= {:?}",
- fn_name,
- window.function_reference
- )
- }),
+ // check udwf first, then udaf, then built-in window and aggregate
functions
+ let fun = if let Ok(udwf) = ctx.udwf(fn_name) {
+ Ok(WindowFunctionDefinition::WindowUDF(udwf))
+ } else if let Ok(udaf) = ctx.udaf(fn_name) {
+ Ok(WindowFunctionDefinition::AggregateUDF(udaf))
+ } else if let Some(fun) = find_df_window_func(fn_name) {
+ Ok(fun)
+ } else {
+ not_impl_err!(
+ "Window function {} is not supported: function anchor =
{:?}",
+ fn_name,
+ window.function_reference
+ )
}?;
let order_by =
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 5b2d0fbaca..a7653e11d5 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -31,7 +31,8 @@ use datafusion::error::Result;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::logical_expr::{
- Extension, LogicalPlan, Repartition, UserDefinedLogicalNode, Volatility,
+ Extension, LogicalPlan, PartitionEvaluator, Repartition,
UserDefinedLogicalNode,
+ Volatility,
};
use
datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
use datafusion::prelude::*;
@@ -860,6 +861,39 @@ async fn roundtrip_aggregate_udf() -> Result<()> {
roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await
}
+#[tokio::test]
+async fn roundtrip_window_udf() -> Result<()> {
+ #[derive(Debug)]
+ struct Dummy {}
+
+ impl PartitionEvaluator for Dummy {
+ fn evaluate_all(
+ &mut self,
+ values: &[ArrayRef],
+ _num_rows: usize,
+ ) -> Result<ArrayRef> {
+ Ok(values[0].to_owned())
+ }
+ }
+
+ fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
+ Ok(Box::new(Dummy {}))
+ }
+
+ let dummy_agg = create_udwf(
+ "dummy_window", // name
+ DataType::Int64, // input type
+ Arc::new(DataType::Int64), // return type
+ Volatility::Immutable,
+ Arc::new(make_partition_evaluator),
+ );
+
+ let ctx = create_context().await?;
+ ctx.register_udwf(dummy_agg);
+
+ roundtrip_with_ctx("select dummy_window(a) OVER () from data", ctx).await
+}
+
#[tokio::test]
async fn roundtrip_repartition_roundrobin() -> Result<()> {
let ctx = create_context().await?;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]