This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 99ef98931 Apply workaround for #5444 to DataFrame::describe (#5468)
99ef98931 is described below

commit 99ef989312292b04efdf06b178617e9008b46c84
Author: zhenxing jiang <[email protected]>
AuthorDate: Tue Mar 7 01:04:41 2023 +0800

    Apply workaround for #5444 to DataFrame::describe (#5468)
---
 datafusion/core/src/dataframe.rs   | 54 ++++++++++++++++++++++++--------------
 datafusion/core/tests/dataframe.rs | 30 ++++++++++-----------
 2 files changed, 49 insertions(+), 35 deletions(-)

diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 82f8deb3c..1fbf19c33 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -20,7 +20,7 @@
 use std::any::Any;
 use std::sync::Arc;
 
-use arrow::array::{ArrayRef, Int64Array, StringArray};
+use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
 use arrow::compute::{cast, concat};
 use arrow::datatypes::{DataType, Field};
 use async_trait::async_trait;
@@ -329,10 +329,10 @@ impl DataFrame {
         let supported_describe_functions =
             vec!["count", "null_count", "mean", "std", "min", "max", "median"];
 
-        let fields_iter = self.schema().fields().iter();
+        let original_schema_fields = self.schema().fields().iter();
 
         //define describe column
-        let mut describe_schemas = fields_iter
+        let mut describe_schemas = original_schema_fields
             .clone()
             .map(|field| {
                 if field.data_type().is_numeric() {
@@ -344,24 +344,38 @@ impl DataFrame {
             .collect::<Vec<_>>();
         describe_schemas.insert(0, Field::new("describe", DataType::Utf8, 
false));
 
+        //count aggregation
+        let cnt = self.clone().aggregate(
+            vec![],
+            original_schema_fields
+                .clone()
+                .map(|f| count(col(f.name())))
+                .collect::<Vec<_>>(),
+        )?;
+        // The optimization of AggregateStatistics will rewrite the physical 
plan
+        // for the count function and ignore alias functions,
+        // as shown in https://github.com/apache/arrow-datafusion/issues/5444.
+        // This logic should be removed when #5444 is fixed.
+        let cnt = cnt.clone().select(
+            cnt.schema()
+                .fields()
+                .iter()
+                .zip(original_schema_fields.clone())
+                .map(|(count_field, orgin_field)| {
+                    col(count_field.name()).alias(orgin_field.name())
+                })
+                .collect::<Vec<_>>(),
+        )?;
+        //should be removed when #5444 is fixed
         //collect recordBatch
         let describe_record_batch = vec![
             // count aggregation
-            self.clone()
-                .aggregate(
-                    vec![],
-                    fields_iter
-                        .clone()
-                        .map(|f| count(col(f.name())).alias(f.name()))
-                        .collect::<Vec<_>>(),
-                )?
-                .collect()
-                .await?,
+            cnt.collect().await?,
             // null_count aggregation
             self.clone()
                 .aggregate(
                     vec![],
-                    fields_iter
+                    original_schema_fields
                         .clone()
                         .map(|f| count(is_null(col(f.name()))).alias(f.name()))
                         .collect::<Vec<_>>(),
@@ -372,7 +386,7 @@ impl DataFrame {
             self.clone()
                 .aggregate(
                     vec![],
-                    fields_iter
+                    original_schema_fields
                         .clone()
                         .filter(|f| f.data_type().is_numeric())
                         .map(|f| avg(col(f.name())).alias(f.name()))
@@ -384,7 +398,7 @@ impl DataFrame {
             self.clone()
                 .aggregate(
                     vec![],
-                    fields_iter
+                    original_schema_fields
                         .clone()
                         .filter(|f| f.data_type().is_numeric())
                         .map(|f| stddev(col(f.name())).alias(f.name()))
@@ -396,7 +410,7 @@ impl DataFrame {
             self.clone()
                 .aggregate(
                     vec![],
-                    fields_iter
+                    original_schema_fields
                         .clone()
                         .filter(|f| {
                             !matches!(f.data_type(), DataType::Binary | 
DataType::Boolean)
@@ -410,7 +424,7 @@ impl DataFrame {
             self.clone()
                 .aggregate(
                     vec![],
-                    fields_iter
+                    original_schema_fields
                         .clone()
                         .filter(|f| {
                             !matches!(f.data_type(), DataType::Binary | 
DataType::Boolean)
@@ -424,7 +438,7 @@ impl DataFrame {
             self.clone()
                 .aggregate(
                     vec![],
-                    fields_iter
+                    original_schema_fields
                         .clone()
                         .filter(|f| f.data_type().is_numeric())
                         .map(|f| median(col(f.name())).alias(f.name()))
@@ -435,7 +449,7 @@ impl DataFrame {
         ];
 
         let mut array_ref_vec: Vec<ArrayRef> = vec![];
-        for field in fields_iter {
+        for field in original_schema_fields {
             let mut array_datas = vec![];
             for record_batch in describe_record_batch.iter() {
                 let column = 
record_batch.get(0).unwrap().column_by_name(field.name());
diff --git a/datafusion/core/tests/dataframe.rs 
b/datafusion/core/tests/dataframe.rs
index ede74b227..453b8f5cb 100644
--- a/datafusion/core/tests/dataframe.rs
+++ b/datafusion/core/tests/dataframe.rs
@@ -40,26 +40,26 @@ async fn describe() -> Result<()> {
     let ctx = SessionContext::new();
     let testdata = datafusion::test_util::parquet_test_data();
 
-    let filename = &format!("{testdata}/alltypes_plain.parquet");
-
     let df = ctx
-        .read_parquet(filename, ParquetReadOptions::default())
+        .read_parquet(
+            &format!("{testdata}/alltypes_tiny_pages.parquet"),
+            ParquetReadOptions::default(),
+        )
         .await?;
-
     let describe_record_batch = 
df.describe().await.unwrap().collect().await.unwrap();
     #[rustfmt::skip]
         let expected = vec![
-        
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
-        "| describe   | id                 | bool_col | tinyint_col        | 
smallint_col       | int_col            | bigint_col         | float_col        
  | double_col        | date_string_col | string_col | timestamp_col       |",
-        
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
-        "| count      | 8.0                | 8        | 8.0                | 
8.0                | 8.0                | 8.0                | 8.0              
  | 8.0               | 8               | 8          | 8                   |",
-        "| null_count | 8.0                | 8        | 8.0                | 
8.0                | 8.0                | 8.0                | 8.0              
  | 8.0               | 8               | 8          | 8                   |",
-        "| mean       | 3.5                | null     | 0.5                | 
0.5                | 0.5                | 5.0                | 
0.550000011920929  | 5.05              | null            | null       | null    
            |",
-        "| std        | 2.4494897427831783 | null     | 0.5345224838248488 | 
0.5345224838248488 | 0.5345224838248488 | 5.3452248382484875 | 
0.5879747449513427 | 5.398677086630973 | null            | null       | null    
            |",
-        "| min        | 0.0                | null     | 0.0                | 
0.0                | 0.0                | 0.0                | 0.0              
  | 0.0               | null            | null       | 2009-01-01T00:00:00 |",
-        "| max        | 7.0                | null     | 1.0                | 
1.0                | 1.0                | 10.0               | 
1.100000023841858  | 10.1              | null            | null       | 
2009-04-01T00:01:00 |",
-        "| median     | 3.0                | null     | 0.0                | 
0.0                | 0.0                | 5.0                | 
0.550000011920929  | 5.05              | null            | null       | null    
            |",
-        
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
+        
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
+        "| describe   | id                | bool_col | tinyint_col        | 
smallint_col       | int_col            | bigint_col         | float_col        
  | double_col         | date_string_col | string_col | timestamp_col           
| year               | month             |",
+        
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
+        "| count      | 7300.0            | 7300     | 7300.0             | 
7300.0             | 7300.0             | 7300.0             | 7300.0           
  | 7300.0             | 7300            | 7300       | 7300                    
| 7300.0             | 7300.0            |",
+        "| null_count | 7300.0            | 7300     | 7300.0             | 
7300.0             | 7300.0             | 7300.0             | 7300.0           
  | 7300.0             | 7300            | 7300       | 7300                    
| 7300.0             | 7300.0            |",
+        "| mean       | 3649.5            | null     | 4.5                | 
4.5                | 4.5                | 45.0               | 
4.949999964237213  | 45.45000000000001  | null            | null       | null   
                 | 2009.5             | 6.526027397260274 |",
+        "| std        | 2107.472815166704 | null     | 2.8724780750809518 | 
2.8724780750809518 | 2.8724780750809518 | 28.724780750809533 | 
3.1597258182544645 | 29.012028558317645 | null            | null       | null   
                 | 0.5000342500942125 | 3.44808750051728  |",
+        "| min        | 0.0               | null     | 0.0                | 
0.0                | 0.0                | 0.0                | 0.0              
  | 0.0                | 01/01/09        | 0          | 2008-12-31T23:00:00     
| 2009.0             | 1.0               |",
+        "| max        | 7299.0            | null     | 9.0                | 
9.0                | 9.0                | 90.0               | 
9.899999618530273  | 90.89999999999999  | 12/31/10        | 9          | 
2010-12-31T04:09:13.860 | 2010.0             | 12.0              |",
+        "| median     | 3649.0            | null     | 4.0                | 
4.0                | 4.0                | 45.0               | 
4.949999809265137  | 45.45              | null            | null       | null   
                 | 2009.0             | 7.0               |",
+        
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
     ];
     assert_batches_eq!(expected, &describe_record_batch);
 

Reply via email to