alamb commented on code in PR #11357: URL: https://github.com/apache/datafusion/pull/11357#discussion_r1672573026
########## datafusion/expr/src/expr.rs: ########## @@ -1401,6 +1401,41 @@ impl Expr { .expect("traversal is infallable"); } + /// Return all references to columns and their occurrence counts in the expression. + /// + /// # Example + /// ``` + /// # use std::collections::HashMap; + /// # use datafusion_common::Column; + /// # use datafusion_expr::col; + /// // For an expression `a + (b * a)` + /// let expr = col("a") + (col("b") * col("a")); + /// let mut refs = expr.column_refs_counts(); + /// // refs contains "a" and "b" + /// assert_eq!(refs.len(), 2); + /// assert_eq!(*refs.get(&Column::new_unqualified("a")).unwrap(), 2); + /// assert_eq!(*refs.get(&Column::new_unqualified("b")).unwrap(), 1); + /// ``` + pub fn column_refs_counts(&self) -> HashMap<&Column, usize> { + let mut map = HashMap::new(); + self.add_column_ref_counts(&mut map); + map + } + + /// Adds references to all columns and their occurrence counts in the expression to + /// the map. + /// + /// See [`Self::column_refs`] for details Review Comment: ```suggestion /// See [`Self::column_refs_counts`] for details ``` ########## datafusion/optimizer/src/optimize_projections/mod.rs: ########## @@ -472,11 +471,8 @@ fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Project // Count usages (referrals) of each projection expression in its input fields: let mut column_referral_map = HashMap::<&Column, usize>::new(); - for columns in expr.iter().map(|expr| expr.column_refs()) { - for col in columns.into_iter() { - *column_referral_map.entry(col).or_default() += 1; - } - } + expr.iter() Review Comment: I am confirming my understanding of this change in behavior. The old code counts would treat a projection with two expression like `[a + a, b]` as having only a single reference to `a`. After the change, it would correctly identify 2 occurrences of `a. ########## datafusion/optimizer/src/common_subexpr_eliminate.rs: ########## @@ -954,53 +952,113 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> { } unreachable!("Enter mark should paired with node number"); } + + /// Save the current `conditional` status and run `f` with `conditional` set to true. + fn conditionally<F: FnMut(&mut Self) -> Result<()>>( + &mut self, + mut f: F, + ) -> Result<()> { + let conditional = self.conditional; + self.conditional = true; + f(self)?; + self.conditional = conditional; + + Ok(()) + } } impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { type Node = Expr; fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> { - // If an expression can short circuit its children then don't consider its - // children for CSE (https://github.com/apache/arrow-datafusion/issues/8814). - // This means that we don't recurse into its children, but handle the expression - // as a subtree when we calculate its identifier. - // TODO: consider surely executed children of "short circuited"s for CSE - let is_tree = expr.short_circuits(); - let tnr = if is_tree { - TreeNodeRecursion::Jump - } else { - TreeNodeRecursion::Continue - }; - self.id_array.push((0, None)); self.visit_stack - .push(VisitRecord::EnterMark(self.down_index, is_tree)); + .push(VisitRecord::EnterMark(self.down_index)); self.down_index += 1; - Ok(tnr) + // If an expression can short-circuit then some of its children might not be + // executed so count the occurrence of subexpressions as conditional in all + // children. + Ok(match expr { + // If we are already in a conditionally evaluated subtree then continue + // traversal. + _ if self.conditional => TreeNodeRecursion::Continue, Review Comment: That is a fascinating construct that makes the condition handling uniform 👍 ########## datafusion/optimizer/src/common_subexpr_eliminate.rs: ########## @@ -954,53 +952,113 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> { } unreachable!("Enter mark should paired with node number"); } + + /// Save the current `conditional` status and run `f` with `conditional` set to true. + fn conditionally<F: FnMut(&mut Self) -> Result<()>>( + &mut self, + mut f: F, + ) -> Result<()> { + let conditional = self.conditional; + self.conditional = true; + f(self)?; + self.conditional = conditional; + + Ok(()) + } } impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { type Node = Expr; fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> { - // If an expression can short circuit its children then don't consider its - // children for CSE (https://github.com/apache/arrow-datafusion/issues/8814). - // This means that we don't recurse into its children, but handle the expression - // as a subtree when we calculate its identifier. - // TODO: consider surely executed children of "short circuited"s for CSE - let is_tree = expr.short_circuits(); - let tnr = if is_tree { - TreeNodeRecursion::Jump - } else { - TreeNodeRecursion::Continue - }; - self.id_array.push((0, None)); self.visit_stack - .push(VisitRecord::EnterMark(self.down_index, is_tree)); + .push(VisitRecord::EnterMark(self.down_index)); self.down_index += 1; - Ok(tnr) + // If an expression can short-circuit then some of its children might not be + // executed so count the occurrence of subexpressions as conditional in all + // children. + Ok(match expr { + // If we are already in a conditionally evaluated subtree then continue + // traversal. + _ if self.conditional => TreeNodeRecursion::Continue, + + // In case of `ScalarFunction`s we don't know which children are surely + // executed so start visiting all children conditionally and stop the + // recursion with `TreeNodeRecursion::Jump`. + Expr::ScalarFunction(ScalarFunction { func, args }) + if func.short_circuits() => + { + self.conditionally(|visitor| { + args.iter().try_for_each(|e| e.visit(visitor).map(|_| ())) + })?; + + TreeNodeRecursion::Jump + } + + // In case of `And` and `Or` the first child is surely executed, but we + // account subexpressions as conditional in the second. + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And | Operator::Or, + right, + }) => { + left.visit(self)?; + self.conditionally(|visitor| right.visit(visitor).map(|_| ()))?; + + TreeNodeRecursion::Jump + } + + // In case of `Case` the optional base expression and the first when + // expressions are surely executed, but we account subexpressions as + // conditional in the others. + Expr::Case(Case { + expr, + when_then_expr, + else_expr, + }) => { + expr.iter().try_for_each(|e| e.visit(self).map(|_| ()))?; + when_then_expr.iter().take(1).try_for_each(|(when, then)| { + when.visit(self)?; + self.conditionally(|visitor| then.visit(visitor).map(|_| ())) + })?; + self.conditionally(|visitor| { + when_then_expr.iter().skip(1).try_for_each(|(when, then)| { + when.visit(visitor)?; + then.visit(visitor).map(|_| ()) + })?; + else_expr + .iter() + .try_for_each(|e| e.visit(visitor).map(|_| ())) + })?; + + TreeNodeRecursion::Jump + } + + // In case of non-short-circuit expressions continue the traversal. + _ => TreeNodeRecursion::Continue, + }) } fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> { - let (down_index, is_tree, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark(); + let (down_index, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark(); - let (expr_id, is_valid) = if is_tree { - ( - Identifier::new(expr, true, self.random_state), - !expr.is_volatile()?, - ) - } else { - ( - Identifier::new(expr, false, self.random_state).combine(sub_expr_id), - !expr.is_volatile_node() && sub_expr_is_valid, - ) - }; + let expr_id = + Identifier::new(expr, false, self.random_state).combine(sub_expr_id); + let is_valid = !expr.is_volatile_node() && sub_expr_is_valid; self.id_array[down_index].0 = self.up_index; if is_valid && !self.expr_mask.ignores(expr) { self.id_array[down_index].1 = Some(expr_id); - let count = self.expr_stats.entry(expr_id).or_insert(0); - *count += 1; - if *count > 1 { + let (count, conditional_count) = + self.expr_stats.entry(expr_id).or_insert((0, 0)); + if self.conditional { + *conditional_count += 1; + } else { + *count += 1; + } + if *count > 1 || *count == 1 && *conditional_count > 0 { Review Comment: I personally prefer explict parenthesis to avoid confusion In this case, I think this is the same: ```suggestion if *count > 1 || (*count == 1 && *conditional_count > 0) { ``` ########## datafusion/sqllogictest/test_files/tpch/q14.slt.part: ########## @@ -44,19 +44,20 @@ physical_plan 01)ProjectionExec: expr=[100 * CAST(sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END)@0 AS Float64) / CAST(sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 AS Float64) as promo_revenue] 02)--AggregateExec: mode=Final, gby=[], aggr=[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], projection=[l_extendedprice@1, l_discount@2, p_type@4] -07)------------CoalesceBatchesExec: target_batch_size=8192 -08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 -09)----------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] -10)------------------CoalesceBatchesExec: target_batch_size=8192 -11)--------------------FilterExec: l_shipdate@3 >= 1995-09-01 AND l_shipdate@3 < 1995-10-01 -12)----------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_extendedprice, l_discount, l_shipdate], has_header=false -13)------------CoalesceBatchesExec: target_batch_size=8192 -14)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -15)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -16)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_type], has_header=false +04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +05)--------ProjectionExec: expr=[l_extendedprice@0 * (Some(1),20,0 - l_discount@1) as __common_expr_1, p_type@2 as p_type] Review Comment: its interesting here that this plan shows the evaluation done below the aggregate but the aggregate doesn't seem to reflect that fact (e.g. the aggr expres don't refer to `__common_expr_1` ########## datafusion/sqllogictest/test_files/cse.slt: ########## @@ -171,3 +175,41 @@ logical_plan physical_plan 01)ProjectionExec: expr=[a@0 = random() AND b@1 = 0 as c1, a@0 = random() AND b@1 = 1 as c2, a@0 = 2 + random() OR b@1 = 4 as c3, a@0 = 2 + random() OR b@1 = 5 as c4, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 1 END as c5, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 2 END as c6] 02)--MemoryExec: partitions=1, partition_sizes=[0] + +# Surely only once but also conditionally evaluated expressions +query TT +EXPLAIN SELECT + (a = 1 OR random() = 0) AND a = 1 AS c1, + (a = 2 AND random() = 0) OR a = 2 AS c2, + CASE WHEN a + 3 = 0 THEN a + 3 ELSE 0 END AS c3, + CASE WHEN a + 4 = 0 THEN 0 WHEN a + 4 THEN 0 ELSE 0 END AS c4, + CASE WHEN a + 5 = 0 THEN 0 WHEN random() = 0 THEN a + 5 ELSE 0 END AS c5, + CASE WHEN a + 6 = 0 THEN 0 ELSE a + 6 END AS c6 +FROM t1 +---- +logical_plan +01)Projection: (__common_expr_1 OR random() = Float64(0)) AND __common_expr_1 AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_2 AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Int64(0) WHEN CAST(__common_expr_4 AS Boolean) THEN Int64(0) ELSE Int64(0) END AS c4, CASE WHEN __common_expr_5 = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN __common_expr_5 ELSE Float64(0) END AS c5, CASE WHEN __common_expr_6 = Float64(0) THEN Float64(0) ELSE __common_expr_6 END AS c6 Review Comment: ✅ ########## datafusion/expr/src/expr.rs: ########## @@ -1401,6 +1401,41 @@ impl Expr { .expect("traversal is infallable"); } + /// Return all references to columns and their occurrence counts in the expression. Review Comment: I think this new API makes sense to me as a parallel set of APIs for `column_refs` / `add_column_refs` ########## datafusion/sqllogictest/test_files/select.slt: ########## @@ -1504,21 +1504,25 @@ query TT EXPLAIN SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t; ---- logical_plan -01)Projection: t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x -02)--TableScan: t projection=[x, y] +01)Projection: __common_expr_1 AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND __common_expr_1 AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x +02)--Projection: t.y > Int32(0) AS __common_expr_1, t.x, t.y Review Comment: 👍 I verified that the common expressions do not include the `1 / y` term which can potentially generate a runtime error ########## datafusion/optimizer/src/common_subexpr_eliminate.rs: ########## @@ -901,15 +903,15 @@ struct ExprIdentifierVisitor<'a, 'n> { random_state: &'a RandomState, // a flag to indicate that common expression found found_common: bool, + // if we are in a conditional branch Review Comment: I think it would help to document more what is meant by 'conditional' means -- maybe like this ```suggestion // if we are in a conditional branch. A conditional // branch means that the expression **might** not be executed depending // on the runtime values of other expressions, and thus can not be extracted // as a common expression . ``` ########## datafusion/optimizer/src/common_subexpr_eliminate.rs: ########## @@ -937,14 +935,14 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> { /// information up from children to parents via `visit_stack` during the first, /// visiting traversal and no need to test the expression's validity beforehand with /// an extra traversal). - fn pop_enter_mark(&mut self) -> (usize, bool, Option<Identifier<'n>>, bool) { + fn pop_enter_mark(&mut self) -> (usize, Option<Identifier<'n>>, bool) { Review Comment: Is the information that used to be captured by the 'subtree' boolean flag now kept as part of `ExprStats`? ########## datafusion/sqllogictest/test_files/cse.slt: ########## @@ -171,3 +175,41 @@ logical_plan physical_plan 01)ProjectionExec: expr=[a@0 = random() AND b@1 = 0 as c1, a@0 = random() AND b@1 = 1 as c2, a@0 = 2 + random() OR b@1 = 4 as c3, a@0 = 2 + random() OR b@1 = 5 as c4, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 1 END as c5, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 2 END as c6] 02)--MemoryExec: partitions=1, partition_sizes=[0] + Review Comment: Could we maybe add some negative tests if they aren't already handled For example, I think these should not be CSE'd: ```sql (random() = 0 OR a = 1) AND a = 1 ``` ```sql (random() = 0 AND a = 1) OR a = 1 ``` ```sql CASE WHEN a + 10 = 0 THEN 0 WHEN random() > 0.5 THEN a+10 ELSE 0 END ``` ```sql CASE WHEN random() > 0.5 THEN 0 WHEN a + 10 = 0 THEN 0 ELSE a + 10 END ``` ```sql CASE WHEN a + 10 = 0 THEN 0 WHEN random() > 0.5 WHEN random() > 0.5 THEN a+10 ELSE 0 END ``` ########## datafusion/optimizer/src/common_subexpr_eliminate.rs: ########## @@ -954,53 +952,113 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> { } unreachable!("Enter mark should paired with node number"); } + + /// Save the current `conditional` status and run `f` with `conditional` set to true. + fn conditionally<F: FnMut(&mut Self) -> Result<()>>( + &mut self, + mut f: F, + ) -> Result<()> { + let conditional = self.conditional; + self.conditional = true; + f(self)?; + self.conditional = conditional; + + Ok(()) + } } impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { type Node = Expr; fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> { - // If an expression can short circuit its children then don't consider its - // children for CSE (https://github.com/apache/arrow-datafusion/issues/8814). - // This means that we don't recurse into its children, but handle the expression - // as a subtree when we calculate its identifier. - // TODO: consider surely executed children of "short circuited"s for CSE - let is_tree = expr.short_circuits(); - let tnr = if is_tree { - TreeNodeRecursion::Jump - } else { - TreeNodeRecursion::Continue - }; - self.id_array.push((0, None)); self.visit_stack - .push(VisitRecord::EnterMark(self.down_index, is_tree)); + .push(VisitRecord::EnterMark(self.down_index)); self.down_index += 1; - Ok(tnr) + // If an expression can short-circuit then some of its children might not be + // executed so count the occurrence of subexpressions as conditional in all + // children. + Ok(match expr { + // If we are already in a conditionally evaluated subtree then continue + // traversal. + _ if self.conditional => TreeNodeRecursion::Continue, + + // In case of `ScalarFunction`s we don't know which children are surely + // executed so start visiting all children conditionally and stop the + // recursion with `TreeNodeRecursion::Jump`. + Expr::ScalarFunction(ScalarFunction { func, args }) + if func.short_circuits() => + { + self.conditionally(|visitor| { + args.iter().try_for_each(|e| e.visit(visitor).map(|_| ())) + })?; + + TreeNodeRecursion::Jump + } + + // In case of `And` and `Or` the first child is surely executed, but we + // account subexpressions as conditional in the second. + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And | Operator::Or, + right, + }) => { + left.visit(self)?; + self.conditionally(|visitor| right.visit(visitor).map(|_| ()))?; Review Comment: the use of `conditionally` makes reading this logic quite elegant. Nice work -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org