This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 96aa2a677 add a describe method on DataFrame like Polars (#5226)
96aa2a677 is described below

commit 96aa2a677f1f2b2bb8dd803841d71f2a08301961
Author: jiangzhx <[email protected]>
AuthorDate: Wed Mar 1 00:56:20 2023 +0800

    add a describe method on DataFrame like Polars (#5226)
    
    * add describe method like polars
    
    * clippy fix
    
    * commit suggestion
    
    * fix typos
---
 datafusion-examples/examples/dataframe.rs |   9 ++
 datafusion/core/src/dataframe.rs          | 158 +++++++++++++++++++++++++++++-
 datafusion/core/tests/dataframe.rs        |  30 +++++-
 3 files changed, 193 insertions(+), 4 deletions(-)

diff --git a/datafusion-examples/examples/dataframe.rs 
b/datafusion-examples/examples/dataframe.rs
index f52ff8925..027ff9970 100644
--- a/datafusion-examples/examples/dataframe.rs
+++ b/datafusion-examples/examples/dataframe.rs
@@ -49,6 +49,15 @@ async fn main() -> Result<()> {
     let csv_df = example_read_csv_file_with_schema().await;
     csv_df.show().await?;
 
+    // Reading PARQUET file and print describe
+    let parquet_df = ctx
+        .read_parquet(
+            &format!("{testdata}/alltypes_plain.parquet"),
+            ParquetReadOptions::default(),
+        )
+        .await?;
+    parquet_df.describe().await.unwrap().show().await?;
+
     Ok(())
 }
 
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 1d5396219..5f6fa6b4b 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -20,19 +20,22 @@
 use std::any::Any;
 use std::sync::Arc;
 
-use arrow::array::Int64Array;
+use arrow::array::{ArrayRef, Int64Array, StringArray};
+use arrow::compute::{cast, concat};
+use arrow::datatypes::{DataType, Field};
 use async_trait::async_trait;
 use datafusion_common::DataFusionError;
 use parquet::file::properties::WriterProperties;
 
+use datafusion_common::from_slice::FromSlice;
 use datafusion_common::{Column, DFSchema, ScalarValue};
-use datafusion_expr::TableProviderFilterPushDown;
+use datafusion_expr::{TableProviderFilterPushDown, UNNAMED_TABLE};
 
 use crate::arrow::datatypes::Schema;
 use crate::arrow::datatypes::SchemaRef;
 use crate::arrow::record_batch::RecordBatch;
 use crate::arrow::util::pretty;
-use crate::datasource::{MemTable, TableProvider};
+use crate::datasource::{provider_as_source, MemTable, TableProvider};
 use crate::error::Result;
 use crate::execution::{
     context::{SessionState, TaskContext},
@@ -302,6 +305,155 @@ impl DataFrame {
         ))
     }
 
+    /// Summary statistics for a DataFrame. Only summarizes numeric datatypes 
at the moment and
+    /// returns nulls for non numeric datatypes. Try in keep output similar to 
pandas
+    ///
+    /// ```
+    /// # use datafusion::prelude::*;
+    /// # use datafusion::error::Result;
+    /// # use arrow::util::pretty;
+    /// # #[tokio::main]
+    /// # async fn main() -> Result<()> {
+    /// let ctx = SessionContext::new();
+    /// let df = ctx.read_csv("tests/tpch-csv/customer.csv", 
CsvReadOptions::new()).await?;    
+    /// df.describe().await.unwrap();
+    ///
+    /// # Ok(())
+    /// # }
+    /// ```
+    pub async fn describe(self) -> Result<Self> {
+        //the functions now supported
+        let supported_describe_functions = vec!["count", "null_count", "max", 
"min"];
+
+        let fields_iter = self.schema().fields().iter();
+
+        //define describe column
+        let mut describe_schemas = fields_iter
+            .clone()
+            .map(|field| {
+                if field.data_type().is_numeric() {
+                    Field::new(field.name(), DataType::Float64, true)
+                } else {
+                    Field::new(field.name(), DataType::Utf8, true)
+                }
+            })
+            .collect::<Vec<_>>();
+        describe_schemas.insert(0, Field::new("describe", DataType::Utf8, 
false));
+
+        //collect recordBatch
+        let describe_record_batch = vec![
+            // count aggregation
+            self.clone()
+                .aggregate(
+                    vec![],
+                    fields_iter
+                        .clone()
+                        .map(|f| 
datafusion_expr::count(col(f.name())).alias(f.name()))
+                        .collect::<Vec<_>>(),
+                )?
+                .collect()
+                .await?,
+            // null_count aggregation
+            self.clone()
+                .aggregate(
+                    vec![],
+                    fields_iter
+                        .clone()
+                        .map(|f| {
+                            datafusion_expr::count(datafusion_expr::is_null(
+                                col(f.name()),
+                            ))
+                            .alias(f.name())
+                        })
+                        .collect::<Vec<_>>(),
+                )?
+                .collect()
+                .await?,
+            // max aggregation
+            self.clone()
+                .aggregate(
+                    vec![],
+                    fields_iter
+                        .clone()
+                        .filter(|f| {
+                            !matches!(f.data_type(), DataType::Binary | 
DataType::Boolean)
+                        })
+                        .map(|f| 
datafusion_expr::max(col(f.name())).alias(f.name()))
+                        .collect::<Vec<_>>(),
+                )?
+                .collect()
+                .await?,
+            // min aggregation
+            self.clone()
+                .aggregate(
+                    vec![],
+                    fields_iter
+                        .clone()
+                        .filter(|f| {
+                            !matches!(f.data_type(), DataType::Binary | 
DataType::Boolean)
+                        })
+                        .map(|f| 
datafusion_expr::min(col(f.name())).alias(f.name()))
+                        .collect::<Vec<_>>(),
+                )?
+                .collect()
+                .await?,
+        ];
+
+        let mut array_ref_vec: Vec<ArrayRef> = vec![];
+        for field in fields_iter {
+            let mut array_datas = vec![];
+            for record_batch in describe_record_batch.iter() {
+                let column = 
record_batch.get(0).unwrap().column_by_name(field.name());
+                match column {
+                    Some(c) => {
+                        if field.data_type().is_numeric() {
+                            array_datas.push(cast(c, &DataType::Float64)?);
+                        } else {
+                            array_datas.push(cast(c, &DataType::Utf8)?);
+                        }
+                    }
+                    //if None mean the column cannot be min/max aggregation
+                    None => {
+                        
array_datas.push(Arc::new(StringArray::from_slice(["null"])));
+                    }
+                }
+            }
+
+            array_ref_vec.push(concat(
+                array_datas
+                    .iter()
+                    .map(|af| af.as_ref())
+                    .collect::<Vec<_>>()
+                    .as_slice(),
+            )?);
+        }
+
+        //insert first column with function names
+        array_ref_vec.insert(
+            0,
+            Arc::new(StringArray::from_slice(
+                supported_describe_functions.clone(),
+            )),
+        );
+
+        let describe_record_batch =
+            RecordBatch::try_new(Arc::new(Schema::new(describe_schemas)), 
array_ref_vec)?;
+
+        let provider = MemTable::try_new(
+            describe_record_batch.schema(),
+            vec![vec![describe_record_batch]],
+        )?;
+        Ok(DataFrame::new(
+            self.session_state,
+            LogicalPlanBuilder::scan(
+                UNNAMED_TABLE,
+                provider_as_source(Arc::new(provider)),
+                None,
+            )?
+            .build()?,
+        ))
+    }
+
     /// Sort the DataFrame by the specified sorting expressions. Any 
expression can be turned into
     /// a sort expression by calling its 
[sort](../logical_plan/enum.Expr.html#method.sort) method.
     ///
diff --git a/datafusion/core/tests/dataframe.rs 
b/datafusion/core/tests/dataframe.rs
index a9e28848c..6cb92ca51 100644
--- a/datafusion/core/tests/dataframe.rs
+++ b/datafusion/core/tests/dataframe.rs
@@ -29,12 +29,40 @@ use std::sync::Arc;
 use datafusion::dataframe::DataFrame;
 use datafusion::error::Result;
 use datafusion::execution::context::SessionContext;
-use datafusion::prelude::CsvReadOptions;
 use datafusion::prelude::JoinType;
+use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
 use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
 use datafusion_expr::expr::{GroupingSet, Sort};
 use datafusion_expr::{avg, col, count, lit, sum, Expr, ExprSchemable};
 
+#[tokio::test]
+async fn describe() -> Result<()> {
+    let ctx = SessionContext::new();
+    let testdata = datafusion::test_util::parquet_test_data();
+
+    let filename = &format!("{testdata}/alltypes_plain.parquet");
+
+    let df = ctx
+        .read_parquet(filename, ParquetReadOptions::default())
+        .await?;
+
+    let describe_record_batch = 
df.describe().await.unwrap().collect().await.unwrap();
+    #[rustfmt::skip]
+        let expected = vec![
+        
"+------------+-----+----------+-------------+--------------+---------+------------+-------------------+------------+-----------------+------------+---------------------+",
+        "| describe   | id  | bool_col | tinyint_col | smallint_col | int_col 
| bigint_col | float_col         | double_col | date_string_col | string_col | 
timestamp_col       |",
+        
"+------------+-----+----------+-------------+--------------+---------+------------+-------------------+------------+-----------------+------------+---------------------+",
+        "| count      | 8.0 | 8        | 8.0         | 8.0          | 8.0     
| 8.0        | 8.0               | 8.0        | 8               | 8          | 
8                   |",
+        "| null_count | 8.0 | 8        | 8.0         | 8.0          | 8.0     
| 8.0        | 8.0               | 8.0        | 8               | 8          | 
8                   |",
+        "| max        | 7.0 | null     | 1.0         | 1.0          | 1.0     
| 10.0       | 1.100000023841858 | 10.1       | null            | null       | 
2009-04-01T00:01:00 |",
+        "| min        | 0.0 | null     | 0.0         | 0.0          | 0.0     
| 0.0        | 0.0               | 0.0        | null            | null       | 
2009-01-01T00:00:00 |",
+        
"+------------+-----+----------+-------------+--------------+---------+------------+-------------------+------------+-----------------+------------+---------------------+",
+    ];
+    assert_batches_eq!(expected, &describe_record_batch);
+
+    Ok(())
+}
+
 #[tokio::test]
 async fn join() -> Result<()> {
     let schema1 = Arc::new(Schema::new(vec![

Reply via email to