This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 01a51ac  Optimize min/max queries with table statistics (#719)
01a51ac is described below

commit 01a51aceab09d96931a151d17257e54a57c3e44f
Author: baishen <[email protected]>
AuthorDate: Fri Aug 6 09:35:03 2021 -0500

    Optimize min/max queries with table statistics (#719)
    
    * support statistics max min
    
    * fix test
    
    * add test
    
    * make pub(crate)
    
    * update arrow vertion to 5.1
    
    * fix clippy
---
 datafusion/Cargo.toml                              |   4 +-
 datafusion/src/optimizer/aggregate_statistics.rs   | 166 +++++++++++++-
 .../src/physical_plan/expressions/min_max.rs       |   6 +-
 datafusion/src/physical_plan/expressions/mod.rs    |   1 +
 datafusion/src/physical_plan/parquet.rs            | 252 +++++++++++++++++++--
 datafusion/src/scalar.rs                           |   2 +
 6 files changed, 400 insertions(+), 31 deletions(-)

diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml
index 2f1e997..bfb3a93 100644
--- a/datafusion/Cargo.toml
+++ b/datafusion/Cargo.toml
@@ -46,8 +46,8 @@ unicode_expressions = ["unicode-segmentation"]
 [dependencies]
 ahash = "0.7"
 hashbrown = "0.11"
-arrow = { version = "5.0", features = ["prettyprint"] }
-parquet = { version = "5.0", features = ["arrow"] }
+arrow = { version = "5.1", features = ["prettyprint"] }
+parquet = { version = "5.1", features = ["arrow"] }
 sqlparser = "0.9.0"
 paste = "^1.0"
 num_cpus = "1.13.0"
diff --git a/datafusion/src/optimizer/aggregate_statistics.rs 
b/datafusion/src/optimizer/aggregate_statistics.rs
index a20eafc..e2d9054 100644
--- a/datafusion/src/optimizer/aggregate_statistics.rs
+++ b/datafusion/src/optimizer/aggregate_statistics.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 //! Utilizing exact statistics from sources to avoid scanning data
+use std::collections::HashMap;
 use std::{sync::Arc, vec};
 
 use crate::{
@@ -55,12 +56,40 @@ impl OptimizerRule for AggregateStatistics {
                 // aggregations that can not be replaced
                 // using statistics
                 let mut agg = vec![];
+                let mut max_values = HashMap::new();
+                let mut min_values = HashMap::new();
+
                 // expressions that can be replaced by constants
                 let mut projections = vec![];
                 if let Some(num_rows) = match input.as_ref() {
-                    LogicalPlan::TableScan { source, .. }
-                        if source.has_exact_statistics() =>
-                    {
+                    LogicalPlan::TableScan {
+                        table_name, source, ..
+                    } if source.has_exact_statistics() => {
+                        let schema = source.schema();
+                        let fields = schema.fields();
+                        if let Some(column_statistics) =
+                            source.statistics().column_statistics
+                        {
+                            if fields.len() == column_statistics.len() {
+                                for (i, field) in fields.iter().enumerate() {
+                                    if let Some(max_value) =
+                                        column_statistics[i].max_value.clone()
+                                    {
+                                        let max_key =
+                                            format!("{}.{}", table_name, 
field.name());
+                                        max_values.insert(max_key, max_value);
+                                    }
+                                    if let Some(min_value) =
+                                        column_statistics[i].min_value.clone()
+                                    {
+                                        let min_key =
+                                            format!("{}.{}", table_name, 
field.name());
+                                        min_values.insert(min_key, min_value);
+                                    }
+                                }
+                            }
+                        }
+
                         source.statistics().num_rows
                     }
                     _ => None,
@@ -81,6 +110,60 @@ impl OptimizerRule for AggregateStatistics {
                                     "COUNT(Uint8(1))".to_string(),
                                 ));
                             }
+                            Expr::AggregateFunction {
+                                fun: AggregateFunction::Max,
+                                args,
+                                ..
+                            } => match &args[0] {
+                                Expr::Column(c) => match 
max_values.get(&c.flat_name()) {
+                                    Some(max_value) => {
+                                        if !max_value.is_null() {
+                                            let name = format!("MAX({})", 
c.name);
+                                            projections.push(Expr::Alias(
+                                                Box::new(Expr::Literal(
+                                                    max_value.clone(),
+                                                )),
+                                                name,
+                                            ));
+                                        } else {
+                                            agg.push(expr.clone());
+                                        }
+                                    }
+                                    None => {
+                                        agg.push(expr.clone());
+                                    }
+                                },
+                                _ => {
+                                    agg.push(expr.clone());
+                                }
+                            },
+                            Expr::AggregateFunction {
+                                fun: AggregateFunction::Min,
+                                args,
+                                ..
+                            } => match &args[0] {
+                                Expr::Column(c) => match 
min_values.get(&c.flat_name()) {
+                                    Some(min_value) => {
+                                        if !min_value.is_null() {
+                                            let name = format!("MIN({})", 
c.name);
+                                            projections.push(Expr::Alias(
+                                                Box::new(Expr::Literal(
+                                                    min_value.clone(),
+                                                )),
+                                                name,
+                                            ));
+                                        } else {
+                                            agg.push(expr.clone());
+                                        }
+                                    }
+                                    None => {
+                                        agg.push(expr.clone());
+                                    }
+                                },
+                                _ => {
+                                    agg.push(expr.clone());
+                                }
+                            },
                             _ => {
                                 agg.push(expr.clone());
                             }
@@ -159,13 +242,18 @@ mod tests {
     use crate::logical_plan::LogicalPlan;
     use crate::optimizer::aggregate_statistics::AggregateStatistics;
     use crate::optimizer::optimizer::OptimizerRule;
+    use crate::scalar::ScalarValue;
     use crate::{
-        datasource::{datasource::Statistics, TableProvider},
+        datasource::{
+            datasource::{ColumnStatistics, Statistics},
+            TableProvider,
+        },
         logical_plan::Expr,
     };
 
     struct TestTableProvider {
         num_rows: usize,
+        column_statistics: Vec<ColumnStatistics>,
         is_exact: bool,
     }
 
@@ -186,11 +274,11 @@ mod tests {
         ) -> Result<std::sync::Arc<dyn crate::physical_plan::ExecutionPlan>> {
             unimplemented!()
         }
-        fn statistics(&self) -> crate::datasource::datasource::Statistics {
+        fn statistics(&self) -> Statistics {
             Statistics {
                 num_rows: Some(self.num_rows),
                 total_byte_size: None,
-                column_statistics: None,
+                column_statistics: Some(self.column_statistics.clone()),
             }
         }
         fn has_exact_statistics(&self) -> bool {
@@ -206,6 +294,7 @@ mod tests {
             "test",
             Arc::new(TestTableProvider {
                 num_rows: 100,
+                column_statistics: Vec::new(),
                 is_exact: true,
             }),
         )
@@ -231,6 +320,7 @@ mod tests {
             "test",
             Arc::new(TestTableProvider {
                 num_rows: 100,
+                column_statistics: Vec::new(),
                 is_exact: false,
             }),
         )
@@ -256,6 +346,7 @@ mod tests {
             "test",
             Arc::new(TestTableProvider {
                 num_rows: 100,
+                column_statistics: Vec::new(),
                 is_exact: true,
             }),
         )
@@ -282,6 +373,7 @@ mod tests {
             "test",
             Arc::new(TestTableProvider {
                 num_rows: 100,
+                column_statistics: Vec::new(),
                 is_exact: true,
             }),
         )
@@ -307,6 +399,7 @@ mod tests {
             "test",
             Arc::new(TestTableProvider {
                 num_rows: 100,
+                column_statistics: Vec::new(),
                 is_exact: true,
             }),
         )
@@ -325,6 +418,67 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn optimize_max_min_using_statistics() -> Result<()> {
+        use crate::execution::context::ExecutionContext;
+        let mut ctx = ExecutionContext::new();
+
+        let column_statistic = ColumnStatistics {
+            null_count: None,
+            max_value: Some(ScalarValue::from(100_i64)),
+            min_value: Some(ScalarValue::from(1_i64)),
+            distinct_count: None,
+        };
+        let column_statistics = vec![column_statistic];
+
+        ctx.register_table(
+            "test",
+            Arc::new(TestTableProvider {
+                num_rows: 100,
+                column_statistics,
+                is_exact: true,
+            }),
+        )
+        .unwrap();
+
+        let plan = ctx
+            .create_logical_plan("select max(a), min(a) from test")
+            .unwrap();
+        let expected = "\
+            Projection: #MAX(test.a), #MIN(test.a)\
+            \n  Projection: Int64(100) AS MAX(a), Int64(1) AS MIN(a)\
+            \n    EmptyRelation";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn optimize_max_min_not_using_statistics() -> Result<()> {
+        use crate::execution::context::ExecutionContext;
+        let mut ctx = ExecutionContext::new();
+        ctx.register_table(
+            "test",
+            Arc::new(TestTableProvider {
+                num_rows: 100,
+                column_statistics: Vec::new(),
+                is_exact: true,
+            }),
+        )
+        .unwrap();
+
+        let plan = ctx
+            .create_logical_plan("select max(a), min(a) from test")
+            .unwrap();
+        let expected = "\
+            Projection: #MAX(test.a), #MIN(test.a)\
+            \n  Aggregate: groupBy=[[]], aggr=[[MAX(#test.a), MIN(#test.a)]]\
+            \n    TableScan: test projection=None";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
     fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
         let opt = AggregateStatistics::new();
         let optimized_plan = opt.optimize(plan, 
&ExecutionProps::new()).unwrap();
diff --git a/datafusion/src/physical_plan/expressions/min_max.rs 
b/datafusion/src/physical_plan/expressions/min_max.rs
index 6bb4c5b..21cf95d 100644
--- a/datafusion/src/physical_plan/expressions/min_max.rs
+++ b/datafusion/src/physical_plan/expressions/min_max.rs
@@ -314,8 +314,9 @@ fn max(lhs: &ScalarValue, rhs: &ScalarValue) -> 
Result<ScalarValue> {
     min_max!(lhs, rhs, max)
 }
 
+/// An accumulator to compute the maximum value
 #[derive(Debug)]
-struct MaxAccumulator {
+pub(crate) struct MaxAccumulator {
     max: ScalarValue,
 }
 
@@ -419,8 +420,9 @@ impl AggregateExpr for Min {
     }
 }
 
+/// An accumulator to compute the minimum value
 #[derive(Debug)]
-struct MinAccumulator {
+pub(crate) struct MinAccumulator {
     min: ScalarValue,
 }
 
diff --git a/datafusion/src/physical_plan/expressions/mod.rs 
b/datafusion/src/physical_plan/expressions/mod.rs
index bd3dab6..d60a871 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -62,6 +62,7 @@ pub use is_null::{is_null, IsNullExpr};
 pub use lead_lag::{lag, lead};
 pub use literal::{lit, Literal};
 pub use min_max::{Max, Min};
+pub(crate) use min_max::{MaxAccumulator, MinAccumulator};
 pub use negative::{negative, NegativeExpr};
 pub use not::{not, NotExpr};
 pub use nth_value::NthValue;
diff --git a/datafusion/src/physical_plan/parquet.rs 
b/datafusion/src/physical_plan/parquet.rs
index f606b53..ec5611f 100644
--- a/datafusion/src/physical_plan/parquet.rs
+++ b/datafusion/src/physical_plan/parquet.rs
@@ -36,7 +36,7 @@ use crate::{
 
 use arrow::{
     array::ArrayRef,
-    datatypes::{Schema, SchemaRef},
+    datatypes::{DataType, Schema, SchemaRef},
     error::{ArrowError, Result as ArrowResult},
     record_batch::RecordBatch,
 };
@@ -62,6 +62,8 @@ use async_trait::async_trait;
 use futures::stream::{Stream, StreamExt};
 
 use super::SQLMetric;
+use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator};
+use crate::physical_plan::Accumulator;
 
 /// Execution plan for scanning one or more Parquet partitions
 #[derive(Debug, Clone)]
@@ -173,8 +175,12 @@ impl ParquetExec {
         let filenames: Vec<String> = filenames.iter().map(|s| 
s.to_string()).collect();
         let chunks = split_files(&filenames, max_concurrency);
         let mut num_rows = 0;
+        let mut num_fields = 0;
+        let mut fields = Vec::new();
         let mut total_byte_size = 0;
         let mut null_counts = Vec::new();
+        let mut max_values: Vec<Option<MaxAccumulator>> = Vec::new();
+        let mut min_values: Vec<Option<MinAccumulator>> = Vec::new();
         let mut limit_exhausted = false;
         for chunk in chunks {
             let mut filenames: Vec<String> =
@@ -188,11 +194,23 @@ impl ParquetExec {
                 let meta_data = arrow_reader.get_metadata();
                 // collect all the unique schemas in this data set
                 let schema = arrow_reader.get_schema()?;
-                let num_fields = schema.fields().len();
                 if schemas.is_empty() || schema != schemas[0] {
+                    fields = schema.fields().to_vec();
+                    num_fields = schema.fields().len();
+                    null_counts = vec![0; num_fields];
+                    max_values = schema
+                        .fields()
+                        .iter()
+                        .map(|field| 
MaxAccumulator::try_new(field.data_type()).ok())
+                        .collect::<Vec<_>>();
+                    min_values = schema
+                        .fields()
+                        .iter()
+                        .map(|field| 
MinAccumulator::try_new(field.data_type()).ok())
+                        .collect::<Vec<_>>();
                     schemas.push(schema);
-                    null_counts = vec![0; num_fields]
                 }
+
                 for row_group_meta in meta_data.row_groups() {
                     num_rows += row_group_meta.num_rows();
                     total_byte_size += row_group_meta.total_byte_size();
@@ -207,20 +225,167 @@ impl ParquetExec {
                     for (i, cnt) in columns_null_counts.enumerate() {
                         null_counts[i] += cnt
                     }
+
+                    for (i, column) in 
row_group_meta.columns().iter().enumerate() {
+                        if let Some(stat) = column.statistics() {
+                            match stat {
+                                ParquetStatistics::Boolean(s) => {
+                                    if let DataType::Boolean = 
fields[i].data_type() {
+                                        if s.has_min_max_set() {
+                                            if let Some(max_value) = &mut 
max_values[i] {
+                                                match max_value.update(&[
+                                                    
ScalarValue::Boolean(Some(*s.max())),
+                                                ]) {
+                                                    Ok(_) => {}
+                                                    Err(_) => {
+                                                        max_values[i] = None;
+                                                    }
+                                                }
+                                            }
+                                            if let Some(min_value) = &mut 
min_values[i] {
+                                                match min_value.update(&[
+                                                    
ScalarValue::Boolean(Some(*s.min())),
+                                                ]) {
+                                                    Ok(_) => {}
+                                                    Err(_) => {
+                                                        min_values[i] = None;
+                                                    }
+                                                }
+                                            }
+                                        }
+                                    }
+                                }
+                                ParquetStatistics::Int32(s) => {
+                                    if let DataType::Int32 = 
fields[i].data_type() {
+                                        if s.has_min_max_set() {
+                                            if let Some(max_value) = &mut 
max_values[i] {
+                                                match max_value.update(&[
+                                                    
ScalarValue::Int32(Some(*s.max())),
+                                                ]) {
+                                                    Ok(_) => {}
+                                                    Err(_) => {
+                                                        max_values[i] = None;
+                                                    }
+                                                }
+                                            }
+                                            if let Some(min_value) = &mut 
min_values[i] {
+                                                match min_value.update(&[
+                                                    
ScalarValue::Int32(Some(*s.min())),
+                                                ]) {
+                                                    Ok(_) => {}
+                                                    Err(_) => {
+                                                        min_values[i] = None;
+                                                    }
+                                                }
+                                            }
+                                        }
+                                    }
+                                }
+                                ParquetStatistics::Int64(s) => {
+                                    if let DataType::Int64 = 
fields[i].data_type() {
+                                        if s.has_min_max_set() {
+                                            if let Some(max_value) = &mut 
max_values[i] {
+                                                match max_value.update(&[
+                                                    
ScalarValue::Int64(Some(*s.max())),
+                                                ]) {
+                                                    Ok(_) => {}
+                                                    Err(_) => {
+                                                        max_values[i] = None;
+                                                    }
+                                                }
+                                            }
+                                            if let Some(min_value) = &mut 
min_values[i] {
+                                                match min_value.update(&[
+                                                    
ScalarValue::Int64(Some(*s.min())),
+                                                ]) {
+                                                    Ok(_) => {}
+                                                    Err(_) => {
+                                                        min_values[i] = None;
+                                                    }
+                                                }
+                                            }
+                                        }
+                                    }
+                                }
+                                ParquetStatistics::Float(s) => {
+                                    if let DataType::Float32 = 
fields[i].data_type() {
+                                        if s.has_min_max_set() {
+                                            if let Some(max_value) = &mut 
max_values[i] {
+                                                match max_value.update(&[
+                                                    
ScalarValue::Float32(Some(*s.max())),
+                                                ]) {
+                                                    Ok(_) => {}
+                                                    Err(_) => {
+                                                        max_values[i] = None;
+                                                    }
+                                                }
+                                            }
+                                            if let Some(min_value) = &mut 
min_values[i] {
+                                                match min_value.update(&[
+                                                    
ScalarValue::Float32(Some(*s.min())),
+                                                ]) {
+                                                    Ok(_) => {}
+                                                    Err(_) => {
+                                                        min_values[i] = None;
+                                                    }
+                                                }
+                                            }
+                                        }
+                                    }
+                                }
+                                ParquetStatistics::Double(s) => {
+                                    if let DataType::Float64 = 
fields[i].data_type() {
+                                        if s.has_min_max_set() {
+                                            if let Some(max_value) = &mut 
max_values[i] {
+                                                match max_value.update(&[
+                                                    
ScalarValue::Float64(Some(*s.max())),
+                                                ]) {
+                                                    Ok(_) => {}
+                                                    Err(_) => {
+                                                        max_values[i] = None;
+                                                    }
+                                                }
+                                            }
+                                            if let Some(min_value) = &mut 
min_values[i] {
+                                                match min_value.update(&[
+                                                    
ScalarValue::Float64(Some(*s.min())),
+                                                ]) {
+                                                    Ok(_) => {}
+                                                    Err(_) => {
+                                                        min_values[i] = None;
+                                                    }
+                                                }
+                                            }
+                                        }
+                                    }
+                                }
+                                _ => {}
+                            }
+                        }
+                    }
+
                     if limit.map(|x| num_rows >= x as i64).unwrap_or(false) {
                         limit_exhausted = true;
                         break;
                     }
                 }
             }
-
-            let column_stats = null_counts
-                .iter()
-                .map(|null_count| ColumnStatistics {
-                    null_count: Some(*null_count as usize),
-                    max_value: None,
-                    min_value: None,
-                    distinct_count: None,
+            let column_stats = (0..num_fields)
+                .map(|i| {
+                    let max_value = match &max_values[i] {
+                        Some(max_value) => max_value.evaluate().ok(),
+                        None => None,
+                    };
+                    let min_value = match &min_values[i] {
+                        Some(min_value) => min_value.evaluate().ok(),
+                        None => None,
+                    };
+                    ColumnStatistics {
+                        null_count: Some(null_counts[i] as usize),
+                        max_value,
+                        min_value,
+                        distinct_count: None,
+                    }
                 })
                 .collect();
 
@@ -301,7 +466,17 @@ impl ParquetExec {
         let mut num_rows: Option<usize> = None;
         let mut total_byte_size: Option<usize> = None;
         let mut null_counts: Vec<usize> = vec![0; schema.fields().len()];
-        let mut has_null_counts = false;
+        let mut has_statistics = false;
+        let mut max_values = schema
+            .fields()
+            .iter()
+            .map(|field| MaxAccumulator::try_new(field.data_type()).ok())
+            .collect::<Vec<_>>();
+        let mut min_values = schema
+            .fields()
+            .iter()
+            .map(|field| MinAccumulator::try_new(field.data_type()).ok())
+            .collect::<Vec<_>>();
         for part in &partitions {
             if let Some(n) = part.statistics.num_rows {
                 num_rows = Some(num_rows.unwrap_or(0) + n)
@@ -312,22 +487,57 @@ impl ParquetExec {
             if let Some(x) = &part.statistics.column_statistics {
                 let part_nulls: Vec<Option<usize>> =
                     x.iter().map(|c| c.null_count).collect();
-                has_null_counts = true;
+                has_statistics = true;
+
+                let part_max_values: Vec<Option<ScalarValue>> =
+                    x.iter().map(|c| c.max_value.clone()).collect();
+                let part_min_values: Vec<Option<ScalarValue>> =
+                    x.iter().map(|c| c.min_value.clone()).collect();
 
                 for &i in projection.iter() {
                     null_counts[i] = part_nulls[i].unwrap_or(0);
+                    if let Some(part_max_value) = part_max_values[i].clone() {
+                        if let Some(max_value) = &mut max_values[i] {
+                            match max_value.update(&[part_max_value]) {
+                                Ok(_) => {}
+                                Err(_) => {
+                                    max_values[i] = None;
+                                }
+                            }
+                        }
+                    }
+                    if let Some(part_min_value) = part_min_values[i].clone() {
+                        if let Some(min_value) = &mut min_values[i] {
+                            match min_value.update(&[part_min_value]) {
+                                Ok(_) => {}
+                                Err(_) => {
+                                    min_values[i] = None;
+                                }
+                            }
+                        }
+                    }
                 }
             }
         }
-        let column_stats = if has_null_counts {
+
+        let column_stats = if has_statistics {
             Some(
-                null_counts
-                    .iter()
-                    .map(|null_count| ColumnStatistics {
-                        null_count: Some(*null_count),
-                        distinct_count: None,
-                        max_value: None,
-                        min_value: None,
+                (0..schema.fields().len())
+                    .map(|i| {
+                        let max_value = match &max_values[i] {
+                            Some(max_value) => max_value.evaluate().ok(),
+                            None => None,
+                        };
+                        let min_value = match &min_values[i] {
+                            Some(min_value) => min_value.evaluate().ok(),
+                            None => None,
+                        };
+                        ColumnStatistics {
+                            null_count: Some(null_counts[i] as usize),
+                            max_value,
+                            min_value,
+                            distinct_count: None,
+                        }
                     })
                     .collect(),
             )
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index 90c9bf7..3896055 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -402,6 +402,8 @@ impl ScalarValue {
                 | ScalarValue::Int64(None)
                 | ScalarValue::Float32(None)
                 | ScalarValue::Float64(None)
+                | ScalarValue::Date32(None)
+                | ScalarValue::Date64(None)
                 | ScalarValue::Utf8(None)
                 | ScalarValue::LargeUtf8(None)
                 | ScalarValue::List(None, _)

Reply via email to