This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new e161cd65bc fix `NamedStructField should be rewritten in
OperatorToFunction` in subquery regression (change `ApplyFunctionRewrites` to
use TreeNode API (#10032)
e161cd65bc is described below
commit e161cd65bc910a166ead9d93a17295c25cc08a3c
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri Apr 12 07:23:45 2024 -0400
fix `NamedStructField should be rewritten in OperatorToFunction` in
subquery regression (change `ApplyFunctionRewrites` to use TreeNode API (#10032)
* fix NamedStructField should be rewritten in OperatorToFunction in subquery
* Use TreeNode rewriter
---
.../optimizer/src/analyzer/function_rewrite.rs | 99 ++++++++--------------
datafusion/optimizer/src/utils.rs | 44 ++++++++++
datafusion/sqllogictest/test_files/subquery.slt | 55 ++++++++++++
3 files changed, 133 insertions(+), 65 deletions(-)
diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs
b/datafusion/optimizer/src/analyzer/function_rewrite.rs
index 78f65c5b82..deb493e099 100644
--- a/datafusion/optimizer/src/analyzer/function_rewrite.rs
+++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs
@@ -19,11 +19,13 @@
use super::AnalyzerRule;
use datafusion_common::config::ConfigOptions;
-use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
+use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{DFSchema, Result};
-use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite};
+
+use crate::utils::NamePreserver;
+use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::utils::merge_schema;
-use datafusion_expr::{Expr, LogicalPlan};
+use datafusion_expr::LogicalPlan;
use std::sync::Arc;
/// Analyzer rule that invokes [`FunctionRewrite`]s on expressions
@@ -37,36 +39,18 @@ impl ApplyFunctionRewrites {
pub fn new(function_rewrites: Vec<Arc<dyn FunctionRewrite + Send + Sync>>)
-> Self {
Self { function_rewrites }
}
-}
-
-impl AnalyzerRule for ApplyFunctionRewrites {
- fn name(&self) -> &str {
- "apply_function_rewrites"
- }
-
- fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) ->
Result<LogicalPlan> {
- self.analyze_internal(&plan, options)
- }
-}
-impl ApplyFunctionRewrites {
- fn analyze_internal(
+ /// Rewrite a single plan, and all its expressions using the provided
rewriters
+ fn rewrite_plan(
&self,
- plan: &LogicalPlan,
+ plan: LogicalPlan,
options: &ConfigOptions,
- ) -> Result<LogicalPlan> {
- // optimize child plans first
- let new_inputs = plan
- .inputs()
- .iter()
- .map(|p| self.analyze_internal(p, options))
- .collect::<Result<Vec<_>>>()?;
-
+ ) -> Result<Transformed<LogicalPlan>> {
// get schema representing all available input fields. This is used
for data type
// resolution only, so order does not matter here
- let mut schema = merge_schema(new_inputs.iter().collect());
+ let mut schema = merge_schema(plan.inputs());
- if let LogicalPlan::TableScan(ts) = plan {
+ if let LogicalPlan::TableScan(ts) = &plan {
let source_schema = DFSchema::try_from_qualified_schema(
ts.table_name.clone(),
&ts.source.schema(),
@@ -74,49 +58,34 @@ impl ApplyFunctionRewrites {
schema.merge(&source_schema);
}
- let mut expr_rewrite = OperatorToFunctionRewriter {
- function_rewrites: &self.function_rewrites,
- options,
- schema: &schema,
- };
+ let name_preserver = NamePreserver::new(&plan);
+
+ plan.map_expressions(|expr| {
+ let original_name = name_preserver.save(&expr)?;
- let new_expr = plan
- .expressions()
- .into_iter()
- .map(|expr| {
- // ensure names don't change:
- // https://github.com/apache/arrow-datafusion/issues/3555
- rewrite_preserving_name(expr, &mut expr_rewrite)
- })
- .collect::<Result<Vec<_>>>()?;
+ // recursively transform the expression, applying the rewrites at
each step
+ let result = expr.transform_up(&|expr| {
+ let mut result = Transformed::no(expr);
+ for rewriter in self.function_rewrites.iter() {
+ result = result.transform_data(|expr| {
+ rewriter.rewrite(expr, &schema, options)
+ })?;
+ }
+ Ok(result)
+ })?;
- plan.with_new_exprs(new_expr, new_inputs)
+ result.map_data(|expr| original_name.restore(expr))
+ })
}
}
-struct OperatorToFunctionRewriter<'a> {
- function_rewrites: &'a [Arc<dyn FunctionRewrite + Send + Sync>],
- options: &'a ConfigOptions,
- schema: &'a DFSchema,
-}
-
-impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> {
- type Node = Expr;
- fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> {
- // apply transforms one by one
- let mut transformed = false;
- for rewriter in self.function_rewrites.iter() {
- let result = rewriter.rewrite(expr, self.schema, self.options)?;
- if result.transformed {
- transformed = true;
- }
- expr = result.data
- }
+impl AnalyzerRule for ApplyFunctionRewrites {
+ fn name(&self) -> &str {
+ "apply_function_rewrites"
+ }
- Ok(if transformed {
- Transformed::yes(expr)
- } else {
- Transformed::no(expr)
- })
+ fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) ->
Result<LogicalPlan> {
+ plan.transform_up_with_subqueries(&|plan| self.rewrite_plan(plan,
options))
+ .map(|res| res.data)
}
}
diff --git a/datafusion/optimizer/src/utils.rs
b/datafusion/optimizer/src/utils.rs
index 560c63b188..f0605018e6 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -288,3 +288,47 @@ pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
expr_utils::merge_schema(inputs)
}
+
+/// Handles ensuring the name of rewritten expressions is not changed.
+///
+/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
+/// expression should be preserved: `3 as "1 + 2"`
+///
+/// See <https://github.com/apache/arrow-datafusion/issues/3555> for details
+pub struct NamePreserver {
+ use_alias: bool,
+}
+
+/// If the name of an expression is remembered, it will be preserved when
+/// rewriting the expression
+pub struct SavedName(Option<String>);
+
+impl NamePreserver {
+ /// Create a new NamePreserver for rewriting the `expr` that is part of
the specified plan
+ pub fn new(plan: &LogicalPlan) -> Self {
+ Self {
+ use_alias: !matches!(plan, LogicalPlan::Filter(_) |
LogicalPlan::Join(_)),
+ }
+ }
+
+ pub fn save(&self, expr: &Expr) -> Result<SavedName> {
+ let original_name = if self.use_alias {
+ Some(expr.name_for_alias()?)
+ } else {
+ None
+ };
+
+ Ok(SavedName(original_name))
+ }
+}
+
+impl SavedName {
+ /// Ensures the name of the rewritten expression is preserved
+ pub fn restore(self, expr: Expr) -> Result<Expr> {
+ let Self(original_name) = self;
+ match original_name {
+ Some(name) => expr.alias_if_changed(name),
+ None => Ok(expr),
+ }
+ }
+}
diff --git a/datafusion/sqllogictest/test_files/subquery.slt
b/datafusion/sqllogictest/test_files/subquery.slt
index cc6428e514..1ae89c9159 100644
--- a/datafusion/sqllogictest/test_files/subquery.slt
+++ b/datafusion/sqllogictest/test_files/subquery.slt
@@ -1060,3 +1060,58 @@ logical_plan
Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a /
Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1)
--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a
----TableScan: t projection=[a]
+
+###
+## Ensure that operators are rewritten in subqueries
+###
+
+statement ok
+create table foo(x int) as values (1);
+
+# Show input data
+query ?
+select struct(1, 'b')
+----
+{c0: 1, c1: b}
+
+
+query T
+select (select struct(1, 'b')['c1']);
+----
+b
+
+query T
+select 'foo' || (select struct(1, 'b')['c1']);
+----
+foob
+
+query I
+SELECT * FROM (VALUES (1), (2))
+WHERE column1 IN (SELECT struct(1, 'b')['c0']);
+----
+1
+
+# also add an expression so the subquery is the output expr
+query I
+SELECT * FROM (VALUES (1), (2))
+WHERE 1+2 = 3 AND column1 IN (SELECT struct(1, 'b')['c0']);
+----
+1
+
+
+query I
+SELECT * FROM foo
+WHERE EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1,
'b')['c0'] = 1);
+----
+1
+
+# also add an expression so the subquery is the output expr
+query I
+SELECT * FROM foo
+WHERE 1+2 = 3 AND EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND
struct(1, 'b')['c0'] = 1);
+----
+1
+
+
+statement ok
+drop table foo;