This is an automated email from the ASF dual-hosted git repository.

akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b5db718776 [MINOR]: Update create_window_expr to refer only input 
schema (#8945)
b5db718776 is described below

commit b5db7187763bc4511aaffdd6d89b2f0908f17938
Author: Mustafa Akur <[email protected]>
AuthorDate: Wed Jan 24 13:26:20 2024 +0300

    [MINOR]: Update create_window_expr to refer only input schema (#8945)
    
    * create_window_expr now receives physical input schema
    
    * Resolve linter errors
    
    * Match argument signature for some window functions
    
    * Remove physical input_schema
---
 datafusion/core/src/physical_planner.rs         |  7 +--
 datafusion/core/tests/fuzz_cases/window_fuzz.rs | 84 ++++++++-----------------
 datafusion/physical-plan/src/windows/mod.rs     |  9 ++-
 3 files changed, 36 insertions(+), 64 deletions(-)

diff --git a/datafusion/core/src/physical_planner.rs 
b/datafusion/core/src/physical_planner.rs
index ed92688559..ac7827fafc 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -86,7 +86,6 @@ use datafusion_expr::expr::{
 };
 use datafusion_expr::expr_rewriter::unnormalize_cols;
 use 
datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
-use datafusion_expr::utils::exprlist_to_fields;
 use datafusion_expr::{
     DescribeTable, DmlStatement, RecursiveQuery, ScalarFunctionDefinition,
     StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
@@ -720,16 +719,12 @@ impl DefaultPhysicalPlanner {
                     }
 
                     let logical_input_schema = input.schema();
-                    // Extend the schema to include window expression fields 
as builtin window functions derives its datatype from incoming schema
-                    let mut window_fields = 
logical_input_schema.fields().clone();
-                    
window_fields.extend_from_slice(&exprlist_to_fields(window_expr.iter(), 
input)?);
-                    let extended_schema = 
&DFSchema::new_with_metadata(window_fields, HashMap::new())?;
                     let window_expr = window_expr
                         .iter()
                         .map(|e| {
                             create_window_expr(
                                 e,
-                                extended_schema,
+                                logical_input_schema,
                                 session_state.execution_props(),
                             )
                         })
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index 4c440d6a5b..7358ec2884 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -22,7 +22,6 @@ use arrow::compute::{concat_batches, SortOptions};
 use arrow::datatypes::SchemaRef;
 use arrow::record_batch::RecordBatch;
 use arrow::util::pretty::pretty_format_batches;
-use arrow_schema::{Field, Schema};
 use datafusion::physical_plan::memory::MemoryExec;
 use datafusion::physical_plan::sorts::sort::SortExec;
 use datafusion::physical_plan::windows::{
@@ -38,7 +37,6 @@ use datafusion_expr::{
 };
 use datafusion_physical_expr::expressions::{cast, col, lit};
 use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
-use itertools::Itertools;
 use test_utils::add_empty_batches;
 
 use hashbrown::HashMap;
@@ -229,14 +227,14 @@ fn get_random_function(
     rng: &mut StdRng,
     is_linear: bool,
 ) -> (WindowFunctionDefinition, Vec<Arc<dyn PhysicalExpr>>, String) {
-    let mut args = if is_linear {
+    let arg = if is_linear {
         // In linear test for the test version with WindowAggExec we use 
insert SortExecs to the plan to be able to generate
         // same result with BoundedWindowAggExec which doesn't use any 
SortExec. To make result
         // non-dependent on table order. We should use column a in the window 
function
         // (Given that we do not use ROWS for the window frame. ROWS also 
introduces dependency to the table order.).
-        vec![col("a", schema).unwrap()]
+        col("a", schema).unwrap()
     } else {
-        vec![col("x", schema).unwrap()]
+        col("x", schema).unwrap()
     };
     let mut window_fn_map = HashMap::new();
     // HashMap values consists of tuple first element is WindowFunction, 
second is additional argument
@@ -245,28 +243,28 @@ fn get_random_function(
         "sum",
         (
             
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
-            vec![],
+            vec![arg.clone()],
         ),
     );
     window_fn_map.insert(
         "count",
         (
             
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
-            vec![],
+            vec![arg.clone()],
         ),
     );
     window_fn_map.insert(
         "min",
         (
             
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
-            vec![],
+            vec![arg.clone()],
         ),
     );
     window_fn_map.insert(
         "max",
         (
             
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
-            vec![],
+            vec![arg.clone()],
         ),
     );
     if !is_linear {
@@ -307,6 +305,7 @@ fn get_random_function(
                     BuiltInWindowFunction::Lead,
                 ),
                 vec![
+                    arg.clone(),
                     lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
                     lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))),
                 ],
@@ -319,6 +318,7 @@ fn get_random_function(
                     BuiltInWindowFunction::Lag,
                 ),
                 vec![
+                    arg.clone(),
                     lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
                     lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))),
                 ],
@@ -331,7 +331,7 @@ fn get_random_function(
             WindowFunctionDefinition::BuiltInWindowFunction(
                 BuiltInWindowFunction::FirstValue,
             ),
-            vec![],
+            vec![arg.clone()],
         ),
     );
     window_fn_map.insert(
@@ -340,7 +340,7 @@ fn get_random_function(
             WindowFunctionDefinition::BuiltInWindowFunction(
                 BuiltInWindowFunction::LastValue,
             ),
-            vec![],
+            vec![arg.clone()],
         ),
     );
     window_fn_map.insert(
@@ -349,23 +349,26 @@ fn get_random_function(
             WindowFunctionDefinition::BuiltInWindowFunction(
                 BuiltInWindowFunction::NthValue,
             ),
-            vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))],
+            vec![
+                arg.clone(),
+                lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
+            ],
         ),
     );
 
     let rand_fn_idx = rng.gen_range(0..window_fn_map.len());
     let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
-    let (window_fn, new_args) = 
window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
+    let (window_fn, args) = 
window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
+    let mut args = args.clone();
     if let WindowFunctionDefinition::AggregateFunction(f) = window_fn {
-        let a = args[0].clone();
-        let dt = a.data_type(schema.as_ref()).unwrap();
-        let sig = f.signature();
-        let coerced = coerce_types(f, &[dt], &sig).unwrap();
-        args[0] = cast(a, schema, coerced[0].clone()).unwrap();
-    }
-
-    for new_arg in new_args {
-        args.push(new_arg.clone());
+        if !args.is_empty() {
+            // Do type coercion first argument
+            let a = args[0].clone();
+            let dt = a.data_type(schema.as_ref()).unwrap();
+            let sig = f.signature();
+            let coerced = coerce_types(f, &[dt], &sig).unwrap();
+            args[0] = cast(a, schema, coerced[0].clone()).unwrap();
+        }
     }
 
     (window_fn.clone(), args, fn_name.to_string())
@@ -534,39 +537,6 @@ async fn run_window_test(
         exec1 = Arc::new(SortExec::new(sort_keys.clone(), exec1)) as _;
     }
 
-    // The schema needs to be enriched before the `create_window_expr`
-    // The reason for this is window expressions datatypes are derived from 
the schema
-    // The datafusion code enriches the schema on physical planner and this 
test copies the same behavior manually
-    // Also bunch of functions dont require input arguments thus just send an 
empty vec for such functions
-    let data_types = if [
-        "row_number",
-        "rank",
-        "dense_rank",
-        "percent_rank",
-        "ntile",
-        "cume_dist",
-    ]
-    .contains(&fn_name.as_str())
-    {
-        vec![]
-    } else {
-        args.iter()
-            .map(|e| e.clone().as_ref().data_type(&schema))
-            .collect::<Result<Vec<_>>>()?
-    };
-    let window_expr_return_type = window_fn.return_type(&data_types)?;
-    let mut window_fields = schema
-        .fields()
-        .iter()
-        .map(|f| f.as_ref().clone())
-        .collect_vec();
-    window_fields.extend_from_slice(&[Field::new(
-        &fn_name,
-        window_expr_return_type,
-        true,
-    )]);
-    let extended_schema = Arc::new(Schema::new(window_fields));
-
     let usual_window_exec = Arc::new(
         WindowAggExec::try_new(
             vec![create_window_expr(
@@ -576,7 +546,7 @@ async fn run_window_test(
                 &partitionby_exprs,
                 &orderby_exprs,
                 Arc::new(window_frame.clone()),
-                &extended_schema,
+                schema.as_ref(),
             )
             .unwrap()],
             exec1,
@@ -598,7 +568,7 @@ async fn run_window_test(
                 &partitionby_exprs,
                 &orderby_exprs,
                 Arc::new(window_frame.clone()),
-                extended_schema.as_ref(),
+                schema.as_ref(),
             )
             .unwrap()],
             exec2,
diff --git a/datafusion/physical-plan/src/windows/mod.rs 
b/datafusion/physical-plan/src/windows/mod.rs
index e55cc7fca7..01818405b8 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -160,7 +160,14 @@ fn create_built_in_window_expr(
     input_schema: &Schema,
     name: String,
 ) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
-    let data_type = input_schema.field_with_name(&name)?.data_type();
+    // need to get the types into an owned vec for some reason
+    let input_types: Vec<_> = args
+        .iter()
+        .map(|arg| arg.data_type(input_schema))
+        .collect::<Result<_>>()?;
+
+    // figure out the output type
+    let data_type = &fun.return_type(&input_types)?;
     Ok(match fun {
         BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, 
data_type)),
         BuiltInWindowFunction::Rank => Arc::new(rank(name, data_type)),

Reply via email to