This is an automated email from the ASF dual-hosted git repository.

berkay pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b9cef8c590 Preserve constant values across union operations (#13805)
b9cef8c590 is described below

commit b9cef8c59020ab75c47599468c248be2bf6c186a
Author: Goksel Kabadayi <[email protected]>
AuthorDate: Wed Dec 25 17:47:28 2024 +0300

    Preserve constant values across union operations (#13805)
    
    * Add value tracking to ConstExpr for improved union optimization
    
    * Update PartialEq impl
    
    * Minor change
    
    * Add docstring for ConstExpr value
    
    * Improve constant propagation across union partitions
    
    * Add assertion for across_partitions
    
    * fix fmt
    
    * Update properties.rs
    
    * Remove redundant constant removal loop
    
    * Remove unnecessary mut
    
    * Set across_partitions=true when both sides are constant
    
    * Extract and use constant values in filter expressions
    
    * Add initial SLT for constant value tracking across UNION ALL
    
    * Assign values to ConstExpr where possible
    
    * Revert "Set across_partitions=true when both sides are constant"
    
    This reverts commit 3051cd470b0ad4a70cd8bd3518813f5ce0b3a449.
    
    * Temporarily take value from literal
    
    * Lint fixes
    
    * Cargo fmt
    
    * Add get_expr_constant_value
    
    * Make `with_value()` accept optional value
    
    * Add todo
    
    * Move test to union.slt
    
    * Fix changed slt after merge
    
    * Simplify constexpr
    
    * Update properties.rs
    
    ---------
    
    Co-authored-by: berkaysynnada <[email protected]>
---
 datafusion/physical-expr/src/equivalence/class.rs  |  64 ++++--
 datafusion/physical-expr/src/equivalence/mod.rs    |   2 +-
 .../physical-expr/src/equivalence/ordering.rs      |   9 +-
 .../physical-expr/src/equivalence/properties.rs    | 226 +++++++++++++++------
 datafusion/physical-expr/src/lib.rs                |   4 +-
 datafusion/physical-plan/src/filter.rs             |  27 ++-
 datafusion/sqllogictest/test_files/aggregate.slt   |   2 +-
 datafusion/sqllogictest/test_files/union.slt       |  59 ++++++
 8 files changed, 303 insertions(+), 90 deletions(-)

diff --git a/datafusion/physical-expr/src/equivalence/class.rs 
b/datafusion/physical-expr/src/equivalence/class.rs
index 03b3c7761a..9e535a94eb 100644
--- a/datafusion/physical-expr/src/equivalence/class.rs
+++ b/datafusion/physical-expr/src/equivalence/class.rs
@@ -24,7 +24,7 @@ use std::fmt::Display;
 use std::sync::Arc;
 
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
-use datafusion_common::JoinType;
+use datafusion_common::{JoinType, ScalarValue};
 use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
 
 use indexmap::{IndexMap, IndexSet};
@@ -55,13 +55,45 @@ use indexmap::{IndexMap, IndexSet};
 /// // create a constant expression from a physical expression
 /// let const_expr = ConstExpr::from(col);
 /// ```
+// TODO: Consider refactoring the `across_partitions` and `value` fields into 
an enum:
+//
+// ```
+// enum PartitionValues {
+//     Uniform(Option<ScalarValue>),           // Same value across all 
partitions
+//     Heterogeneous(Vec<Option<ScalarValue>>) // Different values per 
partition
+// }
+// ```
+//
+// This would provide more flexible representation of partition values.
+// Note: This is a breaking change for the equivalence API and should be
+// addressed in a separate issue/PR.
 #[derive(Debug, Clone)]
 pub struct ConstExpr {
     /// The  expression that is known to be constant (e.g. a `Column`)
     expr: Arc<dyn PhysicalExpr>,
     /// Does the constant have the same value across all partitions? See
     /// struct docs for more details
-    across_partitions: bool,
+    across_partitions: AcrossPartitions,
+}
+
+#[derive(PartialEq, Clone, Debug)]
+/// Represents whether a constant expression's value is uniform or varies 
across partitions.
+///
+/// The `AcrossPartitions` enum is used to describe the nature of a constant 
expression
+/// in a physical execution plan:
+///
+/// - `Heterogeneous`: The constant expression may have different values for 
different partitions.
+/// - `Uniform(Option<ScalarValue>)`: The constant expression has the same 
value across all partitions,
+///   or is `None` if the value is not specified.
+pub enum AcrossPartitions {
+    Heterogeneous,
+    Uniform(Option<ScalarValue>),
+}
+
+impl Default for AcrossPartitions {
+    fn default() -> Self {
+        Self::Heterogeneous
+    }
 }
 
 impl PartialEq for ConstExpr {
@@ -79,14 +111,14 @@ impl ConstExpr {
         Self {
             expr,
             // By default, assume constant expressions are not same across 
partitions.
-            across_partitions: false,
+            across_partitions: Default::default(),
         }
     }
 
     /// Set the `across_partitions` flag
     ///
     /// See struct docs for more details
-    pub fn with_across_partitions(mut self, across_partitions: bool) -> Self {
+    pub fn with_across_partitions(mut self, across_partitions: 
AcrossPartitions) -> Self {
         self.across_partitions = across_partitions;
         self
     }
@@ -94,8 +126,8 @@ impl ConstExpr {
     /// Is the  expression the same across all partitions?
     ///
     /// See struct docs for more details
-    pub fn across_partitions(&self) -> bool {
-        self.across_partitions
+    pub fn across_partitions(&self) -> AcrossPartitions {
+        self.across_partitions.clone()
     }
 
     pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
@@ -113,7 +145,7 @@ impl ConstExpr {
         let maybe_expr = f(&self.expr);
         maybe_expr.map(|expr| Self {
             expr,
-            across_partitions: self.across_partitions,
+            across_partitions: self.across_partitions.clone(),
         })
     }
 
@@ -143,14 +175,20 @@ impl ConstExpr {
     }
 }
 
-/// Display implementation for `ConstExpr`
-///
-/// Example `c` or `c(across_partitions)`
 impl Display for ConstExpr {
-    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         write!(f, "{}", self.expr)?;
-        if self.across_partitions {
-            write!(f, "(across_partitions)")?;
+        match &self.across_partitions {
+            AcrossPartitions::Heterogeneous => {
+                write!(f, "(heterogeneous)")?;
+            }
+            AcrossPartitions::Uniform(value) => {
+                if let Some(val) = value {
+                    write!(f, "(uniform: {})", val)?;
+                } else {
+                    write!(f, "(uniform: unknown)")?;
+                }
+            }
         }
         Ok(())
     }
diff --git a/datafusion/physical-expr/src/equivalence/mod.rs 
b/datafusion/physical-expr/src/equivalence/mod.rs
index b35d978045..d4c14f7bc8 100644
--- a/datafusion/physical-expr/src/equivalence/mod.rs
+++ b/datafusion/physical-expr/src/equivalence/mod.rs
@@ -27,7 +27,7 @@ mod ordering;
 mod projection;
 mod properties;
 
-pub use class::{ConstExpr, EquivalenceClass, EquivalenceGroup};
+pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, 
EquivalenceGroup};
 pub use ordering::OrderingEquivalenceClass;
 pub use projection::ProjectionMapping;
 pub use properties::{
diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs 
b/datafusion/physical-expr/src/equivalence/ordering.rs
index 06f85b657e..24e2fc7dba 100644
--- a/datafusion/physical-expr/src/equivalence/ordering.rs
+++ b/datafusion/physical-expr/src/equivalence/ordering.rs
@@ -262,7 +262,7 @@ mod tests {
     };
     use crate::expressions::{col, BinaryExpr, Column};
     use crate::utils::tests::TestScalarUDF;
-    use crate::{ConstExpr, PhysicalExpr, PhysicalSortExpr};
+    use crate::{AcrossPartitions, ConstExpr, PhysicalExpr, PhysicalSortExpr};
 
     use arrow::datatypes::{DataType, Field, Schema};
     use arrow_schema::SortOptions;
@@ -583,9 +583,10 @@ mod tests {
             let eq_group = EquivalenceGroup::new(eq_group);
             eq_properties.add_equivalence_group(eq_group);
 
-            let constants = constants
-                .into_iter()
-                .map(|expr| 
ConstExpr::from(expr).with_across_partitions(true));
+            let constants = constants.into_iter().map(|expr| {
+                ConstExpr::from(expr)
+                    .with_across_partitions(AcrossPartitions::Uniform(None))
+            });
             eq_properties = eq_properties.with_constants(constants);
 
             let reqs = convert_to_sort_exprs(&reqs);
diff --git a/datafusion/physical-expr/src/equivalence/properties.rs 
b/datafusion/physical-expr/src/equivalence/properties.rs
index a7f27ab736..c3d4581032 100755
--- a/datafusion/physical-expr/src/equivalence/properties.rs
+++ b/datafusion/physical-expr/src/equivalence/properties.rs
@@ -23,7 +23,7 @@ use std::sync::Arc;
 use std::{fmt, mem};
 
 use super::ordering::collapse_lex_ordering;
-use crate::equivalence::class::const_exprs_contains;
+use crate::equivalence::class::{const_exprs_contains, AcrossPartitions};
 use crate::equivalence::{
     collapse_lex_req, EquivalenceClass, EquivalenceGroup, 
OrderingEquivalenceClass,
     ProjectionMapping,
@@ -120,7 +120,7 @@ use itertools::Itertools;
 ///   PhysicalSortExpr::new_default(col_c).desc(),
 /// ]));
 ///
-/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC, c@2 DESC]], 
const: [b@1]")
+/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC, c@2 DESC]], 
const: [b@1(heterogeneous)]")
 /// ```
 #[derive(Debug, Clone)]
 pub struct EquivalenceProperties {
@@ -217,7 +217,9 @@ impl EquivalenceProperties {
     /// Removes constant expressions that may change across partitions.
     /// This method should be used when data from different partitions are 
merged.
     pub fn clear_per_partition_constants(&mut self) {
-        self.constants.retain(|item| item.across_partitions());
+        self.constants.retain(|item| {
+            matches!(item.across_partitions(), AcrossPartitions::Uniform(_))
+        })
     }
 
     /// Extends this `EquivalenceProperties` by adding the orderings inside the
@@ -257,14 +259,16 @@ impl EquivalenceProperties {
         if self.is_expr_constant(left) {
             // Left expression is constant, add right as constant
             if !const_exprs_contains(&self.constants, right) {
-                self.constants
-                    .push(ConstExpr::from(right).with_across_partitions(true));
+                let const_expr = ConstExpr::from(right)
+                    
.with_across_partitions(self.get_expr_constant_value(left));
+                self.constants.push(const_expr);
             }
         } else if self.is_expr_constant(right) {
             // Right expression is constant, add left as constant
             if !const_exprs_contains(&self.constants, left) {
-                self.constants
-                    .push(ConstExpr::from(left).with_across_partitions(true));
+                let const_expr = ConstExpr::from(left)
+                    
.with_across_partitions(self.get_expr_constant_value(right));
+                self.constants.push(const_expr);
             }
         }
 
@@ -293,30 +297,28 @@ impl EquivalenceProperties {
         mut self,
         constants: impl IntoIterator<Item = ConstExpr>,
     ) -> Self {
-        let (const_exprs, across_partition_flags): (
-            Vec<Arc<dyn PhysicalExpr>>,
-            Vec<bool>,
-        ) = constants
+        let normalized_constants = constants
             .into_iter()
-            .map(|const_expr| {
-                let across_partitions = const_expr.across_partitions();
-                let expr = const_expr.owned_expr();
-                (expr, across_partitions)
+            .filter_map(|c| {
+                let across_partitions = c.across_partitions();
+                let expr = c.owned_expr();
+                let normalized_expr = self.eq_group.normalize_expr(expr);
+
+                if const_exprs_contains(&self.constants, &normalized_expr) {
+                    return None;
+                }
+
+                let const_expr = ConstExpr::from(normalized_expr)
+                    .with_across_partitions(across_partitions);
+
+                Some(const_expr)
             })
-            .unzip();
-        for (expr, across_partitions) in self
-            .eq_group
-            .normalize_exprs(const_exprs)
-            .into_iter()
-            .zip(across_partition_flags)
-        {
-            if !const_exprs_contains(&self.constants, &expr) {
-                let const_expr =
-                    
ConstExpr::from(expr).with_across_partitions(across_partitions);
-                self.constants.push(const_expr);
-            }
-        }
+            .collect::<Vec<_>>();
+
+        // Add all new normalized constants
+        self.constants.extend(normalized_constants);
 
+        // Discover any new orderings based on the constants
         for ordering in self.normalized_oeq_class().iter() {
             if let Err(e) = self.discover_new_orderings(&ordering[0].expr) {
                 log::debug!("error discovering new orderings: {e}");
@@ -551,7 +553,7 @@ impl EquivalenceProperties {
     /// is satisfied based on the orderings within, equivalence classes, and
     /// constant expressions.
     ///
-    /// # Arguments
+    /// # Parameters
     ///
     /// - `req`: A reference to a `PhysicalSortRequirement` for which the 
ordering
     ///   satisfaction check will be done.
@@ -919,7 +921,7 @@ impl EquivalenceProperties {
     /// constants based on the existing constants and the mapping. It ensures
     /// that constants are appropriately propagated through the projection.
     ///
-    /// # Arguments
+    /// # Parameters
     ///
     /// - `mapping`: A reference to a `ProjectionMapping` representing the
     ///   mapping of source expressions to target expressions in the 
projection.
@@ -935,19 +937,31 @@ impl EquivalenceProperties {
             .constants
             .iter()
             .flat_map(|const_expr| {
-                const_expr.map(|expr| self.eq_group.project_expr(mapping, 
expr))
+                const_expr
+                    .map(|expr| self.eq_group.project_expr(mapping, expr))
+                    .map(|projected_expr| {
+                        projected_expr
+                            
.with_across_partitions(const_expr.across_partitions())
+                    })
             })
             .collect::<Vec<_>>();
+
         // Add projection expressions that are known to be constant:
         for (source, target) in mapping.iter() {
             if self.is_expr_constant(source)
                 && !const_exprs_contains(&projected_constants, target)
             {
-                let across_partitions = 
self.is_expr_constant_accross_partitions(source);
-                // Expression evaluates to single value
-                projected_constants.push(
-                    
ConstExpr::from(target).with_across_partitions(across_partitions),
-                );
+                if self.is_expr_constant_accross_partitions(source) {
+                    projected_constants.push(
+                        ConstExpr::from(target)
+                            
.with_across_partitions(self.get_expr_constant_value(source)),
+                    )
+                } else {
+                    projected_constants.push(
+                        ConstExpr::from(target)
+                            
.with_across_partitions(AcrossPartitions::Heterogeneous),
+                    )
+                }
             }
         }
         projected_constants
@@ -1054,7 +1068,7 @@ impl EquivalenceProperties {
     /// This function determines whether the provided expression is constant
     /// based on the known constants.
     ///
-    /// # Arguments
+    /// # Parameters
     ///
     /// - `expr`: A reference to a `Arc<dyn PhysicalExpr>` representing the
     ///   expression to be checked.
@@ -1079,7 +1093,7 @@ impl EquivalenceProperties {
     /// This function determines whether the provided expression is constant
     /// across partitions based on the known constants.
     ///
-    /// # Arguments
+    /// # Parameters
     ///
     /// - `expr`: A reference to a `Arc<dyn PhysicalExpr>` representing the
     ///   expression to be checked.
@@ -1095,18 +1109,57 @@ impl EquivalenceProperties {
         // As an example, assume that we know columns `a` and `b` are constant.
         // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will
         // return `false`.
-        let const_exprs = self.constants.iter().flat_map(|const_expr| {
-            if const_expr.across_partitions() {
-                Some(Arc::clone(const_expr.expr()))
-            } else {
-                None
-            }
-        });
+        let const_exprs = self
+            .constants
+            .iter()
+            .filter_map(|const_expr| {
+                if matches!(
+                    const_expr.across_partitions(),
+                    AcrossPartitions::Uniform { .. }
+                ) {
+                    Some(Arc::clone(const_expr.expr()))
+                } else {
+                    None
+                }
+            })
+            .collect::<Vec<_>>();
         let normalized_constants = self.eq_group.normalize_exprs(const_exprs);
         let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr));
         is_constant_recurse(&normalized_constants, &normalized_expr)
     }
 
+    /// Retrieves the constant value of a given physical expression, if it 
exists.
+    ///
+    /// Normalizes the input expression and checks if it matches any known 
constants
+    /// in the current context. Returns whether the expression has a uniform 
value,
+    /// varies across partitions, or is not constant.
+    ///
+    /// # Parameters
+    /// - `expr`: A reference to the physical expression to evaluate.
+    ///
+    /// # Returns
+    /// - `AcrossPartitions::Uniform(value)`: If the expression has the same 
value across partitions.
+    /// - `AcrossPartitions::Heterogeneous`: If the expression varies across 
partitions.
+    /// - `None`: If the expression is not recognized as constant.
+    pub fn get_expr_constant_value(
+        &self,
+        expr: &Arc<dyn PhysicalExpr>,
+    ) -> AcrossPartitions {
+        let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr));
+
+        if let Some(lit) = normalized_expr.as_any().downcast_ref::<Literal>() {
+            return AcrossPartitions::Uniform(Some(lit.value().clone()));
+        }
+
+        for const_expr in self.constants.iter() {
+            if normalized_expr.eq(const_expr.expr()) {
+                return const_expr.across_partitions();
+            }
+        }
+
+        AcrossPartitions::Heterogeneous
+    }
+
     /// Retrieves the properties for a given physical expression.
     ///
     /// This function constructs an [`ExprProperties`] object for the given
@@ -1282,7 +1335,7 @@ fn update_properties(
 /// This function determines whether the provided expression is constant
 /// based on the known constants.
 ///
-/// # Arguments
+/// # Parameters
 ///
 /// - `constants`: A `&[Arc<dyn PhysicalExpr>]` containing expressions known to
 ///   be a constant.
@@ -1915,7 +1968,7 @@ impl Hash for ExprWrapper {
 /// *all* output partitions, that is the same as being true for all *input*
 /// partitions
 fn calculate_union_binary(
-    mut lhs: EquivalenceProperties,
+    lhs: EquivalenceProperties,
     mut rhs: EquivalenceProperties,
 ) -> Result<EquivalenceProperties> {
     // Harmonize the schema of the rhs with the schema of the lhs (which is 
the accumulator schema):
@@ -1924,26 +1977,34 @@ fn calculate_union_binary(
     }
 
     // First, calculate valid constants for the union. An expression is 
constant
-    // at the output of the union if it is constant in both sides.
-    let constants: Vec<_> = lhs
+    // at the output of the union if it is constant in both sides with 
matching values.
+    let constants = lhs
         .constants()
         .iter()
-        .filter(|const_expr| const_exprs_contains(rhs.constants(), 
const_expr.expr()))
-        .map(|const_expr| {
-            // TODO: When both sides have a constant column, and the actual
-            // constant value is the same, then the output properties could
-            // reflect the constant is valid across all partitions. However we
-            // don't track the actual value that the ConstExpr takes on, so we
-            // can't determine that yet
-            
ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false)
+        .filter_map(|lhs_const| {
+            // Find matching constant expression in RHS
+            rhs.constants()
+                .iter()
+                .find(|rhs_const| rhs_const.expr().eq(lhs_const.expr()))
+                .map(|rhs_const| {
+                    let mut const_expr = 
ConstExpr::new(Arc::clone(lhs_const.expr()));
+
+                    // If both sides have matching constant values, preserve 
the value and set across_partitions=true
+                    if let (
+                        AcrossPartitions::Uniform(Some(lhs_val)),
+                        AcrossPartitions::Uniform(Some(rhs_val)),
+                    ) = (lhs_const.across_partitions(), 
rhs_const.across_partitions())
+                    {
+                        if lhs_val == rhs_val {
+                            const_expr = const_expr.with_across_partitions(
+                                AcrossPartitions::Uniform(Some(lhs_val)),
+                            )
+                        }
+                    }
+                    const_expr
+                })
         })
-        .collect();
-
-    // remove any constants that are shared in both outputs (avoid double 
counting them)
-    for c in &constants {
-        lhs = lhs.remove_constant(c);
-        rhs = rhs.remove_constant(c);
-    }
+        .collect::<Vec<_>>();
 
     // Next, calculate valid orderings for the union by searching for prefixes
     // in both sides.
@@ -2210,6 +2271,7 @@ mod tests {
 
     use arrow::datatypes::{DataType, Field, Schema};
     use arrow_schema::{Fields, TimeUnit};
+    use datafusion_common::ScalarValue;
     use datafusion_expr::Operator;
 
     use datafusion_functions::string::concat;
@@ -4133,4 +4195,40 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn test_union_constant_value_preservation() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int32, true),
+            Field::new("b", DataType::Int32, true),
+        ]));
+
+        let col_a = col("a", &schema)?;
+        let literal_10 = ScalarValue::Int32(Some(10));
+
+        // Create first input with a=10
+        let const_expr1 = ConstExpr::new(Arc::clone(&col_a))
+            
.with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone())));
+        let input1 = EquivalenceProperties::new(Arc::clone(&schema))
+            .with_constants(vec![const_expr1]);
+
+        // Create second input with a=10
+        let const_expr2 = ConstExpr::new(Arc::clone(&col_a))
+            
.with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone())));
+        let input2 = EquivalenceProperties::new(Arc::clone(&schema))
+            .with_constants(vec![const_expr2]);
+
+        // Calculate union properties
+        let union_props = calculate_union(vec![input1, input2], schema)?;
+
+        // Verify column 'a' remains constant with value 10
+        let const_a = &union_props.constants()[0];
+        assert!(const_a.expr().eq(&col_a));
+        assert_eq!(
+            const_a.across_partitions(),
+            AcrossPartitions::Uniform(Some(literal_10))
+        );
+
+        Ok(())
+    }
 }
diff --git a/datafusion/physical-expr/src/lib.rs 
b/datafusion/physical-expr/src/lib.rs
index 405b6bbd69..4c55f4ddba 100644
--- a/datafusion/physical-expr/src/lib.rs
+++ b/datafusion/physical-expr/src/lib.rs
@@ -45,7 +45,9 @@ pub mod execution_props {
 
 pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, NullState};
 pub use analysis::{analyze, AnalysisContext, ExprBoundaries};
-pub use equivalence::{calculate_union, ConstExpr, EquivalenceProperties};
+pub use equivalence::{
+    calculate_union, AcrossPartitions, ConstExpr, EquivalenceProperties,
+};
 pub use partitioning::{Distribution, Partitioning};
 pub use physical_expr::{
     physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal,
diff --git a/datafusion/physical-plan/src/filter.rs 
b/datafusion/physical-plan/src/filter.rs
index 901907cf38..8e7c14f0ba 100644
--- a/datafusion/physical-plan/src/filter.rs
+++ b/datafusion/physical-plan/src/filter.rs
@@ -45,7 +45,8 @@ use datafusion_physical_expr::expressions::BinaryExpr;
 use datafusion_physical_expr::intervals::utils::check_support;
 use datafusion_physical_expr::utils::collect_columns;
 use datafusion_physical_expr::{
-    analyze, split_conjunction, AnalysisContext, ConstExpr, ExprBoundaries, 
PhysicalExpr,
+    analyze, split_conjunction, AcrossPartitions, AnalysisContext, ConstExpr,
+    ExprBoundaries, PhysicalExpr,
 };
 
 use crate::execution_plan::CardinalityEffect;
@@ -218,13 +219,23 @@ impl FilterExec {
                 if binary.op() == &Operator::Eq {
                     // Filter evaluates to single value for all partitions
                     if input_eqs.is_expr_constant(binary.left()) {
+                        let (expr, across_parts) = (
+                            binary.right(),
+                            input_eqs.get_expr_constant_value(binary.right()),
+                        );
                         res_constants.push(
-                            
ConstExpr::from(binary.right()).with_across_partitions(true),
-                        )
+                            ConstExpr::new(Arc::clone(expr))
+                                .with_across_partitions(across_parts),
+                        );
                     } else if input_eqs.is_expr_constant(binary.right()) {
+                        let (expr, across_parts) = (
+                            binary.left(),
+                            input_eqs.get_expr_constant_value(binary.left()),
+                        );
                         res_constants.push(
-                            
ConstExpr::from(binary.left()).with_across_partitions(true),
-                        )
+                            ConstExpr::new(Arc::clone(expr))
+                                .with_across_partitions(across_parts),
+                        );
                     }
                 }
             }
@@ -252,8 +263,12 @@ impl FilterExec {
             .into_iter()
             .filter(|column| 
stats.column_statistics[column.index()].is_singleton())
             .map(|column| {
+                let value = stats.column_statistics[column.index()]
+                    .min_value
+                    .get_value();
                 let expr = Arc::new(column) as _;
-                ConstExpr::new(expr).with_across_partitions(true)
+                ConstExpr::new(expr)
+                    
.with_across_partitions(AcrossPartitions::Uniform(value.cloned()))
             });
         // This is for statistics
         eq_properties = eq_properties.with_constants(constants);
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 94c6297073..cd62e56253 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -259,7 +259,7 @@ logical_plan
 15)------------EmptyRelation
 physical_plan
 01)ProjectionExec: expr=[array_length(array_agg(DISTINCT a.foo)@1) as 
array_length(array_agg(DISTINCT a.foo)), sum(DISTINCT Int64(1))@2 as 
sum(DISTINCT Int64(1))]
-02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], 
aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))]
+02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], 
aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))], ordering_mode=Sorted
 03)----CoalesceBatchesExec: target_batch_size=8192
 04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=5
 05)--------AggregateExec: mode=Partial, gby=[id@0 as id], 
aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))], ordering_mode=Sorted
diff --git a/datafusion/sqllogictest/test_files/union.slt 
b/datafusion/sqllogictest/test_files/union.slt
index d94780744d..7b8992b966 100644
--- a/datafusion/sqllogictest/test_files/union.slt
+++ b/datafusion/sqllogictest/test_files/union.slt
@@ -777,3 +777,62 @@ select make_array(make_array(1)) x UNION ALL SELECT 
make_array(arrow_cast(make_a
 ----
 [[-1]]
 [[1]]
+
+statement ok
+CREATE EXTERNAL TABLE aggregate_test_100 (
+  c1  VARCHAR NOT NULL,
+  c2  TINYINT NOT NULL,
+  c3  SMALLINT NOT NULL,
+  c4  SMALLINT,
+  c5  INT,
+  c6  BIGINT NOT NULL,
+  c7  SMALLINT NOT NULL,
+  c8  INT NOT NULL,
+  c9  BIGINT UNSIGNED NOT NULL,
+  c10 VARCHAR NOT NULL,
+  c11 FLOAT NOT NULL,
+  c12 DOUBLE NOT NULL,
+  c13 VARCHAR NOT NULL
+)
+STORED AS CSV
+LOCATION '../../testing/data/csv/aggregate_test_100.csv'
+OPTIONS ('format.has_header' 'true');
+
+statement ok
+set datafusion.execution.batch_size = 2;
+
+# Constant value tracking across union
+query TT
+explain
+SELECT * FROM(
+(
+    SELECT * FROM aggregate_test_100 WHERE c1='a'
+)
+UNION ALL
+(
+    SELECT * FROM aggregate_test_100 WHERE c1='a'
+))
+ORDER BY c1
+----
+logical_plan
+01)Sort: aggregate_test_100.c1 ASC NULLS LAST
+02)--Union
+03)----Filter: aggregate_test_100.c1 = Utf8("a")
+04)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, 
c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8("a")]
+05)----Filter: aggregate_test_100.c1 = Utf8("a")
+06)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, 
c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8("a")]
+physical_plan
+01)CoalescePartitionsExec
+02)--UnionExec
+03)----CoalesceBatchesExec: target_batch_size=2
+04)------FilterExec: c1@0 = a
+05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+06)----------CsvExec: file_groups={1 group: 
[[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, 
c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true
+07)----CoalesceBatchesExec: target_batch_size=2
+08)------FilterExec: c1@0 = a
+09)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+10)----------CsvExec: file_groups={1 group: 
[[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, 
c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true
+
+# Clean up after the test
+statement ok
+drop table aggregate_test_100;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to