This is an automated email from the ASF dual-hosted git repository.
liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git
The following commit(s) were added to refs/heads/main by this push:
new 83cdff4 feat: Implement binding expression (#231)
83cdff4 is described below
commit 83cdff48f3ff4ab0bc6d2b39ce41a65ace4ee26b
Author: Renjie Liu <[email protected]>
AuthorDate: Tue Mar 12 11:37:39 2024 +0800
feat: Implement binding expression (#231)
* feat: Implement binding expression
---
Cargo.toml | 2 +
crates/iceberg/Cargo.toml | 2 +
crates/iceberg/src/expr/mod.rs | 10 +
crates/iceberg/src/expr/predicate.rs | 481 ++++++++++++++++++++++++++++++++++-
crates/iceberg/src/expr/term.rs | 120 ++++++++-
crates/iceberg/src/spec/datatypes.rs | 9 +
crates/iceberg/src/spec/schema.rs | 52 ++++
crates/iceberg/src/spec/values.rs | 19 +-
8 files changed, 668 insertions(+), 27 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
index c482859..809fc4f 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -36,6 +36,7 @@ rust-version = "1.75.0"
[workspace.dependencies]
anyhow = "1.0.72"
apache-avro = "0.16"
+array-init = "2"
arrow-arith = { version = ">=46" }
arrow-array = { version = ">=46" }
arrow-schema = { version = ">=46" }
@@ -48,6 +49,7 @@ chrono = "0.4"
derive_builder = "0.20.0"
either = "1"
env_logger = "0.11.0"
+fnv = "1"
futures = "0.3"
iceberg = { version = "0.2.0", path = "./crates/iceberg" }
iceberg-catalog-rest = { version = "0.2.0", path = "./crates/catalog/rest" }
diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml
index c872874..5aea856 100644
--- a/crates/iceberg/Cargo.toml
+++ b/crates/iceberg/Cargo.toml
@@ -31,6 +31,7 @@ keywords = ["iceberg"]
[dependencies]
anyhow = { workspace = true }
apache-avro = { workspace = true }
+array-init = { workspace = true }
arrow-arith = { workspace = true }
arrow-array = { workspace = true }
arrow-schema = { workspace = true }
@@ -42,6 +43,7 @@ bytes = { workspace = true }
chrono = { workspace = true }
derive_builder = { workspace = true }
either = { workspace = true }
+fnv = { workspace = true }
futures = { workspace = true }
itertools = { workspace = true }
lazy_static = { workspace = true }
diff --git a/crates/iceberg/src/expr/mod.rs b/crates/iceberg/src/expr/mod.rs
index c08c836..567cf7e 100644
--- a/crates/iceberg/src/expr/mod.rs
+++ b/crates/iceberg/src/expr/mod.rs
@@ -23,6 +23,8 @@ use std::fmt::{Display, Formatter};
pub use term::*;
mod predicate;
+
+use crate::spec::SchemaRef;
pub use predicate::*;
/// Predicate operators used in expressions.
@@ -147,6 +149,14 @@ impl PredicateOperator {
}
}
+/// Bind expression to a schema.
+pub trait Bind {
+ /// The type of the bound result.
+ type Bound;
+ /// Bind an expression to a schema.
+ fn bind(self, schema: SchemaRef, case_sensitive: bool) ->
crate::Result<Self::Bound>;
+}
+
#[cfg(test)]
mod tests {
use crate::expr::PredicateOperator;
diff --git a/crates/iceberg/src/expr/predicate.rs
b/crates/iceberg/src/expr/predicate.rs
index 66a3956..4ab9aae 100644
--- a/crates/iceberg/src/expr/predicate.rs
+++ b/crates/iceberg/src/expr/predicate.rs
@@ -19,13 +19,18 @@
//! Predicate expressions are used to filter data, and evaluates to a boolean
value. For example,
//! `a > 10` is a predicate expression, and it evaluates to `true` if `a` is
greater than `10`,
-use crate::expr::{BoundReference, PredicateOperator, Reference};
-use crate::spec::Datum;
-use itertools::Itertools;
-use std::collections::HashSet;
use std::fmt::{Debug, Display, Formatter};
use std::ops::Not;
+use array_init::array_init;
+use fnv::FnvHashSet;
+use itertools::Itertools;
+
+use crate::error::Result;
+use crate::expr::{Bind, BoundReference, PredicateOperator, Reference};
+use crate::spec::{Datum, SchemaRef};
+use crate::{Error, ErrorKind};
+
/// Logical expression, such as `AND`, `OR`, `NOT`.
#[derive(PartialEq)]
pub struct LogicalExpression<T, const N: usize> {
@@ -55,6 +60,24 @@ impl<T, const N: usize> LogicalExpression<T, N> {
}
}
+impl<T: Bind, const N: usize> Bind for LogicalExpression<T, N>
+where
+ T::Bound: Sized,
+{
+ type Bound = LogicalExpression<T::Bound, N>;
+
+ fn bind(self, schema: SchemaRef, case_sensitive: bool) ->
Result<Self::Bound> {
+ let mut outputs: [Option<Box<T::Bound>>; N] = array_init(|_| None);
+ for (i, input) in self.inputs.into_iter().enumerate() {
+ outputs[i] = Some(Box::new(input.bind(schema.clone(),
case_sensitive)?));
+ }
+
+ // It's safe to use `unwrap` here since they are all `Some`.
+ let bound_inputs =
array_init::from_iter(outputs.into_iter().map(Option::unwrap)).unwrap();
+ Ok(LogicalExpression::new(bound_inputs))
+ }
+}
+
/// Unary predicate, for example, `a IS NULL`.
#[derive(PartialEq)]
pub struct UnaryExpression<T> {
@@ -79,6 +102,15 @@ impl<T: Display> Display for UnaryExpression<T> {
}
}
+impl<T: Bind> Bind for UnaryExpression<T> {
+ type Bound = UnaryExpression<T::Bound>;
+
+ fn bind(self, schema: SchemaRef, case_sensitive: bool) ->
Result<Self::Bound> {
+ let bound_term = self.term.bind(schema, case_sensitive)?;
+ Ok(UnaryExpression::new(self.op, bound_term))
+ }
+}
+
impl<T> UnaryExpression<T> {
pub(crate) fn new(op: PredicateOperator, term: T) -> Self {
debug_assert!(op.is_unary());
@@ -120,6 +152,15 @@ impl<T: Display> Display for BinaryExpression<T> {
}
}
+impl<T: Bind> Bind for BinaryExpression<T> {
+ type Bound = BinaryExpression<T::Bound>;
+
+ fn bind(self, schema: SchemaRef, case_sensitive: bool) ->
Result<Self::Bound> {
+ let bound_term = self.term.bind(schema.clone(), case_sensitive)?;
+ Ok(BinaryExpression::new(self.op, bound_term, self.literal))
+ }
+}
+
/// Set predicates, for example, `a in (1, 2, 3)`.
#[derive(PartialEq)]
pub struct SetExpression<T> {
@@ -128,7 +169,7 @@ pub struct SetExpression<T> {
/// Term of this predicate, for example, `a` in `a in (1, 2, 3)`.
term: T,
/// Literals of this predicate, for example, `(1, 2, 3)` in `a in (1, 2,
3)`.
- literals: HashSet<Datum>,
+ literals: FnvHashSet<Datum>,
}
impl<T: Debug> Debug for SetExpression<T> {
@@ -141,12 +182,22 @@ impl<T: Debug> Debug for SetExpression<T> {
}
}
-impl<T: Debug> SetExpression<T> {
- pub(crate) fn new(op: PredicateOperator, term: T, literals:
HashSet<Datum>) -> Self {
+impl<T> SetExpression<T> {
+ pub(crate) fn new(op: PredicateOperator, term: T, literals:
FnvHashSet<Datum>) -> Self {
+ debug_assert!(op.is_set());
Self { op, term, literals }
}
}
+impl<T: Bind> Bind for SetExpression<T> {
+ type Bound = SetExpression<T::Bound>;
+
+ fn bind(self, schema: SchemaRef, case_sensitive: bool) ->
Result<Self::Bound> {
+ let bound_term = self.term.bind(schema.clone(), case_sensitive)?;
+ Ok(SetExpression::new(self.op, bound_term, self.literals))
+ }
+}
+
impl<T: Display + Debug> Display for SetExpression<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut literal_strs = self.literals.iter().map(|l| format!("{}", l));
@@ -172,6 +223,146 @@ pub enum Predicate {
Set(SetExpression<Reference>),
}
+impl Bind for Predicate {
+ type Bound = BoundPredicate;
+
+ fn bind(self, schema: SchemaRef, case_sensitive: bool) ->
Result<BoundPredicate> {
+ match self {
+ Predicate::And(expr) => {
+ let bound_expr = expr.bind(schema, case_sensitive)?;
+
+ let [left, right] = bound_expr.inputs;
+ Ok(match (left, right) {
+ (_, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => {
+ BoundPredicate::AlwaysFalse
+ }
+ (l, _) if matches!(&*l, &BoundPredicate::AlwaysFalse) => {
+ BoundPredicate::AlwaysFalse
+ }
+ (left, r) if matches!(&*r, &BoundPredicate::AlwaysTrue) =>
*left,
+ (l, right) if matches!(&*l, &BoundPredicate::AlwaysTrue)
=> *right,
+ (left, right) =>
BoundPredicate::And(LogicalExpression::new([left, right])),
+ })
+ }
+ Predicate::Not(expr) => {
+ let bound_expr = expr.bind(schema, case_sensitive)?;
+ let [inner] = bound_expr.inputs;
+ Ok(match inner {
+ e if matches!(&*e, &BoundPredicate::AlwaysTrue) =>
BoundPredicate::AlwaysFalse,
+ e if matches!(&*e, &BoundPredicate::AlwaysFalse) =>
BoundPredicate::AlwaysTrue,
+ e => BoundPredicate::Not(LogicalExpression::new([e])),
+ })
+ }
+ Predicate::Or(expr) => {
+ let bound_expr = expr.bind(schema, case_sensitive)?;
+ let [left, right] = bound_expr.inputs;
+ Ok(match (left, right) {
+ (l, r)
+ if matches!(&*r, &BoundPredicate::AlwaysTrue)
+ || matches!(&*l, &BoundPredicate::AlwaysTrue) =>
+ {
+ BoundPredicate::AlwaysTrue
+ }
+ (left, r) if matches!(&*r, &BoundPredicate::AlwaysFalse)
=> *left,
+ (l, right) if matches!(&*l, &BoundPredicate::AlwaysFalse)
=> *right,
+ (left, right) =>
BoundPredicate::Or(LogicalExpression::new([left, right])),
+ })
+ }
+ Predicate::Unary(expr) => {
+ let bound_expr = expr.bind(schema, case_sensitive)?;
+
+ match &bound_expr.op {
+ &PredicateOperator::IsNull => {
+ if bound_expr.term.field().required {
+ return Ok(BoundPredicate::AlwaysFalse);
+ }
+ }
+ &PredicateOperator::NotNull => {
+ if bound_expr.term.field().required {
+ return Ok(BoundPredicate::AlwaysTrue);
+ }
+ }
+ &PredicateOperator::IsNan | &PredicateOperator::NotNan => {
+ if
!bound_expr.term.field().field_type.is_floating_type() {
+ return Err(Error::new(
+ ErrorKind::DataInvalid,
+ format!(
+ "Expecting floating point type, but found
{}",
+ bound_expr.term.field().field_type
+ ),
+ ));
+ }
+ }
+ op => {
+ return Err(Error::new(
+ ErrorKind::Unexpected,
+ format!("Expecting unary operator, but found
{op}"),
+ ))
+ }
+ }
+
+ Ok(BoundPredicate::Unary(bound_expr))
+ }
+ Predicate::Binary(expr) => {
+ let bound_expr = expr.bind(schema, case_sensitive)?;
+ let bound_literal =
bound_expr.literal.to(&bound_expr.term.field().field_type)?;
+ Ok(BoundPredicate::Binary(BinaryExpression::new(
+ bound_expr.op,
+ bound_expr.term,
+ bound_literal,
+ )))
+ }
+ Predicate::Set(expr) => {
+ let bound_expr = expr.bind(schema, case_sensitive)?;
+ let bound_literals = bound_expr
+ .literals
+ .into_iter()
+ .map(|l| l.to(&bound_expr.term.field().field_type))
+ .collect::<Result<FnvHashSet<Datum>>>()?;
+
+ match &bound_expr.op {
+ &PredicateOperator::In => {
+ if bound_literals.is_empty() {
+ return Ok(BoundPredicate::AlwaysFalse);
+ }
+ if bound_literals.len() == 1 {
+ return
Ok(BoundPredicate::Binary(BinaryExpression::new(
+ PredicateOperator::Eq,
+ bound_expr.term,
+ bound_literals.into_iter().next().unwrap(),
+ )));
+ }
+ }
+ &PredicateOperator::NotIn => {
+ if bound_literals.is_empty() {
+ return Ok(BoundPredicate::AlwaysTrue);
+ }
+ if bound_literals.len() == 1 {
+ return
Ok(BoundPredicate::Binary(BinaryExpression::new(
+ PredicateOperator::NotEq,
+ bound_expr.term,
+ bound_literals.into_iter().next().unwrap(),
+ )));
+ }
+ }
+ op => {
+ return Err(Error::new(
+ ErrorKind::Unexpected,
+ format!("Expecting unary operator,but found {op}"),
+ ))
+ }
+ }
+
+ Ok(BoundPredicate::Set(SetExpression::new(
+ bound_expr.op,
+ bound_expr.term,
+ bound_literals,
+ )))
+ }
+ }
+ }
+}
+
impl Display for Predicate {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
@@ -292,7 +483,11 @@ impl Predicate {
impl Not for Predicate {
type Output = Predicate;
- /// Create a predicate which is the reverse of this predicate. For
example: `NOT (a > 10)`
+ /// Create a predicate which is the reverse of this predicate. For
example: `NOT (a > 10)`.
+ ///
+ /// This is different from [`Predicate::negate()`] since it doesn't
rewrite expression, but
+ /// just adds a `NOT` operator.
+ ///
/// # Example
///
///```rust
@@ -332,12 +527,46 @@ pub enum BoundPredicate {
Set(SetExpression<BoundReference>),
}
+impl Display for BoundPredicate {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ match self {
+ BoundPredicate::AlwaysTrue => {
+ write!(f, "True")
+ }
+ BoundPredicate::AlwaysFalse => {
+ write!(f, "False")
+ }
+ BoundPredicate::And(expr) => {
+ write!(f, "({}) AND ({})", expr.inputs()[0], expr.inputs()[1])
+ }
+ BoundPredicate::Or(expr) => {
+ write!(f, "({}) OR ({})", expr.inputs()[0], expr.inputs()[1])
+ }
+ BoundPredicate::Not(expr) => {
+ write!(f, "NOT ({})", expr.inputs()[0])
+ }
+ BoundPredicate::Unary(expr) => {
+ write!(f, "{}", expr)
+ }
+ BoundPredicate::Binary(expr) => {
+ write!(f, "{}", expr)
+ }
+ BoundPredicate::Set(expr) => {
+ write!(f, "{}", expr)
+ }
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
+ use std::ops::Not;
+ use std::sync::Arc;
+
+ use crate::expr::Bind;
use crate::expr::Reference;
use crate::spec::Datum;
- use std::collections::HashSet;
- use std::ops::Not;
+ use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type};
#[test]
fn test_predicate_negate_and() {
@@ -406,13 +635,239 @@ mod tests {
#[test]
fn test_predicate_negate_set() {
- let expression =
Reference::new("a").is_in(HashSet::from([Datum::long(5), Datum::long(6)]));
+ let expression = Reference::new("a").is_in([Datum::long(5),
Datum::long(6)]);
- let expected =
- Reference::new("a").is_not_in(HashSet::from([Datum::long(5),
Datum::long(6)]));
+ let expected = Reference::new("a").is_not_in([Datum::long(5),
Datum::long(6)]);
let result = expression.negate();
assert_eq!(result, expected);
}
+
+ fn table_schema_simple() -> SchemaRef {
+ Arc::new(
+ Schema::builder()
+ .with_schema_id(1)
+ .with_identifier_field_ids(vec![2])
+ .with_fields(vec![
+ NestedField::optional(1, "foo",
Type::Primitive(PrimitiveType::String)).into(),
+ NestedField::required(2, "bar",
Type::Primitive(PrimitiveType::Int)).into(),
+ NestedField::optional(3, "baz",
Type::Primitive(PrimitiveType::Boolean)).into(),
+ ])
+ .build()
+ .unwrap(),
+ )
+ }
+
+ #[test]
+ fn test_bind_is_null() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("foo").is_null();
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "foo IS NULL");
+ }
+
+ #[test]
+ fn test_bind_is_null_required() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").is_null();
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "False");
+ }
+
+ #[test]
+ fn test_bind_is_not_null() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("foo").is_not_null();
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "foo IS NOT NULL");
+ }
+
+ #[test]
+ fn test_bind_is_not_null_required() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").is_not_null();
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "True");
+ }
+
+ #[test]
+ fn test_bind_less_than() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").less_than(Datum::int(10));
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "bar < 10");
+ }
+
+ #[test]
+ fn test_bind_less_than_wrong_type() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").less_than(Datum::string("abcd"));
+ let bound_expr = expr.bind(schema, true);
+ assert!(bound_expr.is_err());
+ }
+
+ #[test]
+ fn test_bind_greater_than_or_eq() {
+ let schema = table_schema_simple();
+ let expr =
Reference::new("bar").greater_than_or_equal_to(Datum::int(10));
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "bar >= 10");
+ }
+
+ #[test]
+ fn test_bind_greater_than_or_eq_wrong_type() {
+ let schema = table_schema_simple();
+ let expr =
Reference::new("bar").greater_than_or_equal_to(Datum::string("abcd"));
+ let bound_expr = expr.bind(schema, true);
+ assert!(bound_expr.is_err());
+ }
+
+ #[test]
+ fn test_bind_in() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").is_in([Datum::int(10),
Datum::int(20)]);
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "bar IN (20, 10)");
+ }
+
+ #[test]
+ fn test_bind_in_empty() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").is_in(vec![]);
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "False");
+ }
+
+ #[test]
+ fn test_bind_in_one_literal() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").is_in(vec![Datum::int(10)]);
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "bar = 10");
+ }
+
+ #[test]
+ fn test_bind_in_wrong_type() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").is_in(vec![Datum::int(10),
Datum::string("abcd")]);
+ let bound_expr = expr.bind(schema, true);
+ assert!(bound_expr.is_err());
+ }
+
+ #[test]
+ fn test_bind_not_in() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").is_not_in([Datum::int(10),
Datum::int(20)]);
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "bar NOT IN (20, 10)");
+ }
+
+ #[test]
+ fn test_bind_not_in_empty() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").is_not_in(vec![]);
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "True");
+ }
+
+ #[test]
+ fn test_bind_not_in_one_literal() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").is_not_in(vec![Datum::int(10)]);
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "bar != 10");
+ }
+
+ #[test]
+ fn test_bind_not_in_wrong_type() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar").is_not_in([Datum::int(10),
Datum::string("abcd")]);
+ let bound_expr = expr.bind(schema, true);
+ assert!(bound_expr.is_err());
+ }
+
+ #[test]
+ fn test_bind_and() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar")
+ .less_than(Datum::int(10))
+ .and(Reference::new("foo").is_null());
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "(bar < 10) AND (foo IS NULL)");
+ }
+
+ #[test]
+ fn test_bind_and_always_false() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("foo")
+ .less_than(Datum::string("abcd"))
+ .and(Reference::new("bar").is_null());
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "False");
+ }
+
+ #[test]
+ fn test_bind_and_always_true() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("foo")
+ .less_than(Datum::string("abcd"))
+ .and(Reference::new("bar").is_not_null());
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#);
+ }
+
+ #[test]
+ fn test_bind_or() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("bar")
+ .less_than(Datum::int(10))
+ .or(Reference::new("foo").is_null());
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "(bar < 10) OR (foo IS NULL)");
+ }
+
+ #[test]
+ fn test_bind_or_always_true() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("foo")
+ .less_than(Datum::string("abcd"))
+ .or(Reference::new("bar").is_not_null());
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "True");
+ }
+
+ #[test]
+ fn test_bind_or_always_false() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("foo")
+ .less_than(Datum::string("abcd"))
+ .or(Reference::new("bar").is_null());
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#);
+ }
+
+ #[test]
+ fn test_bind_not() {
+ let schema = table_schema_simple();
+ let expr = !Reference::new("bar").less_than(Datum::int(10));
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "NOT (bar < 10)");
+ }
+
+ #[test]
+ fn test_bind_not_always_true() {
+ let schema = table_schema_simple();
+ let expr = !Reference::new("bar").is_not_null();
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), "False");
+ }
+
+ #[test]
+ fn test_bind_not_always_false() {
+ let schema = table_schema_simple();
+ let expr = !Reference::new("bar").is_null();
+ let bound_expr = expr.bind(schema, true).unwrap();
+ assert_eq!(&format!("{bound_expr}"), r#"True"#);
+ }
}
diff --git a/crates/iceberg/src/expr/term.rs b/crates/iceberg/src/expr/term.rs
index 6be502f..e39c97e 100644
--- a/crates/iceberg/src/expr/term.rs
+++ b/crates/iceberg/src/expr/term.rs
@@ -17,11 +17,15 @@
//! Term definition.
-use crate::expr::{BinaryExpression, Predicate, PredicateOperator,
SetExpression, UnaryExpression};
-use crate::spec::{Datum, NestedField, NestedFieldRef};
-use std::collections::HashSet;
use std::fmt::{Display, Formatter};
+use fnv::FnvHashSet;
+
+use crate::expr::Bind;
+use crate::expr::{BinaryExpression, Predicate, PredicateOperator,
SetExpression, UnaryExpression};
+use crate::spec::{Datum, NestedField, NestedFieldRef, SchemaRef};
+use crate::{Error, ErrorKind};
+
/// Unbound term before binding to a schema.
pub type Term = Reference;
@@ -123,16 +127,20 @@ impl Reference {
///
/// ```rust
///
- /// use std::collections::HashSet;
+ /// use fnv::FnvHashSet;
/// use iceberg::expr::Reference;
/// use iceberg::spec::Datum;
- /// let expr = Reference::new("a").is_in(HashSet::from([Datum::long(5),
Datum::long(6)]));
+ /// let expr = Reference::new("a").is_in([Datum::long(5), Datum::long(6)]);
///
/// let as_string = format!("{expr}");
/// assert!(&as_string == "a IN (5, 6)" || &as_string == "a IN (6, 5)");
/// ```
- pub fn is_in(self, literals: HashSet<Datum>) -> Predicate {
- Predicate::Set(SetExpression::new(PredicateOperator::In, self,
literals))
+ pub fn is_in(self, literals: impl IntoIterator<Item = Datum>) -> Predicate
{
+ Predicate::Set(SetExpression::new(
+ PredicateOperator::In,
+ self,
+ FnvHashSet::from_iter(literals),
+ ))
}
/// Creates an is-not-in expression. For example, `a IS NOT IN (5, 6)`.
@@ -141,16 +149,20 @@ impl Reference {
///
/// ```rust
///
- /// use std::collections::HashSet;
+ /// use fnv::FnvHashSet;
/// use iceberg::expr::Reference;
/// use iceberg::spec::Datum;
- /// let expr =
Reference::new("a").is_not_in(HashSet::from([Datum::long(5), Datum::long(6)]));
+ /// let expr = Reference::new("a").is_not_in([Datum::long(5),
Datum::long(6)]);
///
/// let as_string = format!("{expr}");
/// assert!(&as_string == "a NOT IN (5, 6)" || &as_string == "a NOT IN (6,
5)");
/// ```
- pub fn is_not_in(self, literals: HashSet<Datum>) -> Predicate {
- Predicate::Set(SetExpression::new(PredicateOperator::NotIn, self,
literals))
+ pub fn is_not_in(self, literals: impl IntoIterator<Item = Datum>) ->
Predicate {
+ Predicate::Set(SetExpression::new(
+ PredicateOperator::NotIn,
+ self,
+ FnvHashSet::from_iter(literals),
+ ))
}
}
@@ -160,8 +172,28 @@ impl Display for Reference {
}
}
+impl Bind for Reference {
+ type Bound = BoundReference;
+
+ fn bind(self, schema: SchemaRef, case_sensitive: bool) ->
crate::Result<Self::Bound> {
+ let field = if case_sensitive {
+ schema.field_by_name(&self.name)
+ } else {
+ schema.field_by_name_case_insensitive(&self.name)
+ };
+
+ let field = field.ok_or_else(|| {
+ Error::new(
+ ErrorKind::DataInvalid,
+ format!("Field {} not found in schema", self.name),
+ )
+ })?;
+ Ok(BoundReference::new(self.name, field.clone()))
+ }
+}
+
/// A named reference in a bound expression after binding to a schema.
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BoundReference {
// This maybe different from [`name`] filed in [`NestedField`] since this
contains full path.
// For example, if the field is `a.b.c`, then `field.name` is `c`, but
`original_name` is `a.b.c`.
@@ -192,3 +224,67 @@ impl Display for BoundReference {
/// Bound term after binding to a schema.
pub type BoundTerm = BoundReference;
+
+#[cfg(test)]
+mod tests {
+ use std::sync::Arc;
+
+ use crate::expr::{Bind, BoundReference, Reference};
+ use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type};
+
+ fn table_schema_simple() -> SchemaRef {
+ Arc::new(
+ Schema::builder()
+ .with_schema_id(1)
+ .with_identifier_field_ids(vec![2])
+ .with_fields(vec![
+ NestedField::optional(1, "foo",
Type::Primitive(PrimitiveType::String)).into(),
+ NestedField::required(2, "bar",
Type::Primitive(PrimitiveType::Int)).into(),
+ NestedField::optional(3, "baz",
Type::Primitive(PrimitiveType::Boolean)).into(),
+ ])
+ .build()
+ .unwrap(),
+ )
+ }
+
+ #[test]
+ fn test_bind_reference() {
+ let schema = table_schema_simple();
+ let reference = Reference::new("bar").bind(schema, true).unwrap();
+
+ let expected_ref = BoundReference::new(
+ "bar",
+ NestedField::required(2, "bar",
Type::Primitive(PrimitiveType::Int)).into(),
+ );
+
+ assert_eq!(expected_ref, reference);
+ }
+
+ #[test]
+ fn test_bind_reference_case_insensitive() {
+ let schema = table_schema_simple();
+ let reference = Reference::new("BAR").bind(schema, false).unwrap();
+
+ let expected_ref = BoundReference::new(
+ "BAR",
+ NestedField::required(2, "bar",
Type::Primitive(PrimitiveType::Int)).into(),
+ );
+
+ assert_eq!(expected_ref, reference);
+ }
+
+ #[test]
+ fn test_bind_reference_failure() {
+ let schema = table_schema_simple();
+ let result = Reference::new("bar_not_eix").bind(schema, true);
+
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_bind_reference_case_insensitive_failure() {
+ let schema = table_schema_simple();
+ let result = Reference::new("bar_non_exist").bind(schema, false);
+ assert!(result.is_err());
+ }
+}
diff --git a/crates/iceberg/src/spec/datatypes.rs
b/crates/iceberg/src/spec/datatypes.rs
index 8f404e9..6ea4175 100644
--- a/crates/iceberg/src/spec/datatypes.rs
+++ b/crates/iceberg/src/spec/datatypes.rs
@@ -135,6 +135,15 @@ impl Type {
ensure_data_valid!(precision > 0 && precision <=
MAX_DECIMAL_PRECISION, "Decimals with precision larger than
{MAX_DECIMAL_PRECISION} are not supported: {precision}",);
Ok(Type::Primitive(PrimitiveType::Decimal { precision, scale }))
}
+
+ /// Check if it's float or double type.
+ #[inline(always)]
+ pub fn is_floating_type(&self) -> bool {
+ matches!(
+ self,
+ Type::Primitive(PrimitiveType::Float) |
Type::Primitive(PrimitiveType::Double)
+ )
+ }
}
impl From<PrimitiveType> for Type {
diff --git a/crates/iceberg/src/spec/schema.rs
b/crates/iceberg/src/spec/schema.rs
index 34e383f..975a2a9 100644
--- a/crates/iceberg/src/spec/schema.rs
+++ b/crates/iceberg/src/spec/schema.rs
@@ -51,6 +51,7 @@ pub struct Schema {
id_to_field: HashMap<i32, NestedFieldRef>,
name_to_id: HashMap<String, i32>,
+ lowercase_name_to_id: HashMap<String, i32>,
id_to_name: HashMap<i32, String>,
}
@@ -117,6 +118,11 @@ impl SchemaBuilder {
index.indexes()
};
+ let lowercase_name_to_id = name_to_id
+ .iter()
+ .map(|(k, v)| (k.to_lowercase(), *v))
+ .collect();
+
Ok(Schema {
r#struct,
schema_id: self.schema_id,
@@ -127,6 +133,7 @@ impl SchemaBuilder {
id_to_field,
name_to_id,
+ lowercase_name_to_id,
id_to_name,
})
}
@@ -212,6 +219,15 @@ impl Schema {
.and_then(|id| self.field_by_id(*id))
}
+ /// Get field by field name, but in case-insensitive way.
+ ///
+ /// Both full name and short name could work here.
+ pub fn field_by_name_case_insensitive(&self, field_name: &str) ->
Option<&NestedFieldRef> {
+ self.lowercase_name_to_id
+ .get(&field_name.to_lowercase())
+ .and_then(|id| self.field_by_id(*id))
+ }
+
/// Get field by alias.
pub fn field_by_alias(&self, alias: &str) -> Option<&NestedFieldRef> {
self.alias_to_id
@@ -1032,6 +1048,42 @@ table {
assert_eq!(&expected_name_to_id, &schema.name_to_id);
}
+ #[test]
+ fn test_schema_index_by_name_case_insensitive() {
+ let expected_name_to_id = HashMap::from(
+ [
+ ("fOo", 1),
+ ("Bar", 2),
+ ("BAz", 3),
+ ("quX", 4),
+ ("quX.ELEment", 5),
+ ("qUUx", 6),
+ ("QUUX.KEY", 7),
+ ("QUUX.Value", 8),
+ ("qUUX.VALUE.Key", 9),
+ ("qUux.VaLue.Value", 10),
+ ("lOCAtION", 11),
+ ("LOCAtioN.ELeMENt", 12),
+ ("LoCATion.element.LATitude", 13),
+ ("locatION.ElemeNT.LONgitude", 14),
+ ("LOCAtiON.LATITUDE", 13),
+ ("LOCATION.LONGITUDE", 14),
+ ("PERSon", 15),
+ ("PERSON.Name", 16),
+ ("peRSON.AGe", 17),
+ ]
+ .map(|e| (e.0.to_string(), e.1)),
+ );
+
+ let schema = table_schema_nested();
+ for (name, id) in expected_name_to_id {
+ assert_eq!(
+ Some(id),
+ schema.field_by_name_case_insensitive(&name).map(|f| f.id)
+ );
+ }
+ }
+
#[test]
fn test_schema_find_column_name() {
let expected_column_name = HashMap::from([
diff --git a/crates/iceberg/src/spec/values.rs
b/crates/iceberg/src/spec/values.rs
index 113620f..00f2e57 100644
--- a/crates/iceberg/src/spec/values.rs
+++ b/crates/iceberg/src/spec/values.rs
@@ -106,7 +106,7 @@ impl Display for Datum {
(_, PrimitiveLiteral::TimestampTZ(val)) => {
write!(f, "{}", microseconds_to_datetimetz(*val))
}
- (_, PrimitiveLiteral::String(val)) => write!(f, "{}", val),
+ (_, PrimitiveLiteral::String(val)) => write!(f, r#""{}""#, val),
(_, PrimitiveLiteral::UUID(val)) => write!(f, "{}", val),
(_, PrimitiveLiteral::Fixed(val)) => display_bytes(val, f),
(_, PrimitiveLiteral::Binary(val)) => display_bytes(val, f),
@@ -529,7 +529,7 @@ impl Datum {
/// use iceberg::spec::Datum;
/// let t = Datum::string("ss");
///
- /// assert_eq!(&format!("{t}"), "ss");
+ /// assert_eq!(&format!("{t}"), r#""ss""#);
/// ```
pub fn string<S: ToString>(s: S) -> Self {
Self {
@@ -658,6 +658,21 @@ impl Datum {
unreachable!("Decimal type must be primitive.")
}
}
+
+ /// Convert the datum to `target_type`.
+ pub fn to(self, target_type: &Type) -> Result<Datum> {
+ // TODO: We should allow more type conversions
+ match target_type {
+ Type::Primitive(typ) if typ == &self.r#type => Ok(self),
+ _ => Err(Error::new(
+ ErrorKind::DataInvalid,
+ format!(
+ "Can't convert datum from {} type to {} type.",
+ self.r#type, target_type
+ ),
+ )),
+ }
+ }
}
/// Values present in iceberg type