This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new abc14eb8f fix: correct CountWildcardRule and move analyzer into a new
directory. (#5671)
abc14eb8f is described below
commit abc14eb8f2e9afc3ea5c3ad0be79a5efa3e268ce
Author: jakevin <[email protected]>
AuthorDate: Wed Mar 22 03:50:21 2023 +0800
fix: correct CountWildcardRule and move analyzer into a new directory.
(#5671)
* refactor: move analyzer to new dir and polish CountWildcardRule.
* fix typo
* correct rule.
---
datafusion/core/tests/dataframe.rs | 6 +-
.../src/{ => analyzer}/count_wildcard_rule.rs | 72 +++++++++++-----------
.../optimizer/src/{analyzer.rs => analyzer/mod.rs} | 4 +-
datafusion/optimizer/src/eliminate_filter.rs | 2 +-
datafusion/optimizer/src/lib.rs | 1 -
5 files changed, 43 insertions(+), 42 deletions(-)
diff --git a/datafusion/core/tests/dataframe.rs
b/datafusion/core/tests/dataframe.rs
index 9da5bf2b8..23c6623ab 100644
--- a/datafusion/core/tests/dataframe.rs
+++ b/datafusion/core/tests/dataframe.rs
@@ -51,16 +51,18 @@ async fn count_wildcard() -> Result<()> {
let sql_results = ctx
.sql("select count(*) from alltypes_tiny_pages")
.await?
+ .select(vec![count(Expr::Wildcard)])?
.explain(false, false)?
.collect()
.await?;
+ // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze
all node instead of just top node.
let df_results = ctx
.table("alltypes_tiny_pages")
.await?
.aggregate(vec![], vec![count(Expr::Wildcard)])?
- .explain(false, false)
- .unwrap()
+ .select(vec![count(Expr::Wildcard)])?
+ .explain(false, false)?
.collect()
.await?;
diff --git a/datafusion/optimizer/src/count_wildcard_rule.rs
b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
similarity index 61%
rename from datafusion/optimizer/src/count_wildcard_rule.rs
rename to datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index 416bd0337..4b4c603bc 100644
--- a/datafusion/optimizer/src/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -15,15 +15,17 @@
// specific language governing permissions and limitations
// under the License.
-use crate::analyzer::AnalyzerRule;
use datafusion_common::config::ConfigOptions;
use datafusion_common::Result;
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan,
Window};
-use std::ops::Deref;
-use std::sync::Arc;
+use crate::analyzer::AnalyzerRule;
+use crate::rewrite::TreeNodeRewritable;
+
+/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
+/// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473.
pub struct CountWildcardRule {}
impl Default for CountWildcardRule {
@@ -39,35 +41,7 @@ impl CountWildcardRule {
}
impl AnalyzerRule for CountWildcardRule {
fn analyze(&self, plan: &LogicalPlan, _: &ConfigOptions) ->
Result<LogicalPlan> {
- let new_plan = match plan {
- LogicalPlan::Window(window) => {
- let inputs = plan.inputs();
- let window_expr = window.clone().window_expr;
- let window_expr = handle_wildcard(window_expr).unwrap();
- LogicalPlan::Window(Window {
- input: Arc::new(inputs.get(0).unwrap().deref().clone()),
- window_expr,
- schema: plan.schema().clone(),
- })
- }
-
- LogicalPlan::Aggregate(aggregate) => {
- let inputs = plan.inputs();
- let aggr_expr = aggregate.clone().aggr_expr;
- let aggr_expr = handle_wildcard(aggr_expr).unwrap();
- LogicalPlan::Aggregate(
- Aggregate::try_new_with_schema(
- Arc::new(inputs.get(0).unwrap().deref().clone()),
- aggregate.clone().group_expr,
- aggr_expr,
- plan.schema().clone(),
- )
- .unwrap(),
- )
- }
- _ => plan.clone(),
- };
- Ok(new_plan)
+ plan.clone().transform_down(&analyze_internal)
}
fn name(&self) -> &str {
@@ -75,9 +49,34 @@ impl AnalyzerRule for CountWildcardRule {
}
}
-//handle Count(Expr:Wildcard) with DataFrame API
-pub fn handle_wildcard(exprs: Vec<Expr>) -> Result<Vec<Expr>> {
- let exprs: Vec<Expr> = exprs
+fn analyze_internal(plan: LogicalPlan) -> Result<Option<LogicalPlan>> {
+ match plan {
+ LogicalPlan::Window(window) => {
+ let window_expr = handle_wildcard(&window.window_expr);
+ Ok(Some(LogicalPlan::Window(Window {
+ input: window.input.clone(),
+ window_expr,
+ schema: window.schema,
+ })))
+ }
+ LogicalPlan::Aggregate(agg) => {
+ let aggr_expr = handle_wildcard(&agg.aggr_expr);
+ Ok(Some(LogicalPlan::Aggregate(
+ Aggregate::try_new_with_schema(
+ agg.input.clone(),
+ agg.group_expr.clone(),
+ aggr_expr,
+ agg.schema,
+ )?,
+ )))
+ }
+ _ => Ok(None),
+ }
+}
+
+// handle Count(Expr:Wildcard) with DataFrame API
+pub fn handle_wildcard(exprs: &[Expr]) -> Vec<Expr> {
+ exprs
.iter()
.map(|expr| match expr {
Expr::AggregateFunction(AggregateFunction {
@@ -96,6 +95,5 @@ pub fn handle_wildcard(exprs: Vec<Expr>) -> Result<Vec<Expr>>
{
},
_ => expr.clone(),
})
- .collect();
- Ok(exprs)
+ .collect()
}
diff --git a/datafusion/optimizer/src/analyzer.rs
b/datafusion/optimizer/src/analyzer/mod.rs
similarity index 98%
rename from datafusion/optimizer/src/analyzer.rs
rename to datafusion/optimizer/src/analyzer/mod.rs
index e999eb241..0982198bb 100644
--- a/datafusion/optimizer/src/analyzer.rs
+++ b/datafusion/optimizer/src/analyzer/mod.rs
@@ -15,7 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-use crate::count_wildcard_rule::CountWildcardRule;
+mod count_wildcard_rule;
+
+use crate::analyzer::count_wildcard_rule::CountWildcardRule;
use crate::rewrite::TreeNodeRewritable;
use datafusion_common::config::ConfigOptions;
use datafusion_common::{DataFusionError, Result};
diff --git a/datafusion/optimizer/src/eliminate_filter.rs
b/datafusion/optimizer/src/eliminate_filter.rs
index c5dc1711c..c97906a81 100644
--- a/datafusion/optimizer/src/eliminate_filter.rs
+++ b/datafusion/optimizer/src/eliminate_filter.rs
@@ -27,7 +27,7 @@ use datafusion_expr::{
use crate::{OptimizerConfig, OptimizerRule};
-/// Optimization rule that elimanate the scalar value (true/false) filter with
an [LogicalPlan::EmptyRelation]
+/// Optimization rule that eliminate the scalar value (true/false) filter with
an [LogicalPlan::EmptyRelation]
#[derive(Default)]
pub struct EliminateFilter;
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 2a97c96e7..4be7bb370 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -45,7 +45,6 @@ pub mod type_coercion;
pub mod unwrap_cast_in_comparison;
pub mod utils;
-pub mod count_wildcard_rule;
#[cfg(test)]
pub mod test;