This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e161cd65bc fix `NamedStructField should be rewritten in 
OperatorToFunction` in subquery regression (change `ApplyFunctionRewrites` to 
use TreeNode API (#10032)
e161cd65bc is described below

commit e161cd65bc910a166ead9d93a17295c25cc08a3c
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri Apr 12 07:23:45 2024 -0400

    fix `NamedStructField should be rewritten in OperatorToFunction` in 
subquery regression (change `ApplyFunctionRewrites` to use TreeNode API (#10032)
    
    * fix NamedStructField should be rewritten in OperatorToFunction in subquery
    
    * Use TreeNode rewriter
---
 .../optimizer/src/analyzer/function_rewrite.rs     | 99 ++++++++--------------
 datafusion/optimizer/src/utils.rs                  | 44 ++++++++++
 datafusion/sqllogictest/test_files/subquery.slt    | 55 ++++++++++++
 3 files changed, 133 insertions(+), 65 deletions(-)

diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs 
b/datafusion/optimizer/src/analyzer/function_rewrite.rs
index 78f65c5b82..deb493e099 100644
--- a/datafusion/optimizer/src/analyzer/function_rewrite.rs
+++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs
@@ -19,11 +19,13 @@
 
 use super::AnalyzerRule;
 use datafusion_common::config::ConfigOptions;
-use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
+use datafusion_common::tree_node::{Transformed, TreeNode};
 use datafusion_common::{DFSchema, Result};
-use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite};
+
+use crate::utils::NamePreserver;
+use datafusion_expr::expr_rewriter::FunctionRewrite;
 use datafusion_expr::utils::merge_schema;
-use datafusion_expr::{Expr, LogicalPlan};
+use datafusion_expr::LogicalPlan;
 use std::sync::Arc;
 
 /// Analyzer rule that invokes [`FunctionRewrite`]s on expressions
@@ -37,36 +39,18 @@ impl ApplyFunctionRewrites {
     pub fn new(function_rewrites: Vec<Arc<dyn FunctionRewrite + Send + Sync>>) 
-> Self {
         Self { function_rewrites }
     }
-}
-
-impl AnalyzerRule for ApplyFunctionRewrites {
-    fn name(&self) -> &str {
-        "apply_function_rewrites"
-    }
-
-    fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> 
Result<LogicalPlan> {
-        self.analyze_internal(&plan, options)
-    }
-}
 
-impl ApplyFunctionRewrites {
-    fn analyze_internal(
+    /// Rewrite a single plan, and all its expressions using the provided 
rewriters
+    fn rewrite_plan(
         &self,
-        plan: &LogicalPlan,
+        plan: LogicalPlan,
         options: &ConfigOptions,
-    ) -> Result<LogicalPlan> {
-        // optimize child plans first
-        let new_inputs = plan
-            .inputs()
-            .iter()
-            .map(|p| self.analyze_internal(p, options))
-            .collect::<Result<Vec<_>>>()?;
-
+    ) -> Result<Transformed<LogicalPlan>> {
         // get schema representing all available input fields. This is used 
for data type
         // resolution only, so order does not matter here
-        let mut schema = merge_schema(new_inputs.iter().collect());
+        let mut schema = merge_schema(plan.inputs());
 
-        if let LogicalPlan::TableScan(ts) = plan {
+        if let LogicalPlan::TableScan(ts) = &plan {
             let source_schema = DFSchema::try_from_qualified_schema(
                 ts.table_name.clone(),
                 &ts.source.schema(),
@@ -74,49 +58,34 @@ impl ApplyFunctionRewrites {
             schema.merge(&source_schema);
         }
 
-        let mut expr_rewrite = OperatorToFunctionRewriter {
-            function_rewrites: &self.function_rewrites,
-            options,
-            schema: &schema,
-        };
+        let name_preserver = NamePreserver::new(&plan);
+
+        plan.map_expressions(|expr| {
+            let original_name = name_preserver.save(&expr)?;
 
-        let new_expr = plan
-            .expressions()
-            .into_iter()
-            .map(|expr| {
-                // ensure names don't change:
-                // https://github.com/apache/arrow-datafusion/issues/3555
-                rewrite_preserving_name(expr, &mut expr_rewrite)
-            })
-            .collect::<Result<Vec<_>>>()?;
+            // recursively transform the expression, applying the rewrites at 
each step
+            let result = expr.transform_up(&|expr| {
+                let mut result = Transformed::no(expr);
+                for rewriter in self.function_rewrites.iter() {
+                    result = result.transform_data(|expr| {
+                        rewriter.rewrite(expr, &schema, options)
+                    })?;
+                }
+                Ok(result)
+            })?;
 
-        plan.with_new_exprs(new_expr, new_inputs)
+            result.map_data(|expr| original_name.restore(expr))
+        })
     }
 }
-struct OperatorToFunctionRewriter<'a> {
-    function_rewrites: &'a [Arc<dyn FunctionRewrite + Send + Sync>],
-    options: &'a ConfigOptions,
-    schema: &'a DFSchema,
-}
-
-impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> {
-    type Node = Expr;
 
-    fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> {
-        // apply transforms one by one
-        let mut transformed = false;
-        for rewriter in self.function_rewrites.iter() {
-            let result = rewriter.rewrite(expr, self.schema, self.options)?;
-            if result.transformed {
-                transformed = true;
-            }
-            expr = result.data
-        }
+impl AnalyzerRule for ApplyFunctionRewrites {
+    fn name(&self) -> &str {
+        "apply_function_rewrites"
+    }
 
-        Ok(if transformed {
-            Transformed::yes(expr)
-        } else {
-            Transformed::no(expr)
-        })
+    fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> 
Result<LogicalPlan> {
+        plan.transform_up_with_subqueries(&|plan| self.rewrite_plan(plan, 
options))
+            .map(|res| res.data)
     }
 }
diff --git a/datafusion/optimizer/src/utils.rs 
b/datafusion/optimizer/src/utils.rs
index 560c63b188..f0605018e6 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -288,3 +288,47 @@ pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
 pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
     expr_utils::merge_schema(inputs)
 }
+
+/// Handles ensuring the name of rewritten expressions is not changed.
+///
+/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
+/// expression should be preserved: `3 as "1 + 2"`
+///
+/// See <https://github.com/apache/arrow-datafusion/issues/3555> for details
+pub struct NamePreserver {
+    use_alias: bool,
+}
+
+/// If the name of an expression is remembered, it will be preserved when
+/// rewriting the expression
+pub struct SavedName(Option<String>);
+
+impl NamePreserver {
+    /// Create a new NamePreserver for rewriting the `expr` that is part of 
the specified plan
+    pub fn new(plan: &LogicalPlan) -> Self {
+        Self {
+            use_alias: !matches!(plan, LogicalPlan::Filter(_) | 
LogicalPlan::Join(_)),
+        }
+    }
+
+    pub fn save(&self, expr: &Expr) -> Result<SavedName> {
+        let original_name = if self.use_alias {
+            Some(expr.name_for_alias()?)
+        } else {
+            None
+        };
+
+        Ok(SavedName(original_name))
+    }
+}
+
+impl SavedName {
+    /// Ensures the name of the rewritten expression is preserved
+    pub fn restore(self, expr: Expr) -> Result<Expr> {
+        let Self(original_name) = self;
+        match original_name {
+            Some(name) => expr.alias_if_changed(name),
+            None => Ok(expr),
+        }
+    }
+}
diff --git a/datafusion/sqllogictest/test_files/subquery.slt 
b/datafusion/sqllogictest/test_files/subquery.slt
index cc6428e514..1ae89c9159 100644
--- a/datafusion/sqllogictest/test_files/subquery.slt
+++ b/datafusion/sqllogictest/test_files/subquery.slt
@@ -1060,3 +1060,58 @@ logical_plan
 Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / 
Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1)
 --Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a
 ----TableScan: t projection=[a]
+
+###
+## Ensure that operators are rewritten in subqueries
+###
+
+statement ok
+create table foo(x int) as values (1);
+
+# Show input data
+query ?
+select struct(1, 'b')
+----
+{c0: 1, c1: b}
+
+
+query T
+select (select struct(1, 'b')['c1']);
+----
+b
+
+query T
+select 'foo' || (select struct(1, 'b')['c1']);
+----
+foob
+
+query I
+SELECT  * FROM (VALUES (1), (2))
+WHERE column1  IN (SELECT struct(1, 'b')['c0']);
+----
+1
+
+# also add an expression so the subquery is the output expr
+query I
+SELECT  * FROM (VALUES (1), (2))
+WHERE 1+2 = 3 AND column1  IN (SELECT struct(1, 'b')['c0']);
+----
+1
+
+
+query I
+SELECT  * FROM foo
+WHERE EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 
'b')['c0'] = 1);
+----
+1
+
+# also add an expression so the subquery is the output expr
+query I
+SELECT  * FROM foo
+WHERE 1+2 = 3 AND EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND 
struct(1, 'b')['c0'] = 1);
+----
+1
+
+
+statement ok
+drop table foo;

Reply via email to