This is an automated email from the ASF dual-hosted git repository.

akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 7f497b3b23 Add non-column expression equality tracking to filter exec 
(#9819)
7f497b3b23 is described below

commit 7f497b3b23d4aa2cb6336671d09b9c9837ed0d82
Author: Mustafa Akur <[email protected]>
AuthorDate: Fri Mar 29 10:34:49 2024 +0300

    Add non-column expression equality tracking to filter exec (#9819)
    
    * Add non-column expression equality tracking to filter exec
    
    * Minor changes
---
 datafusion/physical-plan/src/filter.rs        | 47 +++++++++++++--------------
 datafusion/physical-plan/src/lib.rs           |  1 -
 datafusion/sqllogictest/test_files/select.slt | 21 ++++++++++++
 3 files changed, 44 insertions(+), 25 deletions(-)

diff --git a/datafusion/physical-plan/src/filter.rs 
b/datafusion/physical-plan/src/filter.rs
index 2996152fb9..a9201f435a 100644
--- a/datafusion/physical-plan/src/filter.rs
+++ b/datafusion/physical-plan/src/filter.rs
@@ -29,7 +29,7 @@ use super::{
 };
 use crate::{
     metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
-    Column, DisplayFormatType, ExecutionPlan,
+    DisplayFormatType, ExecutionPlan,
 };
 
 use arrow::compute::filter_record_batch;
@@ -192,9 +192,7 @@ impl FilterExec {
         let mut eq_properties = input.equivalence_properties().clone();
         let (equal_pairs, _) = collect_columns_from_predicate(predicate);
         for (lhs, rhs) in equal_pairs {
-            let lhs_expr = Arc::new(lhs.clone()) as _;
-            let rhs_expr = Arc::new(rhs.clone()) as _;
-            eq_properties.add_equal_conditions(&lhs_expr, &rhs_expr)
+            eq_properties.add_equal_conditions(lhs, rhs)
         }
         // Add the columns that have only one viable value (singleton) after
         // filtering to constants.
@@ -405,34 +403,33 @@ impl RecordBatchStream for FilterExecStream {
 
 /// Return the equals Column-Pairs and Non-equals Column-Pairs
 fn collect_columns_from_predicate(predicate: &Arc<dyn PhysicalExpr>) -> 
EqualAndNonEqual {
-    let mut eq_predicate_columns = Vec::<(&Column, &Column)>::new();
-    let mut ne_predicate_columns = Vec::<(&Column, &Column)>::new();
+    let mut eq_predicate_columns = Vec::<PhysicalExprPairRef>::new();
+    let mut ne_predicate_columns = Vec::<PhysicalExprPairRef>::new();
 
     let predicates = split_conjunction(predicate);
     predicates.into_iter().for_each(|p| {
         if let Some(binary) = p.as_any().downcast_ref::<BinaryExpr>() {
-            if let (Some(left_column), Some(right_column)) = (
-                binary.left().as_any().downcast_ref::<Column>(),
-                binary.right().as_any().downcast_ref::<Column>(),
-            ) {
-                match binary.op() {
-                    Operator::Eq => {
-                        eq_predicate_columns.push((left_column, right_column))
-                    }
-                    Operator::NotEq => {
-                        ne_predicate_columns.push((left_column, right_column))
-                    }
-                    _ => {}
+            match binary.op() {
+                Operator::Eq => {
+                    eq_predicate_columns.push((binary.left(), binary.right()))
+                }
+                Operator::NotEq => {
+                    ne_predicate_columns.push((binary.left(), binary.right()))
                 }
+                _ => {}
             }
         }
     });
 
     (eq_predicate_columns, ne_predicate_columns)
 }
+
+/// Pair of `Arc<dyn PhysicalExpr>`s
+pub type PhysicalExprPairRef<'a> = (&'a Arc<dyn PhysicalExpr>, &'a Arc<dyn 
PhysicalExpr>);
+
 /// The equals Column-Pairs and Non-equals Column-Pairs in the Predicates
 pub type EqualAndNonEqual<'a> =
-    (Vec<(&'a Column, &'a Column)>, Vec<(&'a Column, &'a Column)>);
+    (Vec<PhysicalExprPairRef<'a>>, Vec<PhysicalExprPairRef<'a>>);
 
 #[cfg(test)]
 mod tests {
@@ -482,14 +479,16 @@ mod tests {
         )?;
 
         let (equal_pairs, ne_pairs) = 
collect_columns_from_predicate(&predicate);
+        assert_eq!(2, equal_pairs.len());
+        assert!(equal_pairs[0].0.eq(&col("c2", &schema)?));
+        assert!(equal_pairs[0].1.eq(&lit(4u32)));
 
-        assert_eq!(1, equal_pairs.len());
-        assert_eq!(equal_pairs[0].0.name(), "c2");
-        assert_eq!(equal_pairs[0].1.name(), "c9");
+        assert!(equal_pairs[1].0.eq(&col("c2", &schema)?));
+        assert!(equal_pairs[1].1.eq(&col("c9", &schema)?));
 
         assert_eq!(1, ne_pairs.len());
-        assert_eq!(ne_pairs[0].0.name(), "c1");
-        assert_eq!(ne_pairs[0].1.name(), "c13");
+        assert!(ne_pairs[0].0.eq(&col("c1", &schema)?));
+        assert!(ne_pairs[0].1.eq(&col("c13", &schema)?));
 
         Ok(())
     }
diff --git a/datafusion/physical-plan/src/lib.rs 
b/datafusion/physical-plan/src/lib.rs
index 4b4b37f8b5..3e8e439c9a 100644
--- a/datafusion/physical-plan/src/lib.rs
+++ b/datafusion/physical-plan/src/lib.rs
@@ -33,7 +33,6 @@ use datafusion_common::config::ConfigOptions;
 use datafusion_common::utils::DataPtr;
 use datafusion_common::Result;
 use datafusion_execution::TaskContext;
-use datafusion_physical_expr::expressions::Column;
 use datafusion_physical_expr::{
     EquivalenceProperties, LexOrdering, PhysicalSortExpr, 
PhysicalSortRequirement,
 };
diff --git a/datafusion/sqllogictest/test_files/select.slt 
b/datafusion/sqllogictest/test_files/select.slt
index 3a5c6497eb..ad4b0df1a5 100644
--- a/datafusion/sqllogictest/test_files/select.slt
+++ b/datafusion/sqllogictest/test_files/select.slt
@@ -1386,6 +1386,27 @@ AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], 
aggr=[COUNT(*)]
 --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
 ----------CsvExec: file_groups={1 group: 
[[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2], 
has_header=true
 
+# FilterExec can track equality of non-column expressions.
+# plan below shouldn't have a SortExec because given column 'a' is ordered.
+# 'CAST(ROUND(b) as INT)' is also ordered. After filter is applied.
+query TT
+EXPLAIN SELECT *
+FROM annotated_data_finite2
+WHERE CAST(ROUND(b) as INT) = a
+ORDER BY CAST(ROUND(b) as INT);
+----
+logical_plan
+Sort: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) ASC 
NULLS LAST
+--Filter: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = 
annotated_data_finite2.a
+----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], 
partial_filters=[CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS 
Int32) = annotated_data_finite2.a]
+physical_plan
+SortPreservingMergeExec: [CAST(round(CAST(b@2 AS Float64)) AS Int32) ASC NULLS 
LAST]
+--CoalesceBatchesExec: target_batch_size=8192
+----FilterExec: CAST(round(CAST(b@2 AS Float64)) AS Int32) = a@1
+------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
+--------CsvExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, 
b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC 
NULLS LAST], has_header=true
+
+
 statement ok
 drop table annotated_data_finite2;
 

Reply via email to