This is an automated email from the ASF dual-hosted git repository.

ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e7ac843415 Bug-fix: MemoryExec sort expressions do NOT refer to the 
projected schema (#12876)
e7ac843415 is described below

commit e7ac8434153560816220f0e1492057e61b7ad983
Author: Berkay Şahin <[email protected]>
AuthorDate: Sat Oct 12 09:02:44 2024 +0300

    Bug-fix: MemoryExec sort expressions do NOT refer to the projected schema 
(#12876)
    
    * Update memory.rs
    
    * add assert
    
    * Update memory.rs
    
    * Update memory.rs
    
    * Update memory.rs
    
    * address review
    
    * Update memory.rs
    
    * Update memory.rs
    
    * final fix
    
    * Fix comments in test_utils.rs
    
    ---------
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
 datafusion/core/src/datasource/memory.rs           |  6 +--
 datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs |  3 +-
 .../fuzz_cases/sort_preserving_repartition_fuzz.rs |  3 +-
 datafusion/core/tests/fuzz_cases/window_fuzz.rs    |  4 +-
 datafusion/core/tests/memory_limit/mod.rs          |  2 +-
 .../physical-plan/src/joins/nested_loop_join.rs    |  2 +-
 datafusion/physical-plan/src/joins/test_utils.rs   | 17 ++++---
 datafusion/physical-plan/src/memory.rs             | 58 ++++++++++++++++++++--
 datafusion/physical-plan/src/repartition/mod.rs    |  3 +-
 datafusion/physical-plan/src/union.rs              |  4 +-
 10 files changed, 78 insertions(+), 24 deletions(-)

diff --git a/datafusion/core/src/datasource/memory.rs 
b/datafusion/core/src/datasource/memory.rs
index 24a4938e7b..3c2d1b0205 100644
--- a/datafusion/core/src/datasource/memory.rs
+++ b/datafusion/core/src/datasource/memory.rs
@@ -37,14 +37,14 @@ use crate::physical_planner::create_physical_sort_exprs;
 
 use arrow::datatypes::SchemaRef;
 use arrow::record_batch::RecordBatch;
+use datafusion_catalog::Session;
 use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, 
SchemaExt};
 use datafusion_execution::TaskContext;
 use datafusion_expr::dml::InsertOp;
+use datafusion_expr::SortExpr;
 use datafusion_physical_plan::metrics::MetricsSet;
 
 use async_trait::async_trait;
-use datafusion_catalog::Session;
-use datafusion_expr::SortExpr;
 use futures::StreamExt;
 use log::debug;
 use parking_lot::Mutex;
@@ -241,7 +241,7 @@ impl TableProvider for MemTable {
                     )
                 })
                 .collect::<Result<Vec<_>>>()?;
-            exec = exec.with_sort_information(file_sort_order);
+            exec = exec.try_with_sort_information(file_sort_order)?;
         }
 
         Ok(Arc::new(exec))
diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
index b085250141..64a7514ebd 100644
--- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
@@ -395,7 +395,8 @@ async fn run_aggregate_test(input1: Vec<RecordBatch>, 
group_by_columns: Vec<&str
     let running_source = Arc::new(
         MemoryExec::try_new(&[input1.clone()], schema.clone(), None)
             .unwrap()
-            .with_sort_information(vec![sort_keys]),
+            .try_with_sort_information(vec![sort_keys])
+            .unwrap(),
     );
 
     let aggregate_expr =
diff --git 
a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
index 0cd702372f..a72affc2b0 100644
--- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
@@ -358,7 +358,8 @@ mod sp_repartition_fuzz_tests {
         let running_source = Arc::new(
             MemoryExec::try_new(&[input1.clone()], schema.clone(), None)
                 .unwrap()
-                .with_sort_information(vec![sort_keys.clone()]),
+                .try_with_sort_information(vec![sort_keys.clone()])
+                .unwrap(),
         );
         let hash_exprs = vec![col("c", &schema).unwrap()];
 
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index b9881c9f23..feffb11bf7 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -647,7 +647,7 @@ async fn run_window_test(
     ];
     let mut exec1 = Arc::new(
         MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), 
None)?
-            .with_sort_information(vec![source_sort_keys.clone()]),
+            .try_with_sort_information(vec![source_sort_keys.clone()])?,
     ) as _;
     // Table is ordered according to ORDER BY a, b, c In linear test we use 
PARTITION BY b, ORDER BY a
     // For WindowAggExec  to produce correct result it need table to be 
ordered by b,a. Hence add a sort.
@@ -673,7 +673,7 @@ async fn run_window_test(
     )?) as _;
     let exec2 = Arc::new(
         MemoryExec::try_new(&[input1.clone()], schema.clone(), None)?
-            .with_sort_information(vec![source_sort_keys.clone()]),
+            .try_with_sort_information(vec![source_sort_keys.clone()])?,
     );
     let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
         vec![create_window_expr(
diff --git a/datafusion/core/tests/memory_limit/mod.rs 
b/datafusion/core/tests/memory_limit/mod.rs
index ec66df45c7..fc2fb9afb5 100644
--- a/datafusion/core/tests/memory_limit/mod.rs
+++ b/datafusion/core/tests/memory_limit/mod.rs
@@ -840,7 +840,7 @@ impl TableProvider for SortedTableProvider {
     ) -> Result<Arc<dyn ExecutionPlan>> {
         let mem_exec =
             MemoryExec::try_new(&self.batches, self.schema(), 
projection.cloned())?
-                .with_sort_information(self.sort_information.clone());
+                .try_with_sort_information(self.sort_information.clone())?;
 
         Ok(Arc::new(mem_exec))
     }
diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs 
b/datafusion/physical-plan/src/joins/nested_loop_join.rs
index 029003374a..6068e75263 100644
--- a/datafusion/physical-plan/src/joins/nested_loop_join.rs
+++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs
@@ -780,7 +780,7 @@ mod tests {
                 };
                 sort_info.push(sort_expr);
             }
-            exec = exec.with_sort_information(vec![sort_info]);
+            exec = exec.try_with_sort_information(vec![sort_info]).unwrap();
         }
 
         Arc::new(exec)
diff --git a/datafusion/physical-plan/src/joins/test_utils.rs 
b/datafusion/physical-plan/src/joins/test_utils.rs
index 264f297ffb..090d60f0ba 100644
--- a/datafusion/physical-plan/src/joins/test_utils.rs
+++ b/datafusion/physical-plan/src/joins/test_utils.rs
@@ -289,7 +289,7 @@ macro_rules! join_expr_tests {
                     ScalarValue::$SCALAR(Some(10 as $type)),
                     (Operator::Gt, Operator::Lt),
                 ),
-                // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 
10
+                // left_col - 1 > right_col + 3 AND left_col + 3 < right_col + 
15
                 1 => gen_conjunctive_numerical_expr(
                     left_col,
                     right_col,
@@ -300,9 +300,9 @@ macro_rules! join_expr_tests {
                         Operator::Plus,
                     ),
                     ScalarValue::$SCALAR(Some(1 as $type)),
-                    ScalarValue::$SCALAR(Some(5 as $type)),
                     ScalarValue::$SCALAR(Some(3 as $type)),
-                    ScalarValue::$SCALAR(Some(10 as $type)),
+                    ScalarValue::$SCALAR(Some(3 as $type)),
+                    ScalarValue::$SCALAR(Some(15 as $type)),
                     (Operator::Gt, Operator::Lt),
                 ),
                 // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 
10
@@ -353,7 +353,8 @@ macro_rules! join_expr_tests {
                     ScalarValue::$SCALAR(Some(3 as $type)),
                     (Operator::Gt, Operator::Lt),
                 ),
-                // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col 
- 3
+                // left_col - 2 >= right_col + 5 AND left_col + 7 <= right_col 
- 3
+                // (filters all input rows)
                 5 => gen_conjunctive_numerical_expr(
                     left_col,
                     right_col,
@@ -369,7 +370,7 @@ macro_rules! join_expr_tests {
                     ScalarValue::$SCALAR(Some(3 as $type)),
                     (Operator::GtEq, Operator::LtEq),
                 ),
-                // left_col - 28 >= right_col - 11 AND left_col - 21 <= 
right_col - 39
+                // left_col + 28 >= right_col - 11 AND left_col + 21 <= 
right_col + 39
                 6 => gen_conjunctive_numerical_expr(
                     left_col,
                     right_col,
@@ -385,7 +386,7 @@ macro_rules! join_expr_tests {
                     ScalarValue::$SCALAR(Some(39 as $type)),
                     (Operator::Gt, Operator::LtEq),
                 ),
-                // left_col - 28 >= right_col - 11 AND left_col - 21 <= 
right_col + 39
+                // left_col + 28 >= right_col - 11 AND left_col - 21 <= 
right_col + 39
                 7 => gen_conjunctive_numerical_expr(
                     left_col,
                     right_col,
@@ -526,10 +527,10 @@ pub fn create_memory_table(
 ) -> Result<(Arc<dyn ExecutionPlan>, Arc<dyn ExecutionPlan>)> {
     let left_schema = left_partition[0].schema();
     let left = MemoryExec::try_new(&[left_partition], left_schema, None)?
-        .with_sort_information(left_sorted);
+        .try_with_sort_information(left_sorted)?;
     let right_schema = right_partition[0].schema();
     let right = MemoryExec::try_new(&[right_partition], right_schema, None)?
-        .with_sort_information(right_sorted);
+        .try_with_sort_information(right_sorted)?;
     Ok((Arc::new(left), Arc::new(right)))
 }
 
diff --git a/datafusion/physical-plan/src/memory.rs 
b/datafusion/physical-plan/src/memory.rs
index 3aa445d295..456f0ef2dc 100644
--- a/datafusion/physical-plan/src/memory.rs
+++ b/datafusion/physical-plan/src/memory.rs
@@ -33,6 +33,9 @@ use arrow::record_batch::RecordBatch;
 use datafusion_common::{internal_err, project_schema, Result};
 use datafusion_execution::memory_pool::MemoryReservation;
 use datafusion_execution::TaskContext;
+use datafusion_physical_expr::equivalence::ProjectionMapping;
+use datafusion_physical_expr::expressions::Column;
+use datafusion_physical_expr::utils::collect_columns;
 use datafusion_physical_expr::{EquivalenceProperties, LexOrdering};
 
 use futures::Stream;
@@ -206,16 +209,63 @@ impl MemoryExec {
     /// where both `a ASC` and `b DESC` can describe the table ordering. With
     /// [`EquivalenceProperties`], we can keep track of these equivalences
     /// and treat `a ASC` and `b DESC` as the same ordering requirement.
-    pub fn with_sort_information(mut self, sort_information: Vec<LexOrdering>) 
-> Self {
-        self.sort_information = sort_information;
+    ///
+    /// Note that if there is an internal projection, that projection will be
+    /// also applied to the given `sort_information`.
+    pub fn try_with_sort_information(
+        mut self,
+        mut sort_information: Vec<LexOrdering>,
+    ) -> Result<Self> {
+        // All sort expressions must refer to the original schema
+        let fields = self.schema.fields();
+        let ambiguous_column = sort_information
+            .iter()
+            .flatten()
+            .flat_map(|expr| collect_columns(&expr.expr))
+            .find(|col| {
+                fields
+                    .get(col.index())
+                    .map(|field| field.name() != col.name())
+                    .unwrap_or(true)
+            });
+        if let Some(col) = ambiguous_column {
+            return internal_err!(
+                "Column {:?} is not found in the original schema of the 
MemoryExec",
+                col
+            );
+        }
+
+        // If there is a projection on the source, we also need to project 
orderings
+        if let Some(projection) = &self.projection {
+            let base_eqp = EquivalenceProperties::new_with_orderings(
+                self.original_schema(),
+                &sort_information,
+            );
+            let proj_exprs = projection
+                .iter()
+                .map(|idx| {
+                    let base_schema = self.original_schema();
+                    let name = base_schema.field(*idx).name();
+                    (Arc::new(Column::new(name, *idx)) as _, name.to_string())
+                })
+                .collect::<Vec<_>>();
+            let projection_mapping =
+                ProjectionMapping::try_new(&proj_exprs, 
&self.original_schema())?;
+            sort_information = base_eqp
+                .project(&projection_mapping, self.schema())
+                .oeq_class
+                .orderings;
+        }
 
+        self.sort_information = sort_information;
         // We need to update equivalence properties when updating sort 
information.
         let eq_properties = EquivalenceProperties::new_with_orderings(
             self.schema(),
             &self.sort_information,
         );
         self.cache = self.cache.with_eq_properties(eq_properties);
-        self
+
+        Ok(self)
     }
 
     pub fn original_schema(&self) -> SchemaRef {
@@ -347,7 +397,7 @@ mod tests {
 
         let sort_information = vec![sort1.clone(), sort2.clone()];
         let mem_exec = MemoryExec::try_new(&[vec![]], schema, None)?
-            .with_sort_information(sort_information);
+            .try_with_sort_information(sort_information)?;
 
         assert_eq!(
             mem_exec.properties().output_ordering().unwrap(),
diff --git a/datafusion/physical-plan/src/repartition/mod.rs 
b/datafusion/physical-plan/src/repartition/mod.rs
index d9368cf86d..902d9f4477 100644
--- a/datafusion/physical-plan/src/repartition/mod.rs
+++ b/datafusion/physical-plan/src/repartition/mod.rs
@@ -1677,7 +1677,8 @@ mod test {
         Arc::new(
             MemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
                 .unwrap()
-                .with_sort_information(vec![sort_exprs]),
+                .try_with_sort_information(vec![sort_exprs])
+                .unwrap(),
         )
     }
 }
diff --git a/datafusion/physical-plan/src/union.rs 
b/datafusion/physical-plan/src/union.rs
index 1cf22060b6..108e42e7be 100644
--- a/datafusion/physical-plan/src/union.rs
+++ b/datafusion/physical-plan/src/union.rs
@@ -809,11 +809,11 @@ mod tests {
                 .collect::<Vec<_>>();
             let child1 = Arc::new(
                 MemoryExec::try_new(&[], Arc::clone(&schema), None)?
-                    .with_sort_information(first_orderings),
+                    .try_with_sort_information(first_orderings)?,
             );
             let child2 = Arc::new(
                 MemoryExec::try_new(&[], Arc::clone(&schema), None)?
-                    .with_sort_information(second_orderings),
+                    .try_with_sort_information(second_orderings)?,
             );
 
             let mut union_expected_eq = 
EquivalenceProperties::new(Arc::clone(&schema));


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to