This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new abc14eb8f fix: correct CountWildcardRule and move analyzer into a new 
directory. (#5671)
abc14eb8f is described below

commit abc14eb8f2e9afc3ea5c3ad0be79a5efa3e268ce
Author: jakevin <[email protected]>
AuthorDate: Wed Mar 22 03:50:21 2023 +0800

    fix: correct CountWildcardRule and move analyzer into a new directory. 
(#5671)
    
    * refactor: move analyzer to new dir and polish CountWildcardRule.
    
    * fix typo
    
    * correct rule.
---
 datafusion/core/tests/dataframe.rs                 |  6 +-
 .../src/{ => analyzer}/count_wildcard_rule.rs      | 72 +++++++++++-----------
 .../optimizer/src/{analyzer.rs => analyzer/mod.rs} |  4 +-
 datafusion/optimizer/src/eliminate_filter.rs       |  2 +-
 datafusion/optimizer/src/lib.rs                    |  1 -
 5 files changed, 43 insertions(+), 42 deletions(-)

diff --git a/datafusion/core/tests/dataframe.rs 
b/datafusion/core/tests/dataframe.rs
index 9da5bf2b8..23c6623ab 100644
--- a/datafusion/core/tests/dataframe.rs
+++ b/datafusion/core/tests/dataframe.rs
@@ -51,16 +51,18 @@ async fn count_wildcard() -> Result<()> {
     let sql_results = ctx
         .sql("select count(*) from alltypes_tiny_pages")
         .await?
+        .select(vec![count(Expr::Wildcard)])?
         .explain(false, false)?
         .collect()
         .await?;
 
+    // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze 
all node instead of just top node.
     let df_results = ctx
         .table("alltypes_tiny_pages")
         .await?
         .aggregate(vec![], vec![count(Expr::Wildcard)])?
-        .explain(false, false)
-        .unwrap()
+        .select(vec![count(Expr::Wildcard)])?
+        .explain(false, false)?
         .collect()
         .await?;
 
diff --git a/datafusion/optimizer/src/count_wildcard_rule.rs 
b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
similarity index 61%
rename from datafusion/optimizer/src/count_wildcard_rule.rs
rename to datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index 416bd0337..4b4c603bc 100644
--- a/datafusion/optimizer/src/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -15,15 +15,17 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::analyzer::AnalyzerRule;
 use datafusion_common::config::ConfigOptions;
 use datafusion_common::Result;
 use datafusion_expr::expr::AggregateFunction;
 use datafusion_expr::utils::COUNT_STAR_EXPANSION;
 use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, 
Window};
-use std::ops::Deref;
-use std::sync::Arc;
 
+use crate::analyzer::AnalyzerRule;
+use crate::rewrite::TreeNodeRewritable;
+
+/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
+/// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473.
 pub struct CountWildcardRule {}
 
 impl Default for CountWildcardRule {
@@ -39,35 +41,7 @@ impl CountWildcardRule {
 }
 impl AnalyzerRule for CountWildcardRule {
     fn analyze(&self, plan: &LogicalPlan, _: &ConfigOptions) -> 
Result<LogicalPlan> {
-        let new_plan = match plan {
-            LogicalPlan::Window(window) => {
-                let inputs = plan.inputs();
-                let window_expr = window.clone().window_expr;
-                let window_expr = handle_wildcard(window_expr).unwrap();
-                LogicalPlan::Window(Window {
-                    input: Arc::new(inputs.get(0).unwrap().deref().clone()),
-                    window_expr,
-                    schema: plan.schema().clone(),
-                })
-            }
-
-            LogicalPlan::Aggregate(aggregate) => {
-                let inputs = plan.inputs();
-                let aggr_expr = aggregate.clone().aggr_expr;
-                let aggr_expr = handle_wildcard(aggr_expr).unwrap();
-                LogicalPlan::Aggregate(
-                    Aggregate::try_new_with_schema(
-                        Arc::new(inputs.get(0).unwrap().deref().clone()),
-                        aggregate.clone().group_expr,
-                        aggr_expr,
-                        plan.schema().clone(),
-                    )
-                    .unwrap(),
-                )
-            }
-            _ => plan.clone(),
-        };
-        Ok(new_plan)
+        plan.clone().transform_down(&analyze_internal)
     }
 
     fn name(&self) -> &str {
@@ -75,9 +49,34 @@ impl AnalyzerRule for CountWildcardRule {
     }
 }
 
-//handle Count(Expr:Wildcard) with DataFrame API
-pub fn handle_wildcard(exprs: Vec<Expr>) -> Result<Vec<Expr>> {
-    let exprs: Vec<Expr> = exprs
+fn analyze_internal(plan: LogicalPlan) -> Result<Option<LogicalPlan>> {
+    match plan {
+        LogicalPlan::Window(window) => {
+            let window_expr = handle_wildcard(&window.window_expr);
+            Ok(Some(LogicalPlan::Window(Window {
+                input: window.input.clone(),
+                window_expr,
+                schema: window.schema,
+            })))
+        }
+        LogicalPlan::Aggregate(agg) => {
+            let aggr_expr = handle_wildcard(&agg.aggr_expr);
+            Ok(Some(LogicalPlan::Aggregate(
+                Aggregate::try_new_with_schema(
+                    agg.input.clone(),
+                    agg.group_expr.clone(),
+                    aggr_expr,
+                    agg.schema,
+                )?,
+            )))
+        }
+        _ => Ok(None),
+    }
+}
+
+// handle Count(Expr:Wildcard) with DataFrame API
+pub fn handle_wildcard(exprs: &[Expr]) -> Vec<Expr> {
+    exprs
         .iter()
         .map(|expr| match expr {
             Expr::AggregateFunction(AggregateFunction {
@@ -96,6 +95,5 @@ pub fn handle_wildcard(exprs: Vec<Expr>) -> Result<Vec<Expr>> 
{
             },
             _ => expr.clone(),
         })
-        .collect();
-    Ok(exprs)
+        .collect()
 }
diff --git a/datafusion/optimizer/src/analyzer.rs 
b/datafusion/optimizer/src/analyzer/mod.rs
similarity index 98%
rename from datafusion/optimizer/src/analyzer.rs
rename to datafusion/optimizer/src/analyzer/mod.rs
index e999eb241..0982198bb 100644
--- a/datafusion/optimizer/src/analyzer.rs
+++ b/datafusion/optimizer/src/analyzer/mod.rs
@@ -15,7 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::count_wildcard_rule::CountWildcardRule;
+mod count_wildcard_rule;
+
+use crate::analyzer::count_wildcard_rule::CountWildcardRule;
 use crate::rewrite::TreeNodeRewritable;
 use datafusion_common::config::ConfigOptions;
 use datafusion_common::{DataFusionError, Result};
diff --git a/datafusion/optimizer/src/eliminate_filter.rs 
b/datafusion/optimizer/src/eliminate_filter.rs
index c5dc1711c..c97906a81 100644
--- a/datafusion/optimizer/src/eliminate_filter.rs
+++ b/datafusion/optimizer/src/eliminate_filter.rs
@@ -27,7 +27,7 @@ use datafusion_expr::{
 
 use crate::{OptimizerConfig, OptimizerRule};
 
-/// Optimization rule that elimanate the scalar value (true/false) filter with 
an [LogicalPlan::EmptyRelation]
+/// Optimization rule that eliminate the scalar value (true/false) filter with 
an [LogicalPlan::EmptyRelation]
 #[derive(Default)]
 pub struct EliminateFilter;
 
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 2a97c96e7..4be7bb370 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -45,7 +45,6 @@ pub mod type_coercion;
 pub mod unwrap_cast_in_comparison;
 pub mod utils;
 
-pub mod count_wildcard_rule;
 #[cfg(test)]
 pub mod test;
 

Reply via email to