This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 9464bf2eb Revert describe count() workaround (#5556)
9464bf2eb is described below
commit 9464bf2eb593ce239acf3823c5ecdc6760b15679
Author: Jeffrey <[email protected]>
AuthorDate: Tue Mar 14 02:41:52 2023 +1100
Revert describe count() workaround (#5556)
---
datafusion/core/src/dataframe.rs | 67 +++++++++++++-------------------------
datafusion/core/tests/dataframe.rs | 8 +++--
2 files changed, 29 insertions(+), 46 deletions(-)
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 7cdbaa43c..eab56432e 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -332,45 +332,28 @@ impl DataFrame {
let original_schema_fields = self.schema().fields().iter();
//define describe column
- let mut describe_schemas = original_schema_fields
- .clone()
- .map(|field| {
- if field.data_type().is_numeric() {
- Field::new(field.name(), DataType::Float64, true)
- } else {
- Field::new(field.name(), DataType::Utf8, true)
- }
- })
- .collect::<Vec<_>>();
- describe_schemas.insert(0, Field::new("describe", DataType::Utf8,
false));
+ let mut describe_schemas = vec![Field::new("describe", DataType::Utf8,
false)];
+ describe_schemas.extend(original_schema_fields.clone().map(|field| {
+ if field.data_type().is_numeric() {
+ Field::new(field.name(), DataType::Float64, true)
+ } else {
+ Field::new(field.name(), DataType::Utf8, true)
+ }
+ }));
- //count aggregation
- let cnt = self.clone().aggregate(
- vec![],
- original_schema_fields
- .clone()
- .map(|f| count(col(f.name())))
- .collect::<Vec<_>>(),
- )?;
- // The optimization of AggregateStatistics will rewrite the physical
plan
- // for the count function and ignore alias functions,
- // as shown in https://github.com/apache/arrow-datafusion/issues/5444.
- // This logic should be removed when #5444 is fixed.
- let cnt = cnt.clone().select(
- cnt.schema()
- .fields()
- .iter()
- .zip(original_schema_fields.clone())
- .map(|(count_field, orgin_field)| {
- col(count_field.name()).alias(orgin_field.name())
- })
- .collect::<Vec<_>>(),
- )?;
- //should be removed when #5444 is fixed
//collect recordBatch
let describe_record_batch = vec![
// count aggregation
- cnt.collect().await?,
+ self.clone()
+ .aggregate(
+ vec![],
+ original_schema_fields
+ .clone()
+ .map(|f| count(col(f.name())).alias(f.name()))
+ .collect::<Vec<_>>(),
+ )?
+ .collect()
+ .await?,
// null_count aggregation
self.clone()
.aggregate(
@@ -448,10 +431,14 @@ impl DataFrame {
.await?,
];
- let mut array_ref_vec: Vec<ArrayRef> = vec![];
+ // first column with function names
+ let mut array_ref_vec: Vec<ArrayRef> =
vec![Arc::new(StringArray::from_slice(
+ supported_describe_functions.clone(),
+ ))];
for field in original_schema_fields {
let mut array_datas = vec![];
for record_batch in describe_record_batch.iter() {
+ // safe unwrap since aggregate record batches should have at
least 1 record
let column =
record_batch.get(0).unwrap().column_by_name(field.name());
match column {
Some(c) => {
@@ -477,14 +464,6 @@ impl DataFrame {
)?);
}
- //insert first column with function names
- array_ref_vec.insert(
- 0,
- Arc::new(StringArray::from_slice(
- supported_describe_functions.clone(),
- )),
- );
-
let describe_record_batch =
RecordBatch::try_new(Arc::new(Schema::new(describe_schemas)),
array_ref_vec)?;
diff --git a/datafusion/core/tests/dataframe.rs
b/datafusion/core/tests/dataframe.rs
index b619262e0..19ceebe17 100644
--- a/datafusion/core/tests/dataframe.rs
+++ b/datafusion/core/tests/dataframe.rs
@@ -40,13 +40,17 @@ async fn describe() -> Result<()> {
let ctx = SessionContext::new();
let testdata = datafusion::test_util::parquet_test_data();
- let df = ctx
+ let describe_record_batch = ctx
.read_parquet(
&format!("{testdata}/alltypes_tiny_pages.parquet"),
ParquetReadOptions::default(),
)
+ .await?
+ .describe()
+ .await?
+ .collect()
.await?;
- let describe_record_batch =
df.describe().await.unwrap().collect().await.unwrap();
+
#[rustfmt::skip]
let expected = vec![
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",