This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 34d9bb5e6 Add window func related logic plan to proto ability. (#4485)
34d9bb5e6 is described below

commit 34d9bb5e64e01e1baca4f636c855082f4cadc270
Author: Yang Jiang <[email protected]>
AuthorDate: Mon Dec 5 08:33:36 2022 +0800

    Add window func related logic plan to proto ability. (#4485)
    
    * Add window func related logic plan to proto ability.
    
    Signed-off-by: yangjiang <[email protected]>
    
    * add test.
    
    Signed-off-by: yangjiang <[email protected]>
    
    * more functional
    
    Signed-off-by: yangjiang <[email protected]>
    
    Signed-off-by: yangjiang <[email protected]>
---
 datafusion/proto/src/from_proto.rs | 24 +++++++-------
 datafusion/proto/src/lib.rs        | 67 +++++++++++++++++++++++++++++++++++++-
 2 files changed, 79 insertions(+), 12 deletions(-)

diff --git a/datafusion/proto/src/from_proto.rs 
b/datafusion/proto/src/from_proto.rs
index 1496050d2..95d605d37 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -806,11 +806,15 @@ pub fn parse_expr(
                         .ok_or_else(|| Error::unknown("BuiltInWindowFunction", 
*i))?
                         .into();
 
+                    let args = parse_optional_expr(&expr.expr, registry)?
+                        .map(|e| vec![e])
+                        .unwrap_or_else(Vec::new);
+
                     Ok(Expr::WindowFunction {
                         fun: 
datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction(
                             built_in_function,
                         ),
-                        args: vec![parse_required_expr(&expr.expr, registry, 
"expr")?],
+                        args,
                         partition_by,
                         order_by,
                         window_frame,
@@ -1240,16 +1244,14 @@ impl TryFrom<protobuf::WindowFrameBound> for 
WindowFrameBound {
                 })?;
         match bound_type {
             protobuf::WindowFrameBoundType::CurrentRow => Ok(Self::CurrentRow),
-            protobuf::WindowFrameBoundType::Preceding => {
-                // FIXME implement bound value parsing
-                // https://github.com/apache/arrow-datafusion/issues/361
-                Ok(Self::Preceding(ScalarValue::UInt64(Some(1))))
-            }
-            protobuf::WindowFrameBoundType::Following => {
-                // FIXME implement bound value parsing
-                // https://github.com/apache/arrow-datafusion/issues/361
-                Ok(Self::Following(ScalarValue::UInt64(Some(1))))
-            }
+            protobuf::WindowFrameBoundType::Preceding => match 
bound.bound_value {
+                Some(x) => Ok(Self::Preceding(ScalarValue::try_from(&x)?)),
+                None => Ok(Self::Preceding(ScalarValue::UInt64(None))),
+            },
+            protobuf::WindowFrameBoundType::Following => match 
bound.bound_value {
+                Some(x) => Ok(Self::Following(ScalarValue::try_from(&x)?)),
+                None => Ok(Self::Following(ScalarValue::UInt64(None))),
+            },
         }
     }
 }
diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs
index bf4b777ff..12c2a5e78 100644
--- a/datafusion/proto/src/lib.rs
+++ b/datafusion/proto/src/lib.rs
@@ -70,7 +70,6 @@ mod roundtrip_tests {
     };
     use datafusion::test_util::{TestTableFactory, TestTableProvider};
     use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
-    use datafusion_expr::create_udaf;
     use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, 
Like};
     use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode};
     use datafusion_expr::{
@@ -78,6 +77,9 @@ mod roundtrip_tests {
         BuiltinScalarFunction::{Sqrt, Substr},
         Expr, LogicalPlan, Operator, Volatility,
     };
+    use datafusion_expr::{
+        create_udaf, WindowFrame, WindowFrameBound, WindowFrameUnits, 
WindowFunction,
+    };
     use prost::Message;
     use std::any::Any;
     use std::collections::HashMap;
@@ -1331,4 +1333,67 @@ mod roundtrip_tests {
         roundtrip_expr_test(test_expr, ctx.clone());
         roundtrip_expr_test(test_expr_with_count, ctx);
     }
+    #[test]
+    fn roundtrip_window() {
+        let ctx = SessionContext::new();
+
+        // 1. without window_frame
+        let test_expr1 = Expr::WindowFunction {
+            fun: WindowFunction::BuiltInWindowFunction(
+                datafusion_expr::window_function::BuiltInWindowFunction::Rank,
+            ),
+            args: vec![],
+            partition_by: vec![col("col1")],
+            order_by: vec![col("col2")],
+            window_frame: None,
+        };
+
+        // 2. with default window_frame
+        let test_expr2 = Expr::WindowFunction {
+            fun: WindowFunction::BuiltInWindowFunction(
+                datafusion_expr::window_function::BuiltInWindowFunction::Rank,
+            ),
+            args: vec![],
+            partition_by: vec![col("col1")],
+            order_by: vec![col("col2")],
+            window_frame: Some(WindowFrame::default()),
+        };
+
+        // 3. with window_frame with row numbers
+        let range_number_frame = WindowFrame {
+            units: WindowFrameUnits::Range,
+            start_bound: 
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
+            end_bound: 
WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
+        };
+
+        let test_expr3 = Expr::WindowFunction {
+            fun: WindowFunction::BuiltInWindowFunction(
+                datafusion_expr::window_function::BuiltInWindowFunction::Rank,
+            ),
+            args: vec![],
+            partition_by: vec![col("col1")],
+            order_by: vec![col("col2")],
+            window_frame: Some(range_number_frame),
+        };
+
+        // 4. test with AggregateFunction
+        let row_number_frame = WindowFrame {
+            units: WindowFrameUnits::Rows,
+            start_bound: 
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
+            end_bound: 
WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
+        };
+
+        let test_expr4 = Expr::WindowFunction {
+            fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+            args: vec![col("col1")],
+            partition_by: vec![col("col1")],
+            order_by: vec![col("col2")],
+            window_frame: Some(row_number_frame),
+        };
+
+        roundtrip_expr_test(test_expr1, ctx.clone());
+        roundtrip_expr_test(test_expr2, ctx.clone());
+        roundtrip_expr_test(test_expr3, ctx.clone());
+        roundtrip_expr_test(test_expr4, ctx);
+    }
 }

Reply via email to