alamb commented on code in PR #11357:
URL: https://github.com/apache/datafusion/pull/11357#discussion_r1672573026


##########
datafusion/expr/src/expr.rs:
##########
@@ -1401,6 +1401,41 @@ impl Expr {
         .expect("traversal is infallable");
     }
 
+    /// Return all references to columns and their occurrence counts in the 
expression.
+    ///
+    /// # Example
+    /// ```
+    /// # use std::collections::HashMap;
+    /// # use datafusion_common::Column;
+    /// # use datafusion_expr::col;
+    /// // For an expression `a + (b * a)`
+    /// let expr = col("a") + (col("b") * col("a"));
+    /// let mut refs = expr.column_refs_counts();
+    /// // refs contains "a" and "b"
+    /// assert_eq!(refs.len(), 2);
+    /// assert_eq!(*refs.get(&Column::new_unqualified("a")).unwrap(), 2);
+    /// assert_eq!(*refs.get(&Column::new_unqualified("b")).unwrap(), 1);
+    /// ```
+    pub fn column_refs_counts(&self) -> HashMap<&Column, usize> {
+        let mut map = HashMap::new();
+        self.add_column_ref_counts(&mut map);
+        map
+    }
+
+    /// Adds references to all columns and their occurrence counts in the 
expression to
+    /// the map.
+    ///
+    /// See [`Self::column_refs`] for details

Review Comment:
   ```suggestion
       /// See [`Self::column_refs_counts`] for details
   ```



##########
datafusion/optimizer/src/optimize_projections/mod.rs:
##########
@@ -472,11 +471,8 @@ fn merge_consecutive_projections(proj: Projection) -> 
Result<Transformed<Project
 
     // Count usages (referrals) of each projection expression in its input 
fields:
     let mut column_referral_map = HashMap::<&Column, usize>::new();
-    for columns in expr.iter().map(|expr| expr.column_refs()) {
-        for col in columns.into_iter() {
-            *column_referral_map.entry(col).or_default() += 1;
-        }
-    }
+    expr.iter()

Review Comment:
   I am confirming my understanding of this change in behavior. 
   
   The old code counts would treat a projection with two expression like `[a + 
a, b]` as having only a single reference to `a`. After the change, it would 
correctly identify 2 occurrences of `a.



##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -954,53 +952,113 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> {
         }
         unreachable!("Enter mark should paired with node number");
     }
+
+    /// Save the current `conditional` status and run `f` with `conditional` 
set to true.
+    fn conditionally<F: FnMut(&mut Self) -> Result<()>>(
+        &mut self,
+        mut f: F,
+    ) -> Result<()> {
+        let conditional = self.conditional;
+        self.conditional = true;
+        f(self)?;
+        self.conditional = conditional;
+
+        Ok(())
+    }
 }
 
 impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
     type Node = Expr;
 
     fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
-        // If an expression can short circuit its children then don't consider 
its
-        // children for CSE 
(https://github.com/apache/arrow-datafusion/issues/8814).
-        // This means that we don't recurse into its children, but handle the 
expression
-        // as a subtree when we calculate its identifier.
-        // TODO: consider surely executed children of "short circuited"s for 
CSE
-        let is_tree = expr.short_circuits();
-        let tnr = if is_tree {
-            TreeNodeRecursion::Jump
-        } else {
-            TreeNodeRecursion::Continue
-        };
-
         self.id_array.push((0, None));
         self.visit_stack
-            .push(VisitRecord::EnterMark(self.down_index, is_tree));
+            .push(VisitRecord::EnterMark(self.down_index));
         self.down_index += 1;
 
-        Ok(tnr)
+        // If an expression can short-circuit then some of its children might 
not be
+        // executed so count the occurrence of subexpressions as conditional 
in all
+        // children.
+        Ok(match expr {
+            // If we are already in a conditionally evaluated subtree then 
continue
+            // traversal.
+            _ if self.conditional => TreeNodeRecursion::Continue,

Review Comment:
   That is a fascinating construct that makes the condition handling uniform 👍 



##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -954,53 +952,113 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> {
         }
         unreachable!("Enter mark should paired with node number");
     }
+
+    /// Save the current `conditional` status and run `f` with `conditional` 
set to true.
+    fn conditionally<F: FnMut(&mut Self) -> Result<()>>(
+        &mut self,
+        mut f: F,
+    ) -> Result<()> {
+        let conditional = self.conditional;
+        self.conditional = true;
+        f(self)?;
+        self.conditional = conditional;
+
+        Ok(())
+    }
 }
 
 impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
     type Node = Expr;
 
     fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
-        // If an expression can short circuit its children then don't consider 
its
-        // children for CSE 
(https://github.com/apache/arrow-datafusion/issues/8814).
-        // This means that we don't recurse into its children, but handle the 
expression
-        // as a subtree when we calculate its identifier.
-        // TODO: consider surely executed children of "short circuited"s for 
CSE
-        let is_tree = expr.short_circuits();
-        let tnr = if is_tree {
-            TreeNodeRecursion::Jump
-        } else {
-            TreeNodeRecursion::Continue
-        };
-
         self.id_array.push((0, None));
         self.visit_stack
-            .push(VisitRecord::EnterMark(self.down_index, is_tree));
+            .push(VisitRecord::EnterMark(self.down_index));
         self.down_index += 1;
 
-        Ok(tnr)
+        // If an expression can short-circuit then some of its children might 
not be
+        // executed so count the occurrence of subexpressions as conditional 
in all
+        // children.
+        Ok(match expr {
+            // If we are already in a conditionally evaluated subtree then 
continue
+            // traversal.
+            _ if self.conditional => TreeNodeRecursion::Continue,
+
+            // In case of `ScalarFunction`s we don't know which children are 
surely
+            // executed so start visiting all children conditionally and stop 
the
+            // recursion with `TreeNodeRecursion::Jump`.
+            Expr::ScalarFunction(ScalarFunction { func, args })
+                if func.short_circuits() =>
+            {
+                self.conditionally(|visitor| {
+                    args.iter().try_for_each(|e| e.visit(visitor).map(|_| ()))
+                })?;
+
+                TreeNodeRecursion::Jump
+            }
+
+            // In case of `And` and `Or` the first child is surely executed, 
but we
+            // account subexpressions as conditional in the second.
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: Operator::And | Operator::Or,
+                right,
+            }) => {
+                left.visit(self)?;
+                self.conditionally(|visitor| right.visit(visitor).map(|_| 
()))?;
+
+                TreeNodeRecursion::Jump
+            }
+
+            // In case of `Case` the optional base expression and the first 
when
+            // expressions are surely executed, but we account subexpressions 
as
+            // conditional in the others.
+            Expr::Case(Case {
+                expr,
+                when_then_expr,
+                else_expr,
+            }) => {
+                expr.iter().try_for_each(|e| e.visit(self).map(|_| ()))?;
+                when_then_expr.iter().take(1).try_for_each(|(when, then)| {
+                    when.visit(self)?;
+                    self.conditionally(|visitor| then.visit(visitor).map(|_| 
()))
+                })?;
+                self.conditionally(|visitor| {
+                    when_then_expr.iter().skip(1).try_for_each(|(when, then)| {
+                        when.visit(visitor)?;
+                        then.visit(visitor).map(|_| ())
+                    })?;
+                    else_expr
+                        .iter()
+                        .try_for_each(|e| e.visit(visitor).map(|_| ()))
+                })?;
+
+                TreeNodeRecursion::Jump
+            }
+
+            // In case of non-short-circuit expressions continue the traversal.
+            _ => TreeNodeRecursion::Continue,
+        })
     }
 
     fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
-        let (down_index, is_tree, sub_expr_id, sub_expr_is_valid) = 
self.pop_enter_mark();
+        let (down_index, sub_expr_id, sub_expr_is_valid) = 
self.pop_enter_mark();
 
-        let (expr_id, is_valid) = if is_tree {
-            (
-                Identifier::new(expr, true, self.random_state),
-                !expr.is_volatile()?,
-            )
-        } else {
-            (
-                Identifier::new(expr, false, 
self.random_state).combine(sub_expr_id),
-                !expr.is_volatile_node() && sub_expr_is_valid,
-            )
-        };
+        let expr_id =
+            Identifier::new(expr, false, 
self.random_state).combine(sub_expr_id);
+        let is_valid = !expr.is_volatile_node() && sub_expr_is_valid;
 
         self.id_array[down_index].0 = self.up_index;
         if is_valid && !self.expr_mask.ignores(expr) {
             self.id_array[down_index].1 = Some(expr_id);
-            let count = self.expr_stats.entry(expr_id).or_insert(0);
-            *count += 1;
-            if *count > 1 {
+            let (count, conditional_count) =
+                self.expr_stats.entry(expr_id).or_insert((0, 0));
+            if self.conditional {
+                *conditional_count += 1;
+            } else {
+                *count += 1;
+            }
+            if *count > 1 || *count == 1 && *conditional_count > 0 {

Review Comment:
   I personally prefer explict parenthesis to avoid confusion
   
   In this case, I think this is the same:
   
   ```suggestion
               if *count > 1 || (*count == 1 && *conditional_count > 0) {
   ```



##########
datafusion/sqllogictest/test_files/tpch/q14.slt.part:
##########
@@ -44,19 +44,20 @@ physical_plan
 01)ProjectionExec: expr=[100 * CAST(sum(CASE WHEN part.p_type LIKE 
Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount 
ELSE Int64(0) END)@0 AS Float64) / CAST(sum(lineitem.l_extendedprice * Int64(1) 
- lineitem.l_discount)@1 AS Float64) as promo_revenue]
 02)--AggregateExec: mode=Final, gby=[], aggr=[sum(CASE WHEN part.p_type LIKE 
Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount 
ELSE Int64(0) END), sum(lineitem.l_extendedprice * Int64(1) - 
lineitem.l_discount)]
 03)----CoalescePartitionsExec
-04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(CASE WHEN part.p_type 
LIKE Utf8("PROMO%")  THEN lineitem.l_extendedprice * Int64(1) - 
lineitem.l_discount ELSE Int64(0) END), sum(lineitem.l_extendedprice * Int64(1) 
- lineitem.l_discount)]
-05)--------CoalesceBatchesExec: target_batch_size=8192
-06)----------HashJoinExec: mode=Partitioned, join_type=Inner, 
on=[(l_partkey@0, p_partkey@0)], projection=[l_extendedprice@1, l_discount@2, 
p_type@4]
-07)------------CoalesceBatchesExec: target_batch_size=8192
-08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), 
input_partitions=4
-09)----------------ProjectionExec: expr=[l_partkey@0 as l_partkey, 
l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount]
-10)------------------CoalesceBatchesExec: target_batch_size=8192
-11)--------------------FilterExec: l_shipdate@3 >= 1995-09-01 AND l_shipdate@3 
< 1995-10-01
-12)----------------------CsvExec: file_groups={4 groups: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749],
 
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498],
 
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247],
 
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]},
 projection=[l_partkey, l_extendedprice, l_discount, l_shipdate], 
has_header=false
-13)------------CoalesceBatchesExec: target_batch_size=8192
-14)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), 
input_partitions=4
-15)----------------RepartitionExec: partitioning=RoundRobinBatch(4), 
input_partitions=1
-16)------------------CsvExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, 
projection=[p_partkey, p_type], has_header=false
+04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(CASE WHEN part.p_type 
LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - 
lineitem.l_discount ELSE Int64(0) END), sum(lineitem.l_extendedprice * Int64(1) 
- lineitem.l_discount)]
+05)--------ProjectionExec: expr=[l_extendedprice@0 * (Some(1),20,0 - 
l_discount@1) as __common_expr_1, p_type@2 as p_type]

Review Comment:
   its interesting here that this plan shows the evaluation done below the 
aggregate but the aggregate doesn't seem to reflect that fact (e.g. the aggr 
expres don't refer to `__common_expr_1`



##########
datafusion/sqllogictest/test_files/cse.slt:
##########
@@ -171,3 +175,41 @@ logical_plan
 physical_plan
 01)ProjectionExec: expr=[a@0 = random() AND b@1 = 0 as c1, a@0 = random() AND 
b@1 = 1 as c2, a@0 = 2 + random() OR b@1 = 4 as c3, a@0 = 2 + random() OR b@1 = 
5 as c4, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 1 END as c5, CASE WHEN a@0 = 
4 + random() THEN 0 ELSE 2 END as c6]
 02)--MemoryExec: partitions=1, partition_sizes=[0]
+
+# Surely only once but also conditionally evaluated expressions
+query TT
+EXPLAIN SELECT
+    (a = 1 OR random() = 0) AND a = 1 AS c1,
+    (a = 2 AND random() = 0) OR a = 2 AS c2,
+    CASE WHEN a + 3 = 0 THEN a + 3 ELSE 0 END AS c3,
+    CASE WHEN a + 4 = 0 THEN 0 WHEN a + 4 THEN 0 ELSE 0 END AS c4,
+    CASE WHEN a + 5 = 0 THEN 0 WHEN random() = 0 THEN a + 5 ELSE 0 END AS c5,
+    CASE WHEN a + 6 = 0 THEN 0 ELSE a + 6 END AS c6
+FROM t1
+----
+logical_plan
+01)Projection: (__common_expr_1 OR random() = Float64(0)) AND __common_expr_1 
AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_2 AS c2, CASE 
WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 ELSE Float64(0) END AS 
c3, CASE WHEN __common_expr_4 = Float64(0) THEN Int64(0) WHEN 
CAST(__common_expr_4 AS Boolean) THEN Int64(0) ELSE Int64(0) END AS c4, CASE 
WHEN __common_expr_5 = Float64(0) THEN Float64(0) WHEN random() = Float64(0) 
THEN __common_expr_5 ELSE Float64(0) END AS c5, CASE WHEN __common_expr_6 = 
Float64(0) THEN Float64(0) ELSE __common_expr_6 END AS c6

Review Comment:
   ✅ 



##########
datafusion/expr/src/expr.rs:
##########
@@ -1401,6 +1401,41 @@ impl Expr {
         .expect("traversal is infallable");
     }
 
+    /// Return all references to columns and their occurrence counts in the 
expression.

Review Comment:
   I think this new API makes sense to me as a parallel set of APIs for 
`column_refs` / `add_column_refs`



##########
datafusion/sqllogictest/test_files/select.slt:
##########
@@ -1504,21 +1504,25 @@ query TT
 EXPLAIN SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t;
 ----
 logical_plan
-01)Projection: t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS 
t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND t.y > Int32(0) 
AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > 
Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x
-02)--TableScan: t projection=[x, y]
+01)Projection: __common_expr_1 AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS 
t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND 
__common_expr_1 AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS 
Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x
+02)--Projection: t.y > Int32(0) AS __common_expr_1, t.x, t.y

Review Comment:
   👍  I verified that the common expressions do not include the `1 / y` term 
which can potentially generate a runtime error



##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -901,15 +903,15 @@ struct ExprIdentifierVisitor<'a, 'n> {
     random_state: &'a RandomState,
     // a flag to indicate that common expression found
     found_common: bool,
+    // if we are in a conditional branch

Review Comment:
   I think it would help to document more what is meant by 'conditional' means 
-- maybe like this
   
   ```suggestion
       // if we are in a conditional branch. A conditional
       // branch means that the expression **might** not be executed depending
       // on the runtime values of other expressions, and thus can not be 
extracted 
       // as a common expression . 
   ```
   
   



##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -937,14 +935,14 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> {
     ///   information up from children to parents via `visit_stack` during the 
first,
     ///   visiting traversal and no need to test the expression's validity 
beforehand with
     ///   an extra traversal).
-    fn pop_enter_mark(&mut self) -> (usize, bool, Option<Identifier<'n>>, 
bool) {
+    fn pop_enter_mark(&mut self) -> (usize, Option<Identifier<'n>>, bool) {

Review Comment:
   Is the information that used to be captured by the 'subtree' boolean flag 
now kept as part of  `ExprStats`?



##########
datafusion/sqllogictest/test_files/cse.slt:
##########
@@ -171,3 +175,41 @@ logical_plan
 physical_plan
 01)ProjectionExec: expr=[a@0 = random() AND b@1 = 0 as c1, a@0 = random() AND 
b@1 = 1 as c2, a@0 = 2 + random() OR b@1 = 4 as c3, a@0 = 2 + random() OR b@1 = 
5 as c4, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 1 END as c5, CASE WHEN a@0 = 
4 + random() THEN 0 ELSE 2 END as c6]
 02)--MemoryExec: partitions=1, partition_sizes=[0]
+

Review Comment:
   Could we maybe add some negative tests if they aren't already handled
   
   For example, I think these should not be CSE'd:
   
   ```sql
       (random() = 0 OR a = 1) AND a = 1
   ```
   
   ```sql
       (random() = 0 AND a = 1) OR a = 1
   ```
   
   ```sql
       CASE 
         WHEN a + 10 = 0 THEN 0 
         WHEN random() > 0.5 THEN a+10 
         ELSE 0
       END
   ```
   
   ```sql
       CASE 
         WHEN random() > 0.5 THEN 0
         WHEN a + 10 = 0 THEN 0 
         ELSE a + 10
       END
   ```
   
   ```sql
       CASE 
         WHEN a + 10 = 0 THEN 0 
         WHEN random() > 0.5 
         WHEN random() > 0.5 THEN a+10 
         ELSE 0 
   END
   ```



##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -954,53 +952,113 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> {
         }
         unreachable!("Enter mark should paired with node number");
     }
+
+    /// Save the current `conditional` status and run `f` with `conditional` 
set to true.
+    fn conditionally<F: FnMut(&mut Self) -> Result<()>>(
+        &mut self,
+        mut f: F,
+    ) -> Result<()> {
+        let conditional = self.conditional;
+        self.conditional = true;
+        f(self)?;
+        self.conditional = conditional;
+
+        Ok(())
+    }
 }
 
 impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
     type Node = Expr;
 
     fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
-        // If an expression can short circuit its children then don't consider 
its
-        // children for CSE 
(https://github.com/apache/arrow-datafusion/issues/8814).
-        // This means that we don't recurse into its children, but handle the 
expression
-        // as a subtree when we calculate its identifier.
-        // TODO: consider surely executed children of "short circuited"s for 
CSE
-        let is_tree = expr.short_circuits();
-        let tnr = if is_tree {
-            TreeNodeRecursion::Jump
-        } else {
-            TreeNodeRecursion::Continue
-        };
-
         self.id_array.push((0, None));
         self.visit_stack
-            .push(VisitRecord::EnterMark(self.down_index, is_tree));
+            .push(VisitRecord::EnterMark(self.down_index));
         self.down_index += 1;
 
-        Ok(tnr)
+        // If an expression can short-circuit then some of its children might 
not be
+        // executed so count the occurrence of subexpressions as conditional 
in all
+        // children.
+        Ok(match expr {
+            // If we are already in a conditionally evaluated subtree then 
continue
+            // traversal.
+            _ if self.conditional => TreeNodeRecursion::Continue,
+
+            // In case of `ScalarFunction`s we don't know which children are 
surely
+            // executed so start visiting all children conditionally and stop 
the
+            // recursion with `TreeNodeRecursion::Jump`.
+            Expr::ScalarFunction(ScalarFunction { func, args })
+                if func.short_circuits() =>
+            {
+                self.conditionally(|visitor| {
+                    args.iter().try_for_each(|e| e.visit(visitor).map(|_| ()))
+                })?;
+
+                TreeNodeRecursion::Jump
+            }
+
+            // In case of `And` and `Or` the first child is surely executed, 
but we
+            // account subexpressions as conditional in the second.
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: Operator::And | Operator::Or,
+                right,
+            }) => {
+                left.visit(self)?;
+                self.conditionally(|visitor| right.visit(visitor).map(|_| 
()))?;

Review Comment:
   the use of `conditionally` makes reading this logic quite elegant. Nice work



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to