This is an automated email from the ASF dual-hosted git repository.

liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new 0fc0b0f  feat: add `UnboundPredicate::negate()` (#228)
0fc0b0f is described below

commit 0fc0b0f98ae1acf420b4537b27f5412ce887f8d1
Author: Scott Donnelly <[email protected]>
AuthorDate: Thu Mar 7 02:51:01 2024 +0000

    feat: add `UnboundPredicate::negate()` (#228)
    
    Issue: #150
---
 crates/iceberg/src/expr/mod.rs       |  35 ++++++-
 crates/iceberg/src/expr/predicate.rs | 172 ++++++++++++++++++++++++++++++++---
 crates/iceberg/src/expr/term.rs      |  93 ++++++++++++++++++-
 crates/iceberg/src/spec/datatypes.rs |   2 +-
 crates/iceberg/src/spec/values.rs    |   2 +-
 5 files changed, 284 insertions(+), 20 deletions(-)

diff --git a/crates/iceberg/src/expr/mod.rs b/crates/iceberg/src/expr/mod.rs
index ef3d2a6..c08c836 100644
--- a/crates/iceberg/src/expr/mod.rs
+++ b/crates/iceberg/src/expr/mod.rs
@@ -30,7 +30,7 @@ pub use predicate::*;
 /// The discriminant of this enum is used for determining the type of the 
operator, see
 /// [`PredicateOperator::is_unary`], [`PredicateOperator::is_binary`], 
[`PredicateOperator::is_set`]
 #[allow(missing_docs)]
-#[derive(Debug, Clone, Copy)]
+#[derive(Debug, Clone, Copy, PartialEq)]
 #[repr(u16)]
 pub enum PredicateOperator {
     // Unary operators
@@ -112,6 +112,39 @@ impl PredicateOperator {
     pub fn is_set(self) -> bool {
         (self as u16) > (PredicateOperator::NotStartsWith as u16)
     }
+
+    /// Returns the predicate that is the inverse of self
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// use iceberg::expr::PredicateOperator;
+    /// assert!(PredicateOperator::IsNull.negate() == 
PredicateOperator::NotNull);
+    /// assert!(PredicateOperator::IsNan.negate() == 
PredicateOperator::NotNan);
+    /// assert!(PredicateOperator::LessThan.negate() == 
PredicateOperator::GreaterThanOrEq);
+    /// assert!(PredicateOperator::GreaterThan.negate() == 
PredicateOperator::LessThanOrEq);
+    /// assert!(PredicateOperator::Eq.negate() == PredicateOperator::NotEq);
+    /// assert!(PredicateOperator::In.negate() == PredicateOperator::NotIn);
+    /// assert!(PredicateOperator::StartsWith.negate() == 
PredicateOperator::NotStartsWith);
+    /// ```
+    pub fn negate(self) -> PredicateOperator {
+        match self {
+            PredicateOperator::IsNull => PredicateOperator::NotNull,
+            PredicateOperator::NotNull => PredicateOperator::IsNull,
+            PredicateOperator::IsNan => PredicateOperator::NotNan,
+            PredicateOperator::NotNan => PredicateOperator::IsNan,
+            PredicateOperator::LessThan => PredicateOperator::GreaterThanOrEq,
+            PredicateOperator::LessThanOrEq => PredicateOperator::GreaterThan,
+            PredicateOperator::GreaterThan => PredicateOperator::LessThanOrEq,
+            PredicateOperator::GreaterThanOrEq => PredicateOperator::LessThan,
+            PredicateOperator::Eq => PredicateOperator::NotEq,
+            PredicateOperator::NotEq => PredicateOperator::Eq,
+            PredicateOperator::In => PredicateOperator::NotIn,
+            PredicateOperator::NotIn => PredicateOperator::In,
+            PredicateOperator::StartsWith => PredicateOperator::NotStartsWith,
+            PredicateOperator::NotStartsWith => PredicateOperator::StartsWith,
+        }
+    }
 }
 
 #[cfg(test)]
diff --git a/crates/iceberg/src/expr/predicate.rs 
b/crates/iceberg/src/expr/predicate.rs
index c9c047e..66a3956 100644
--- a/crates/iceberg/src/expr/predicate.rs
+++ b/crates/iceberg/src/expr/predicate.rs
@@ -21,11 +21,13 @@
 
 use crate::expr::{BoundReference, PredicateOperator, Reference};
 use crate::spec::Datum;
+use itertools::Itertools;
 use std::collections::HashSet;
 use std::fmt::{Debug, Display, Formatter};
 use std::ops::Not;
 
 /// Logical expression, such as `AND`, `OR`, `NOT`.
+#[derive(PartialEq)]
 pub struct LogicalExpression<T, const N: usize> {
     inputs: [Box<T>; N],
 }
@@ -54,6 +56,7 @@ impl<T, const N: usize> LogicalExpression<T, N> {
 }
 
 /// Unary predicate, for example, `a IS NULL`.
+#[derive(PartialEq)]
 pub struct UnaryExpression<T> {
     /// Operator of this predicate, must be single operand operator.
     op: PredicateOperator,
@@ -84,6 +87,7 @@ impl<T> UnaryExpression<T> {
 }
 
 /// Binary predicate, for example, `a > 10`.
+#[derive(PartialEq)]
 pub struct BinaryExpression<T> {
     /// Operator of this predicate, must be binary operator, such as `=`, `>`, 
`<`, etc.
     op: PredicateOperator,
@@ -117,6 +121,7 @@ impl<T: Display> Display for BinaryExpression<T> {
 }
 
 /// Set predicates, for example, `a in (1, 2, 3)`.
+#[derive(PartialEq)]
 pub struct SetExpression<T> {
     /// Operator of this predicate, must be set operator, such as `IN`, `NOT 
IN`, etc.
     op: PredicateOperator,
@@ -136,8 +141,22 @@ impl<T: Debug> Debug for SetExpression<T> {
     }
 }
 
+impl<T: Debug> SetExpression<T> {
+    pub(crate) fn new(op: PredicateOperator, term: T, literals: 
HashSet<Datum>) -> Self {
+        Self { op, term, literals }
+    }
+}
+
+impl<T: Display + Debug> Display for SetExpression<T> {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        let mut literal_strs = self.literals.iter().map(|l| format!("{}", l));
+
+        write!(f, "{} {} ({})", self.term, self.op, literal_strs.join(", "))
+    }
+}
+
 /// Unbound predicate expression before binding to a schema.
-#[derive(Debug)]
+#[derive(Debug, PartialEq)]
 pub enum Predicate {
     /// And predicate, for example, `a > 10 AND b < 20`.
     And(LogicalExpression<Predicate, 2>),
@@ -166,23 +185,13 @@ impl Display for Predicate {
                 write!(f, "NOT ({})", expr.inputs()[0])
             }
             Predicate::Unary(expr) => {
-                write!(f, "{}", expr.term)
+                write!(f, "{}", expr)
             }
             Predicate::Binary(expr) => {
-                write!(f, "{} {} {}", expr.term, expr.op, expr.literal)
+                write!(f, "{}", expr)
             }
             Predicate::Set(expr) => {
-                write!(
-                    f,
-                    "{} {} ({})",
-                    expr.term,
-                    expr.op,
-                    expr.literals
-                        .iter()
-                        .map(|l| format!("{:?}", l))
-                        .collect::<Vec<String>>()
-                        .join(", ")
-                )
+                write!(f, "{}", expr)
             }
         }
     }
@@ -230,6 +239,54 @@ impl Predicate {
     pub fn or(self, other: Predicate) -> Predicate {
         Predicate::Or(LogicalExpression::new([Box::new(self), 
Box::new(other)]))
     }
+
+    /// Returns a predicate representing the negation ('NOT') of this one,
+    /// by using inverse predicates rather than wrapping in a `NOT`.
+    /// Used for `NOT` elimination.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// use std::ops::Bound::Unbounded;
+    /// use iceberg::expr::BoundPredicate::Unary;
+    /// use iceberg::expr::{LogicalExpression, Predicate, Reference};
+    /// use iceberg::spec::Datum;
+    /// let expr1 = Reference::new("a").less_than(Datum::long(10));
+    /// let expr2 = 
Reference::new("b").less_than(Datum::long(5)).and(Reference::new("c").less_than(Datum::long(10)));
+    ///
+    /// let result = expr1.negate();
+    /// assert_eq!(&format!("{result}"), "a >= 10");
+    ///
+    /// let result = expr2.negate();
+    /// assert_eq!(&format!("{result}"), "(b >= 5) OR (c >= 10)");
+    /// ```
+    pub fn negate(self) -> Predicate {
+        match self {
+            Predicate::And(expr) => Predicate::Or(LogicalExpression::new(
+                expr.inputs.map(|expr| Box::new(expr.negate())),
+            )),
+            Predicate::Or(expr) => Predicate::And(LogicalExpression::new(
+                expr.inputs.map(|expr| Box::new(expr.negate())),
+            )),
+            Predicate::Not(expr) => {
+                let LogicalExpression { inputs: [input_0] } = expr;
+                *input_0
+            }
+            Predicate::Unary(expr) => {
+                Predicate::Unary(UnaryExpression::new(expr.op.negate(), 
expr.term))
+            }
+            Predicate::Binary(expr) => Predicate::Binary(BinaryExpression::new(
+                expr.op.negate(),
+                expr.term,
+                expr.literal,
+            )),
+            Predicate::Set(expr) => Predicate::Set(SetExpression::new(
+                expr.op.negate(),
+                expr.term,
+                expr.literals,
+            )),
+        }
+    }
 }
 
 impl Not for Predicate {
@@ -271,6 +328,91 @@ pub enum BoundPredicate {
     Unary(UnaryExpression<BoundReference>),
     /// Binary expression, for example, `a > 10`.
     Binary(BinaryExpression<BoundReference>),
-    /// Set predicates, for example, `a in (1, 2, 3)`.
+    /// Set predicates, for example, `a IN (1, 2, 3)`.
     Set(SetExpression<BoundReference>),
 }
+
+#[cfg(test)]
+mod tests {
+    use crate::expr::Reference;
+    use crate::spec::Datum;
+    use std::collections::HashSet;
+    use std::ops::Not;
+
+    #[test]
+    fn test_predicate_negate_and() {
+        let expression = Reference::new("b")
+            .less_than(Datum::long(5))
+            .and(Reference::new("c").less_than(Datum::long(10)));
+
+        let expected = Reference::new("b")
+            .greater_than_or_equal_to(Datum::long(5))
+            .or(Reference::new("c").greater_than_or_equal_to(Datum::long(10)));
+
+        let result = expression.negate();
+
+        assert_eq!(result, expected);
+    }
+
+    #[test]
+    fn test_predicate_negate_or() {
+        let expression = Reference::new("b")
+            .greater_than_or_equal_to(Datum::long(5))
+            .or(Reference::new("c").greater_than_or_equal_to(Datum::long(10)));
+
+        let expected = Reference::new("b")
+            .less_than(Datum::long(5))
+            .and(Reference::new("c").less_than(Datum::long(10)));
+
+        let result = expression.negate();
+
+        assert_eq!(result, expected);
+    }
+
+    #[test]
+    fn test_predicate_negate_not() {
+        let expression = Reference::new("b")
+            .greater_than_or_equal_to(Datum::long(5))
+            .not();
+
+        let expected = 
Reference::new("b").greater_than_or_equal_to(Datum::long(5));
+
+        let result = expression.negate();
+
+        assert_eq!(result, expected);
+    }
+
+    #[test]
+    fn test_predicate_negate_unary() {
+        let expression = Reference::new("b").is_not_null();
+
+        let expected = Reference::new("b").is_null();
+
+        let result = expression.negate();
+
+        assert_eq!(result, expected);
+    }
+
+    #[test]
+    fn test_predicate_negate_binary() {
+        let expression = Reference::new("a").less_than(Datum::long(5));
+
+        let expected = 
Reference::new("a").greater_than_or_equal_to(Datum::long(5));
+
+        let result = expression.negate();
+
+        assert_eq!(result, expected);
+    }
+
+    #[test]
+    fn test_predicate_negate_set() {
+        let expression = 
Reference::new("a").is_in(HashSet::from([Datum::long(5), Datum::long(6)]));
+
+        let expected =
+            Reference::new("a").is_not_in(HashSet::from([Datum::long(5), 
Datum::long(6)]));
+
+        let result = expression.negate();
+
+        assert_eq!(result, expected);
+    }
+}
diff --git a/crates/iceberg/src/expr/term.rs b/crates/iceberg/src/expr/term.rs
index a4338a3..6be502f 100644
--- a/crates/iceberg/src/expr/term.rs
+++ b/crates/iceberg/src/expr/term.rs
@@ -17,8 +17,9 @@
 
 //! Term definition.
 
-use crate::expr::{BinaryExpression, Predicate, PredicateOperator};
+use crate::expr::{BinaryExpression, Predicate, PredicateOperator, 
SetExpression, UnaryExpression};
 use crate::spec::{Datum, NestedField, NestedFieldRef};
+use std::collections::HashSet;
 use std::fmt::{Display, Formatter};
 
 /// Unbound term before binding to a schema.
@@ -26,7 +27,7 @@ pub type Term = Reference;
 
 /// A named reference in an unbound expression.
 /// For example, `a` in `a > 10`.
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, PartialEq)]
 pub struct Reference {
     name: String,
 }
@@ -63,6 +64,94 @@ impl Reference {
             datum,
         ))
     }
+
+    /// Creates a greater-than-or-equal-to than expression. For example, `a >= 
10`.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    ///
+    /// use iceberg::expr::Reference;
+    /// use iceberg::spec::Datum;
+    /// let expr = 
Reference::new("a").greater_than_or_equal_to(Datum::long(10));
+    ///
+    /// assert_eq!(&format!("{expr}"), "a >= 10");
+    /// ```
+    pub fn greater_than_or_equal_to(self, datum: Datum) -> Predicate {
+        Predicate::Binary(BinaryExpression::new(
+            PredicateOperator::GreaterThanOrEq,
+            self,
+            datum,
+        ))
+    }
+
+    /// Creates an is-null expression. For example, `a IS NULL`.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    ///
+    /// use iceberg::expr::Reference;
+    /// use iceberg::spec::Datum;
+    /// let expr = Reference::new("a").is_null();
+    ///
+    /// assert_eq!(&format!("{expr}"), "a IS NULL");
+    /// ```
+    pub fn is_null(self) -> Predicate {
+        Predicate::Unary(UnaryExpression::new(PredicateOperator::IsNull, self))
+    }
+
+    /// Creates an is-not-null expression. For example, `a IS NOT NULL`.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    ///
+    /// use iceberg::expr::Reference;
+    /// use iceberg::spec::Datum;
+    /// let expr = Reference::new("a").is_not_null();
+    ///
+    /// assert_eq!(&format!("{expr}"), "a IS NOT NULL");
+    /// ```
+    pub fn is_not_null(self) -> Predicate {
+        Predicate::Unary(UnaryExpression::new(PredicateOperator::NotNull, 
self))
+    }
+
+    /// Creates an is-in expression. For example, `a IS IN (5, 6)`.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    ///
+    /// use std::collections::HashSet;
+    /// use iceberg::expr::Reference;
+    /// use iceberg::spec::Datum;
+    /// let expr = Reference::new("a").is_in(HashSet::from([Datum::long(5), 
Datum::long(6)]));
+    ///
+    /// let as_string = format!("{expr}");
+    /// assert!(&as_string == "a IN (5, 6)" || &as_string == "a IN (6, 5)");
+    /// ```
+    pub fn is_in(self, literals: HashSet<Datum>) -> Predicate {
+        Predicate::Set(SetExpression::new(PredicateOperator::In, self, 
literals))
+    }
+
+    /// Creates an is-not-in expression. For example, `a IS NOT IN (5, 6)`.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    ///
+    /// use std::collections::HashSet;
+    /// use iceberg::expr::Reference;
+    /// use iceberg::spec::Datum;
+    /// let expr = 
Reference::new("a").is_not_in(HashSet::from([Datum::long(5), Datum::long(6)]));
+    ///
+    /// let as_string = format!("{expr}");
+    /// assert!(&as_string == "a NOT IN (5, 6)" || &as_string == "a NOT IN (6, 
5)");
+    /// ```
+    pub fn is_not_in(self, literals: HashSet<Datum>) -> Predicate {
+        Predicate::Set(SetExpression::new(PredicateOperator::NotIn, self, 
literals))
+    }
 }
 
 impl Display for Reference {
diff --git a/crates/iceberg/src/spec/datatypes.rs 
b/crates/iceberg/src/spec/datatypes.rs
index 636f14e..8f404e9 100644
--- a/crates/iceberg/src/spec/datatypes.rs
+++ b/crates/iceberg/src/spec/datatypes.rs
@@ -162,7 +162,7 @@ impl From<MapType> for Type {
 }
 
 /// Primitive data types
-#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
+#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
 #[serde(rename_all = "lowercase", remote = "Self")]
 pub enum PrimitiveType {
     /// True or False
diff --git a/crates/iceberg/src/spec/values.rs 
b/crates/iceberg/src/spec/values.rs
index f202914..550b48d 100644
--- a/crates/iceberg/src/spec/values.rs
+++ b/crates/iceberg/src/spec/values.rs
@@ -84,7 +84,7 @@ pub enum PrimitiveLiteral {
 ///
 /// By default, we decouple the type and value of a literal, so we can use 
avoid the cost of storing extra type info
 /// for each literal. But associate type with literal can be useful in some 
cases, for example, in unbound expression.
-#[derive(Debug)]
+#[derive(Debug, PartialEq, Hash, Eq)]
 pub struct Datum {
     r#type: PrimitiveType,
     literal: PrimitiveLiteral,

Reply via email to