This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new ef90bac1 feat: Support count AggregateUDF for window function (#736)
ef90bac1 is described below

commit ef90bac1964edfba4cdb7284c4e3507dce633517
Author: Huaxin Gao <[email protected]>
AuthorDate: Thu Aug 1 08:01:38 2024 -0700

    feat: Support count AggregateUDF for window function (#736)
    
    * feat: Support count AggregateUDF for window function
    
    * fix style
    
    * fix style
    
    * address comments
    
    * fix style
    
    * fix import order
    
    * remove unused import
    
    * look for AggregateUDF from function registry
    
    * formatting
    
    * style
---
 native/core/src/execution/datafusion/planner.rs         | 17 +++++++++++++++--
 .../scala/org/apache/comet/serde/QueryPlanSerde.scala   |  5 +----
 .../scala/org/apache/comet/exec/CometExecSuite.scala    |  2 +-
 3 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/native/core/src/execution/datafusion/planner.rs 
b/native/core/src/execution/datafusion/planner.rs
index c283cebb..6d6102ae 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -56,7 +56,7 @@ use datafusion_common::{
     JoinType as DFJoinType, ScalarValue,
 };
 use datafusion_expr::expr::find_df_window_func;
-use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
+use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits, 
WindowFunctionDefinition};
 use datafusion_physical_expr::window::WindowExpr;
 use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
 use itertools::Itertools;
@@ -1483,7 +1483,7 @@ impl PhysicalPlanner {
             ));
         }
 
-        let window_func = match find_df_window_func(&window_func_name) {
+        let window_func = match 
self.find_df_window_function(&window_func_name) {
             Some(f) => f,
             _ => {
                 return Err(ExecutionError::GeneralError(format!(
@@ -1599,6 +1599,19 @@ impl PhysicalPlanner {
         }
     }
 
+    /// Find DataFusion's built-in window function by name.
+    fn find_df_window_function(&self, name: &str) -> 
Option<WindowFunctionDefinition> {
+        if let Some(f) = find_df_window_func(name) {
+            Some(f)
+        } else {
+            let registry = &self.session_ctx.state();
+            registry
+                .udaf(name)
+                .map(WindowFunctionDefinition::AggregateUDF)
+                .ok()
+        }
+    }
+
     /// Create a DataFusion physical partitioning from Spark physical 
partitioning
     fn create_partitioning(
         &self,
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index e06405a1..e5acd245 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -208,10 +208,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
       expr match {
         case agg: AggregateExpression =>
           agg.aggregateFunction match {
-            // TODO add support for Count (this was removed when upgrading
-            // to DataFusion 40 because it is no longer a built-in window 
function)
-            // https://github.com/apache/datafusion-comet/issues/645
-            case _: Min | _: Max =>
+            case _: Min | _: Max | _: Count =>
               Some(agg)
             case _ =>
               withInfo(windowExpr, "Unsupported aggregate", expr)
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 40ec349e..5cbc4975 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -1487,7 +1487,7 @@ class CometExecSuite extends CometTestBase {
         SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) {
         withParquetTable((0 until 10).map(i => (i, 10 - i)), "t1") { // TODO: 
test nulls
           val aggregateFunctions =
-            List("MAX(_1)", "MIN(_1)") // TODO: Test all the aggregates
+            List("COUNT(_1)", "COUNT(*)", "MAX(_1)", "MIN(_1)") // TODO: Test 
all the aggregates
 
           aggregateFunctions.foreach { function =>
             val queries = Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to