This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 34d9bb5e6 Add window func related logic plan to proto ability. (#4485)
34d9bb5e6 is described below
commit 34d9bb5e64e01e1baca4f636c855082f4cadc270
Author: Yang Jiang <[email protected]>
AuthorDate: Mon Dec 5 08:33:36 2022 +0800
Add window func related logic plan to proto ability. (#4485)
* Add window func related logic plan to proto ability.
Signed-off-by: yangjiang <[email protected]>
* add test.
Signed-off-by: yangjiang <[email protected]>
* more functional
Signed-off-by: yangjiang <[email protected]>
Signed-off-by: yangjiang <[email protected]>
---
datafusion/proto/src/from_proto.rs | 24 +++++++-------
datafusion/proto/src/lib.rs | 67 +++++++++++++++++++++++++++++++++++++-
2 files changed, 79 insertions(+), 12 deletions(-)
diff --git a/datafusion/proto/src/from_proto.rs
b/datafusion/proto/src/from_proto.rs
index 1496050d2..95d605d37 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -806,11 +806,15 @@ pub fn parse_expr(
.ok_or_else(|| Error::unknown("BuiltInWindowFunction",
*i))?
.into();
+ let args = parse_optional_expr(&expr.expr, registry)?
+ .map(|e| vec![e])
+ .unwrap_or_else(Vec::new);
+
Ok(Expr::WindowFunction {
fun:
datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction(
built_in_function,
),
- args: vec![parse_required_expr(&expr.expr, registry,
"expr")?],
+ args,
partition_by,
order_by,
window_frame,
@@ -1240,16 +1244,14 @@ impl TryFrom<protobuf::WindowFrameBound> for
WindowFrameBound {
})?;
match bound_type {
protobuf::WindowFrameBoundType::CurrentRow => Ok(Self::CurrentRow),
- protobuf::WindowFrameBoundType::Preceding => {
- // FIXME implement bound value parsing
- // https://github.com/apache/arrow-datafusion/issues/361
- Ok(Self::Preceding(ScalarValue::UInt64(Some(1))))
- }
- protobuf::WindowFrameBoundType::Following => {
- // FIXME implement bound value parsing
- // https://github.com/apache/arrow-datafusion/issues/361
- Ok(Self::Following(ScalarValue::UInt64(Some(1))))
- }
+ protobuf::WindowFrameBoundType::Preceding => match
bound.bound_value {
+ Some(x) => Ok(Self::Preceding(ScalarValue::try_from(&x)?)),
+ None => Ok(Self::Preceding(ScalarValue::UInt64(None))),
+ },
+ protobuf::WindowFrameBoundType::Following => match
bound.bound_value {
+ Some(x) => Ok(Self::Following(ScalarValue::try_from(&x)?)),
+ None => Ok(Self::Following(ScalarValue::UInt64(None))),
+ },
}
}
}
diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs
index bf4b777ff..12c2a5e78 100644
--- a/datafusion/proto/src/lib.rs
+++ b/datafusion/proto/src/lib.rs
@@ -70,7 +70,6 @@ mod roundtrip_tests {
};
use datafusion::test_util::{TestTableFactory, TestTableProvider};
use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
- use datafusion_expr::create_udaf;
use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet,
Like};
use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode};
use datafusion_expr::{
@@ -78,6 +77,9 @@ mod roundtrip_tests {
BuiltinScalarFunction::{Sqrt, Substr},
Expr, LogicalPlan, Operator, Volatility,
};
+ use datafusion_expr::{
+ create_udaf, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunction,
+ };
use prost::Message;
use std::any::Any;
use std::collections::HashMap;
@@ -1331,4 +1333,67 @@ mod roundtrip_tests {
roundtrip_expr_test(test_expr, ctx.clone());
roundtrip_expr_test(test_expr_with_count, ctx);
}
+ #[test]
+ fn roundtrip_window() {
+ let ctx = SessionContext::new();
+
+ // 1. without window_frame
+ let test_expr1 = Expr::WindowFunction {
+ fun: WindowFunction::BuiltInWindowFunction(
+ datafusion_expr::window_function::BuiltInWindowFunction::Rank,
+ ),
+ args: vec![],
+ partition_by: vec![col("col1")],
+ order_by: vec![col("col2")],
+ window_frame: None,
+ };
+
+ // 2. with default window_frame
+ let test_expr2 = Expr::WindowFunction {
+ fun: WindowFunction::BuiltInWindowFunction(
+ datafusion_expr::window_function::BuiltInWindowFunction::Rank,
+ ),
+ args: vec![],
+ partition_by: vec![col("col1")],
+ order_by: vec![col("col2")],
+ window_frame: Some(WindowFrame::default()),
+ };
+
+ // 3. with window_frame with row numbers
+ let range_number_frame = WindowFrame {
+ units: WindowFrameUnits::Range,
+ start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
+ end_bound:
WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
+ };
+
+ let test_expr3 = Expr::WindowFunction {
+ fun: WindowFunction::BuiltInWindowFunction(
+ datafusion_expr::window_function::BuiltInWindowFunction::Rank,
+ ),
+ args: vec![],
+ partition_by: vec![col("col1")],
+ order_by: vec![col("col2")],
+ window_frame: Some(range_number_frame),
+ };
+
+ // 4. test with AggregateFunction
+ let row_number_frame = WindowFrame {
+ units: WindowFrameUnits::Rows,
+ start_bound:
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
+ end_bound:
WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
+ };
+
+ let test_expr4 = Expr::WindowFunction {
+ fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+ args: vec![col("col1")],
+ partition_by: vec![col("col1")],
+ order_by: vec![col("col2")],
+ window_frame: Some(row_number_frame),
+ };
+
+ roundtrip_expr_test(test_expr1, ctx.clone());
+ roundtrip_expr_test(test_expr2, ctx.clone());
+ roundtrip_expr_test(test_expr3, ctx.clone());
+ roundtrip_expr_test(test_expr4, ctx);
+ }
}