This is an automated email from the ASF dual-hosted git repository.

goldmedal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 0b45b9a2dd Improve TableScan with filters pushdown unparsing (joins) 
(#13132)
0b45b9a2dd is described below

commit 0b45b9a2dd84e30a68d8701b627293e6ac803643
Author: Sergei Grebnov <[email protected]>
AuthorDate: Tue Oct 29 02:20:05 2024 -0700

    Improve TableScan with filters pushdown unparsing (joins) (#13132)
    
    * Improve TableScan with filters pushdown unparsing (joins)
    
    * Fix formatting
    
    * Add test with filters before and after join
---
 datafusion/sql/src/unparser/plan.rs       | 77 +++++++++++++++++++++----
 datafusion/sql/src/unparser/utils.rs      | 93 +++++++++++++++++++++++++++++--
 datafusion/sql/tests/cases/plan_to_sql.rs | 87 +++++++++++++++++++++++++++++
 3 files changed, 240 insertions(+), 17 deletions(-)

diff --git a/datafusion/sql/src/unparser/plan.rs 
b/datafusion/sql/src/unparser/plan.rs
index 7c9054656b..2c38a1d36c 100644
--- a/datafusion/sql/src/unparser/plan.rs
+++ b/datafusion/sql/src/unparser/plan.rs
@@ -27,8 +27,8 @@ use super::{
     },
     utils::{
         find_agg_node_within_select, find_unnest_node_within_select,
-        find_window_nodes_within_select, unproject_sort_expr, 
unproject_unnest_expr,
-        unproject_window_exprs,
+        find_window_nodes_within_select, 
try_transform_to_simple_table_scan_with_filters,
+        unproject_sort_expr, unproject_unnest_expr, unproject_window_exprs,
     },
     Unparser,
 };
@@ -39,8 +39,8 @@ use datafusion_common::{
     Column, DataFusionError, Result, TableReference,
 };
 use datafusion_expr::{
-    expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan,
-    LogicalPlanBuilder, Projection, SortExpr, TableScan,
+    expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, 
LogicalPlan,
+    LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan,
 };
 use sqlparser::ast::{self, Ident, SetExpr};
 use std::sync::Arc;
@@ -468,22 +468,77 @@ impl Unparser<'_> {
                 self.select_to_sql_recursively(input, query, select, relation)
             }
             LogicalPlan::Join(join) => {
-                let join_constraint = self.join_constraint_to_sql(
-                    join.join_constraint,
-                    &join.on,
-                    join.filter.as_ref(),
+                let mut table_scan_filters = vec![];
+
+                let left_plan =
+                    match 
try_transform_to_simple_table_scan_with_filters(&join.left)? {
+                        Some((plan, filters)) => {
+                            table_scan_filters.extend(filters);
+                            Arc::new(plan)
+                        }
+                        None => Arc::clone(&join.left),
+                    };
+
+                self.select_to_sql_recursively(
+                    left_plan.as_ref(),
+                    query,
+                    select,
+                    relation,
                 )?;
 
+                let right_plan =
+                    match 
try_transform_to_simple_table_scan_with_filters(&join.right)? {
+                        Some((plan, filters)) => {
+                            table_scan_filters.extend(filters);
+                            Arc::new(plan)
+                        }
+                        None => Arc::clone(&join.right),
+                    };
+
                 let mut right_relation = RelationBuilder::default();
 
                 self.select_to_sql_recursively(
-                    join.left.as_ref(),
+                    right_plan.as_ref(),
                     query,
                     select,
-                    relation,
+                    &mut right_relation,
                 )?;
+
+                let join_filters = if table_scan_filters.is_empty() {
+                    join.filter.clone()
+                } else {
+                    // Combine `table_scan_filters` into a single filter using 
`AND`
+                    let Some(combined_filters) =
+                        table_scan_filters.into_iter().reduce(|acc, filter| {
+                            Expr::BinaryExpr(BinaryExpr {
+                                left: Box::new(acc),
+                                op: Operator::And,
+                                right: Box::new(filter),
+                            })
+                        })
+                    else {
+                        return internal_err!("Failed to combine TableScan 
filters");
+                    };
+
+                    // Combine `join.filter` with `combined_filters` using 
`AND`
+                    match &join.filter {
+                        Some(filter) => Some(Expr::BinaryExpr(BinaryExpr {
+                            left: Box::new(filter.clone()),
+                            op: Operator::And,
+                            right: Box::new(combined_filters),
+                        })),
+                        None => Some(combined_filters),
+                    }
+                };
+
+                let join_constraint = self.join_constraint_to_sql(
+                    join.join_constraint,
+                    &join.on,
+                    join_filters.as_ref(),
+                )?;
+
                 self.select_to_sql_recursively(
-                    join.right.as_ref(),
+                    right_plan.as_ref(),
                     query,
                     select,
                     &mut right_relation,
diff --git a/datafusion/sql/src/unparser/utils.rs 
b/datafusion/sql/src/unparser/utils.rs
index d3d1bf3513..284956cef1 100644
--- a/datafusion/sql/src/unparser/utils.rs
+++ b/datafusion/sql/src/unparser/utils.rs
@@ -15,20 +15,20 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::cmp::Ordering;
+use std::{cmp::Ordering, sync::Arc, vec};
 
 use datafusion_common::{
     internal_err,
-    tree_node::{Transformed, TreeNode},
-    Column, Result, ScalarValue,
+    tree_node::{Transformed, TransformedResult, TreeNode},
+    Column, DataFusionError, Result, ScalarValue,
 };
 use datafusion_expr::{
-    expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, 
Projection,
-    SortExpr, Unnest, Window,
+    expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan,
+    LogicalPlanBuilder, Projection, SortExpr, Unnest, Window,
 };
 use sqlparser::ast;
 
-use super::{dialect::DateFieldExtractStyle, Unparser};
+use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, 
Unparser};
 
 /// Recursively searches children of [LogicalPlan] to find an Aggregate node 
if exists
 /// prior to encountering a Join, TableScan, or a nested subquery (derived 
table factor).
@@ -288,6 +288,87 @@ pub(crate) fn unproject_sort_expr(
     Ok(sort_expr)
 }
 
+/// Iterates through the children of a [LogicalPlan] to find a TableScan node 
before encountering
+/// a Projection or any unexpected node that indicates the presence of a 
Projection (SELECT) in the plan.
+/// If a TableScan node is found, returns the TableScan node without filters, 
along with the collected filters separately.
+/// If the plan contains a Projection, returns None.
+///
+/// Note: If a table alias is present, TableScan filters are rewritten to 
reference the alias.
+///
+/// LogicalPlan example:
+///   Filter: ta.j1_id < 5
+///     Alias:  ta
+///       TableScan: j1, j1_id > 10
+///
+/// Will return LogicalPlan below:
+///     Alias:  ta
+///       TableScan: j1
+/// And filters: [ta.j1_id < 5, ta.j1_id > 10]
+pub(crate) fn try_transform_to_simple_table_scan_with_filters(
+    plan: &LogicalPlan,
+) -> Result<Option<(LogicalPlan, Vec<Expr>)>> {
+    let mut filters: Vec<Expr> = vec![];
+    let mut plan_stack = vec![plan];
+    let mut table_alias = None;
+
+    while let Some(current_plan) = plan_stack.pop() {
+        match current_plan {
+            LogicalPlan::SubqueryAlias(alias) => {
+                table_alias = Some(alias.alias.clone());
+                plan_stack.push(alias.input.as_ref());
+            }
+            LogicalPlan::Filter(filter) => {
+                filters.push(filter.predicate.clone());
+                plan_stack.push(filter.input.as_ref());
+            }
+            LogicalPlan::TableScan(table_scan) => {
+                let table_schema = table_scan.source.schema();
+                // optional rewriter if table has an alias
+                let mut filter_alias_rewriter =
+                    table_alias.as_ref().map(|alias_name| TableAliasRewriter {
+                        table_schema: &table_schema,
+                        alias_name: alias_name.clone(),
+                    });
+
+                // rewrite filters to use table alias if present
+                let table_scan_filters = table_scan
+                    .filters
+                    .iter()
+                    .cloned()
+                    .map(|expr| {
+                        if let Some(ref mut rewriter) = filter_alias_rewriter {
+                            expr.rewrite(rewriter).data()
+                        } else {
+                            Ok(expr)
+                        }
+                    })
+                    .collect::<Result<Vec<_>, DataFusionError>>()?;
+
+                filters.extend(table_scan_filters);
+
+                let mut builder = LogicalPlanBuilder::scan(
+                    table_scan.table_name.clone(),
+                    Arc::clone(&table_scan.source),
+                    None,
+                )?;
+
+                if let Some(alias) = table_alias.take() {
+                    builder = builder.alias(alias)?;
+                }
+
+                let plan = builder.build()?;
+
+                return Ok(Some((plan, filters)));
+            }
+            _ => {
+                return Ok(None);
+            }
+        }
+    }
+
+    Ok(None)
+}
+
 /// Converts a date_part function to SQL, tailoring it to the supported date 
field extraction style.
 pub(crate) fn date_part_to_sql(
     unparser: &Unparser,
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs 
b/datafusion/sql/tests/cases/plan_to_sql.rs
index 16941c5d91..ea0ccb8e4b 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -1008,6 +1008,93 @@ fn test_sort_with_push_down_fetch() -> Result<()> {
     Ok(())
 }
 
+#[test]
+fn test_join_with_table_scan_filters() -> Result<()> {
+    let schema_left = Schema::new(vec![
+        Field::new("id", DataType::Utf8, false),
+        Field::new("name", DataType::Utf8, false),
+    ]);
+
+    let schema_right = Schema::new(vec![
+        Field::new("id", DataType::Utf8, false),
+        Field::new("age", DataType::Utf8, false),
+    ]);
+
+    let left_plan = table_scan_with_filters(
+        Some("left_table"),
+        &schema_left,
+        None,
+        vec![col("name").like(lit("some_name"))],
+    )?
+    .alias("left")?
+    .build()?;
+
+    let right_plan = table_scan_with_filters(
+        Some("right_table"),
+        &schema_right,
+        None,
+        vec![col("age").gt(lit(10))],
+    )?
+    .build()?;
+
+    let join_plan_with_filter = LogicalPlanBuilder::from(left_plan.clone())
+        .join(
+            right_plan.clone(),
+            datafusion_expr::JoinType::Inner,
+            (vec!["left.id"], vec!["right_table.id"]),
+            Some(col("left.id").gt(lit(5))),
+        )?
+        .build()?;
+
+    let sql = plan_to_sql(&join_plan_with_filter)?;
+
+    let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table 
ON "left".id = right_table.id AND (("left".id > 5) AND ("left"."name" LIKE 
'some_name' AND (age > 10)))"#;
+
+    assert_eq!(sql.to_string(), expected_sql);
+
+    let join_plan_no_filter = LogicalPlanBuilder::from(left_plan.clone())
+        .join(
+            right_plan,
+            datafusion_expr::JoinType::Inner,
+            (vec!["left.id"], vec!["right_table.id"]),
+            None,
+        )?
+        .build()?;
+
+    let sql = plan_to_sql(&join_plan_no_filter)?;
+
+    let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table 
ON "left".id = right_table.id AND ("left"."name" LIKE 'some_name' AND (age > 
10))"#;
+
+    assert_eq!(sql.to_string(), expected_sql);
+
+    let right_plan_with_filter = table_scan_with_filters(
+        Some("right_table"),
+        &schema_right,
+        None,
+        vec![col("age").gt(lit(10))],
+    )?
+    .filter(col("right_table.name").eq(lit("before_join_filter_val")))?
+    .build()?;
+
+    let join_plan_multiple_filters = 
LogicalPlanBuilder::from(left_plan.clone())
+        .join(
+            right_plan_with_filter,
+            datafusion_expr::JoinType::Inner,
+            (vec!["left.id"], vec!["right_table.id"]),
+            Some(col("left.id").gt(lit(5))),
+        )?
+        .filter(col("left.name").eq(lit("after_join_filter_val")))?
+        .build()?;
+
+    let sql = plan_to_sql(&join_plan_multiple_filters)?;
+
+    let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table 
ON "left".id = right_table.id AND (("left".id > 5) AND (("left"."name" LIKE 
'some_name' AND (right_table."name" = 'before_join_filter_val')) AND (age > 
10))) WHERE ("left"."name" = 'after_join_filter_val')"#;
+
+    assert_eq!(sql.to_string(), expected_sql);
+
+    Ok(())
+}
+
 #[test]
 fn test_interval_lhs_eq() {
     sql_round_trip(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to