metesynnada commented on code in PR #4553:
URL: https://github.com/apache/arrow-datafusion/pull/4553#discussion_r1044165304


##########
datafusion/expr/src/window_function.rs:
##########
@@ -79,6 +74,7 @@ impl fmt::Display for WindowFunction {
         match self {
             WindowFunction::AggregateFunction(fun) => fun.fmt(f),
             WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f),
+            WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f),

Review Comment:
   LGTM.



##########
datafusion/expr/src/window_function.rs:
##########
@@ -35,24 +36,18 @@ pub enum WindowFunction {
     AggregateFunction(AggregateFunction),
     /// window function that leverages a built-in window function
     BuiltInWindowFunction(BuiltInWindowFunction),
+    AggregateUDF(Arc<AggregateUDF>),
 }
 
-impl FromStr for WindowFunction {
-    type Err = DataFusionError;
-    fn from_str(name: &str) -> Result<WindowFunction> {
-        let name = name.to_lowercase();
-        if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) {
-            Ok(WindowFunction::AggregateFunction(aggregate))
-        } else if let Ok(built_in_function) =
-            BuiltInWindowFunction::from_str(name.as_str())
-        {
-            Ok(WindowFunction::BuiltInWindowFunction(built_in_function))
-        } else {
-            Err(DataFusionError::Plan(format!(
-                "There is no window function named {}",
-                name
-            )))
-        }
+/// Find DataFusion's built-in window function by name.
+pub fn find_df_window_func(name: &str) -> Option<WindowFunction> {

Review Comment:
   Quite logical since you do not have access to the `SessionState`. I believe 
that you already checked that option.



##########
datafusion/sql/src/planner.rs:
##########
@@ -5273,6 +5288,27 @@ mod tests {
         quick_test(sql, expected);
     }
 
+    #[test]
+    fn udaf_as_window_func() -> Result<()> {
+        let my_max = create_udaf(
+            "my_max",
+            DataType::Int32,
+            Arc::new(DataType::Int32),
+            Volatility::Immutable,
+            Arc::new(|_| 
Ok(Box::new(MaxAccumulator::try_new(&DataType::Int32)?))),
+            Arc::new(vec![DataType::Int32]),
+        );
+
+        let mut context = MockContextProvider::default();

Review Comment:
   I think you can also add the usual `SessionContext` and 
`ctx.register_udaf(my_max)` test.



##########
datafusion/sql/src/planner.rs:
##########
@@ -2408,6 +2406,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         }
     }
 
+    fn find_window_func(&self, name: &str) -> Result<WindowFunction> {
+        window_function::find_df_window_func(name)

Review Comment:
   I think this is OK, if you cannot find a name, search in UDAFs.



##########
datafusion/expr/src/window_function.rs:
##########
@@ -153,6 +149,9 @@ pub fn return_type(
         WindowFunction::BuiltInWindowFunction(fun) => {
             return_type_for_built_in(fun, input_expr_types)
         }
+        WindowFunction::AggregateUDF(fun) => {
+            Ok((*(fun.return_type)(input_expr_types)?).clone())

Review Comment:
   LGTM.



##########
datafusion/core/src/physical_plan/windows/mod.rs:
##########
@@ -180,6 +188,81 @@ mod tests {
         Ok((csv, schema))
     }
 
+    #[tokio::test]
+    async fn window_function_with_udaf() -> Result<()> {
+        #[derive(Debug)]
+        struct MyCount(i64);
+
+        impl Accumulator for MyCount {
+            fn state(&self) -> Result<Vec<AggregateState>> {
+                Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some(
+                    self.0,
+                )))])
+            }
+
+            fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+                let array = &values[0];
+                self.0 += (array.len() - array.data().null_count()) as i64;
+                Ok(())
+            }
+
+            fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+                let counts = downcast_value!(states[0], Int64Array);
+                if let Some(c) = &arrow::compute::sum(counts) {
+                    self.0 += *c;
+                }
+                Ok(())
+            }
+
+            fn evaluate(&self) -> Result<ScalarValue> {
+                Ok(ScalarValue::Int64(Some(self.0)))
+            }
+
+            fn size(&self) -> usize {
+                std::mem::size_of_val(self)
+            }
+        }
+
+        let my_count = create_udaf(
+            "my_count",
+            DataType::Int64,
+            Arc::new(DataType::Int64),
+            Volatility::Immutable,
+            Arc::new(|_| Ok(Box::new(MyCount(0)))),
+            Arc::new(vec![DataType::Int64]),
+        );
+
+        let session_ctx = SessionContext::new();
+        let task_ctx = session_ctx.task_ctx();
+        let (input, schema) = create_test_schema(1)?;
+
+        let window_exec = Arc::new(WindowAggExec::try_new(

Review Comment:
   Cool test.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to