alamb commented on code in PR #3903:
URL: https://github.com/apache/arrow-datafusion/pull/3903#discussion_r1005869701
##########
datafusion/optimizer/src/utils.rs:
##########
@@ -99,27 +99,176 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs:
Vec<&'a Expr>) -> Vec<&
/// assert_eq!(split_conjunction_owned(expr), split);
/// ```
pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
- split_conjunction_owned_impl(expr, vec![])
+ split_binary_owned(expr, Operator::And)
}
-fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec<Expr>) -> Vec<Expr>
{
+/// Splits an owned binary operator tree [`Expr`] such as `A <OP> B <OP> C` =>
`[A, B, C]`
+///
+/// This is often used to "split" expressions such as `col1 = 5
+/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
+///
+/// # Example
+/// ```
+/// # use datafusion_expr::{col, lit, Operator};
+/// # use datafusion_optimizer::utils::split_binary_owned;
+/// # use std::ops::Add;
+/// // a=1 + b=2
+/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2)));
+///
+/// // [a=1, b=2]
+/// let split = vec![
+/// col("a").eq(lit(1)),
+/// col("b").eq(lit(2)),
+/// ];
+///
+/// // use split_binary_owned to split them
+/// assert_eq!(split_binary_owned(expr, Operator::Plus), split);
+/// ```
+pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
+ split_binary_owned_impl(expr, op, vec![])
+}
+
+fn split_binary_owned_impl(
+ expr: Expr,
+ operator: Operator,
+ mut exprs: Vec<Expr>,
+) -> Vec<Expr> {
match expr {
- Expr::BinaryExpr(BinaryExpr {
- right,
- op: Operator::And,
- left,
- }) => {
- let exprs = split_conjunction_owned_impl(*left, exprs);
- split_conjunction_owned_impl(*right, exprs)
+ Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
+ let exprs = split_binary_owned_impl(*left, operator, exprs);
+ split_binary_owned_impl(*right, operator, exprs)
}
- Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, exprs),
+ Expr::Alias(expr, _) => split_binary_owned_impl(*expr, operator,
exprs),
other => {
exprs.push(other);
exprs
}
}
}
+/// Splits an binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A,
B, C]`
+///
+/// See [`split_binary_owned`] for more details and an example.
+pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
+ split_binary_impl(expr, op, vec![])
+}
+
+fn split_binary_impl<'a>(
+ expr: &'a Expr,
+ operator: Operator,
+ mut exprs: Vec<&'a Expr>,
+) -> Vec<&'a Expr> {
+ match expr {
+ Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator =>
{
+ let exprs = split_binary_impl(left, operator, exprs);
+ split_binary_impl(right, operator, exprs)
+ }
+ Expr::Alias(expr, _) => split_binary_impl(expr, operator, exprs),
+ other => {
+ exprs.push(other);
+ exprs
+ }
+ }
+}
+
+/// Given a list of lists of [`Expr`]s, returns a list of lists of
+/// [`Expr`]s of expressions where there is one expression from each
+/// from each of the input expressions
+///
+/// For example, given the input `[[a, b], [c], [d, e]]` returns
+/// `[a, c, d], [a, c, e], [b, c, d], [b, c, e]]`.
+fn permutations(mut exprs: VecDeque<Vec<&Expr>>) -> Vec<Vec<&Expr>> {
+ let first = if let Some(first) = exprs.pop_front() {
+ first
+ } else {
+ return vec![];
+ };
+
+ // base case:
+ if exprs.is_empty() {
+ first.into_iter().map(|e| vec![e]).collect()
+ } else {
+ first
+ .iter()
+ .flat_map(|expr| {
+ permutations(exprs.clone())
+ .into_iter()
+ .map(|expr_list| {
+ // Create [expr, ...] for each permutation
+ std::iter::once(expr.clone())
+ .chain(expr_list.into_iter())
+ .collect::<Vec<&Expr>>()
+ })
+ .collect::<Vec<Vec<&Expr>>>()
+ })
+ .collect()
+ }
+}
+
+const MAX_CNF_REWRITE_CONJUNCTS: usize = 10;
+
+/// Tries to convert an expression to conjunctive normal form (CNF).
+///
+/// Does not convert the expression if the total number of conjuncts
+/// (exprs ANDed together) would exceed [`MAX_CNF_REWRITE_CONJUNCTS`].
+///
+/// The following expression is in CNF:
+/// `(a OR b) AND (c OR d)`
+///
+/// The following is not in CNF:
+/// `(a AND b) OR c`.
+///
+/// But could be rewrite to a CNF expression:
+/// `(a OR c) AND (b OR c)`.
+///
+///
+/// # Example
+/// ```
+/// # use datafusion_expr::{col, lit};
+/// # use datafusion_optimizer::utils::cnf_rewrite;
+/// // (a=1 AND b=2)OR c = 3
+/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
+/// let expr2 = col("c").eq(lit(3));
+/// let expr = expr1.or(expr2);
+///
+/// //(a=1 or c=3)AND(b=2 or c=3)
+/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3)));
+/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3)));
+/// let expect = expr1.and(expr2);
+/// assert_eq!(expect, cnf_rewrite(expr));
+/// ```
+pub fn cnf_rewrite(expr: Expr) -> Expr {
+ // Find all exprs joined by OR
+ let disjuncts = split_binary(&expr, Operator::Or);
+
+ // For each expr, split now on AND
+ // A OR B OR C --> split each A, B and C
+ let disjunct_conjuncts: VecDeque<Vec<&Expr>> = disjuncts
+ .into_iter()
+ .map(|e| split_binary(e, Operator::And))
+ .collect::<VecDeque<_>>();
+
+ // Decide if we want to distribute the clauses. Heuristic is
+ // chosen to avoid creating huge predicates
+ let num_conjuncts = disjunct_conjuncts
+ .iter()
+ .fold(1usize, |sz, exprs| sz.saturating_mul(exprs.len()));
+
+ if disjunct_conjuncts.iter().any(|exprs| exprs.len() > 1)
+ && num_conjuncts < MAX_CNF_REWRITE_CONJUNCTS
+ {
+ let or_clauses = permutations(disjunct_conjuncts)
+ .into_iter()
+ // form the OR clauses( A OR B OR C ..)
+ .map(|exprs| disjunction(exprs.into_iter().cloned()).unwrap());
+ conjunction(or_clauses).unwrap()
+ }
+ // otherwise return the original expression
+ else {
+ expr
+ }
+}
Review Comment:
@Ted-Jiang I got nerdsniped working on this PR -- I hope you don't mind I
rewrote your logic into something that avoided cloning unless it is actually
doing the rewrite
##########
datafusion/optimizer/src/utils.rs:
##########
@@ -655,4 +822,135 @@ mod tests {
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
)
}
+
+ #[test]
+ fn test_permutations() {
+ assert_eq!(make_permutations(vec![]), vec![] as Vec<Vec<Expr>>)
+ }
+
+ #[test]
+ fn test_permutations_one() {
+ // [[a]] --> [[a]]
+ assert_eq!(
+ make_permutations(vec![vec![col("a")]]),
+ vec![vec![col("a")]]
+ )
+ }
+
+ #[test]
+ fn test_permutations_two() {
+ // [[a, b]] --> [[a], [b]]
+ assert_eq!(
+ make_permutations(vec![vec![col("a"), col("b")]]),
+ vec![vec![col("a")], vec![col("b")]]
+ )
+ }
+
+ #[test]
+ fn test_permutations_two_and_one() {
+ // [[a, b], [c]] --> [[a, c], [b, c]]
+ assert_eq!(
+ make_permutations(vec![vec![col("a"), col("b")], vec![col("c")]]),
+ vec![vec![col("a"), col("c")], vec![col("b"), col("c")]]
+ )
+ }
+
+ #[test]
+ fn test_permutations_two_and_one_and_two() {
+ // [[a, b], [c], [d, e]] --> [[a, c, d], [a, c, e], [b, c, d], [b, c,
e]]
+ assert_eq!(
+ make_permutations(vec![
+ vec![col("a"), col("b")],
+ vec![col("c")],
+ vec![col("d"), col("e")]
+ ]),
+ vec![
+ vec![col("a"), col("c"), col("d")],
+ vec![col("a"), col("c"), col("e")],
+ vec![col("b"), col("c"), col("d")],
+ vec![col("b"), col("c"), col("e")],
+ ]
+ )
+ }
+
+ /// call permutations with owned `Expr`s for easier testing
+ fn make_permutations(exprs: impl IntoIterator<Item = Vec<Expr>>) ->
Vec<Vec<Expr>> {
+ let exprs = exprs.into_iter().collect::<Vec<_>>();
+
+ let exprs: VecDeque<Vec<&Expr>> = exprs
+ .iter()
+ .map(|exprs| exprs.iter().collect::<Vec<&Expr>>())
+ .collect();
+
+ permutations(exprs)
+ .into_iter()
+ // copy &Expr --> Expr
+ .map(|exprs| exprs.into_iter().cloned().collect())
+ .collect()
+ }
+
+ #[test]
+ fn test_rewrite_cnf() {
Review Comment:
These tests are all the same
##########
datafusion/optimizer/src/utils.rs:
##########
@@ -99,27 +99,176 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs:
Vec<&'a Expr>) -> Vec<&
/// assert_eq!(split_conjunction_owned(expr), split);
/// ```
pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
- split_conjunction_owned_impl(expr, vec![])
+ split_binary_owned(expr, Operator::And)
}
-fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec<Expr>) -> Vec<Expr>
{
+/// Splits an owned binary operator tree [`Expr`] such as `A <OP> B <OP> C` =>
`[A, B, C]`
+///
+/// This is often used to "split" expressions such as `col1 = 5
+/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
+///
+/// # Example
+/// ```
+/// # use datafusion_expr::{col, lit, Operator};
+/// # use datafusion_optimizer::utils::split_binary_owned;
+/// # use std::ops::Add;
+/// // a=1 + b=2
+/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2)));
+///
+/// // [a=1, b=2]
+/// let split = vec![
+/// col("a").eq(lit(1)),
+/// col("b").eq(lit(2)),
+/// ];
+///
+/// // use split_binary_owned to split them
+/// assert_eq!(split_binary_owned(expr, Operator::Plus), split);
+/// ```
+pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
+ split_binary_owned_impl(expr, op, vec![])
+}
+
+fn split_binary_owned_impl(
+ expr: Expr,
+ operator: Operator,
+ mut exprs: Vec<Expr>,
+) -> Vec<Expr> {
match expr {
- Expr::BinaryExpr(BinaryExpr {
- right,
- op: Operator::And,
- left,
- }) => {
- let exprs = split_conjunction_owned_impl(*left, exprs);
- split_conjunction_owned_impl(*right, exprs)
+ Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
+ let exprs = split_binary_owned_impl(*left, operator, exprs);
+ split_binary_owned_impl(*right, operator, exprs)
}
- Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, exprs),
+ Expr::Alias(expr, _) => split_binary_owned_impl(*expr, operator,
exprs),
other => {
exprs.push(other);
exprs
}
}
}
+/// Splits an binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A,
B, C]`
+///
+/// See [`split_binary_owned`] for more details and an example.
+pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
+ split_binary_impl(expr, op, vec![])
+}
+
+fn split_binary_impl<'a>(
+ expr: &'a Expr,
+ operator: Operator,
+ mut exprs: Vec<&'a Expr>,
+) -> Vec<&'a Expr> {
+ match expr {
+ Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator =>
{
+ let exprs = split_binary_impl(left, operator, exprs);
+ split_binary_impl(right, operator, exprs)
+ }
+ Expr::Alias(expr, _) => split_binary_impl(expr, operator, exprs),
+ other => {
+ exprs.push(other);
+ exprs
+ }
+ }
+}
+
+/// Given a list of lists of [`Expr`]s, returns a list of lists of
+/// [`Expr`]s of expressions where there is one expression from each
+/// from each of the input expressions
+///
+/// For example, given the input `[[a, b], [c], [d, e]]` returns
+/// `[a, c, d], [a, c, e], [b, c, d], [b, c, e]]`.
+fn permutations(mut exprs: VecDeque<Vec<&Expr>>) -> Vec<Vec<&Expr>> {
+ let first = if let Some(first) = exprs.pop_front() {
+ first
+ } else {
+ return vec![];
+ };
+
+ // base case:
+ if exprs.is_empty() {
+ first.into_iter().map(|e| vec![e]).collect()
+ } else {
+ first
+ .iter()
+ .flat_map(|expr| {
+ permutations(exprs.clone())
+ .into_iter()
+ .map(|expr_list| {
+ // Create [expr, ...] for each permutation
+ std::iter::once(expr.clone())
+ .chain(expr_list.into_iter())
+ .collect::<Vec<&Expr>>()
+ })
+ .collect::<Vec<Vec<&Expr>>>()
+ })
+ .collect()
+ }
+}
+
+const MAX_CNF_REWRITE_CONJUNCTS: usize = 10;
Review Comment:
By setting limiting the number of exprs to be created, I also avoided having
to explicitly disable the cnf rewrite
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]