This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 99ef98931 Apply workaround for #5444 to DataFrame::describe (#5468)
99ef98931 is described below
commit 99ef989312292b04efdf06b178617e9008b46c84
Author: zhenxing jiang <[email protected]>
AuthorDate: Tue Mar 7 01:04:41 2023 +0800
Apply workaround for #5444 to DataFrame::describe (#5468)
---
datafusion/core/src/dataframe.rs | 54 ++++++++++++++++++++++++--------------
datafusion/core/tests/dataframe.rs | 30 ++++++++++-----------
2 files changed, 49 insertions(+), 35 deletions(-)
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 82f8deb3c..1fbf19c33 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -20,7 +20,7 @@
use std::any::Any;
use std::sync::Arc;
-use arrow::array::{ArrayRef, Int64Array, StringArray};
+use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
use arrow::datatypes::{DataType, Field};
use async_trait::async_trait;
@@ -329,10 +329,10 @@ impl DataFrame {
let supported_describe_functions =
vec!["count", "null_count", "mean", "std", "min", "max", "median"];
- let fields_iter = self.schema().fields().iter();
+ let original_schema_fields = self.schema().fields().iter();
//define describe column
- let mut describe_schemas = fields_iter
+ let mut describe_schemas = original_schema_fields
.clone()
.map(|field| {
if field.data_type().is_numeric() {
@@ -344,24 +344,38 @@ impl DataFrame {
.collect::<Vec<_>>();
describe_schemas.insert(0, Field::new("describe", DataType::Utf8,
false));
+ //count aggregation
+ let cnt = self.clone().aggregate(
+ vec![],
+ original_schema_fields
+ .clone()
+ .map(|f| count(col(f.name())))
+ .collect::<Vec<_>>(),
+ )?;
+ // The optimization of AggregateStatistics will rewrite the physical
plan
+ // for the count function and ignore alias functions,
+ // as shown in https://github.com/apache/arrow-datafusion/issues/5444.
+ // This logic should be removed when #5444 is fixed.
+ let cnt = cnt.clone().select(
+ cnt.schema()
+ .fields()
+ .iter()
+ .zip(original_schema_fields.clone())
+ .map(|(count_field, orgin_field)| {
+ col(count_field.name()).alias(orgin_field.name())
+ })
+ .collect::<Vec<_>>(),
+ )?;
+ //should be removed when #5444 is fixed
//collect recordBatch
let describe_record_batch = vec![
// count aggregation
- self.clone()
- .aggregate(
- vec![],
- fields_iter
- .clone()
- .map(|f| count(col(f.name())).alias(f.name()))
- .collect::<Vec<_>>(),
- )?
- .collect()
- .await?,
+ cnt.collect().await?,
// null_count aggregation
self.clone()
.aggregate(
vec![],
- fields_iter
+ original_schema_fields
.clone()
.map(|f| count(is_null(col(f.name()))).alias(f.name()))
.collect::<Vec<_>>(),
@@ -372,7 +386,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
- fields_iter
+ original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| avg(col(f.name())).alias(f.name()))
@@ -384,7 +398,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
- fields_iter
+ original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| stddev(col(f.name())).alias(f.name()))
@@ -396,7 +410,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
- fields_iter
+ original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary |
DataType::Boolean)
@@ -410,7 +424,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
- fields_iter
+ original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary |
DataType::Boolean)
@@ -424,7 +438,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
- fields_iter
+ original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| median(col(f.name())).alias(f.name()))
@@ -435,7 +449,7 @@ impl DataFrame {
];
let mut array_ref_vec: Vec<ArrayRef> = vec![];
- for field in fields_iter {
+ for field in original_schema_fields {
let mut array_datas = vec![];
for record_batch in describe_record_batch.iter() {
let column =
record_batch.get(0).unwrap().column_by_name(field.name());
diff --git a/datafusion/core/tests/dataframe.rs
b/datafusion/core/tests/dataframe.rs
index ede74b227..453b8f5cb 100644
--- a/datafusion/core/tests/dataframe.rs
+++ b/datafusion/core/tests/dataframe.rs
@@ -40,26 +40,26 @@ async fn describe() -> Result<()> {
let ctx = SessionContext::new();
let testdata = datafusion::test_util::parquet_test_data();
- let filename = &format!("{testdata}/alltypes_plain.parquet");
-
let df = ctx
- .read_parquet(filename, ParquetReadOptions::default())
+ .read_parquet(
+ &format!("{testdata}/alltypes_tiny_pages.parquet"),
+ ParquetReadOptions::default(),
+ )
.await?;
-
let describe_record_batch =
df.describe().await.unwrap().collect().await.unwrap();
#[rustfmt::skip]
let expected = vec![
-
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
- "| describe | id | bool_col | tinyint_col |
smallint_col | int_col | bigint_col | float_col
| double_col | date_string_col | string_col | timestamp_col |",
-
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
- "| count | 8.0 | 8 | 8.0 |
8.0 | 8.0 | 8.0 | 8.0
| 8.0 | 8 | 8 | 8 |",
- "| null_count | 8.0 | 8 | 8.0 |
8.0 | 8.0 | 8.0 | 8.0
| 8.0 | 8 | 8 | 8 |",
- "| mean | 3.5 | null | 0.5 |
0.5 | 0.5 | 5.0 |
0.550000011920929 | 5.05 | null | null | null
|",
- "| std | 2.4494897427831783 | null | 0.5345224838248488 |
0.5345224838248488 | 0.5345224838248488 | 5.3452248382484875 |
0.5879747449513427 | 5.398677086630973 | null | null | null
|",
- "| min | 0.0 | null | 0.0 |
0.0 | 0.0 | 0.0 | 0.0
| 0.0 | null | null | 2009-01-01T00:00:00 |",
- "| max | 7.0 | null | 1.0 |
1.0 | 1.0 | 10.0 |
1.100000023841858 | 10.1 | null | null |
2009-04-01T00:01:00 |",
- "| median | 3.0 | null | 0.0 |
0.0 | 0.0 | 5.0 |
0.550000011920929 | 5.05 | null | null | null
|",
-
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
+
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
+ "| describe | id | bool_col | tinyint_col |
smallint_col | int_col | bigint_col | float_col
| double_col | date_string_col | string_col | timestamp_col
| year | month |",
+
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
+ "| count | 7300.0 | 7300 | 7300.0 |
7300.0 | 7300.0 | 7300.0 | 7300.0
| 7300.0 | 7300 | 7300 | 7300
| 7300.0 | 7300.0 |",
+ "| null_count | 7300.0 | 7300 | 7300.0 |
7300.0 | 7300.0 | 7300.0 | 7300.0
| 7300.0 | 7300 | 7300 | 7300
| 7300.0 | 7300.0 |",
+ "| mean | 3649.5 | null | 4.5 |
4.5 | 4.5 | 45.0 |
4.949999964237213 | 45.45000000000001 | null | null | null
| 2009.5 | 6.526027397260274 |",
+ "| std | 2107.472815166704 | null | 2.8724780750809518 |
2.8724780750809518 | 2.8724780750809518 | 28.724780750809533 |
3.1597258182544645 | 29.012028558317645 | null | null | null
| 0.5000342500942125 | 3.44808750051728 |",
+ "| min | 0.0 | null | 0.0 |
0.0 | 0.0 | 0.0 | 0.0
| 0.0 | 01/01/09 | 0 | 2008-12-31T23:00:00
| 2009.0 | 1.0 |",
+ "| max | 7299.0 | null | 9.0 |
9.0 | 9.0 | 90.0 |
9.899999618530273 | 90.89999999999999 | 12/31/10 | 9 |
2010-12-31T04:09:13.860 | 2010.0 | 12.0 |",
+ "| median | 3649.0 | null | 4.0 |
4.0 | 4.0 | 45.0 |
4.949999809265137 | 45.45 | null | null | null
| 2009.0 | 7.0 |",
+
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
];
assert_batches_eq!(expected, &describe_record_batch);