This is an automated email from the ASF dual-hosted git repository.

kosiew pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 4925e6c67e fix: dataframe function count_all with alias (#17282)
4925e6c67e is described below

commit 4925e6c67e45da65c0156a3ad0b826c5e0c7f72c
Author: Loakesh Indiran <66478092+loak...@users.noreply.github.com>
AuthorDate: Sun Aug 24 19:49:04 2025 +0530

    fix: dataframe function count_all with alias (#17282)
    
    * fix: dataframe function count_all with alias
---
 datafusion/core/src/physical_planner.rs            | 35 ++++++++++++++++++----
 .../core/tests/dataframe/dataframe_functions.rs    | 25 ++++++++++++++++
 2 files changed, 54 insertions(+), 6 deletions(-)

diff --git a/datafusion/core/src/physical_planner.rs 
b/datafusion/core/src/physical_planner.rs
index 1021abc9e4..0ce5621ac8 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -1782,21 +1782,24 @@ pub fn create_aggregate_expr_and_maybe_filter(
     physical_input_schema: &Schema,
     execution_props: &ExecutionProps,
 ) -> Result<AggregateExprWithOptionalArgs> {
-    // unpack (nested) aliased logical expressions, e.g. "sum(col) as total"
+    // Unpack (potentially nested) aliased logical expressions, e.g. "sum(col) 
as total"
+    // Some functions like `count_all()` create internal aliases,
+    // Unwrap all alias layers to get to the underlying aggregate function
     let (name, human_display, e) = match e {
-        Expr::Alias(Alias { expr, name, .. }) => {
-            (Some(name.clone()), String::default(), expr.as_ref())
+        Expr::Alias(Alias { name, .. }) => {
+            let unaliased = e.clone().unalias_nested().data;
+            (Some(name.clone()), e.human_display().to_string(), unaliased)
         }
         Expr::AggregateFunction(_) => (
             Some(e.schema_name().to_string()),
             e.human_display().to_string(),
-            e,
+            e.clone(),
         ),
-        _ => (None, String::default(), e),
+        _ => (None, String::default(), e.clone()),
     };
 
     create_aggregate_expr_with_name_and_maybe_filter(
-        e,
+        &e,
         name,
         human_display,
         logical_input_schema,
@@ -2416,6 +2419,7 @@ mod tests {
     use datafusion_expr::{
         col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore,
     };
+    use datafusion_functions_aggregate::count::count_all;
     use datafusion_functions_aggregate::expr_fn::sum;
     use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr};
     use datafusion_physical_expr::EquivalenceProperties;
@@ -2876,6 +2880,25 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn test_aggregate_count_all_with_alias() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("c1", DataType::Utf8, false),
+            Field::new("c2", DataType::UInt32, false),
+        ]));
+
+        let logical_plan = scan_empty(None, schema.as_ref(), None)?
+            .aggregate(Vec::<Expr>::new(), 
vec![count_all().alias("total_rows")])?
+            .build()?;
+
+        let physical_plan = plan(&logical_plan).await?;
+        assert_eq!(
+            "total_rows",
+            physical_plan.schema().field(0).name().as_str()
+        );
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_explain() {
         let schema = Schema::new(vec![Field::new("id", DataType::Int32, 
false)]);
diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs 
b/datafusion/core/tests/dataframe/dataframe_functions.rs
index be49b88a99..b664fccdfa 100644
--- a/datafusion/core/tests/dataframe/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe/dataframe_functions.rs
@@ -1316,3 +1316,28 @@ async fn test_count_wildcard() -> Result<()> {
 
     Ok(())
 }
+
+/// Call count wildcard with alias from dataframe API
+#[tokio::test]
+async fn test_count_wildcard_with_alias() -> Result<()> {
+    let df = create_test_table().await?;
+    let result_df = df.aggregate(vec![], 
vec![count_all().alias("total_count")])?;
+
+    let schema = result_df.schema();
+    assert_eq!(schema.fields().len(), 1);
+    assert_eq!(schema.field(0).name(), "total_count");
+    assert_eq!(*schema.field(0).data_type(), DataType::Int64);
+
+    let batches = result_df.collect().await?;
+    assert_eq!(batches.len(), 1);
+    assert_eq!(batches[0].num_rows(), 1);
+
+    let count_array = batches[0]
+        .column(0)
+        .as_any()
+        .downcast_ref::<arrow::array::Int64Array>()
+        .unwrap();
+    assert_eq!(count_array.value(0), 4);
+
+    Ok(())
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to