This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-go.git
The following commit(s) were added to refs/heads/main by this push:
new 704a6e7 feat(exprs): Adding BooleanExpressions and Predicates (#91)
704a6e7 is described below
commit 704a6e78c13ea63f1ff4bb387f7d4b365b5f0f82
Author: Matt Topol <[email protected]>
AuthorDate: Tue Jun 25 09:51:07 2024 -0700
feat(exprs): Adding BooleanExpressions and Predicates (#91)
* feat(exprs): Adding BooleanExpressions and Predicates
* exclude the generated file from rat check
---
dev/.rat-excludes | 1 +
exprs.go | 968 ++++++++++++++++++++++++++++++++++++++++++++++++++++
exprs_test.go | 742 ++++++++++++++++++++++++++++++++++++++++
operation_string.go | 41 +++
predicates.go | 138 ++++++++
schema.go | 67 ++++
utils.go | 127 +++++++
7 files changed, 2084 insertions(+)
diff --git a/dev/.rat-excludes b/dev/.rat-excludes
index b968f68..e947260 100644
--- a/dev/.rat-excludes
+++ b/dev/.rat-excludes
@@ -5,3 +5,4 @@ NOTICE
go.sum
build
rat-results.txt
+operation_string.go
\ No newline at end of file
diff --git a/exprs.go b/exprs.go
new file mode 100644
index 0000000..1123b8b
--- /dev/null
+++ b/exprs.go
@@ -0,0 +1,968 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package iceberg
+
+import (
+ "fmt"
+
+ "github.com/google/uuid"
+)
+
+//go:generate stringer -type=Operation -linecomment
+
+// Operation is an enum used for constants to define what operation a given
+// expression or predicate is going to execute.
+type Operation int
+
+const (
+ // do not change the order of these enum constants.
+ // they are grouped for quick validation of operation type by
+ // using <= and >= of the first/last operation in a group
+
+ OpTrue Operation = iota // True
+ OpFalse // False
+ // unary ops
+ OpIsNull // IsNull
+ OpNotNull // NotNull
+ OpIsNan // IsNaN
+ OpNotNan // NotNaN
+ // literal ops
+ OpLT // LessThan
+ OpLTEQ // LessThanEqual
+ OpGT // GreaterThan
+ OpGTEQ // GreaterThanEqual
+ OpEQ // Equal
+ OpNEQ // NotEqual
+ OpStartsWith // StartsWith
+ OpNotStartsWith // NotStartsWith
+ // set ops
+ OpIn // In
+ OpNotIn // NotIn
+ // boolean ops
+ OpNot // Not
+ OpAnd // And
+ OpOr // Or
+)
+
+// Negate returns the inverse operation for a given op
+func (op Operation) Negate() Operation {
+ switch op {
+ case OpIsNull:
+ return OpNotNull
+ case OpNotNull:
+ return OpIsNull
+ case OpIsNan:
+ return OpNotNan
+ case OpNotNan:
+ return OpIsNan
+ case OpLT:
+ return OpGTEQ
+ case OpLTEQ:
+ return OpGT
+ case OpGT:
+ return OpLTEQ
+ case OpGTEQ:
+ return OpLT
+ case OpEQ:
+ return OpNEQ
+ case OpNEQ:
+ return OpEQ
+ case OpIn:
+ return OpNotIn
+ case OpNotIn:
+ return OpIn
+ case OpStartsWith:
+ return OpNotStartsWith
+ case OpNotStartsWith:
+ return OpStartsWith
+ default:
+ panic("no negation for operation " + op.String())
+ }
+}
+
+// FlipLR returns the correct operation to use if the left and right operands
+// are flipped.
+func (op Operation) FlipLR() Operation {
+ switch op {
+ case OpLT:
+ return OpGT
+ case OpLTEQ:
+ return OpGTEQ
+ case OpGT:
+ return OpLT
+ case OpGTEQ:
+ return OpLTEQ
+ case OpAnd:
+ return OpAnd
+ case OpOr:
+ return OpOr
+ default:
+ panic("no left-right flip for operation: " + op.String())
+ }
+}
+
+// BooleanExpression represents a full expression which will evaluate to a
+// boolean value such as GreaterThan or StartsWith, etc.
+type BooleanExpression interface {
+ fmt.Stringer
+ Op() Operation
+ Negate() BooleanExpression
+ Equals(BooleanExpression) bool
+}
+
+// AlwaysTrue is the boolean expression "True"
+type AlwaysTrue struct{}
+
+func (AlwaysTrue) String() string { return "AlwaysTrue()" }
+func (AlwaysTrue) Op() Operation { return OpTrue }
+func (AlwaysTrue) Negate() BooleanExpression { return AlwaysFalse{} }
+func (AlwaysTrue) Equals(other BooleanExpression) bool {
+ _, ok := other.(AlwaysTrue)
+ return ok
+}
+
+// AlwaysFalse is the boolean expression "False"
+type AlwaysFalse struct{}
+
+func (AlwaysFalse) String() string { return "AlwaysFalse()" }
+func (AlwaysFalse) Op() Operation { return OpFalse }
+func (AlwaysFalse) Negate() BooleanExpression { return AlwaysTrue{} }
+func (AlwaysFalse) Equals(other BooleanExpression) bool {
+ _, ok := other.(AlwaysFalse)
+ return ok
+}
+
+type NotExpr struct {
+ child BooleanExpression
+}
+
+// NewNot creates a BooleanExpression representing a "Not" operation on the
given
+// argument. It will optimize slightly though:
+//
+// If the argument is AlwaysTrue or AlwaysFalse, the appropriate inverse
expression
+// will be returned directly. If the argument is itself a NotExpr, then the
child
+// will be returned rather than NotExpr(NotExpr(child)).
+func NewNot(child BooleanExpression) BooleanExpression {
+ if child == nil {
+ panic(fmt.Errorf("%w: cannot create NotExpr with nil child",
+ ErrInvalidArgument))
+ }
+
+ switch t := child.(type) {
+ case NotExpr:
+ return t.child
+ case AlwaysTrue:
+ return AlwaysFalse{}
+ case AlwaysFalse:
+ return AlwaysTrue{}
+ }
+
+ return NotExpr{child: child}
+}
+
+func (n NotExpr) String() string { return "Not(child=" +
n.child.String() + ")" }
+func (NotExpr) Op() Operation { return OpNot }
+func (n NotExpr) Negate() BooleanExpression { return n.child }
+func (n NotExpr) Equals(other BooleanExpression) bool {
+ rhs, ok := other.(NotExpr)
+ if !ok {
+ return false
+ }
+ return n.child.Equals(rhs.child)
+}
+
+type AndExpr struct {
+ left, right BooleanExpression
+}
+
+func newAnd(left, right BooleanExpression) BooleanExpression {
+ if left == nil || right == nil {
+ panic(fmt.Errorf("%w: cannot construct AndExpr with nil
arguments",
+ ErrInvalidArgument))
+ }
+
+ switch {
+ case left == AlwaysFalse{} || right == AlwaysFalse{}:
+ return AlwaysFalse{}
+ case left == AlwaysTrue{}:
+ return right
+ case right == AlwaysTrue{}:
+ return left
+ }
+
+ return AndExpr{left: left, right: right}
+}
+
+// NewAnd will construct a new AndExpr, allowing the caller to provide
potentially
+// more than just two arguments which will be folded to create an appropriate
expression
+// tree. i.e. NewAnd(a, b, c, d) becomes AndExpr(a, AndExpr(b, AndExpr(c, d)))
+//
+// Slight optimizations are performed on creation if either argument is
AlwaysFalse
+// or AlwaysTrue by performing reductions. If any argument is AlwaysFalse,
then everything
+// will get folded to a return of AlwaysFalse. If an argument is AlwaysTrue,
then the other
+// argument will be returned directly rather than creating an AndExpr.
+//
+// Will panic if any argument is nil
+func NewAnd(left, right BooleanExpression, addl ...BooleanExpression)
BooleanExpression {
+ folded := newAnd(left, right)
+ for _, a := range addl {
+ folded = newAnd(folded, a)
+ }
+ return folded
+}
+
+func (a AndExpr) String() string {
+ return "And(left=" + a.left.String() + ", right=" + a.right.String() +
")"
+}
+
+func (AndExpr) Op() Operation { return OpAnd }
+
+func (a AndExpr) Equals(other BooleanExpression) bool {
+ rhs, ok := other.(AndExpr)
+ if !ok {
+ return false
+ }
+
+ return (a.left.Equals(rhs.left) && a.right.Equals(rhs.right)) ||
+ (a.left.Equals(rhs.right) && a.right.Equals(rhs.left))
+}
+
+func (a AndExpr) Negate() BooleanExpression {
+ return NewOr(a.left.Negate(), a.right.Negate())
+}
+
+type OrExpr struct {
+ left, right BooleanExpression
+}
+
+func newOr(left, right BooleanExpression) BooleanExpression {
+ if left == nil || right == nil {
+ panic(fmt.Errorf("%w: cannot construct OrExpr with nil
arguments",
+ ErrInvalidArgument))
+ }
+
+ switch {
+ case left == AlwaysTrue{} || right == AlwaysTrue{}:
+ return AlwaysTrue{}
+ case left == AlwaysFalse{}:
+ return right
+ case right == AlwaysFalse{}:
+ return left
+ }
+
+ return OrExpr{left: left, right: right}
+}
+
+// NewOr will construct a new OrExpr, allowing the caller to provide
potentially
+// more than just two arguments which will be folded to create an appropriate
expression
+// tree. i.e. NewOr(a, b, c, d) becomes OrExpr(a, OrExpr(b, OrExpr(c, d)))
+//
+// Slight optimizations are performed on creation if either argument is
AlwaysFalse
+// or AlwaysTrue by performing reductions. If any argument is AlwaysTrue, then
everything
+// will get folded to a return of AlwaysTrue. If an argument is AlwaysFalse,
then the other
+// argument will be returned directly rather than creating an OrExpr.
+//
+// Will panic if any argument is nil
+func NewOr(left, right BooleanExpression, addl ...BooleanExpression)
BooleanExpression {
+ folded := newOr(left, right)
+ for _, a := range addl {
+ folded = newOr(folded, a)
+ }
+ return folded
+}
+
+func (o OrExpr) String() string {
+ return "Or(left=" + o.left.String() + ", right=" + o.right.String() +
")"
+}
+
+func (OrExpr) Op() Operation { return OpOr }
+
+func (o OrExpr) Equals(other BooleanExpression) bool {
+ rhs, ok := other.(OrExpr)
+ if !ok {
+ return false
+ }
+
+ return (o.left.Equals(rhs.left) && o.right.Equals(rhs.right)) ||
+ (o.left.Equals(rhs.right) && o.right.Equals(rhs.left))
+}
+
+func (o OrExpr) Negate() BooleanExpression {
+ return NewAnd(o.left.Negate(), o.right.Negate())
+}
+
+// A Term is a simple expression that evaluates to a value
+type Term interface {
+ fmt.Stringer
+ // requiring this method ensures that only types we define can be used
+ // as a term.
+ isTerm()
+}
+
+// UnboundTerm is an expression that evaluates to a value that isn't yet bound
+// to a schema, thus it isn't yet known what the type will be.
+type UnboundTerm interface {
+ Term
+
+ Equals(UnboundTerm) bool
+ Bind(schema *Schema, caseSensitive bool) (BoundTerm, error)
+}
+
+// BoundTerm is a simple expression (typically a reference) that evaluates to a
+// value and has been bound to a schema.
+type BoundTerm interface {
+ Term
+
+ Equals(BoundTerm) bool
+ Ref() BoundReference
+ Type() Type
+
+ evalToLiteral(structLike) Literal
+ evalIsNull(structLike) bool
+}
+
+// unbound is a generic interface representing something that is not yet bound
+// to a particular type.
+type unbound[B any] interface {
+ Bind(schema *Schema, caseSensitive bool) (B, error)
+}
+
+// An UnboundPredicate represents a boolean predicate expression which has not
+// yet been bound to a schema. Binding it will produce a BooleanExpression.
+//
+// BooleanExpression is used for the binding result because we may optimize and
+// return AlwaysTrue / AlwaysFalse in some scenarios during binding which are
+// not considered to be "Bound" as they do not have a bound Term or Reference.
+type UnboundPredicate interface {
+ BooleanExpression
+ unbound[BooleanExpression]
+ Term() UnboundTerm
+}
+
+// BoundPredicate is a boolean predicate expression which has been bound to a
schema.
+// The underlying reference and term can be retrieved from it.
+type BoundPredicate interface {
+ BooleanExpression
+ Ref() BoundReference
+ Term() BoundTerm
+}
+
+// Reference is a field name not yet bound to a particular field in a schema
+type Reference string
+
+func (r Reference) String() string {
+ return "Reference(name='" + string(r) + "')"
+}
+
+func (Reference) isTerm() {}
+func (r Reference) Equals(other UnboundTerm) bool {
+ rhs, ok := other.(Reference)
+ if !ok {
+ return false
+ }
+
+ return r == rhs
+}
+
+func (r Reference) Bind(s *Schema, caseSensitive bool) (BoundTerm, error) {
+ var (
+ field NestedField
+ found bool
+ )
+
+ if caseSensitive {
+ field, found = s.FindFieldByName(string(r))
+ } else {
+ field, found = s.FindFieldByNameCaseInsensitive(string(r))
+ }
+ if !found {
+ return nil, fmt.Errorf("%w: could not bind reference '%s',
caseSensitive=%t",
+ ErrInvalidSchema, string(r), caseSensitive)
+ }
+
+ acc, ok := s.accessorForField(field.ID)
+ if !ok {
+ return nil, ErrInvalidSchema
+ }
+
+ return createBoundRef(field, acc), nil
+}
+
+// BoundReference is a named reference that has been bound to a particular
field
+// in a given schema.
+type BoundReference interface {
+ BoundTerm
+
+ Field() NestedField
+}
+
+type boundRef[T LiteralType] struct {
+ field NestedField
+ acc accessor
+}
+
+func createBoundRef(field NestedField, acc accessor) BoundReference {
+ switch field.Type.(type) {
+ case BooleanType:
+ return &boundRef[bool]{field: field, acc: acc}
+ case Int32Type:
+ return &boundRef[int32]{field: field, acc: acc}
+ case Int64Type:
+ return &boundRef[int64]{field: field, acc: acc}
+ case Float32Type:
+ return &boundRef[float32]{field: field, acc: acc}
+ case Float64Type:
+ return &boundRef[float64]{field: field, acc: acc}
+ case DateType:
+ return &boundRef[Date]{field: field, acc: acc}
+ case TimeType:
+ return &boundRef[Time]{field: field, acc: acc}
+ case TimestampType, TimestampTzType:
+ return &boundRef[Timestamp]{field: field, acc: acc}
+ case StringType:
+ return &boundRef[string]{field: field, acc: acc}
+ case FixedType, BinaryType:
+ return &boundRef[[]byte]{field: field, acc: acc}
+ case DecimalType:
+ return &boundRef[Decimal]{field: field, acc: acc}
+ case UUIDType:
+ return &boundRef[uuid.UUID]{field: field, acc: acc}
+ }
+ panic("unhandled bound reference type: " + field.Type.String())
+}
+
+func (*boundRef[T]) isTerm() {}
+
+func (b *boundRef[T]) String() string {
+ return fmt.Sprintf("BoundReference(field=%s, accessor=%s)", b.field,
&b.acc)
+}
+
+func (b *boundRef[T]) Equals(other BoundTerm) bool {
+ rhs, ok := other.(*boundRef[T])
+ if !ok {
+ return false
+ }
+
+ return b.field.Equals(rhs.field)
+}
+
+func (b *boundRef[T]) Ref() BoundReference { return b }
+func (b *boundRef[T]) Field() NestedField { return b.field }
+func (b *boundRef[T]) Type() Type { return b.field.Type }
+
+func (b *boundRef[T]) eval(st structLike) Optional[T] {
+ switch v := b.acc.Get(st).(type) {
+ case nil:
+ return Optional[T]{}
+ case T:
+ return Optional[T]{Valid: true, Val: v}
+ }
+ panic("unexpected type returned for bound ref")
+}
+
+func (b *boundRef[T]) evalToLiteral(st structLike) Literal {
+ v := b.eval(st)
+ lit := NewLiteral[T](v.Val)
+ if !lit.Type().Equals(b.field.Type) {
+ lit, _ = lit.To(b.field.Type)
+ }
+ return lit
+}
+
+func (b *boundRef[T]) evalIsNull(st structLike) bool {
+ v := b.eval(st)
+ return !v.Valid
+}
+
+// UnaryPredicate creates and returns an unbound predicate for the provided
unary operation.
+// Will panic if op is not a unary operation.
+func UnaryPredicate(op Operation, t UnboundTerm) UnboundPredicate {
+ if op < OpIsNull || op > OpNotNan {
+ panic(fmt.Errorf("%w: invalid operation for unary predicate:
%s",
+ ErrInvalidArgument, op))
+ }
+
+ if t == nil {
+ panic(fmt.Errorf("%w: cannot create unary predicate with nil
term",
+ ErrInvalidArgument))
+ }
+
+ return &unboundUnaryPredicate{op: op, term: t}
+}
+
+type unboundUnaryPredicate struct {
+ op Operation
+ term UnboundTerm
+}
+
+func (up *unboundUnaryPredicate) String() string {
+ return fmt.Sprintf("%s(term=%s)", up.op, up.term)
+}
+
+func (up *unboundUnaryPredicate) Equals(other BooleanExpression) bool {
+ rhs, ok := other.(*unboundUnaryPredicate)
+ if !ok {
+ return false
+ }
+
+ return up.op == rhs.op && up.term.Equals(rhs.term)
+}
+
+func (up *unboundUnaryPredicate) Op() Operation { return up.op }
+func (up *unboundUnaryPredicate) Negate() BooleanExpression {
+ return &unboundUnaryPredicate{op: up.op.Negate(), term: up.term}
+}
+
+func (up *unboundUnaryPredicate) Term() UnboundTerm { return up.term }
+func (up *unboundUnaryPredicate) Bind(schema *Schema, caseSensitive bool)
(BooleanExpression, error) {
+ bound, err := up.term.Bind(schema, caseSensitive)
+ if err != nil {
+ return nil, err
+ }
+
+ // fast case optimizations
+ switch up.op {
+ case OpIsNull:
+ if bound.Ref().Field().Required {
+ return AlwaysFalse{}, nil
+ }
+ case OpNotNull:
+ if bound.Ref().Field().Required {
+ return AlwaysTrue{}, nil
+ }
+ case OpIsNan:
+ if !bound.Type().Equals(PrimitiveTypes.Float32) &&
!bound.Type().Equals(PrimitiveTypes.Float64) {
+ return AlwaysFalse{}, nil
+ }
+ case OpNotNan:
+ if !bound.Type().Equals(PrimitiveTypes.Float32) &&
!bound.Type().Equals(PrimitiveTypes.Float64) {
+ return AlwaysTrue{}, nil
+ }
+ }
+
+ return createBoundUnaryPredicate(up.op, bound), nil
+}
+
+// BoundUnaryPredicate is a bound predicate expression that has no arguments
+type BoundUnaryPredicate interface {
+ BoundPredicate
+
+ AsUnbound(Reference) UnboundPredicate
+}
+
+type bound[T LiteralType] interface {
+ BoundTerm
+
+ eval(structLike) Optional[T]
+}
+
+func newBoundUnaryPred[T LiteralType](op Operation, term BoundTerm)
BoundUnaryPredicate {
+ return &boundUnaryPredicate[T]{op: op, term: term.(bound[T])}
+}
+
+func createBoundUnaryPredicate(op Operation, term BoundTerm)
BoundUnaryPredicate {
+ switch term.Type().(type) {
+ case BooleanType:
+ return newBoundUnaryPred[bool](op, term)
+ case Int32Type:
+ return newBoundUnaryPred[int32](op, term)
+ case Int64Type:
+ return newBoundUnaryPred[int64](op, term)
+ case Float32Type:
+ return newBoundUnaryPred[float32](op, term)
+ case Float64Type:
+ return newBoundUnaryPred[float64](op, term)
+ case DateType:
+ return newBoundUnaryPred[Date](op, term)
+ case TimeType:
+ return newBoundUnaryPred[Time](op, term)
+ case TimestampType, TimestampTzType:
+ return newBoundUnaryPred[Timestamp](op, term)
+ case StringType:
+ return newBoundUnaryPred[string](op, term)
+ case FixedType, BinaryType:
+ return newBoundUnaryPred[[]byte](op, term)
+ case DecimalType:
+ return newBoundUnaryPred[Decimal](op, term)
+ case UUIDType:
+ return newBoundUnaryPred[uuid.UUID](op, term)
+ }
+ panic("unhandled bound reference type: " + term.Type().String())
+}
+
+type boundUnaryPredicate[T LiteralType] struct {
+ op Operation
+ term bound[T]
+}
+
+func (bp *boundUnaryPredicate[T]) AsUnbound(r Reference) UnboundPredicate {
+ return &unboundUnaryPredicate{op: bp.op, term: r}
+}
+
+func (bp *boundUnaryPredicate[T]) Equals(other BooleanExpression) bool {
+ rhs, ok := other.(*boundUnaryPredicate[T])
+ if !ok {
+ return false
+ }
+
+ return bp.op == rhs.op && bp.term.Equals(rhs.term)
+}
+
+func (bp *boundUnaryPredicate[T]) Op() Operation { return bp.op }
+func (bp *boundUnaryPredicate[T]) Negate() BooleanExpression {
+ return &boundUnaryPredicate[T]{op: bp.op.Negate(), term: bp.term}
+}
+
+func (bp *boundUnaryPredicate[T]) Term() BoundTerm { return bp.term }
+func (bp *boundUnaryPredicate[T]) Ref() BoundReference { return bp.term.Ref() }
+func (bp *boundUnaryPredicate[T]) String() string {
+ return fmt.Sprintf("Bound%s(term=%s)", bp.op, bp.term)
+}
+
+// LiteralPredicate constructs an unbound predicate for an operation that
requires
+// a single literal argument, such as LessThan or StartsWith.
+//
+// Panics if the operation provided is not a valid Literal operation,
+// if the term is nil or if the literal is nil.
+func LiteralPredicate(op Operation, t UnboundTerm, lit Literal)
UnboundPredicate {
+ switch {
+ case op < OpLT || op > OpNotStartsWith:
+ panic(fmt.Errorf("%w: invalid operation for LiteralPredicate:
%s",
+ ErrInvalidArgument, op))
+ case t == nil:
+ panic(fmt.Errorf("%w: cannot create literal predicate with nil
term",
+ ErrInvalidArgument))
+ case lit == nil:
+ panic(fmt.Errorf("%w: cannot create literal predicate with nil
literal",
+ ErrInvalidArgument))
+ }
+
+ return &unboundLiteralPredicate{op: op, term: t, lit: lit}
+}
+
+type unboundLiteralPredicate struct {
+ op Operation
+ term UnboundTerm
+ lit Literal
+}
+
+func (ul *unboundLiteralPredicate) String() string {
+ return fmt.Sprintf("%s(term=%s, literal=%s)", ul.op, ul.term, ul.lit)
+}
+
+func (ul *unboundLiteralPredicate) Equals(other BooleanExpression) bool {
+ rhs, ok := other.(*unboundLiteralPredicate)
+ if !ok {
+ return false
+ }
+
+ return ul.op == rhs.op && ul.term.Equals(rhs.term) &&
ul.lit.Equals(rhs.lit)
+}
+
+func (ul *unboundLiteralPredicate) Op() Operation { return ul.op }
+func (ul *unboundLiteralPredicate) Negate() BooleanExpression {
+ return &unboundLiteralPredicate{op: ul.op.Negate(), term: ul.term, lit:
ul.lit}
+}
+func (ul *unboundLiteralPredicate) Term() UnboundTerm { return ul.term }
+func (ul *unboundLiteralPredicate) Bind(schema *Schema, caseSensitive bool)
(BooleanExpression, error) {
+ bound, err := ul.term.Bind(schema, caseSensitive)
+ if err != nil {
+ return nil, err
+ }
+
+ if (ul.op == OpStartsWith || ul.op == OpNotStartsWith) &&
+ !bound.Type().Equals(PrimitiveTypes.String) {
+ return nil, fmt.Errorf("%w: StartsWith and NotStartsWith must
bind to String type, not %s",
+ ErrType, bound.Type())
+ }
+
+ lit, err := ul.lit.To(bound.Type())
+ if err != nil {
+ return nil, err
+ }
+
+ switch lit.(type) {
+ case AboveMaxLiteral:
+ switch ul.op {
+ case OpLT, OpLTEQ, OpNEQ:
+ return AlwaysTrue{}, nil
+ case OpGT, OpGTEQ, OpEQ:
+ return AlwaysFalse{}, nil
+ }
+ case BelowMinLiteral:
+ switch ul.op {
+ case OpLT, OpLTEQ, OpEQ:
+ return AlwaysFalse{}, nil
+ case OpGT, OpGTEQ, OpNEQ:
+ return AlwaysTrue{}, nil
+ }
+ }
+
+ return createBoundLiteralPredicate(ul.op, bound, lit)
+}
+
+// BoundLiteralPredicate represents a bound boolean expression that utilizes a
single
+// literal as an argument, such as Equals or StartsWith.
+type BoundLiteralPredicate interface {
+ BoundPredicate
+
+ Literal() Literal
+ AsUnbound(Reference, Literal) UnboundPredicate
+}
+
+func newBoundLiteralPredicate[T LiteralType](op Operation, term BoundTerm, lit
Literal) BoundPredicate {
+ return &boundLiteralPredicate[T]{op: op, term: term.(bound[T]),
+ lit: lit.(TypedLiteral[T])}
+}
+
+func createBoundLiteralPredicate(op Operation, term BoundTerm, lit Literal)
(BoundPredicate, error) {
+ finalLit, err := lit.To(term.Type())
+ if err != nil {
+ return nil, err
+ }
+
+ switch term.Type().(type) {
+ case BooleanType:
+ return newBoundLiteralPredicate[bool](op, term, finalLit), nil
+ case Int32Type:
+ return newBoundLiteralPredicate[int32](op, term, finalLit), nil
+ case Int64Type:
+ return newBoundLiteralPredicate[int64](op, term, finalLit), nil
+ case Float32Type:
+ return newBoundLiteralPredicate[float32](op, term, finalLit),
nil
+ case Float64Type:
+ return newBoundLiteralPredicate[float64](op, term, finalLit),
nil
+ case DateType:
+ return newBoundLiteralPredicate[Date](op, term, finalLit), nil
+ case TimeType:
+ return newBoundLiteralPredicate[Time](op, term, finalLit), nil
+ case TimestampType, TimestampTzType:
+ return newBoundLiteralPredicate[Timestamp](op, term, finalLit),
nil
+ case StringType:
+ return newBoundLiteralPredicate[string](op, term, finalLit), nil
+ case FixedType, BinaryType:
+ return newBoundLiteralPredicate[[]byte](op, term, finalLit), nil
+ case DecimalType:
+ return newBoundLiteralPredicate[Decimal](op, term, finalLit),
nil
+ case UUIDType:
+ return newBoundLiteralPredicate[uuid.UUID](op, term, finalLit),
nil
+ }
+ return nil, fmt.Errorf("%w: could not create bound literal predicate
for term type %s",
+ ErrInvalidArgument, term.Type())
+}
+
+type boundLiteralPredicate[T LiteralType] struct {
+ op Operation
+ term bound[T]
+ lit TypedLiteral[T]
+}
+
+func (blp *boundLiteralPredicate[T]) Equals(other BooleanExpression) bool {
+ rhs, ok := other.(*boundLiteralPredicate[T])
+ if !ok {
+ return false
+ }
+
+ return blp.op == rhs.op && blp.term.Equals(rhs.term) &&
blp.lit.Equals(rhs.lit)
+}
+
+func (blp *boundLiteralPredicate[T]) Op() Operation { return blp.op }
+func (blp *boundLiteralPredicate[T]) Negate() BooleanExpression {
+ return &boundLiteralPredicate[T]{op: blp.op.Negate(), term: blp.term,
lit: blp.lit}
+}
+func (blp *boundLiteralPredicate[T]) Term() BoundTerm { return blp.term }
+func (blp *boundLiteralPredicate[T]) Ref() BoundReference { return
blp.term.Ref() }
+func (blp *boundLiteralPredicate[T]) String() string {
+ return fmt.Sprintf("Bound%s(term=%s, literal=%s)", blp.op, blp.term,
blp.lit)
+}
+func (blp *boundLiteralPredicate[T]) Literal() Literal { return blp.lit }
+func (blp *boundLiteralPredicate[T]) AsUnbound(r Reference, l Literal)
UnboundPredicate {
+ return &unboundLiteralPredicate{op: blp.op, term: r, lit: l}
+}
+
+// SetPredicate creates a boolean expression representing a predicate that
uses a set
+// of literals as the argument, like In or NotIn. Duplicate literals will be
folded
+// into a set, only maintaining the unique literals.
+//
+// Will panic if op is not a valid Set operation
+func SetPredicate(op Operation, t UnboundTerm, lits []Literal)
BooleanExpression {
+ if op < OpIn || op > OpNotIn {
+ panic(fmt.Errorf("%w: invalid operation for SetPredicate: %s",
+ ErrInvalidArgument, op))
+ }
+
+ if t == nil {
+ panic(fmt.Errorf("%w: cannot create set predicate with nil
term",
+ ErrInvalidArgument))
+ }
+
+ switch len(lits) {
+ case 0:
+ if op == OpIn {
+ return AlwaysFalse{}
+ } else if op == OpNotIn {
+ return AlwaysTrue{}
+ }
+ case 1:
+ if op == OpIn {
+ return LiteralPredicate(OpEQ, t, lits[0])
+ } else if op == OpNotIn {
+ return LiteralPredicate(OpNEQ, t, lits[0])
+ }
+ }
+
+ return &unboundSetPredicate{op: op, term: t, lits:
newLiteralSet(lits...)}
+}
+
+type unboundSetPredicate struct {
+ op Operation
+ term UnboundTerm
+ lits Set[Literal]
+}
+
+func (usp *unboundSetPredicate) String() string {
+ return fmt.Sprintf("%s(term=%s, {%v})", usp.op, usp.term,
usp.lits.Members())
+}
+
+func (usp *unboundSetPredicate) Equals(other BooleanExpression) bool {
+ rhs, ok := other.(*unboundSetPredicate)
+ if !ok {
+ return false
+ }
+
+ return usp.op == rhs.op && usp.term.Equals(rhs.term) &&
+ usp.lits.Equals(rhs.lits)
+}
+
+func (usp *unboundSetPredicate) Op() Operation { return usp.op }
+func (usp *unboundSetPredicate) Negate() BooleanExpression {
+ return &unboundSetPredicate{op: usp.op.Negate(), term: usp.term, lits:
usp.lits}
+}
+
+func (usp *unboundSetPredicate) Term() UnboundTerm { return usp.term }
+func (usp *unboundSetPredicate) Bind(schema *Schema, caseSensitive bool)
(BooleanExpression, error) {
+ bound, err := usp.term.Bind(schema, caseSensitive)
+ if err != nil {
+ return nil, err
+ }
+
+ return createBoundSetPredicate(usp.op, bound, usp.lits)
+}
+
+// BoundSetPredicate is a bound expression that utilizes a set of literals
such as In or NotIn
+type BoundSetPredicate interface {
+ BoundPredicate
+
+ Literals() Set[Literal]
+ AsUnbound(Reference, []Literal) UnboundPredicate
+}
+
+func createBoundSetPredicate(op Operation, term BoundTerm, lits Set[Literal])
(BooleanExpression, error) {
+ boundType := term.Type()
+
+ typedSet := newLiteralSet()
+ for _, v := range lits.Members() {
+ casted, err := v.To(boundType)
+ if err != nil {
+ return nil, err
+ }
+ typedSet.Add(casted)
+ }
+
+ switch typedSet.Len() {
+ case 0:
+ if op == OpIn {
+ return AlwaysFalse{}, nil
+ } else if op == OpNotIn {
+ return AlwaysTrue{}, nil
+ }
+ case 1:
+ if op == OpIn {
+ return createBoundLiteralPredicate(OpEQ, term,
typedSet.Members()[0])
+ } else if op == OpNotIn {
+ return createBoundLiteralPredicate(OpNEQ, term,
typedSet.Members()[0])
+ }
+ }
+
+ switch term.Type().(type) {
+ case BooleanType:
+ return newBoundSetPredicate[bool](op, term, typedSet), nil
+ case Int32Type:
+ return newBoundSetPredicate[int32](op, term, typedSet), nil
+ case Int64Type:
+ return newBoundSetPredicate[int64](op, term, typedSet), nil
+ case Float32Type:
+ return newBoundSetPredicate[float32](op, term, typedSet), nil
+ case Float64Type:
+ return newBoundSetPredicate[float64](op, term, typedSet), nil
+ case DateType:
+ return newBoundSetPredicate[Date](op, term, typedSet), nil
+ case TimeType:
+ return newBoundSetPredicate[Time](op, term, typedSet), nil
+ case TimestampType, TimestampTzType:
+ return newBoundSetPredicate[Timestamp](op, term, typedSet), nil
+ case StringType:
+ return newBoundSetPredicate[string](op, term, typedSet), nil
+ case BinaryType, FixedType:
+ return newBoundSetPredicate[[]byte](op, term, typedSet), nil
+ case DecimalType:
+ return newBoundSetPredicate[Decimal](op, term, typedSet), nil
+ case UUIDType:
+ return newBoundSetPredicate[uuid.UUID](op, term, typedSet), nil
+ }
+
+ return nil, fmt.Errorf("%w: invalid bound type for set predicate - %s",
+ ErrType, term.Type())
+}
+
+func newBoundSetPredicate[T LiteralType](op Operation, term BoundTerm, lits
Set[Literal]) *boundSetPredicate[T] {
+ return &boundSetPredicate[T]{op: op, term: term.(bound[T]), lits: lits}
+}
+
+type boundSetPredicate[T LiteralType] struct {
+ op Operation
+ term bound[T]
+ lits Set[Literal]
+}
+
+func (bsp *boundSetPredicate[T]) Equals(other BooleanExpression) bool {
+ rhs, ok := other.(*boundSetPredicate[T])
+ if !ok {
+ return false
+ }
+
+ return bsp.op == rhs.op && bsp.term.Equals(rhs.term) &&
+ bsp.lits.Equals(rhs.lits)
+}
+
+func (bsp *boundSetPredicate[T]) Op() Operation { return bsp.op }
+func (bsp *boundSetPredicate[T]) Negate() BooleanExpression {
+ return &boundSetPredicate[T]{op: bsp.op.Negate(), term: bsp.term,
+ lits: bsp.lits}
+}
+func (bsp *boundSetPredicate[T]) Term() BoundTerm { return bsp.term }
+func (bsp *boundSetPredicate[T]) Ref() BoundReference { return bsp.term.Ref() }
+func (bsp *boundSetPredicate[T]) String() string {
+ return fmt.Sprintf("Bound%s(term=%s, {%v})", bsp.op, bsp.term,
bsp.lits.Members())
+}
+func (bsp *boundSetPredicate[T]) AsUnbound(r Reference, lits []Literal)
UnboundPredicate {
+ return &unboundSetPredicate{op: bsp.op, term: r, lits:
newLiteralSet(lits...)}
+}
+func (bsp *boundSetPredicate[T]) Literals() Set[Literal] {
+ return bsp.lits
+}
diff --git a/exprs_test.go b/exprs_test.go
new file mode 100644
index 0000000..3ea5257
--- /dev/null
+++ b/exprs_test.go
@@ -0,0 +1,742 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package iceberg_test
+
+import (
+ "math"
+ "strconv"
+ "testing"
+
+ "github.com/apache/iceberg-go"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+type ExprA struct{}
+
+func (ExprA) String() string { return "ExprA" }
+func (ExprA) Op() iceberg.Operation { return iceberg.OpFalse }
+func (ExprA) Negate() iceberg.BooleanExpression { return ExprB{} }
+func (ExprA) Equals(o iceberg.BooleanExpression) bool {
+ _, ok := o.(ExprA)
+ return ok
+}
+
+type ExprB struct{}
+
+func (ExprB) String() string { return "ExprB" }
+func (ExprB) Op() iceberg.Operation { return iceberg.OpTrue }
+func (ExprB) Negate() iceberg.BooleanExpression { return ExprA{} }
+func (ExprB) Equals(o iceberg.BooleanExpression) bool {
+ _, ok := o.(ExprB)
+ return ok
+}
+
+func TestUnaryExpr(t *testing.T) {
+ assert.PanicsWithError(t, "invalid argument: invalid operation for
unary predicate: LessThan", func() {
+ iceberg.UnaryPredicate(iceberg.OpLT, iceberg.Reference("a"))
+ })
+
+ assert.PanicsWithError(t, "invalid argument: cannot create unary
predicate with nil term", func() {
+ iceberg.UnaryPredicate(iceberg.OpIsNull, nil)
+ })
+
+ t.Run("negate", func(t *testing.T) {
+ n := iceberg.IsNull(iceberg.Reference("a")).Negate()
+ exp := iceberg.NotNull(iceberg.Reference("a"))
+
+ assert.Equal(t, exp, n)
+ assert.True(t, exp.Equals(n))
+ assert.True(t, n.Equals(exp))
+ })
+
+ sc := iceberg.NewSchema(1, iceberg.NestedField{
+ ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Int32})
+ sc2 := iceberg.NewSchema(1, iceberg.NestedField{
+ ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Float64})
+ sc3 := iceberg.NewSchema(1, iceberg.NestedField{
+ ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Int32, Required:
true})
+ sc4 := iceberg.NewSchema(1, iceberg.NestedField{
+ ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Float32,
Required: true})
+
+ t.Run("isnull and notnull", func(t *testing.T) {
+ t.Run("bind", func(t *testing.T) {
+ n, err :=
iceberg.IsNull(iceberg.Reference("a")).Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpIsNull, n.Op())
+ assert.Implements(t,
(*iceberg.BoundUnaryPredicate)(nil), n)
+ p := n.(iceberg.BoundUnaryPredicate)
+ assert.IsType(t, iceberg.PrimitiveTypes.Int32,
p.Term().Type())
+ assert.Same(t, p.Ref(), p.Term().Ref())
+ assert.Same(t, p.Ref(), p.Ref().Ref())
+
+ f := p.Ref().Field()
+ assert.True(t, f.Equals(sc.Field(0)))
+ })
+
+ t.Run("negate and bind", func(t *testing.T) {
+ n1, err :=
iceberg.IsNull(iceberg.Reference("a")).Bind(sc, true)
+ require.NoError(t, err)
+
+ n2, err :=
iceberg.NotNull(iceberg.Reference("a")).Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.True(t, n1.Negate().Equals(n2))
+ assert.True(t, n2.Negate().Equals(n1))
+ })
+
+ t.Run("null bind required", func(t *testing.T) {
+ n1, err :=
iceberg.IsNull(iceberg.Reference("a")).Bind(sc3, true)
+ require.NoError(t, err)
+
+ n2, err :=
iceberg.NotNull(iceberg.Reference("a")).Bind(sc3, true)
+ require.NoError(t, err)
+
+ assert.True(t, n1.Equals(iceberg.AlwaysFalse{}))
+ assert.True(t, n2.Equals(iceberg.AlwaysTrue{}))
+ })
+ })
+
+ t.Run("isnan notnan", func(t *testing.T) {
+ t.Run("negate and bind", func(t *testing.T) {
+ n1, err :=
iceberg.IsNaN(iceberg.Reference("a")).Bind(sc2, true)
+ require.NoError(t, err)
+
+ n2, err :=
iceberg.NotNaN(iceberg.Reference("a")).Bind(sc2, true)
+ require.NoError(t, err)
+
+ assert.True(t, n1.Negate().Equals(n2))
+ assert.True(t, n2.Negate().Equals(n1))
+ })
+
+ t.Run("bind float", func(t *testing.T) {
+ n, err :=
iceberg.IsNaN(iceberg.Reference("a")).Bind(sc4, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpIsNan, n.Op())
+ assert.Implements(t,
(*iceberg.BoundUnaryPredicate)(nil), n)
+ p := n.(iceberg.BoundUnaryPredicate)
+ assert.IsType(t, iceberg.PrimitiveTypes.Float32,
p.Term().Type())
+
+ n2, err :=
iceberg.NotNaN(iceberg.Reference("a")).Bind(sc4, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpNotNan, n2.Op())
+ assert.Implements(t,
(*iceberg.BoundUnaryPredicate)(nil), n2)
+ p2 := n2.(iceberg.BoundUnaryPredicate)
+ assert.IsType(t, iceberg.PrimitiveTypes.Float32,
p2.Term().Type())
+ })
+
+ t.Run("bind double", func(t *testing.T) {
+ n, err :=
iceberg.IsNaN(iceberg.Reference("a")).Bind(sc2, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpIsNan, n.Op())
+ assert.Implements(t,
(*iceberg.BoundUnaryPredicate)(nil), n)
+ p := n.(iceberg.BoundUnaryPredicate)
+ assert.IsType(t, iceberg.PrimitiveTypes.Float64,
p.Term().Type())
+
+ n2, err :=
iceberg.NotNaN(iceberg.Reference("a")).Bind(sc2, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpNotNan, n2.Op())
+ assert.Implements(t,
(*iceberg.BoundUnaryPredicate)(nil), n2)
+ p2 := n2.(iceberg.BoundUnaryPredicate)
+ assert.IsType(t, iceberg.PrimitiveTypes.Float64,
p2.Term().Type())
+ })
+
+ t.Run("bind non floating", func(t *testing.T) {
+ n1, err :=
iceberg.IsNaN(iceberg.Reference("a")).Bind(sc, true)
+ require.NoError(t, err)
+
+ n2, err :=
iceberg.NotNaN(iceberg.Reference("a")).Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.True(t, n1.Equals(iceberg.AlwaysFalse{}))
+ assert.True(t, n2.Equals(iceberg.AlwaysTrue{}))
+ })
+ })
+}
+
+func TestRefBindingCaseSensitive(t *testing.T) {
+ ref1, ref2 := iceberg.Reference("foo"), iceberg.Reference("Foo")
+
+ bound1, err := ref1.Bind(tableSchemaSimple, true)
+ require.NoError(t, err)
+ assert.True(t, bound1.Type().Equals(iceberg.PrimitiveTypes.String))
+
+ _, err = ref2.Bind(tableSchemaSimple, true)
+ assert.ErrorIs(t, err, iceberg.ErrInvalidSchema)
+ assert.ErrorContains(t, err, "could not bind reference 'Foo',
caseSensitive=true")
+
+ bound2, err := ref2.Bind(tableSchemaSimple, false)
+ require.NoError(t, err)
+ assert.True(t, bound1.Equals(bound2))
+
+ _, err = iceberg.Reference("foot").Bind(tableSchemaSimple, false)
+ assert.ErrorIs(t, err, iceberg.ErrInvalidSchema)
+ assert.ErrorContains(t, err, "could not bind reference 'foot',
caseSensitive=false")
+}
+
+func TestRefTypes(t *testing.T) {
+ sc := iceberg.NewSchema(1,
+ iceberg.NestedField{ID: 1, Name: "a", Type:
iceberg.PrimitiveTypes.Bool},
+ iceberg.NestedField{ID: 2, Name: "b", Type:
iceberg.PrimitiveTypes.Int32},
+ iceberg.NestedField{ID: 3, Name: "c", Type:
iceberg.PrimitiveTypes.Int64},
+ iceberg.NestedField{ID: 4, Name: "d", Type:
iceberg.PrimitiveTypes.Float32},
+ iceberg.NestedField{ID: 5, Name: "e", Type:
iceberg.PrimitiveTypes.Float64},
+ iceberg.NestedField{ID: 6, Name: "f", Type:
iceberg.PrimitiveTypes.Date},
+ iceberg.NestedField{ID: 7, Name: "g", Type:
iceberg.PrimitiveTypes.Time},
+ iceberg.NestedField{ID: 8, Name: "h", Type:
iceberg.PrimitiveTypes.Timestamp},
+ iceberg.NestedField{ID: 9, Name: "i", Type:
iceberg.DecimalTypeOf(9, 2)},
+ iceberg.NestedField{ID: 10, Name: "j", Type:
iceberg.PrimitiveTypes.String},
+ iceberg.NestedField{ID: 11, Name: "k", Type:
iceberg.PrimitiveTypes.Binary},
+ iceberg.NestedField{ID: 12, Name: "l", Type:
iceberg.PrimitiveTypes.UUID},
+ iceberg.NestedField{ID: 13, Name: "m", Type:
iceberg.FixedTypeOf(5)})
+
+ t.Run("bind term", func(t *testing.T) {
+ for i := 0; i < sc.NumFields(); i++ {
+ fld := sc.Field(i)
+ t.Run(fld.Type.String(), func(t *testing.T) {
+ ref, err :=
iceberg.Reference(fld.Name).Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.True(t, ref.Type().Equals(fld.Type))
+ assert.True(t, fld.Equals(ref.Ref().Field()))
+ })
+ }
+ })
+
+ t.Run("bind unary", func(t *testing.T) {
+ for i := 0; i < sc.NumFields(); i++ {
+ fld := sc.Field(i)
+ t.Run(fld.Type.String(), func(t *testing.T) {
+ b, err :=
iceberg.IsNull(iceberg.Reference(fld.Name)).Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.True(t,
b.(iceberg.BoundUnaryPredicate).Ref().Type().Equals(fld.Type))
+
+ un :=
b.(iceberg.BoundUnaryPredicate).AsUnbound(iceberg.Reference("foo"))
+ assert.Equal(t, b.Op(), un.Op())
+ })
+ }
+ })
+
+ t.Run("bind literal", func(t *testing.T) {
+ t.Run("bool", func(t *testing.T) {
+ b1, err := iceberg.EqualTo(iceberg.Reference("a"),
true).Bind(sc, true)
+ require.NoError(t, err)
+ assert.Equal(t, iceberg.OpEQ, b1.Op())
+ assert.True(t,
b1.(iceberg.BoundLiteralPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.Bool))
+ })
+
+ for i := 1; i < 9; i++ {
+ fld := sc.Field(i)
+ t.Run(fld.Type.String(), func(t *testing.T) {
+ b, err :=
iceberg.EqualTo(iceberg.Reference(fld.Name), int32(5)).Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpEQ, b.Op())
+ assert.True(t,
b.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(fld.Type))
+ assert.True(t,
b.(iceberg.BoundLiteralPredicate).Ref().Type().Equals(fld.Type))
+ })
+ }
+
+ t.Run("string-binary", func(t *testing.T) {
+ str, err := iceberg.EqualTo(iceberg.Reference("j"),
"foobar").Bind(sc, true)
+ require.NoError(t, err)
+
+ bin, err := iceberg.EqualTo(iceberg.Reference("k"),
[]byte("foobar")).Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpEQ, str.Op())
+ assert.True(t,
str.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.PrimitiveTypes.String))
+ assert.Equal(t, iceberg.OpEQ, bin.Op())
+ assert.True(t,
bin.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.PrimitiveTypes.Binary))
+ })
+
+ t.Run("fixed", func(t *testing.T) {
+ fx, err := iceberg.EqualTo(iceberg.Reference("m"),
[]byte{0, 1, 2, 3, 4}).Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpEQ, fx.Op())
+ assert.True(t,
fx.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.FixedTypeOf(5)))
+ })
+
+ t.Run("uuid", func(t *testing.T) {
+ uid, err := iceberg.EqualTo(iceberg.Reference("l"),
uuid.New().String()).Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpEQ, uid.Op())
+ assert.True(t,
uid.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.PrimitiveTypes.UUID))
+ })
+ })
+
+ t.Run("bind set", func(t *testing.T) {
+ t.Run("bool", func(t *testing.T) {
+ b, err := iceberg.IsIn(iceberg.Reference("a"), true,
false).(iceberg.UnboundPredicate).Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpIn, b.Op())
+ })
+
+ for i := 1; i < 9; i++ {
+ fld := sc.Field(i)
+ t.Run(fld.Type.String(), func(t *testing.T) {
+ b, err :=
iceberg.IsIn(iceberg.Reference(fld.Name), int32(10), int32(5),
int32(5)).(iceberg.UnboundPredicate).
+ Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpIn, b.Op())
+ assert.True(t,
b.(iceberg.BoundSetPredicate).Ref().Type().Equals(fld.Type))
+ for _, v := range
b.(iceberg.BoundSetPredicate).Literals().Members() {
+ assert.True(t,
v.Type().Equals(fld.Type))
+ }
+ })
+ }
+
+ t.Run("string-binary", func(t *testing.T) {
+ str, err := iceberg.IsIn(iceberg.Reference("j"),
"hello", "foobar").(iceberg.UnboundPredicate).
+ Bind(sc, true)
+ require.NoError(t, err)
+
+ bin, err := iceberg.IsIn(iceberg.Reference("k"),
[]byte("baz"), []byte("foobar")).(iceberg.UnboundPredicate).
+ Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpIn, str.Op())
+ assert.Equal(t, iceberg.OpIn, bin.Op())
+
+ assert.True(t,
str.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.String))
+ for _, v := range
str.(iceberg.BoundSetPredicate).Literals().Members() {
+ assert.True(t,
v.Type().Equals(iceberg.PrimitiveTypes.String))
+ }
+
+ assert.True(t,
bin.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.Binary))
+ for _, v := range
bin.(iceberg.BoundSetPredicate).Literals().Members() {
+ assert.True(t,
v.Type().Equals(iceberg.PrimitiveTypes.Binary))
+ }
+ })
+
+ t.Run("fixed", func(t *testing.T) {
+ fx, err := iceberg.IsIn(iceberg.Reference("m"),
[]byte{4, 5, 6, 7, 8}, []byte{0, 1, 2, 3, 4}).(iceberg.UnboundPredicate).
+ Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpIn, fx.Op())
+ assert.True(t,
fx.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.FixedTypeOf(5)))
+ for _, v := range
fx.(iceberg.BoundSetPredicate).Literals().Members() {
+ assert.True(t,
v.Type().Equals(iceberg.FixedTypeOf(5)))
+ }
+ })
+
+ t.Run("uuid", func(t *testing.T) {
+ uid, err := iceberg.IsIn(iceberg.Reference("l"),
uuid.New().String(), uuid.New().String()).(iceberg.UnboundPredicate).
+ Bind(sc, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpIn, uid.Op())
+ assert.True(t,
uid.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.UUID))
+ for _, v := range
uid.(iceberg.BoundSetPredicate).Literals().Members() {
+ assert.True(t,
v.Type().Equals(iceberg.PrimitiveTypes.UUID))
+ }
+ })
+ })
+}
+
+func TestInNotInSimplifications(t *testing.T) {
+ assert.PanicsWithError(t, "invalid argument: invalid operation for
SetPredicate: LessThan",
+ func() { iceberg.SetPredicate(iceberg.OpLT,
iceberg.Reference("x"), nil) })
+ assert.PanicsWithError(t, "invalid argument: cannot create set
predicate with nil term",
+ func() { iceberg.SetPredicate(iceberg.OpIn, nil, nil) })
+ assert.NotPanics(t, func() { iceberg.SetPredicate(iceberg.OpIn,
iceberg.Reference("x"), nil) })
+
+ t.Run("in to eq", func(t *testing.T) {
+ a := iceberg.IsIn(iceberg.Reference("x"), 34.56)
+ b := iceberg.EqualTo(iceberg.Reference("x"), 34.56)
+ assert.True(t, a.Equals(b))
+ })
+
+ t.Run("notin to notequal", func(t *testing.T) {
+ a := iceberg.NotIn(iceberg.Reference("x"), 34.56)
+ b := iceberg.NotEqualTo(iceberg.Reference("x"), 34.56)
+ assert.True(t, a.Equals(b))
+ })
+
+ t.Run("empty", func(t *testing.T) {
+ a := iceberg.IsIn[float32](iceberg.Reference("x"))
+ b := iceberg.NotIn[float32](iceberg.Reference("x"))
+
+ assert.Equal(t, iceberg.AlwaysFalse{}, a)
+ assert.Equal(t, iceberg.AlwaysTrue{}, b)
+ })
+
+ t.Run("bind and negate", func(t *testing.T) {
+ inexp := iceberg.IsIn(iceberg.Reference("foo"), "hello",
"world")
+ notin := iceberg.NotIn(iceberg.Reference("foo"), "hello",
"world")
+ assert.True(t, inexp.Negate().Equals(notin))
+ assert.True(t, notin.Negate().Equals(inexp))
+ assert.Equal(t, iceberg.OpIn, inexp.Op())
+ assert.Equal(t, iceberg.OpNotIn, notin.Op())
+
+ boundin, err :=
inexp.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true)
+ require.NoError(t, err)
+
+ boundnot, err :=
notin.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true)
+ require.NoError(t, err)
+
+ assert.True(t, boundin.Negate().Equals(boundnot))
+ assert.True(t, boundnot.Negate().Equals(boundin))
+ })
+
+ t.Run("bind dedup", func(t *testing.T) {
+ isin := iceberg.IsIn(iceberg.Reference("foo"), "hello",
"world", "world")
+ bound, err :=
isin.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true)
+ require.NoError(t, err)
+
+ assert.Implements(t, (*iceberg.BoundSetPredicate)(nil), bound)
+ bsp := bound.(iceberg.BoundSetPredicate)
+ assert.Equal(t, 2, bsp.Literals().Len())
+ assert.True(t,
bsp.Literals().Contains(iceberg.NewLiteral("hello")))
+ assert.True(t,
bsp.Literals().Contains(iceberg.NewLiteral("world")))
+ })
+
+ t.Run("bind dedup to eq", func(t *testing.T) {
+ isin := iceberg.IsIn(iceberg.Reference("foo"), "world", "world")
+ bound, err :=
isin.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, iceberg.OpEQ, bound.Op())
+ assert.Equal(t, iceberg.NewLiteral("world"),
+ bound.(iceberg.BoundLiteralPredicate).Literal())
+ })
+}
+
+func TestLiteralPredicateErrors(t *testing.T) {
+ assert.PanicsWithError(t, "invalid argument: invalid operation for
LiteralPredicate: In",
+ func() { iceberg.LiteralPredicate(iceberg.OpIn,
iceberg.Reference("foo"), iceberg.NewLiteral("hello")) })
+ assert.PanicsWithError(t, "invalid argument: cannot create literal
predicate with nil term",
+ func() { iceberg.LiteralPredicate(iceberg.OpLT, nil,
iceberg.NewLiteral("hello")) })
+ assert.PanicsWithError(t, "invalid argument: cannot create literal
predicate with nil literal",
+ func() { iceberg.LiteralPredicate(iceberg.OpLT,
iceberg.Reference("foo"), nil) })
+}
+
+func TestNegations(t *testing.T) {
+ ref := iceberg.Reference("foo")
+
+ tests := []struct {
+ name string
+ ex1, ex2 iceberg.UnboundPredicate
+ }{
+ {"equal-not", iceberg.EqualTo(ref, "hello"),
iceberg.NotEqualTo(ref, "hello")},
+ {"greater-equal-less", iceberg.GreaterThanEqual(ref, "hello"),
iceberg.LessThan(ref, "hello")},
+ {"greater-less-equal", iceberg.GreaterThan(ref, "hello"),
iceberg.LessThanEqual(ref, "hello")},
+ {"starts-with", iceberg.StartsWith(ref, "hello"),
iceberg.NotStartsWith(ref, "hello")},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.False(t, tt.ex1.Equals(tt.ex2))
+ assert.False(t, tt.ex2.Equals(tt.ex1))
+ assert.True(t, tt.ex1.Negate().Equals(tt.ex2))
+ assert.True(t, tt.ex2.Negate().Equals(tt.ex1))
+
+ b1, err := tt.ex1.Bind(tableSchemaSimple, true)
+ require.NoError(t, err)
+ b2, err := tt.ex2.Bind(tableSchemaSimple, true)
+ require.NoError(t, err)
+
+ assert.False(t, b1.Equals(b2))
+ assert.False(t, b2.Equals(b1))
+ assert.True(t, b1.Negate().Equals(b2))
+ assert.True(t, b2.Negate().Equals(b1))
+ })
+ }
+}
+
+func TestBoolExprEQ(t *testing.T) {
+ tests := []struct {
+ exp, testexpra, testexprb iceberg.BooleanExpression
+ }{
+ {iceberg.NewAnd(ExprA{}, ExprB{}),
+ iceberg.NewAnd(ExprA{}, ExprB{}),
+ iceberg.NewOr(ExprA{}, ExprB{})},
+ {iceberg.NewOr(ExprA{}, ExprB{}),
+ iceberg.NewOr(ExprA{}, ExprB{}),
+ iceberg.NewAnd(ExprA{}, ExprB{})},
+ {iceberg.NewAnd(ExprA{}, ExprB{}),
+ iceberg.NewAnd(ExprB{}, ExprA{}),
+ iceberg.NewOr(ExprB{}, ExprA{})},
+ {iceberg.NewOr(ExprA{}, ExprB{}),
+ iceberg.NewOr(ExprB{}, ExprA{}),
+ iceberg.NewAnd(ExprB{}, ExprA{})},
+ {iceberg.NewNot(ExprA{}), iceberg.NewNot(ExprA{}), ExprB{}},
+ {ExprA{}, ExprA{}, ExprB{}},
+ {ExprB{}, ExprB{}, ExprA{}},
+ {iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"),
+ iceberg.IsIn(iceberg.Reference("foo"), "hello",
"world"),
+ iceberg.IsIn(iceberg.Reference("not_foo"), "hello",
"world")},
+ {iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"),
+ iceberg.IsIn(iceberg.Reference("foo"), "hello",
"world"),
+ iceberg.IsIn(iceberg.Reference("foo"), "goodbye",
"world")},
+ }
+
+ for i, tt := range tests {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ assert.True(t, tt.exp.Equals(tt.testexpra))
+ assert.False(t, tt.exp.Equals(tt.testexprb))
+ })
+ }
+}
+
+func TestBoolExprNegate(t *testing.T) {
+ tests := []struct {
+ lhs, rhs iceberg.BooleanExpression
+ }{
+ {iceberg.NewAnd(ExprA{}, ExprB{}), iceberg.NewOr(ExprB{},
ExprA{})},
+ {iceberg.NewOr(ExprB{}, ExprA{}), iceberg.NewAnd(ExprA{},
ExprB{})},
+ {iceberg.NewNot(ExprA{}), ExprA{}},
+ {iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"),
+ iceberg.NotIn(iceberg.Reference("foo"), "hello",
"world")},
+ {iceberg.NotIn(iceberg.Reference("foo"), "hello", "world"),
+ iceberg.IsIn(iceberg.Reference("foo"), "hello",
"world")},
+ {iceberg.GreaterThan(iceberg.Reference("foo"), int32(5)),
+ iceberg.LessThanEqual(iceberg.Reference("foo"),
int32(5))},
+ {iceberg.LessThan(iceberg.Reference("foo"), int32(5)),
+ iceberg.GreaterThanEqual(iceberg.Reference("foo"),
int32(5))},
+ {iceberg.EqualTo(iceberg.Reference("foo"), int32(5)),
+ iceberg.NotEqualTo(iceberg.Reference("foo"), int32(5))},
+ {ExprA{}, ExprB{}},
+ }
+
+ for _, tt := range tests {
+ assert.True(t, tt.lhs.Negate().Equals(tt.rhs))
+ }
+}
+
+func TestBoolExprPanics(t *testing.T) {
+ assert.PanicsWithError(t, "invalid argument: cannot construct AndExpr
with nil arguments",
+ func() { iceberg.NewAnd(nil, ExprA{}) })
+ assert.PanicsWithError(t, "invalid argument: cannot construct AndExpr
with nil arguments",
+ func() { iceberg.NewAnd(ExprA{}, nil) })
+ assert.PanicsWithError(t, "invalid argument: cannot construct AndExpr
with nil arguments",
+ func() { iceberg.NewAnd(ExprA{}, ExprA{}, nil) })
+
+ assert.PanicsWithError(t, "invalid argument: cannot construct OrExpr
with nil arguments",
+ func() { iceberg.NewOr(nil, ExprA{}) })
+ assert.PanicsWithError(t, "invalid argument: cannot construct OrExpr
with nil arguments",
+ func() { iceberg.NewOr(ExprA{}, nil) })
+ assert.PanicsWithError(t, "invalid argument: cannot construct OrExpr
with nil arguments",
+ func() { iceberg.NewOr(ExprA{}, ExprA{}, nil) })
+
+ assert.PanicsWithError(t, "invalid argument: cannot create NotExpr with
nil child",
+ func() { iceberg.NewNot(nil) })
+}
+
+func TestExprFolding(t *testing.T) {
+ tests := []struct {
+ lhs, rhs iceberg.BooleanExpression
+ }{
+ {iceberg.NewAnd(ExprA{}, ExprB{}, ExprA{}),
+ iceberg.NewAnd(iceberg.NewAnd(ExprA{}, ExprB{}),
ExprA{})},
+ {iceberg.NewOr(ExprA{}, ExprB{}, ExprA{}),
+ iceberg.NewOr(iceberg.NewOr(ExprA{}, ExprB{}),
ExprA{})},
+ {iceberg.NewNot(iceberg.NewNot(ExprA{})), ExprA{}},
+ }
+
+ for _, tt := range tests {
+ assert.True(t, tt.lhs.Equals(tt.rhs))
+ }
+}
+
+func TestBaseAlwaysTrueAlwaysFalse(t *testing.T) {
+ tests := []struct {
+ lhs, rhs iceberg.BooleanExpression
+ }{
+ {iceberg.NewAnd(iceberg.AlwaysTrue{}, ExprB{}), ExprB{}},
+ {iceberg.NewAnd(iceberg.AlwaysFalse{}, ExprB{}),
iceberg.AlwaysFalse{}},
+ {iceberg.NewAnd(ExprB{}, iceberg.AlwaysTrue{}), ExprB{}},
+ {iceberg.NewOr(iceberg.AlwaysTrue{}, ExprB{}),
iceberg.AlwaysTrue{}},
+ {iceberg.NewOr(iceberg.AlwaysFalse{}, ExprB{}), ExprB{}},
+ {iceberg.NewOr(ExprA{}, iceberg.AlwaysFalse{}), ExprA{}},
+ {iceberg.NewNot(iceberg.NewNot(ExprA{})), ExprA{}},
+ {iceberg.NewNot(iceberg.AlwaysTrue{}), iceberg.AlwaysFalse{}},
+ {iceberg.NewNot(iceberg.AlwaysFalse{}), iceberg.AlwaysTrue{}},
+ }
+
+ for _, tt := range tests {
+ assert.True(t, tt.lhs.Equals(tt.rhs))
+ }
+}
+
+func TestNegateAlways(t *testing.T) {
+ assert.Equal(t, iceberg.OpTrue, iceberg.AlwaysTrue{}.Op())
+ assert.Equal(t, iceberg.OpFalse, iceberg.AlwaysFalse{}.Op())
+
+ assert.Equal(t, iceberg.AlwaysTrue{}, iceberg.AlwaysFalse{}.Negate())
+ assert.Equal(t, iceberg.AlwaysFalse{}, iceberg.AlwaysTrue{}.Negate())
+}
+
+func TestBoundReferenceToString(t *testing.T) {
+ ref, err := iceberg.Reference("foo").Bind(tableSchemaSimple, true)
+ require.NoError(t, err)
+
+ assert.Equal(t, "BoundReference(field=1: foo: optional string,
accessor=Accessor(position=0, inner=<nil>))",
+ ref.String())
+}
+
+func TestToString(t *testing.T) {
+ schema := iceberg.NewSchema(1,
+ iceberg.NestedField{ID: 1, Name: "a", Type:
iceberg.PrimitiveTypes.String},
+ iceberg.NestedField{ID: 2, Name: "b", Type:
iceberg.PrimitiveTypes.String},
+ iceberg.NestedField{ID: 3, Name: "c", Type:
iceberg.PrimitiveTypes.String},
+ iceberg.NestedField{ID: 4, Name: "d", Type:
iceberg.PrimitiveTypes.Int32},
+ iceberg.NestedField{ID: 5, Name: "e", Type:
iceberg.PrimitiveTypes.Int32},
+ iceberg.NestedField{ID: 6, Name: "f", Type:
iceberg.PrimitiveTypes.Int32},
+ iceberg.NestedField{ID: 7, Name: "g", Type:
iceberg.PrimitiveTypes.Float32},
+ iceberg.NestedField{ID: 8, Name: "h", Type:
iceberg.DecimalTypeOf(8, 4)},
+ iceberg.NestedField{ID: 9, Name: "i", Type:
iceberg.PrimitiveTypes.UUID},
+ iceberg.NestedField{ID: 10, Name: "j", Type:
iceberg.PrimitiveTypes.Bool},
+ iceberg.NestedField{ID: 11, Name: "k", Type:
iceberg.PrimitiveTypes.Bool},
+ iceberg.NestedField{ID: 12, Name: "l", Type:
iceberg.PrimitiveTypes.Binary})
+
+ null := iceberg.IsNull(iceberg.Reference("a"))
+ nan := iceberg.IsNaN(iceberg.Reference("g"))
+ boundNull, _ := null.Bind(schema, true)
+ boundNan, _ := nan.Bind(schema, true)
+
+ equal := iceberg.EqualTo(iceberg.Reference("c"), "a")
+ grtequal := iceberg.GreaterThanEqual(iceberg.Reference("a"), "a")
+ greater := iceberg.GreaterThan(iceberg.Reference("a"), "a")
+ startsWith := iceberg.StartsWith(iceberg.Reference("b"), "foo")
+
+ boundEqual, _ := equal.Bind(schema, true)
+ boundGrtEqual, _ := grtequal.Bind(schema, true)
+ boundGreater, _ := greater.Bind(schema, true)
+ boundStarts, _ := startsWith.Bind(schema, true)
+
+ tests := []struct {
+ e iceberg.BooleanExpression
+ expected string
+ }{
+ {iceberg.NewAnd(null, nan),
+ "And(left=IsNull(term=Reference(name='a')),
right=IsNaN(term=Reference(name='g')))"},
+ {iceberg.NewOr(null, nan),
+ "Or(left=IsNull(term=Reference(name='a')),
right=IsNaN(term=Reference(name='g')))"},
+ {iceberg.NewNot(null),
+ "Not(child=IsNull(term=Reference(name='a')))"},
+ {iceberg.AlwaysTrue{}, "AlwaysTrue()"},
+ {iceberg.AlwaysFalse{}, "AlwaysFalse()"},
+ {boundNull,
+ "BoundIsNull(term=BoundReference(field=1: a: optional
string, accessor=Accessor(position=0, inner=<nil>)))"},
+ {boundNull.Negate(),
+ "BoundNotNull(term=BoundReference(field=1: a: optional
string, accessor=Accessor(position=0, inner=<nil>)))"},
+ {boundNan,
+ "BoundIsNaN(term=BoundReference(field=7: g: optional
float, accessor=Accessor(position=6, inner=<nil>)))"},
+ {boundNan.Negate(),
+ "BoundNotNaN(term=BoundReference(field=7: g: optional
float, accessor=Accessor(position=6, inner=<nil>)))"},
+ {equal,
+ "Equal(term=Reference(name='c'), literal=a)"},
+ {equal.Negate(),
+ "NotEqual(term=Reference(name='c'), literal=a)"},
+ {grtequal,
+ "GreaterThanEqual(term=Reference(name='a'),
literal=a)"},
+ {grtequal.Negate(),
+ "LessThan(term=Reference(name='a'), literal=a)"},
+ {greater,
+ "GreaterThan(term=Reference(name='a'), literal=a)"},
+ {greater.Negate(),
+ "LessThanEqual(term=Reference(name='a'), literal=a)"},
+ {startsWith,
+ "StartsWith(term=Reference(name='b'), literal=foo)"},
+ {startsWith.Negate(),
+ "NotStartsWith(term=Reference(name='b'), literal=foo)"},
+ {boundEqual,
+ "BoundEqual(term=BoundReference(field=3: c: optional
string, accessor=Accessor(position=2, inner=<nil>)), literal=a)"},
+ {boundEqual.Negate(),
+ "BoundNotEqual(term=BoundReference(field=3: c: optional
string, accessor=Accessor(position=2, inner=<nil>)), literal=a)"},
+ {boundGreater,
+ "BoundGreaterThan(term=BoundReference(field=1: a:
optional string, accessor=Accessor(position=0, inner=<nil>)), literal=a)"},
+ {boundGreater.Negate(),
+ "BoundLessThanEqual(term=BoundReference(field=1: a:
optional string, accessor=Accessor(position=0, inner=<nil>)), literal=a)"},
+ {boundGrtEqual,
+ "BoundGreaterThanEqual(term=BoundReference(field=1: a:
optional string, accessor=Accessor(position=0, inner=<nil>)), literal=a)"},
+ {boundGrtEqual.Negate(),
+ "BoundLessThan(term=BoundReference(field=1: a: optional
string, accessor=Accessor(position=0, inner=<nil>)), literal=a)"},
+ {boundStarts,
+ "BoundStartsWith(term=BoundReference(field=2: b:
optional string, accessor=Accessor(position=1, inner=<nil>)), literal=foo)"},
+ {boundStarts.Negate(),
+ "BoundNotStartsWith(term=BoundReference(field=2: b:
optional string, accessor=Accessor(position=1, inner=<nil>)), literal=foo)"},
+ }
+
+ for _, tt := range tests {
+ assert.Equal(t, tt.expected, tt.e.String())
+ }
+}
+
+func TestBindAboveBelowIntMax(t *testing.T) {
+ sc := iceberg.NewSchema(1,
+ iceberg.NestedField{ID: 1, Name: "a", Type:
iceberg.PrimitiveTypes.Int32},
+ iceberg.NestedField{ID: 2, Name: "b", Type:
iceberg.PrimitiveTypes.Float32},
+ )
+
+ ref, ref2 := iceberg.Reference("a"), iceberg.Reference("b")
+ above, below := int64(math.MaxInt32)+1, int64(math.MinInt32)-1
+ above2, below2 := float64(math.MaxFloat32)+1e37,
float64(-math.MaxFloat32)-1e37
+
+ tests := []struct {
+ pred iceberg.UnboundPredicate
+ exp iceberg.BooleanExpression
+ }{
+ {iceberg.EqualTo(ref, above), iceberg.AlwaysFalse{}},
+ {iceberg.EqualTo(ref, below), iceberg.AlwaysFalse{}},
+ {iceberg.NotEqualTo(ref, above), iceberg.AlwaysTrue{}},
+ {iceberg.NotEqualTo(ref, below), iceberg.AlwaysTrue{}},
+ {iceberg.LessThan(ref, above), iceberg.AlwaysTrue{}},
+ {iceberg.LessThan(ref, below), iceberg.AlwaysFalse{}},
+ {iceberg.LessThanEqual(ref, above), iceberg.AlwaysTrue{}},
+ {iceberg.LessThanEqual(ref, below), iceberg.AlwaysFalse{}},
+ {iceberg.GreaterThan(ref, above), iceberg.AlwaysFalse{}},
+ {iceberg.GreaterThan(ref, below), iceberg.AlwaysTrue{}},
+ {iceberg.GreaterThanEqual(ref, above), iceberg.AlwaysFalse{}},
+ {iceberg.GreaterThanEqual(ref, below), iceberg.AlwaysTrue{}},
+
+ {iceberg.EqualTo(ref2, above2), iceberg.AlwaysFalse{}},
+ {iceberg.EqualTo(ref2, below2), iceberg.AlwaysFalse{}},
+ {iceberg.NotEqualTo(ref2, above2), iceberg.AlwaysTrue{}},
+ {iceberg.NotEqualTo(ref2, below2), iceberg.AlwaysTrue{}},
+ {iceberg.LessThan(ref2, above2), iceberg.AlwaysTrue{}},
+ {iceberg.LessThan(ref2, below2), iceberg.AlwaysFalse{}},
+ {iceberg.LessThanEqual(ref2, above2), iceberg.AlwaysTrue{}},
+ {iceberg.LessThanEqual(ref2, below2), iceberg.AlwaysFalse{}},
+ {iceberg.GreaterThan(ref2, above2), iceberg.AlwaysFalse{}},
+ {iceberg.GreaterThan(ref2, below2), iceberg.AlwaysTrue{}},
+ {iceberg.GreaterThanEqual(ref2, above2), iceberg.AlwaysFalse{}},
+ {iceberg.GreaterThanEqual(ref2, below2), iceberg.AlwaysTrue{}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.pred.String(), func(t *testing.T) {
+ b, err := tt.pred.Bind(sc, true)
+ require.NoError(t, err)
+ assert.Equal(t, tt.exp, b)
+ })
+ }
+}
diff --git a/operation_string.go b/operation_string.go
new file mode 100644
index 0000000..3af65e3
--- /dev/null
+++ b/operation_string.go
@@ -0,0 +1,41 @@
+// Code generated by "stringer -type=Operation -linecomment"; DO NOT EDIT.
+
+package iceberg
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant
values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[OpTrue-0]
+ _ = x[OpFalse-1]
+ _ = x[OpIsNull-2]
+ _ = x[OpNotNull-3]
+ _ = x[OpIsNan-4]
+ _ = x[OpNotNan-5]
+ _ = x[OpLT-6]
+ _ = x[OpLTEQ-7]
+ _ = x[OpGT-8]
+ _ = x[OpGTEQ-9]
+ _ = x[OpEQ-10]
+ _ = x[OpNEQ-11]
+ _ = x[OpStartsWith-12]
+ _ = x[OpNotStartsWith-13]
+ _ = x[OpIn-14]
+ _ = x[OpNotIn-15]
+ _ = x[OpNot-16]
+ _ = x[OpAnd-17]
+ _ = x[OpOr-18]
+}
+
+const _Operation_name =
"TrueFalseIsNullNotNullIsNaNNotNaNLessThanLessThanEqualGreaterThanGreaterThanEqualEqualNotEqualStartsWithNotStartsWithInNotInNotAndOr"
+
+var _Operation_index = [...]uint8{0, 4, 9, 15, 22, 27, 33, 41, 54, 65, 81, 86,
94, 104, 117, 119, 124, 127, 130, 132}
+
+func (i Operation) String() string {
+ if i < 0 || i >= Operation(len(_Operation_index)-1) {
+ return "Operation(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _Operation_name[_Operation_index[i]:_Operation_index[i+1]]
+}
diff --git a/predicates.go b/predicates.go
new file mode 100644
index 0000000..24ace71
--- /dev/null
+++ b/predicates.go
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package iceberg
+
+// IsNull is a convenience wrapper for calling UnaryPredicate(OpIsNull, t)
+//
+// Will panic if t is nil
+func IsNull(t UnboundTerm) UnboundPredicate {
+ return UnaryPredicate(OpIsNull, t)
+}
+
+// NotNull is a convenience wrapper for calling UnaryPredicate(OpNotNull, t)
+//
+// Will panic if t is nil
+func NotNull(t UnboundTerm) UnboundPredicate {
+ return UnaryPredicate(OpNotNull, t)
+}
+
+// IsNaN is a convenience wrapper for calling UnaryPredicate(OpIsNan, t)
+//
+// Will panic if t is nil
+func IsNaN(t UnboundTerm) UnboundPredicate {
+ return UnaryPredicate(OpIsNan, t)
+}
+
+// NotNaN is a convenience wrapper for calling UnaryPredicate(OpNotNan, t)
+//
+// Will panic if t is nil
+func NotNaN(t UnboundTerm) UnboundPredicate {
+ return UnaryPredicate(OpNotNan, t)
+}
+
+// IsIn is a convenience wrapper for constructing an unbound set predicate for
+// OpIn. It returns a BooleanExpression instead of an UnboundPredicate because
+// depending on the arguments, it can automatically reduce to AlwaysFalse or
+// AlwaysTrue (if given no values for examples). It may also reduce to EqualTo
+// if only one value is provided.
+//
+// Will panic if t is nil
+func IsIn[T LiteralType](t UnboundTerm, vals ...T) BooleanExpression {
+ lits := make([]Literal, 0, len(vals))
+ for _, v := range vals {
+ lits = append(lits, NewLiteral(v))
+ }
+ return SetPredicate(OpIn, t, lits)
+}
+
+// NotIn is a convenience wrapper for constructing an unbound set predicate for
+// OpNotIn. It returns a BooleanExpression instead of an UnboundPredicate
because
+// depending on the arguments, it can automatically reduce to AlwaysFalse or
+// AlwaysTrue (if given no values for examples). It may also reduce to
NotEqualTo
+// if only one value is provided.
+//
+// Will panic if t is nil
+func NotIn[T LiteralType](t UnboundTerm, vals ...T) BooleanExpression {
+ lits := make([]Literal, 0, len(vals))
+ for _, v := range vals {
+ lits = append(lits, NewLiteral(v))
+ }
+ return SetPredicate(OpNotIn, t, lits)
+}
+
+// EqualTo is a convenience wrapper for calling LiteralPredicate(OpEQ, t,
NewLiteral(v))
+//
+// Will panic if t is nil
+func EqualTo[T LiteralType](t UnboundTerm, v T) UnboundPredicate {
+ return LiteralPredicate(OpEQ, t, NewLiteral(v))
+}
+
+// NotEqualTo is a convenience wrapper for calling LiteralPredicate(OpNEQ, t,
NewLiteral(v))
+//
+// Will panic if t is nil
+func NotEqualTo[T LiteralType](t UnboundTerm, v T) UnboundPredicate {
+ return LiteralPredicate(OpNEQ, t, NewLiteral(v))
+}
+
+// GreaterThanEqual is a convenience wrapper for calling
LiteralPredicate(OpGTEQ,
+// t, NewLiteral(v))
+//
+// Will panic if t is nil
+func GreaterThanEqual[T LiteralType](t UnboundTerm, v T) UnboundPredicate {
+ return LiteralPredicate(OpGTEQ, t, NewLiteral(v))
+}
+
+// GreaterThan is a convenience wrapper for calling LiteralPredicate(OpGT,
+// t, NewLiteral(v))
+//
+// Will panic if t is nil
+func GreaterThan[T LiteralType](t UnboundTerm, v T) UnboundPredicate {
+ return LiteralPredicate(OpGT, t, NewLiteral(v))
+}
+
+// LessThanEqual is a convenience wrapper for calling LiteralPredicate(OpLTEQ,
+// t, NewLiteral(v))
+//
+// Will panic if t is nil
+func LessThanEqual[T LiteralType](t UnboundTerm, v T) UnboundPredicate {
+ return LiteralPredicate(OpLTEQ, t, NewLiteral(v))
+}
+
+// LessThan is a convenience wrapper for calling LiteralPredicate(OpLT,
+// t, NewLiteral(v))
+//
+// Will panic if t is nil
+func LessThan[T LiteralType](t UnboundTerm, v T) UnboundPredicate {
+ return LiteralPredicate(OpLT, t, NewLiteral(v))
+}
+
+// StartsWith is a convenience wrapper for calling
LiteralPredicate(OpStartsWith,
+// t, NewLiteral(v))
+//
+// Will panic if t is nil
+func StartsWith(t UnboundTerm, v string) UnboundPredicate {
+ return LiteralPredicate(OpStartsWith, t, NewLiteral(v))
+}
+
+// NotStartsWith is a convenience wrapper for calling
LiteralPredicate(OpNotStartsWith,
+// t, NewLiteral(v))
+//
+// Will panic if t is nil
+func NotStartsWith(t UnboundTerm, v string) UnboundPredicate {
+ return LiteralPredicate(OpNotStartsWith, t, NewLiteral(v))
+}
diff --git a/schema.go b/schema.go
index 8edce1f..44fbb0e 100644
--- a/schema.go
+++ b/schema.go
@@ -44,6 +44,7 @@ type Schema struct {
idToField atomic.Pointer[map[int]NestedField]
nameToID atomic.Pointer[map[string]int]
nameToIDLower atomic.Pointer[map[string]int]
+ idToAccessor atomic.Pointer[map[int]accessor]
}
// NewSchema constructs a new schema with the provided ID
@@ -135,6 +136,21 @@ func (s *Schema) lazyNameToIDLower() (map[string]int,
error) {
return out, nil
}
+func (s *Schema) lazyIdToAccessor() (map[int]accessor, error) {
+ index := s.idToAccessor.Load()
+ if index != nil {
+ return *index, nil
+ }
+
+ idx, err := buildAccessors(s)
+ if err != nil {
+ return nil, err
+ }
+
+ s.idToAccessor.Store(&idx)
+ return idx, nil
+}
+
func (s *Schema) Type() string { return "struct" }
// AsStruct returns a Struct with the same fields as the schema which can
@@ -255,6 +271,16 @@ func (s *Schema) FindTypeByNameCaseInsensitive(name
string) (Type, bool) {
return f.Type, true
}
+func (s *Schema) accessorForField(id int) (accessor, bool) {
+ idx, err := s.lazyIdToAccessor()
+ if err != nil {
+ return accessor{}, false
+ }
+
+ acc, ok := idx[id]
+ return acc, ok
+}
+
// Equals compares the fields and identifierIDs, but does not compare
// the schema ID itself.
func (s *Schema) Equals(other *Schema) bool {
@@ -858,3 +884,44 @@ func (findLastFieldID) Map(_ MapType, keyResult,
valueResult int) int {
}
func (findLastFieldID) Primitive(PrimitiveType) int { return 0 }
+
+type buildPosAccessors struct{}
+
+func (buildPosAccessors) Schema(_ *Schema, structResult map[int]accessor)
map[int]accessor {
+ return structResult
+}
+
+func (buildPosAccessors) Struct(st StructType, fieldResults
[]map[int]accessor) map[int]accessor {
+ result := map[int]accessor{}
+ for pos, f := range st.FieldList {
+ if innerMap := fieldResults[pos]; len(innerMap) != 0 {
+ for inner, acc := range innerMap {
+ acc := acc
+ result[inner] = accessor{pos: pos, inner: &acc}
+ }
+ } else {
+ result[f.ID] = accessor{pos: pos}
+ }
+ }
+ return result
+}
+
+func (buildPosAccessors) Field(_ NestedField, fieldResult map[int]accessor)
map[int]accessor {
+ return fieldResult
+}
+
+func (buildPosAccessors) List(ListType, map[int]accessor) map[int]accessor {
+ return map[int]accessor{}
+}
+
+func (buildPosAccessors) Map(_ MapType, _, _ map[int]accessor)
map[int]accessor {
+ return map[int]accessor{}
+}
+
+func (buildPosAccessors) Primitive(PrimitiveType) map[int]accessor {
+ return map[int]accessor{}
+}
+
+func buildAccessors(schema *Schema) (map[int]accessor, error) {
+ return Visit(schema, buildPosAccessors{})
+}
diff --git a/utils.go b/utils.go
index fd669f8..c70c2bb 100644
--- a/utils.go
+++ b/utils.go
@@ -19,6 +19,9 @@ package iceberg
import (
"cmp"
+ "fmt"
+ "hash/maphash"
+ "maps"
"runtime/debug"
"strings"
)
@@ -52,3 +55,127 @@ func max[T cmp.Ordered](vals ...T) T {
}
return out
}
+
+// Optional represents a typed value that could be null
+type Optional[T any] struct {
+ Val T
+ Valid bool
+}
+
+// represents a single row in a record
+type structLike interface {
+ // Size returns the number of columns in this row
+ Size() int
+ // Get returns the value in the requested column,
+ // will panic if pos is out of bounds.
+ Get(pos int) any
+ // Set changes the value in the column indicated,
+ // will panic if pos is out of bounds.
+ Set(pos int, val any)
+}
+
+type accessor struct {
+ pos int
+ inner *accessor
+}
+
+func (a *accessor) String() string {
+ return fmt.Sprintf("Accessor(position=%d, inner=%s)", a.pos, a.inner)
+}
+
+func (a *accessor) Get(s structLike) any {
+ val, inner := s.Get(a.pos), a
+ for inner.inner != nil {
+ inner = inner.inner
+ val = val.(structLike).Get(inner.pos)
+ }
+ return val
+}
+
+type Set[E any] interface {
+ Add(...E)
+ Contains(E) bool
+ Members() []E
+ Equals(Set[E]) bool
+ Len() int
+}
+
+var lzseed = maphash.MakeSeed()
+
+type literalSet map[any]struct{ orig Literal }
+
+func newLiteralSet(vals ...Literal) Set[Literal] {
+ s := literalSet{}
+ for _, v := range vals {
+ s.addliteral(v)
+ }
+ return s
+}
+
+func (l literalSet) addliteral(v Literal) {
+ switch v := v.(type) {
+ case FixedLiteral:
+ l[maphash.Bytes(lzseed, []byte(v))] = struct{ orig Literal }{v}
+ case BinaryLiteral:
+ l[maphash.Bytes(lzseed, []byte(v))] = struct{ orig Literal }{v}
+ default:
+ l[v] = struct{ orig Literal }{}
+ }
+}
+
+func (l literalSet) Add(lits ...Literal) {
+ for _, v := range lits {
+ l.addliteral(v)
+ }
+}
+
+func (l literalSet) Contains(lit Literal) bool {
+ switch lit := lit.(type) {
+ case BinaryLiteral:
+ v, ok := l[maphash.Bytes(lzseed, []byte(lit))]
+ if !ok {
+ return false
+ }
+ return lit.Equals(v.orig)
+ case FixedLiteral:
+ v, ok := l[maphash.Bytes(lzseed, []byte(lit))]
+ if !ok {
+ return false
+ }
+ return lit.Equals(v.orig)
+ default:
+ _, ok := l[lit]
+ return ok
+ }
+}
+
+func (l literalSet) Members() []Literal {
+ result := make([]Literal, 0, len(l))
+ for k, v := range l {
+ if k, ok := k.(Literal); ok {
+ result = append(result, k)
+ } else {
+ result = append(result, v.orig)
+ }
+ }
+ return result
+}
+
+func (l literalSet) Equals(other Set[Literal]) bool {
+ rhs, ok := other.(literalSet)
+ if !ok {
+ return false
+ }
+ return maps.EqualFunc(l, rhs, func(v1, v2 struct{ orig Literal }) bool {
+ switch {
+ case v1.orig == nil:
+ return v2.orig == nil
+ case v2.orig == nil:
+ return v1.orig == nil
+ default:
+ return v1.orig.Equals(v2.orig)
+ }
+ })
+}
+
+func (l literalSet) Len() int { return len(l) }