This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-go.git
The following commit(s) were added to refs/heads/main by this push:
new fad5d82 feat(visitors): Implement basic boolean expression visitors
(#108)
fad5d82 is described below
commit fad5d8228bef04d756df000678772301abb91ce3
Author: Matt Topol <[email protected]>
AuthorDate: Fri Jul 19 07:59:01 2024 -0700
feat(visitors): Implement basic boolean expression visitors (#108)
* feat: expression evaluator
* Implement manifest evaluator and tests
* fix comparison of metadata
---
errors.go | 17 +-
exprs.go | 33 +-
literals.go | 336 +++++++++++++++
literals_test.go | 277 +++++++++++-
partitions.go | 2 +-
schema.go | 96 ++++-
table/metadata.go | 66 +++
table/metadata_test.go | 5 +-
table/refs.go | 5 +
table/sorting.go | 6 +
table/table.go | 4 +-
utils.go | 19 +-
visitors.go | 864 ++++++++++++++++++++++++++++++++++++++
visitors_test.go | 1085 ++++++++++++++++++++++++++++++++++++++++++++++++
14 files changed, 2792 insertions(+), 23 deletions(-)
diff --git a/errors.go b/errors.go
index 71f9f63..f4fc986 100644
--- a/errors.go
+++ b/errors.go
@@ -20,12 +20,13 @@ package iceberg
import "errors"
var (
- ErrInvalidTypeString = errors.New("invalid type")
- ErrNotImplemented = errors.New("not implemented")
- ErrInvalidArgument = errors.New("invalid argument")
- ErrInvalidSchema = errors.New("invalid schema")
- ErrInvalidTransform = errors.New("invalid transform syntax")
- ErrType = errors.New("type error")
- ErrBadCast = errors.New("could not cast value")
- ErrBadLiteral = errors.New("invalid literal value")
+ ErrInvalidTypeString = errors.New("invalid type")
+ ErrNotImplemented = errors.New("not implemented")
+ ErrInvalidArgument = errors.New("invalid argument")
+ ErrInvalidSchema = errors.New("invalid schema")
+ ErrInvalidTransform = errors.New("invalid transform syntax")
+ ErrType = errors.New("type error")
+ ErrBadCast = errors.New("could not cast value")
+ ErrBadLiteral = errors.New("invalid literal value")
+ ErrInvalidBinSerialization = errors.New("invalid binary serialization")
)
diff --git a/exprs.go b/exprs.go
index 1123b8b..bc451ca 100644
--- a/exprs.go
+++ b/exprs.go
@@ -19,6 +19,7 @@ package iceberg
import (
"fmt"
+ "reflect"
"github.com/google/uuid"
)
@@ -332,7 +333,7 @@ type BoundTerm interface {
Ref() BoundReference
Type() Type
- evalToLiteral(structLike) Literal
+ evalToLiteral(structLike) Optional[Literal]
evalIsNull(structLike) bool
}
@@ -409,6 +410,7 @@ type BoundReference interface {
BoundTerm
Field() NestedField
+ Pos() int
}
type boundRef[T LiteralType] struct {
@@ -446,6 +448,8 @@ func createBoundRef(field NestedField, acc accessor)
BoundReference {
panic("unhandled bound reference type: " + field.Type.String())
}
+func (b *boundRef[T]) Pos() int { return b.acc.pos }
+
func (*boundRef[T]) isTerm() {}
func (b *boundRef[T]) String() string {
@@ -471,17 +475,32 @@ func (b *boundRef[T]) eval(st structLike) Optional[T] {
return Optional[T]{}
case T:
return Optional[T]{Valid: true, Val: v}
+ default:
+ var z T
+ typ, val := reflect.TypeOf(z), reflect.ValueOf(v)
+ if !val.CanConvert(typ) {
+ panic(fmt.Errorf("%w: cannot convert value '%+v' to
expected type %s",
+ ErrInvalidSchema, val.Interface(),
typ.String()))
+ }
+
+ return Optional[T]{
+ Valid: true,
+ Val: val.Convert(typ).Interface().(T),
+ }
}
- panic("unexpected type returned for bound ref")
}
-func (b *boundRef[T]) evalToLiteral(st structLike) Literal {
+func (b *boundRef[T]) evalToLiteral(st structLike) Optional[Literal] {
v := b.eval(st)
+ if !v.Valid {
+ return Optional[Literal]{}
+ }
+
lit := NewLiteral[T](v.Val)
if !lit.Type().Equals(b.field.Type) {
lit, _ = lit.To(b.field.Type)
}
- return lit
+ return Optional[Literal]{Val: lit, Valid: true}
}
func (b *boundRef[T]) evalIsNull(st structLike) bool {
@@ -538,11 +557,11 @@ func (up *unboundUnaryPredicate) Bind(schema *Schema,
caseSensitive bool) (Boole
// fast case optimizations
switch up.op {
case OpIsNull:
- if bound.Ref().Field().Required {
+ if bound.Ref().Field().Required &&
!schema.FieldHasOptionalParent(bound.Ref().Field().ID) {
return AlwaysFalse{}, nil
}
case OpNotNull:
- if bound.Ref().Field().Required {
+ if bound.Ref().Field().Required &&
!schema.FieldHasOptionalParent(bound.Ref().Field().ID) {
return AlwaysTrue{}, nil
}
case OpIsNan:
@@ -686,7 +705,7 @@ func (ul *unboundLiteralPredicate) Bind(schema *Schema,
caseSensitive bool) (Boo
}
if (ul.op == OpStartsWith || ul.op == OpNotStartsWith) &&
- !bound.Type().Equals(PrimitiveTypes.String) {
+ !(bound.Type().Equals(PrimitiveTypes.String) ||
bound.Type().Equals(PrimitiveTypes.Binary)) {
return nil, fmt.Errorf("%w: StartsWith and NotStartsWith must
bind to String type, not %s",
ErrType, bound.Type())
}
diff --git a/literals.go b/literals.go
index 9b42e59..d88f4aa 100644
--- a/literals.go
+++ b/literals.go
@@ -20,12 +20,16 @@ package iceberg
import (
"bytes"
"cmp"
+ "encoding"
+ "encoding/binary"
"errors"
"fmt"
"math"
+ "math/big"
"reflect"
"strconv"
"time"
+ "unsafe"
"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/decimal128"
@@ -50,6 +54,7 @@ type Comparator[T LiteralType] func(v1, v2 T) int
// equality against other literals.
type Literal interface {
fmt.Stringer
+ encoding.BinaryMarshaler
Type() Type
To(Type) (Literal, error)
@@ -97,6 +102,80 @@ func NewLiteral[T LiteralType](val T) Literal {
panic("can't happen due to literal type constraint")
}
+// LiteralFromBytes uses the defined Iceberg spec for how to serialize a value
of
+// a the provided type and returns the appropriate Literal value from it.
+//
+// If you already have a value of the desired Literal type, you could
alternatively
+// call UnmarshalBinary on it yourself manually.
+//
+// This is primarily used for retrieving stat values.
+func LiteralFromBytes(typ Type, data []byte) (Literal, error) {
+ if data == nil {
+ return nil, ErrInvalidBinSerialization
+ }
+
+ switch t := typ.(type) {
+ case BooleanType:
+ var v BoolLiteral
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case Int32Type:
+ var v Int32Literal
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case Int64Type:
+ var v Int64Literal
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case Float32Type:
+ var v Float32Literal
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case Float64Type:
+ var v Float64Literal
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case StringType:
+ var v StringLiteral
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case BinaryType:
+ var v BinaryLiteral
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case FixedType:
+ if len(data) != t.Len() {
+ return nil, fmt.Errorf("%w: expected length %d for type
%s, got %d",
+ ErrInvalidBinSerialization, t.Len(), t,
len(data))
+ }
+ var v FixedLiteral
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case DecimalType:
+ v := DecimalLiteral{Scale: t.scale}
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case DateType:
+ var v DateLiteral
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case TimeType:
+ var v TimeLiteral
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case TimestampType, TimestampTzType:
+ var v TimestampLiteral
+ err := v.UnmarshalBinary(data)
+ return v, err
+ case UUIDType:
+ var v UUIDLiteral
+ err := v.UnmarshalBinary(data)
+ return v, err
+ }
+
+ return nil, ErrType
+}
+
// convenience to avoid repreating this pattern for primitive types
func literalEq[L interface {
comparable
@@ -130,6 +209,11 @@ type aboveMaxLiteral[T int32 | int64 | float32 | float64]
struct {
value T
}
+func (ab aboveMaxLiteral[T]) MarshalBinary() (data []byte, err error) {
+ return nil, fmt.Errorf("%w: cannot marshal above max literal",
+ ErrInvalidBinSerialization)
+}
+
func (ab aboveMaxLiteral[T]) aboveMax() {}
func (ab aboveMaxLiteral[T]) Type() Type {
@@ -168,6 +252,11 @@ type belowMinLiteral[T int32 | int64 | float32 | float64]
struct {
value T
}
+func (bm belowMinLiteral[T]) MarshalBinary() (data []byte, err error) {
+ return nil, fmt.Errorf("%w: cannot marshal above max literal",
+ ErrInvalidBinSerialization)
+}
+
func (bm belowMinLiteral[T]) belowMin() {}
func (bm belowMinLiteral[T]) Type() Type {
@@ -263,6 +352,27 @@ func (b BoolLiteral) Equals(l Literal) bool {
return literalEq(b, l)
}
+var (
+ falseBin, trueBin = [1]byte{0x0}, [1]byte{0x1}
+)
+
+func (b BoolLiteral) MarshalBinary() (data []byte, err error) {
+ // stored as 0x00 for false, and anything non-zero for True
+ if b {
+ return trueBin[:], nil
+ }
+ return falseBin[:], nil
+}
+
+func (b *BoolLiteral) UnmarshalBinary(data []byte) error {
+ // stored as 0x00 for false and anything non-zero for True
+ if len(data) < 1 {
+ return fmt.Errorf("%w: expected at least 1 byte for bool",
ErrInvalidBinSerialization)
+ }
+ *b = data[0] != 0
+ return nil
+}
+
type Int32Literal int32
func (Int32Literal) Comparator() Comparator[int32] { return cmp.Compare[int32]
}
@@ -306,6 +416,24 @@ func (i Int32Literal) Equals(other Literal) bool {
return literalEq(i, other)
}
+func (i Int32Literal) MarshalBinary() (data []byte, err error) {
+ // stored as 4 bytes in little endian order
+ data = make([]byte, 4)
+ binary.LittleEndian.PutUint32(data, uint32(i))
+ return
+}
+
+func (i *Int32Literal) UnmarshalBinary(data []byte) error {
+ // stored as 4 bytes little endian
+ if len(data) != 4 {
+ return fmt.Errorf("%w: expected 4 bytes for int32 value, got
%d",
+ ErrInvalidBinSerialization, len(data))
+ }
+
+ *i = Int32Literal(binary.LittleEndian.Uint32(data))
+ return nil
+}
+
type Int64Literal int64
func (Int64Literal) Comparator() Comparator[int64] { return cmp.Compare[int64]
}
@@ -349,10 +477,28 @@ func (i Int64Literal) To(t Type) (Literal, error) {
return nil, fmt.Errorf("%w: Int64Literal to %s", ErrBadCast, t)
}
+
func (i Int64Literal) Equals(other Literal) bool {
return literalEq(i, other)
}
+func (i Int64Literal) MarshalBinary() (data []byte, err error) {
+ // stored as 8 byte little-endian
+ data = make([]byte, 8)
+ binary.LittleEndian.PutUint64(data, uint64(i))
+ return
+}
+
+func (i *Int64Literal) UnmarshalBinary(data []byte) error {
+ // stored as 8 byte little-endian
+ if len(data) != 8 {
+ return fmt.Errorf("%w: expected 8 bytes for int64 value, got
%d",
+ ErrInvalidBinSerialization, len(data))
+ }
+ *i = Int64Literal(binary.LittleEndian.Uint64(data))
+ return nil
+}
+
type Float32Literal float32
func (Float32Literal) Comparator() Comparator[float32] { return
cmp.Compare[float32] }
@@ -375,10 +521,28 @@ func (f Float32Literal) To(t Type) (Literal, error) {
return nil, fmt.Errorf("%w: Float32Literal to %s", ErrBadCast, t)
}
+
func (f Float32Literal) Equals(other Literal) bool {
return literalEq(f, other)
}
+func (f Float32Literal) MarshalBinary() (data []byte, err error) {
+ // stored as 4 bytes little endian
+ data = make([]byte, 4)
+ binary.LittleEndian.PutUint32(data, math.Float32bits(float32(f)))
+ return
+}
+
+func (f *Float32Literal) UnmarshalBinary(data []byte) error {
+ // stored as 4 bytes little endian
+ if len(data) != 4 {
+ return fmt.Errorf("%w: expected 4 bytes for float32 value, got
%d",
+ ErrInvalidBinSerialization, len(data))
+ }
+ *f =
Float32Literal(math.Float32frombits(binary.LittleEndian.Uint32(data)))
+ return nil
+}
+
type Float64Literal float64
func (Float64Literal) Comparator() Comparator[float64] { return
cmp.Compare[float64] }
@@ -406,10 +570,28 @@ func (f Float64Literal) To(t Type) (Literal, error) {
return nil, fmt.Errorf("%w: Float64Literal to %s", ErrBadCast, t)
}
+
func (f Float64Literal) Equals(other Literal) bool {
return literalEq(f, other)
}
+func (f Float64Literal) MarshalBinary() (data []byte, err error) {
+ // stored as 8 bytes little endian
+ data = make([]byte, 8)
+ binary.LittleEndian.PutUint64(data, math.Float64bits(float64(f)))
+ return
+}
+
+func (f *Float64Literal) UnmarshalBinary(data []byte) error {
+ // stored as 8 bytes in little endian
+ if len(data) != 8 {
+ return fmt.Errorf("%w: expected 8 bytes for float64 value, got
%d",
+ ErrInvalidBinSerialization, len(data))
+ }
+ *f =
Float64Literal(math.Float64frombits(binary.LittleEndian.Uint64(data)))
+ return nil
+}
+
type DateLiteral Date
func (DateLiteral) Comparator() Comparator[Date] { return cmp.Compare[Date] }
@@ -430,6 +612,23 @@ func (d DateLiteral) Equals(other Literal) bool {
return literalEq(d, other)
}
+func (d DateLiteral) MarshalBinary() (data []byte, err error) {
+ // stored as 4 byte little endian
+ data = make([]byte, 4)
+ binary.LittleEndian.PutUint32(data, uint32(d))
+ return
+}
+
+func (d *DateLiteral) UnmarshalBinary(data []byte) error {
+ // stored as 4 byte little endian
+ if len(data) != 4 {
+ return fmt.Errorf("%w: expected 4 bytes for date value, got %d",
+ ErrInvalidBinSerialization, len(data))
+ }
+ *d = DateLiteral(binary.LittleEndian.Uint32(data))
+ return nil
+}
+
type TimeLiteral Time
func (TimeLiteral) Comparator() Comparator[Time] { return cmp.Compare[Time] }
@@ -451,6 +650,23 @@ func (t TimeLiteral) Equals(other Literal) bool {
return literalEq(t, other)
}
+func (t TimeLiteral) MarshalBinary() (data []byte, err error) {
+ // stored as 8 byte little-endian
+ data = make([]byte, 8)
+ binary.LittleEndian.PutUint64(data, uint64(t))
+ return
+}
+
+func (t *TimeLiteral) UnmarshalBinary(data []byte) error {
+ // stored as 8 byte little-endian representing microseconds from
midnight
+ if len(data) != 8 {
+ return fmt.Errorf("%w: expected 8 bytes for time value, got %d",
+ ErrInvalidBinSerialization, len(data))
+ }
+ *t = TimeLiteral(binary.LittleEndian.Uint64(data))
+ return nil
+}
+
type TimestampLiteral Timestamp
func (TimestampLiteral) Comparator() Comparator[Timestamp] { return
cmp.Compare[Timestamp] }
@@ -475,6 +691,23 @@ func (t TimestampLiteral) Equals(other Literal) bool {
return literalEq(t, other)
}
+func (t TimestampLiteral) MarshalBinary() (data []byte, err error) {
+ // stored as 8 byte little endian
+ data = make([]byte, 8)
+ binary.LittleEndian.PutUint64(data, uint64(t))
+ return
+}
+
+func (t *TimestampLiteral) UnmarshalBinary(data []byte) error {
+ // stored as 8 byte little endian value representing microseconds since
epoch
+ if len(data) != 8 {
+ return fmt.Errorf("%w: expected 8 bytes for timestamp value,
got %d",
+ ErrInvalidBinSerialization, len(data))
+ }
+ *t = TimestampLiteral(binary.LittleEndian.Uint64(data))
+ return nil
+}
+
type StringLiteral string
func (StringLiteral) Comparator() Comparator[string] { return
cmp.Compare[string] }
@@ -575,6 +808,14 @@ func (s StringLiteral) To(typ Type) (Literal, error) {
ErrBadCast, s, typ, err.Error())
}
return BoolLiteral(val), nil
+ case BinaryType:
+ return BinaryLiteral(s), nil
+ case FixedType:
+ if len(s) != t.len {
+ return nil, fmt.Errorf("%w: cast '%s' to %s - wrong
length",
+ ErrBadCast, s, t)
+ }
+ return FixedLiteral(s), nil
}
return nil, fmt.Errorf("%w: StringLiteral to %s", ErrBadCast, typ)
}
@@ -583,6 +824,21 @@ func (s StringLiteral) Equals(other Literal) bool {
return literalEq(s, other)
}
+func (s StringLiteral) MarshalBinary() (data []byte, err error) {
+ // stored as UTF-8 bytes without length
+ // avoid copying by just returning a slice of the raw bytes
+ data = unsafe.Slice(unsafe.StringData(string(s)), len(s))
+ return
+}
+
+func (s *StringLiteral) UnmarshalBinary(data []byte) error {
+ // stored as UTF-8 bytes without length
+ // avoid copy, but this means that the passed in slice is being given
+ // to the literal for ownership
+ *s = StringLiteral(unsafe.String(unsafe.SliceData(data), len(data)))
+ return nil
+}
+
type BinaryLiteral []byte
func (BinaryLiteral) Comparator() Comparator[[]byte] {
@@ -622,6 +878,18 @@ func (b BinaryLiteral) Equals(other Literal) bool {
return bytes.Equal([]byte(b), rhs)
}
+func (b BinaryLiteral) MarshalBinary() (data []byte, err error) {
+ // stored directly as is
+ data = b
+ return
+}
+
+func (b *BinaryLiteral) UnmarshalBinary(data []byte) error {
+ // stored directly as is
+ *b = BinaryLiteral(data)
+ return nil
+}
+
type FixedLiteral []byte
func (FixedLiteral) Comparator() Comparator[[]byte] { return bytes.Compare }
@@ -660,6 +928,18 @@ func (f FixedLiteral) Equals(other Literal) bool {
return bytes.Equal([]byte(f), rhs)
}
+func (f FixedLiteral) MarshalBinary() (data []byte, err error) {
+ // stored directly as is
+ data = f
+ return
+}
+
+func (f *FixedLiteral) UnmarshalBinary(data []byte) error {
+ // stored directly as is
+ *f = FixedLiteral(data)
+ return nil
+}
+
type UUIDLiteral uuid.UUID
func (UUIDLiteral) Comparator() Comparator[uuid.UUID] {
@@ -699,6 +979,20 @@ func (u UUIDLiteral) Equals(other Literal) bool {
return uuid.UUID(u) == uuid.UUID(rhs)
}
+func (u UUIDLiteral) MarshalBinary() (data []byte, err error) {
+ return uuid.UUID(u).MarshalBinary()
+}
+
+func (u *UUIDLiteral) UnmarshalBinary(data []byte) error {
+ // stored as 16-byte big-endian value
+ out, err := uuid.FromBytes(data)
+ if err != nil {
+ return err
+ }
+ *u = UUIDLiteral(out)
+ return nil
+}
+
type DecimalLiteral Decimal
func (DecimalLiteral) Comparator() Comparator[Decimal] {
@@ -778,3 +1072,45 @@ func (d DecimalLiteral) Equals(other Literal) bool {
}
return d.Val == rescaled
}
+
+func (d DecimalLiteral) MarshalBinary() (data []byte, err error) {
+ // stored as unscaled value in two's compliment big-endian values
+ // using the minimum number of bytes for the values
+ n := decimal128.Num(d.Val).BigInt()
+ // bytes gives absolute value as big-endian bytes
+ data = n.Bytes()
+ if n.Sign() < 0 {
+ // convert to 2's complement for negative value
+ for i, v := range data {
+ data[i] = ^v
+ }
+ data[len(data)-1] += 1
+ }
+ return
+}
+
+func (d *DecimalLiteral) UnmarshalBinary(data []byte) error {
+ // stored as unscaled value in two's complement
+ // big-endian values using the minimum number of bytes
+ if len(data) == 0 {
+ d.Val = decimal128.Num{}
+ return nil
+ }
+
+ if int8(data[0]) >= 0 {
+ // not negative
+ d.Val = decimal128.FromBigInt((&big.Int{}).SetBytes(data))
+ return nil
+ }
+
+ // convert two's complement and remember it's negative
+ out := make([]byte, len(data))
+ for i, b := range data {
+ out[i] = ^b
+ }
+ out[len(out)-1] += 1
+
+ value := (&big.Int{}).SetBytes(out)
+ d.Val = decimal128.FromBigInt(value.Neg(value))
+ return nil
+}
diff --git a/literals_test.go b/literals_test.go
index f7a483b..1f9baa9 100644
--- a/literals_test.go
+++ b/literals_test.go
@@ -641,7 +641,7 @@ func TestInvalidDateTimeLiteralConversions(t *testing.T) {
func TestInvalidStringLiteralConversions(t *testing.T) {
testInvalidLiteralConversions(t, iceberg.NewLiteral("abc"),
[]iceberg.Type{
- iceberg.FixedTypeOf(1), iceberg.PrimitiveTypes.Binary,
+ iceberg.FixedTypeOf(1),
})
}
@@ -727,3 +727,278 @@ func TestStringLiteralToIntMaxMinValue(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, iceberg.Int32BelowMinLiteral(), below)
}
+
+func TestUnmarshalBinary(t *testing.T) {
+ tests := []struct {
+ typ iceberg.Type
+ data []byte
+ result iceberg.Literal
+ }{
+ {iceberg.PrimitiveTypes.Bool, []byte{0x0},
iceberg.BoolLiteral(false)},
+ {iceberg.PrimitiveTypes.Bool, []byte{0x1},
iceberg.BoolLiteral(true)},
+ {iceberg.PrimitiveTypes.Int32, []byte{0xd2, 0x04, 0x00, 0x00},
iceberg.Int32Literal(1234)},
+ {iceberg.PrimitiveTypes.Int64, []byte{0xd2, 0x04, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00},
+ iceberg.Int64Literal(1234)},
+ {iceberg.PrimitiveTypes.Float32, []byte{0x00, 0x00, 0x90,
0xc0}, iceberg.Float32Literal(-4.5)},
+ {iceberg.PrimitiveTypes.Float64, []byte{0x8d, 0x97, 0x6e, 0x12,
0x83, 0xc0, 0xf3, 0x3f},
+ iceberg.Float64Literal(1.2345)},
+ {iceberg.PrimitiveTypes.Date, []byte{0xe8, 0x03, 0x00, 0x00},
iceberg.DateLiteral(1000)},
+ {iceberg.PrimitiveTypes.Date, []byte{0xd2, 0x04, 0x00, 0x00},
iceberg.DateLiteral(1234)},
+ {iceberg.PrimitiveTypes.Time, []byte{0x10, 0x27, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00},
+ iceberg.TimeLiteral(10000)},
+ {iceberg.PrimitiveTypes.Time, []byte{0x00, 0xe8, 0x76, 0x48,
0x17, 0x00, 0x00, 0x00},
+ iceberg.TimeLiteral(100000000000)},
+ {iceberg.PrimitiveTypes.TimestampTz, []byte{0x80, 0x1a, 0x06,
0x00, 0x00, 0x00, 0x00, 0x00},
+ iceberg.TimestampLiteral(400000)},
+ {iceberg.PrimitiveTypes.TimestampTz, []byte{0x00, 0xe8, 0x76,
0x48, 0x17, 0x00, 0x00, 0x00},
+ iceberg.TimestampLiteral(100000000000)},
+ {iceberg.PrimitiveTypes.Timestamp, []byte{0x80, 0x1a, 0x06,
0x00, 0x00, 0x00, 0x00, 0x00},
+ iceberg.TimestampLiteral(400000)},
+ {iceberg.PrimitiveTypes.Timestamp, []byte{0x00, 0xe8, 0x76,
0x48, 0x17, 0x00, 0x00, 0x00},
+ iceberg.TimestampLiteral(100000000000)},
+ {iceberg.PrimitiveTypes.String, []byte("ABC"),
iceberg.StringLiteral("ABC")},
+ {iceberg.PrimitiveTypes.String, []byte("foo"),
iceberg.StringLiteral("foo")},
+ {iceberg.PrimitiveTypes.UUID,
+ []byte{0xf7, 0x9c, 0x3e, 0x09, 0x67, 0x7c, 0x4b, 0xbd,
0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7},
+ iceberg.UUIDLiteral(uuid.UUID{0xf7, 0x9c, 0x3e, 0x09,
0x67, 0x7c, 0x4b, 0xbd, 0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7})},
+ {iceberg.FixedTypeOf(3), []byte("foo"),
iceberg.FixedLiteral([]byte("foo"))},
+ {iceberg.PrimitiveTypes.Binary, []byte("foo"),
iceberg.BinaryLiteral([]byte("foo"))},
+ {iceberg.DecimalTypeOf(5, 2), []byte{0x30, 0x39},
+ iceberg.DecimalLiteral{Scale: 2, Val:
decimal128.FromU64(12345)}},
+ {iceberg.DecimalTypeOf(7, 4), []byte{0x12, 0xd6, 0x87},
+ iceberg.DecimalLiteral{Scale: 4, Val:
decimal128.FromU64(1234567)}},
+ {iceberg.DecimalTypeOf(7, 4), []byte{0xff, 0xed, 0x29, 0x79},
+ iceberg.DecimalLiteral{Scale: 4, Val:
decimal128.FromI64(-1234567)}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.typ.String(), func(t *testing.T) {
+ lit, err := iceberg.LiteralFromBytes(tt.typ, tt.data)
+ require.NoError(t, err)
+
+ assert.Truef(t, tt.result.Equals(lit), "expected: %s,
got: %s", tt.result, lit)
+ })
+ }
+}
+
+func TestRoundTripLiteralBinary(t *testing.T) {
+ tests := []struct {
+ typ iceberg.Type
+ b []byte
+ result iceberg.Literal
+ }{
+ {iceberg.PrimitiveTypes.Bool, []byte{0x0},
iceberg.BoolLiteral(false)},
+ {iceberg.PrimitiveTypes.Bool, []byte{0x1},
iceberg.BoolLiteral(true)},
+ {iceberg.PrimitiveTypes.Int32, []byte{0xd2, 0x04, 0x00, 0x00},
iceberg.Int32Literal(1234)},
+ {iceberg.PrimitiveTypes.Int64, []byte{0xd2, 0x04, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00},
+ iceberg.Int64Literal(1234)},
+ {iceberg.PrimitiveTypes.Float32, []byte{0x00, 0x00, 0x90,
0xc0}, iceberg.Float32Literal(-4.5)},
+ {iceberg.PrimitiveTypes.Float32, []byte{0x19, 0x04, 0x9e,
0x3f}, iceberg.Float32Literal(1.2345)},
+ {iceberg.PrimitiveTypes.Float64, []byte{0x8d, 0x97, 0x6e, 0x12,
0x83, 0xc0, 0xf3, 0x3f},
+ iceberg.Float64Literal(1.2345)},
+ {iceberg.PrimitiveTypes.Date, []byte{0xe8, 0x03, 0x00, 0x00},
iceberg.DateLiteral(1000)},
+ {iceberg.PrimitiveTypes.Date, []byte{0xd2, 0x04, 0x00, 0x00},
iceberg.DateLiteral(1234)},
+ {iceberg.PrimitiveTypes.Time, []byte{0x10, 0x27, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00},
+ iceberg.TimeLiteral(10000)},
+ {iceberg.PrimitiveTypes.Time, []byte{0x00, 0xe8, 0x76, 0x48,
0x17, 0x00, 0x00, 0x00},
+ iceberg.TimeLiteral(100000000000)},
+ {iceberg.PrimitiveTypes.TimestampTz, []byte{0x80, 0x1a, 0x06,
0x00, 0x00, 0x00, 0x00, 0x00},
+ iceberg.TimestampLiteral(400000)},
+ {iceberg.PrimitiveTypes.TimestampTz, []byte{0x00, 0xe8, 0x76,
0x48, 0x17, 0x00, 0x00, 0x00},
+ iceberg.TimestampLiteral(100000000000)},
+ {iceberg.PrimitiveTypes.Timestamp, []byte{0x80, 0x1a, 0x06,
0x00, 0x00, 0x00, 0x00, 0x00},
+ iceberg.TimestampLiteral(400000)},
+ {iceberg.PrimitiveTypes.Timestamp, []byte{0x00, 0xe8, 0x76,
0x48, 0x17, 0x00, 0x00, 0x00},
+ iceberg.TimestampLiteral(100000000000)},
+ {iceberg.PrimitiveTypes.String, []byte("ABC"),
iceberg.StringLiteral("ABC")},
+ {iceberg.PrimitiveTypes.String, []byte("foo"),
iceberg.StringLiteral("foo")},
+ {iceberg.PrimitiveTypes.UUID,
+ []byte{0xf7, 0x9c, 0x3e, 0x09, 0x67, 0x7c, 0x4b, 0xbd,
0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7},
+ iceberg.UUIDLiteral(uuid.UUID{0xf7, 0x9c, 0x3e, 0x09,
0x67, 0x7c, 0x4b, 0xbd, 0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7})},
+ {iceberg.FixedTypeOf(3), []byte("foo"),
iceberg.FixedLiteral([]byte("foo"))},
+ {iceberg.PrimitiveTypes.Binary, []byte("foo"),
iceberg.BinaryLiteral([]byte("foo"))},
+ {iceberg.DecimalTypeOf(5, 2), []byte{0x30, 0x39},
+ iceberg.DecimalLiteral{Scale: 2, Val:
decimal128.FromU64(12345)}},
+ // decimal on 3-bytes to test that we use the minimum number of
bytes and not a power of 2
+ // 1234567 is 00010010|11010110|10000111 in binary
+ // 00010010 -> 18, 11010110 -> 214, 10000111 -> 135
+ {iceberg.DecimalTypeOf(7, 4), []byte{0x12, 0xd6, 0x87},
+ iceberg.DecimalLiteral{Scale: 4, Val:
decimal128.FromU64(1234567)}},
+ // negative decimal to test two's complement
+ // -1234567 is 11101101|00101001|01111001 in binary
+ // 11101101 -> 237, 00101001 -> 41, 01111001 -> 121
+ {iceberg.DecimalTypeOf(7, 4), []byte{0xed, 0x29, 0x79},
+ iceberg.DecimalLiteral{Scale: 4, Val:
decimal128.FromI64(-1234567)}},
+ // test empty byte in decimal
+ // 11 is 00001011 in binary
+ // 00001011 -> 11
+ {iceberg.DecimalTypeOf(10, 3), []byte{0x0b},
iceberg.DecimalLiteral{Scale: 3, Val: decimal128.FromU64(11)}},
+ {iceberg.DecimalTypeOf(4, 2), []byte{0x04, 0xd2},
iceberg.DecimalLiteral{Scale: 2, Val: decimal128.FromU64(1234)}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.result.String(), func(t *testing.T) {
+ lit, err := iceberg.LiteralFromBytes(tt.typ, tt.b)
+ require.NoError(t, err)
+
+ assert.True(t, lit.Equals(tt.result))
+
+ data, err := lit.MarshalBinary()
+ require.NoError(t, err)
+
+ assert.Equal(t, tt.b, data)
+ })
+ }
+}
+
+func TestLargeDecimalRoundTrip(t *testing.T) {
+ tests := []struct {
+ typ iceberg.DecimalType
+ b []byte
+ val string
+ }{
+ {iceberg.DecimalTypeOf(38, 21),
+ []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x18, 0x30,
0x73, 0xb9, 0x1e,
+ 0x7e, 0xa2, 0xb3, 0x6a, 0x83},
+ "12345678912345678.123456789123456789123"},
+ {iceberg.DecimalTypeOf(38, 22),
+ []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x16, 0xbb,
0x01, 0x2f,
+ 0x4c, 0xc3, 0x2b, 0x42, 0x29, 0x22},
+ "1234567891234567.1234567891234567891234"},
+ {iceberg.DecimalTypeOf(38, 23),
+ []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x0a, 0x42,
0xa1, 0xad,
+ 0xe5, 0x2b, 0x33, 0x15, 0x9b, 0x59},
+ "123456789123456.12345678912345678912345"},
+ {iceberg.DecimalTypeOf(38, 24),
+ []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe8, 0xa2, 0xbb,
0xe9, 0x67,
+ 0xba, 0x86, 0x77, 0xd8, 0x11, 0x80},
+ "12345678912345.123456789123456789123456"},
+ {iceberg.DecimalTypeOf(38, 25),
+ []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe5, 0x6b, 0x3a,
0xd2, 0x78,
+ 0xdd, 0x04, 0xc8, 0x70, 0xaf, 0x07},
+ "1234567891234.1234567891234567891234567"},
+ {iceberg.DecimalTypeOf(38, 26),
+ []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xcd, 0x85, 0xc5,
0x03, 0x38, 0x37,
+ 0x3c, 0x38, 0x66, 0xd6, 0x4e},
+ "123456789123.12345678912345678912345678"},
+ {iceberg.DecimalTypeOf(38, 27),
+ []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0x31, 0x46, 0xfd,
0xc7, 0x79,
+ 0xca, 0x39, 0x7c, 0x04, 0x5f, 0x15},
+ "12345678912.123456789123456789123456789"},
+ {iceberg.DecimalTypeOf(38, 28),
+ []byte{0x09, 0x49, 0xb0, 0xf7, 0x10, 0x52, 0x01, 0x72,
0x11, 0xda,
+ 0x08, 0x5b, 0x08, 0x2b, 0xb6, 0xd3},
+ "1234567891.1234567891234567891234567891"},
+ {iceberg.DecimalTypeOf(38, 29),
+ []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x18, 0x5b,
0x37, 0xc1,
+ 0x78, 0x0b, 0x91, 0xb5, 0x24, 0x40},
+ "123456789.12345678912345678912345678912"},
+ {iceberg.DecimalTypeOf(38, 30),
+ []byte{0x09, 0x49, 0xb0, 0xed, 0x1e, 0xdf, 0x80, 0x03,
0x47, 0x3b,
+ 0x16, 0x9b, 0xf1, 0x13, 0x6a, 0x83},
+ "12345678.123456789123456789123456789123"},
+ {iceberg.DecimalTypeOf(38, 31),
+ []byte{0x09, 0x49, 0xb0, 0x96, 0x2b, 0xac, 0x29, 0x64,
0x28, 0x70,
+ 0x36, 0x29, 0xea, 0xc2, 0x29, 0x22},
+ "1234567.1234567891234567891234567891234"},
+ {iceberg.DecimalTypeOf(38, 32),
+ []byte{0x09, 0x49, 0xad, 0xae, 0xe3, 0x68, 0xe7, 0x4f,
0xb5, 0x14,
+ 0xbc, 0xdc, 0x2b, 0x95, 0x9b, 0x59},
+ "123456.12345678912345678912345678912345"},
+ {iceberg.DecimalTypeOf(38, 33),
+ []byte{0x09, 0x49, 0x95, 0x94, 0x3e, 0x35, 0x93, 0xde,
0xb9, 0x2e,
+ 0xef, 0x53, 0xb3, 0xd8, 0x11, 0x80},
+ "12345.123456789123456789123456789123456"},
+ {iceberg.DecimalTypeOf(38, 34),
+ []byte{0x09, 0x48, 0xd5, 0xd7, 0x90, 0x78, 0xdf, 0x08,
0x1a, 0xf6,
+ 0x43, 0x09, 0x06, 0x70, 0xaf, 0x07},
+ "1234.1234567891234567891234567891234567"},
+ {iceberg.DecimalTypeOf(38, 35),
+ []byte{0x09, 0x43, 0x45, 0x82, 0x85, 0xc7, 0x56, 0x66,
0x24, 0x4d,
+ 0x16, 0x82, 0x40, 0x66, 0xd6, 0x4e},
+ "123.12345678912345678912345678912345678"},
+ {iceberg.DecimalTypeOf(21, 16),
+ []byte{0x06, 0xb1, 0x3a, 0xe3, 0xc4, 0x4e, 0x94, 0xaf,
0x07},
+ "12345.1234567891234567"},
+ {iceberg.DecimalTypeOf(22, 17),
+ []byte{0x42, 0xec, 0x4c, 0xe5, 0xab, 0x11, 0xce, 0xd6,
0x4e},
+ "12345.12345678912345678"},
+ {iceberg.DecimalTypeOf(23, 18),
+ []byte{0x02, 0x9d, 0x3b, 0x00, 0xf8, 0xae, 0xb2, 0x14,
0x5f, 0x15},
+ "12345.123456789123456789"},
+ {iceberg.DecimalTypeOf(24, 19),
+ []byte{0x1a, 0x24, 0x4e, 0x09, 0xb6, 0xd2, 0xf4, 0xcb,
0xb6, 0xd3},
+ "12345.1234567891234567891"},
+ {iceberg.DecimalTypeOf(25, 20),
+ []byte{0x01, 0x05, 0x6b, 0x0c, 0x61, 0x24, 0x3d, 0x8f,
0xf5, 0x24, 0x40},
+ "12345.12345678912345678912"},
+ {iceberg.DecimalTypeOf(26, 21),
+ []byte{0x0a, 0x36, 0x2e, 0x7b, 0xcb, 0x6a, 0x67, 0x9f,
0x93, 0x6a, 0x83},
+ "12345.123456789123456789123"},
+ {iceberg.DecimalTypeOf(27, 22),
+ []byte{0x66, 0x1d, 0xd0, 0xd5, 0xf2, 0x28, 0x0c, 0x3b,
0xc2, 0x29, 0x22},
+ "12345.1234567891234567891234"},
+ {iceberg.DecimalTypeOf(28, 23),
+ []byte{0x03, 0xfd, 0x2a, 0x28, 0x5b, 0x75, 0x90, 0x7a,
0x55, 0x95, 0x9b, 0x59},
+ "12345.12345678912345678912345"},
+ {iceberg.DecimalTypeOf(29, 24),
+ []byte{0x27, 0xe3, 0xa5, 0x93, 0x92, 0x97, 0xa4, 0xc7,
0x57, 0xd8, 0x11, 0x80},
+ "12345.123456789123456789123456"},
+ {iceberg.DecimalTypeOf(30, 25),
+ []byte{0x01, 0x8e, 0xe4, 0x77, 0xc3, 0xb9, 0xec, 0x6f,
0xc9, 0x6e, 0x70, 0xaf, 0x07},
+ "12345.1234567891234567891234567"},
+ {iceberg.DecimalTypeOf(31, 26),
+ []byte{0x0f, 0x94, 0xec, 0xad, 0xa5, 0x43, 0x3c, 0x5d,
0xde, 0x50, 0x66, 0xd6, 0x4e},
+ "12345.12345678912345678912345678"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.val, func(t *testing.T) {
+ lit, err := iceberg.LiteralFromBytes(tt.typ, tt.b)
+ require.NoError(t, err)
+
+ v, err := decimal128.FromString(tt.val,
int32(tt.typ.Precision()), int32(tt.typ.Scale()))
+ require.NoError(t, err)
+
+ assert.True(t, lit.Equals(iceberg.DecimalLiteral{Scale:
tt.typ.Scale(), Val: v}))
+
+ data, err := lit.MarshalBinary()
+ require.NoError(t, err)
+
+ assert.Equal(t, tt.b, data)
+ })
+ }
+}
+
+func TestDecimalMaxMinRoundTrip(t *testing.T) {
+ tests := []struct {
+ typ iceberg.DecimalType
+ v string
+ }{
+ {iceberg.DecimalTypeOf(6, 2), "9999.99"},
+ {iceberg.DecimalTypeOf(10, 10), ".9999999999"},
+ {iceberg.DecimalTypeOf(2, 1), "9.9"},
+ {iceberg.DecimalTypeOf(38, 37),
"9.9999999999999999999999999999999999999"},
+ {iceberg.DecimalTypeOf(20, 1), "9999999999999999999.9"},
+ {iceberg.DecimalTypeOf(6, 2), "-9999.99"},
+ {iceberg.DecimalTypeOf(10, 10), "-.9999999999"},
+ {iceberg.DecimalTypeOf(2, 1), "-9.9"},
+ {iceberg.DecimalTypeOf(38, 37),
"-9.9999999999999999999999999999999999999"},
+ {iceberg.DecimalTypeOf(20, 1), "-9999999999999999999.9"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.v, func(t *testing.T) {
+ v, err := decimal128.FromString(tt.v,
int32(tt.typ.Precision()), int32(tt.typ.Scale()))
+ require.NoError(t, err)
+
+ lit := iceberg.DecimalLiteral{Scale: tt.typ.Scale(),
Val: v}
+ b, err := lit.MarshalBinary()
+ require.NoError(t, err)
+ val, err := iceberg.LiteralFromBytes(tt.typ, b)
+ require.NoError(t, err)
+
+ assert.True(t, val.Equals(lit))
+ })
+ }
+}
diff --git a/partitions.go b/partitions.go
index c24f082..321af2e 100644
--- a/partitions.go
+++ b/partitions.go
@@ -113,7 +113,7 @@ func (ps *PartitionSpec) CompatibleWith(other
*PartitionSpec) bool {
// Equals returns true iff the field lists are the same AND the spec id
// is the same between this partition spec and the provided one.
-func (ps *PartitionSpec) Equals(other PartitionSpec) bool {
+func (ps PartitionSpec) Equals(other PartitionSpec) bool {
return ps.id == other.id && slices.Equal(ps.fields, other.fields)
}
diff --git a/schema.go b/schema.go
index 44fbb0e..a204b54 100644
--- a/schema.go
+++ b/schema.go
@@ -21,6 +21,7 @@ import (
"encoding/json"
"fmt"
"strings"
+ "sync"
"sync/atomic"
"golang.org/x/exp/maps"
@@ -45,6 +46,8 @@ type Schema struct {
nameToID atomic.Pointer[map[string]int]
nameToIDLower atomic.Pointer[map[string]int]
idToAccessor atomic.Pointer[map[int]accessor]
+
+ lazyIDToParent func() (map[int]int, error)
}
// NewSchema constructs a new schema with the provided ID
@@ -57,7 +60,11 @@ func NewSchema(id int, fields ...NestedField) *Schema {
// and fields, along with a slice of field IDs to be listed as identifier
// fields.
func NewSchemaWithIdentifiers(id int, identifierIDs []int, fields
...NestedField) *Schema {
- return &Schema{ID: id, fields: fields, IdentifierFieldIDs:
identifierIDs}
+ s := &Schema{ID: id, fields: fields, IdentifierFieldIDs: identifierIDs}
+ s.lazyIDToParent = sync.OnceValues(func() (map[int]int, error) {
+ return IndexParents(s)
+ })
+ return s
}
func (s *Schema) String() string {
@@ -171,6 +178,12 @@ func (s *Schema) UnmarshalJSON(b []byte) error {
return err
}
+ if s.lazyIDToParent == nil {
+ s.lazyIDToParent = sync.OnceValues(func() (map[int]int, error) {
+ return IndexParents(s)
+ })
+ }
+
s.fields = aux.Fields
if s.IdentifierFieldIDs == nil {
s.IdentifierFieldIDs = []int{}
@@ -346,6 +359,27 @@ func (s *Schema) Select(caseSensitive bool, names
...string) (*Schema, error) {
return PruneColumns(s, ids, true)
}
+func (s *Schema) FieldHasOptionalParent(id int) bool {
+ idToParent, _ := s.lazyIDToParent()
+ idToField, _ := s.lazyIDToField()
+
+ f, ok := idToField[id]
+ if !ok {
+ return false
+ }
+
+ for {
+ parent, ok := idToParent[f.ID]
+ if !ok {
+ return false
+ }
+
+ if f = idToField[parent]; !f.Required {
+ return true
+ }
+ }
+}
+
// SchemaVisitor is an interface that can be implemented to allow for
// easy traversal and processing of a schema.
//
@@ -885,6 +919,66 @@ func (findLastFieldID) Map(_ MapType, keyResult,
valueResult int) int {
func (findLastFieldID) Primitive(PrimitiveType) int { return 0 }
+// IndexParents generates an index of field IDs to their parent field
+// IDs. Root fields are not indexed
+func IndexParents(schema *Schema) (map[int]int, error) {
+ indexer := &indexParents{
+ idToParent: make(map[int]int),
+ idStack: make([]int, 0),
+ }
+ return Visit(schema, indexer)
+}
+
+type indexParents struct {
+ idToParent map[int]int
+ idStack []int
+}
+
+func (i *indexParents) BeforeField(field NestedField) {
+ i.idStack = append(i.idStack, field.ID)
+}
+
+func (i *indexParents) AfterField(field NestedField) {
+ i.idStack = i.idStack[:len(i.idStack)-1]
+}
+
+func (i *indexParents) Schema(schema *Schema, _ map[int]int) map[int]int {
+ return i.idToParent
+}
+
+func (i *indexParents) Struct(st StructType, _ []map[int]int) map[int]int {
+ var parent int
+ stackLen := len(i.idStack)
+ if stackLen > 0 {
+ parent = i.idStack[stackLen-1]
+ for _, f := range st.FieldList {
+ i.idToParent[f.ID] = parent
+ }
+ }
+
+ return i.idToParent
+}
+
+func (i *indexParents) Field(NestedField, map[int]int) map[int]int {
+ return i.idToParent
+}
+
+func (i *indexParents) List(list ListType, _ map[int]int) map[int]int {
+ i.idToParent[list.ElementID] = i.idStack[len(i.idStack)-1]
+ return i.idToParent
+}
+
+func (i *indexParents) Map(mapType MapType, _, _ map[int]int) map[int]int {
+ parent := i.idStack[len(i.idStack)-1]
+ i.idToParent[mapType.KeyID] = parent
+ i.idToParent[mapType.ValueID] = parent
+ return i.idToParent
+}
+
+func (i *indexParents) Primitive(PrimitiveType) map[int]int {
+ return i.idToParent
+}
+
type buildPosAccessors struct{}
func (buildPosAccessors) Schema(_ *Schema, structResult map[int]accessor)
map[int]accessor {
diff --git a/table/metadata.go b/table/metadata.go
index 957e163..47b3ffe 100644
--- a/table/metadata.go
+++ b/table/metadata.go
@@ -22,6 +22,8 @@ import (
"errors"
"fmt"
"io"
+ "maps"
+ "slices"
"github.com/apache/iceberg-go"
@@ -88,6 +90,8 @@ type Metadata interface {
// to be used for arbitrary metadata. For example,
commit.retry.num-retries
// is used to control the number of commit retries.
Properties() iceberg.Properties
+
+ Equals(Metadata) bool
}
var (
@@ -134,6 +138,12 @@ func ParseMetadataBytes(b []byte) (Metadata, error) {
return ret, json.Unmarshal(b, ret)
}
+func sliceEqualHelper[T interface{ Equals(T) bool }](s1, s2 []T) bool {
+ return slices.EqualFunc(s1, s2, func(t1, t2 T) bool {
+ return t1.Equals(t2)
+ })
+}
+
// https://iceberg.apache.org/spec/#iceberg-table-spec
type commonMetadata struct {
FormatVersion int `json:"format-version"`
@@ -156,6 +166,42 @@ type commonMetadata struct {
Refs map[string]SnapshotRef `json:"refs"`
}
+func (c *commonMetadata) Equals(other *commonMetadata) bool {
+ switch {
+ case c.LastPartitionID == nil && other.LastPartitionID != nil:
+ fallthrough
+ case c.LastPartitionID != nil && other.LastPartitionID == nil:
+ fallthrough
+ case c.CurrentSnapshotID == nil && other.CurrentSnapshotID != nil:
+ fallthrough
+ case c.CurrentSnapshotID != nil && other.CurrentSnapshotID == nil:
+ return false
+ }
+
+ switch {
+ case !sliceEqualHelper(c.SchemaList, other.SchemaList):
+ fallthrough
+ case !sliceEqualHelper(c.SnapshotList, other.SnapshotList):
+ fallthrough
+ case !sliceEqualHelper(c.Specs, other.Specs):
+ fallthrough
+ case !maps.Equal(c.Props, other.Props):
+ fallthrough
+ case !maps.EqualFunc(c.Refs, other.Refs, func(sr1, sr2 SnapshotRef)
bool { return sr1.Equals(sr2) }):
+ return false
+ }
+
+ return c.FormatVersion == other.FormatVersion && c.UUID == other.UUID &&
+ ((c.LastPartitionID == other.LastPartitionID) ||
(*c.LastPartitionID == *other.LastPartitionID)) &&
+ ((c.CurrentSnapshotID == other.CurrentSnapshotID) ||
(*c.CurrentSnapshotID == *other.CurrentSnapshotID)) &&
+ c.Loc == other.Loc && c.LastUpdatedMS == other.LastUpdatedMS &&
+ c.LastColumnId == other.LastColumnId && c.CurrentSchemaID ==
other.CurrentSchemaID &&
+ c.DefaultSpecID == other.DefaultSpecID && c.DefaultSortOrderID
== other.DefaultSortOrderID &&
+ slices.Equal(c.SnapshotLog, other.SnapshotLog) &&
slices.Equal(c.MetadataLog, other.MetadataLog) &&
+ sliceEqualHelper(c.SortOrderList, other.SortOrderList)
+
+}
+
func (c *commonMetadata) TableUUID() uuid.UUID { return c.UUID }
func (c *commonMetadata) Location() string { return c.Loc }
func (c *commonMetadata) LastUpdatedMillis() int64 { return c.LastUpdatedMS }
@@ -331,6 +377,16 @@ type MetadataV1 struct {
commonMetadata
}
+func (m *MetadataV1) Equals(other Metadata) bool {
+ rhs, ok := other.(*MetadataV1)
+ if !ok {
+ return false
+ }
+
+ return m.Schema.Equals(&rhs.Schema) && slices.Equal(m.Partition,
rhs.Partition) &&
+ m.commonMetadata.Equals(&rhs.commonMetadata)
+}
+
func (m *MetadataV1) preValidate() {
if len(m.SchemaList) == 0 {
m.SchemaList = []*iceberg.Schema{&m.Schema}
@@ -388,6 +444,16 @@ type MetadataV2 struct {
commonMetadata
}
+func (m *MetadataV2) Equals(other Metadata) bool {
+ rhs, ok := other.(*MetadataV2)
+ if !ok {
+ return false
+ }
+
+ return m.LastSequenceNumber == rhs.LastSequenceNumber &&
+ m.commonMetadata.Equals(&rhs.commonMetadata)
+}
+
func (m *MetadataV2) UnmarshalJSON(b []byte) error {
type Alias MetadataV2
aux := (*Alias)(m)
diff --git a/table/metadata_test.go b/table/metadata_test.go
index 080688f..e268d88 100644
--- a/table/metadata_test.go
+++ b/table/metadata_test.go
@@ -19,6 +19,7 @@ package table_test
import (
"encoding/json"
+ "slices"
"testing"
"github.com/apache/iceberg-go"
@@ -131,7 +132,9 @@ func TestMetadataV1Parsing(t *testing.T) {
iceberg.NestedField{ID: 3, Name: "z", Type:
iceberg.PrimitiveTypes.Int64, Required: true},
)
- assert.Equal(t, []*iceberg.Schema{expected}, meta.Schemas())
+ assert.True(t, slices.EqualFunc([]*iceberg.Schema{expected},
meta.Schemas(), func(s1, s2 *iceberg.Schema) bool {
+ return s1.Equals(s2)
+ }))
assert.Zero(t, data.SchemaList[0].ID)
assert.True(t, meta.CurrentSchema().Equals(expected))
assert.Equal(t, []iceberg.PartitionSpec{
diff --git a/table/refs.go b/table/refs.go
index cf63efc..f0eb697 100644
--- a/table/refs.go
+++ b/table/refs.go
@@ -20,6 +20,7 @@ package table
import (
"encoding/json"
"errors"
+ "reflect"
)
const MainBranch = "main"
@@ -45,6 +46,10 @@ type SnapshotRef struct {
MaxRefAgeMs *int64 `json:"max-ref-age-ms,omitempty"`
}
+func (s *SnapshotRef) Equals(rhs SnapshotRef) bool {
+ return reflect.DeepEqual(s, &rhs)
+}
+
func (s *SnapshotRef) UnmarshalJSON(b []byte) error {
type Alias SnapshotRef
aux := (*Alias)(s)
diff --git a/table/sorting.go b/table/sorting.go
index 89bc76c..425a92e 100644
--- a/table/sorting.go
+++ b/table/sorting.go
@@ -21,6 +21,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "slices"
"strings"
"github.com/apache/iceberg-go"
@@ -134,6 +135,11 @@ type SortOrder struct {
Fields []SortField `json:"fields"`
}
+func (s SortOrder) Equals(rhs SortOrder) bool {
+ return s.OrderID == rhs.OrderID &&
+ slices.Equal(s.Fields, rhs.Fields)
+}
+
func (s SortOrder) String() string {
var b strings.Builder
fmt.Fprintf(&b, "%d: ", s.OrderID)
diff --git a/table/table.go b/table/table.go
index 80b68b3..ff370ec 100644
--- a/table/table.go
+++ b/table/table.go
@@ -18,8 +18,6 @@
package table
import (
- "reflect"
-
"github.com/apache/iceberg-go"
"github.com/apache/iceberg-go/io"
"golang.org/x/exp/slices"
@@ -37,7 +35,7 @@ type Table struct {
func (t Table) Equals(other Table) bool {
return slices.Equal(t.identifier, other.identifier) &&
t.metadataLocation == other.metadataLocation &&
- reflect.DeepEqual(t.metadata, other.metadata)
+ t.metadata.Equals(other.metadata)
}
func (t Table) Identifier() Identifier { return t.identifier }
diff --git a/utils.go b/utils.go
index c70c2bb..c0a00fe 100644
--- a/utils.go
+++ b/utils.go
@@ -85,7 +85,7 @@ func (a *accessor) String() string {
func (a *accessor) Get(s structLike) any {
val, inner := s.Get(a.pos), a
- for inner.inner != nil {
+ for val != nil && inner.inner != nil {
inner = inner.inner
val = val.(structLike).Get(inner.pos)
}
@@ -98,6 +98,7 @@ type Set[E any] interface {
Members() []E
Equals(Set[E]) bool
Len() int
+ All(func(E) bool) bool
}
var lzseed = maphash.MakeSeed()
@@ -179,3 +180,19 @@ func (l literalSet) Equals(other Set[Literal]) bool {
}
func (l literalSet) Len() int { return len(l) }
+
+func (l literalSet) All(fn func(Literal) bool) bool {
+ for k, v := range l {
+ var e Literal
+ if k, ok := k.(Literal); ok {
+ e = k
+ } else {
+ e = v.orig
+ }
+
+ if !fn(e) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/visitors.go b/visitors.go
new file mode 100644
index 0000000..3428b2c
--- /dev/null
+++ b/visitors.go
@@ -0,0 +1,864 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package iceberg
+
+import (
+ "fmt"
+ "math"
+ "strings"
+
+ "github.com/google/uuid"
+)
+
+// BooleanExprVisitor is an interface for recursively visiting the nodes of a
+// boolean expression
+type BooleanExprVisitor[T any] interface {
+ VisitTrue() T
+ VisitFalse() T
+ VisitNot(childREsult T) T
+ VisitAnd(left, right T) T
+ VisitOr(left, right T) T
+ VisitUnbound(UnboundPredicate) T
+ VisitBound(BoundPredicate) T
+}
+
+// BoundBooleanExprVisitor builds on BooleanExprVisitor by adding interface
+// methods for visiting bound expressions, because we do casting of literals
+// during binding you can assume that the BoundTerm and the Literal passed
+// to a method have the same type.
+type BoundBooleanExprVisitor[T any] interface {
+ BooleanExprVisitor[T]
+
+ VisitIn(BoundTerm, Set[Literal]) T
+ VisitNotIn(BoundTerm, Set[Literal]) T
+ VisitIsNan(BoundTerm) T
+ VisitNotNan(BoundTerm) T
+ VisitIsNull(BoundTerm) T
+ VisitNotNull(BoundTerm) T
+ VisitEqual(BoundTerm, Literal) T
+ VisitNotEqual(BoundTerm, Literal) T
+ VisitGreaterEqual(BoundTerm, Literal) T
+ VisitGreater(BoundTerm, Literal) T
+ VisitLessEqual(BoundTerm, Literal) T
+ VisitLess(BoundTerm, Literal) T
+ VisitStartsWith(BoundTerm, Literal) T
+ VisitNotStartsWith(BoundTerm, Literal) T
+}
+
+// VisitExpr is a convenience function to use a given visitor to visit all
parts of
+// a boolean expression in-order. Values returned from the methods are passed
to the
+// subsequent methods, effectively "bubbling up" the results.
+func VisitExpr[T any](expr BooleanExpression, visitor BooleanExprVisitor[T])
(res T, err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ switch e := r.(type) {
+ case string:
+ err = fmt.Errorf("error encountered during
visitExpr: %s", e)
+ case error:
+ err = e
+ }
+ }
+ }()
+
+ return visitBoolExpr(expr, visitor), err
+}
+
+func visitBoolExpr[T any](e BooleanExpression, visitor BooleanExprVisitor[T])
T {
+ switch e := e.(type) {
+ case AlwaysFalse:
+ return visitor.VisitFalse()
+ case AlwaysTrue:
+ return visitor.VisitTrue()
+ case AndExpr:
+ left, right := visitBoolExpr(e.left, visitor),
visitBoolExpr(e.right, visitor)
+ return visitor.VisitAnd(left, right)
+ case OrExpr:
+ left, right := visitBoolExpr(e.left, visitor),
visitBoolExpr(e.right, visitor)
+ return visitor.VisitOr(left, right)
+ case NotExpr:
+ child := visitBoolExpr(e.child, visitor)
+ return visitor.VisitNot(child)
+ case UnboundPredicate:
+ return visitor.VisitUnbound(e)
+ case BoundPredicate:
+ return visitor.VisitBound(e)
+ }
+ panic(fmt.Errorf("%w: VisitBooleanExpression type %s",
ErrNotImplemented, e))
+}
+
+// VisitBoundPredicate uses a BoundBooleanExprVisitor to call the appropriate
method
+// based on the type of operation in the predicate. This is a convenience
function
+// for implementing the VisitBound method of a BoundBooleanExprVisitor by
simply calling
+// iceberg.VisitBoundPredicate(pred, this).
+func VisitBoundPredicate[T any](e BoundPredicate, visitor
BoundBooleanExprVisitor[T]) T {
+ switch e.Op() {
+ case OpIn:
+ return visitor.VisitIn(e.Term(),
e.(BoundSetPredicate).Literals())
+ case OpNotIn:
+ return visitor.VisitNotIn(e.Term(),
e.(BoundSetPredicate).Literals())
+ case OpIsNan:
+ return visitor.VisitIsNan(e.Term())
+ case OpNotNan:
+ return visitor.VisitNotNan(e.Term())
+ case OpIsNull:
+ return visitor.VisitIsNull(e.Term())
+ case OpNotNull:
+ return visitor.VisitNotNull(e.Term())
+ case OpEQ:
+ return visitor.VisitEqual(e.Term(),
e.(BoundLiteralPredicate).Literal())
+ case OpNEQ:
+ return visitor.VisitNotEqual(e.Term(),
e.(BoundLiteralPredicate).Literal())
+ case OpGTEQ:
+ return visitor.VisitGreaterEqual(e.Term(),
e.(BoundLiteralPredicate).Literal())
+ case OpGT:
+ return visitor.VisitGreater(e.Term(),
e.(BoundLiteralPredicate).Literal())
+ case OpLTEQ:
+ return visitor.VisitLessEqual(e.Term(),
e.(BoundLiteralPredicate).Literal())
+ case OpLT:
+ return visitor.VisitLess(e.Term(),
e.(BoundLiteralPredicate).Literal())
+ case OpStartsWith:
+ return visitor.VisitStartsWith(e.Term(),
e.(BoundLiteralPredicate).Literal())
+ case OpNotStartsWith:
+ return visitor.VisitNotStartsWith(e.Term(),
e.(BoundLiteralPredicate).Literal())
+ }
+ panic(fmt.Errorf("%w: unhandled bound predicate type: %s",
ErrNotImplemented, e))
+}
+
+// BindExpr recursively binds each portion of an expression using the provided
schema.
+// Because the expression can end up being simplified to just
AlwaysTrue/AlwaysFalse,
+// this returns a BooleanExpression.
+func BindExpr(s *Schema, expr BooleanExpression, caseSensitive bool)
(BooleanExpression, error) {
+ return VisitExpr(expr, &bindVisitor{schema: s, caseSensitive:
caseSensitive})
+}
+
+type bindVisitor struct {
+ schema *Schema
+ caseSensitive bool
+}
+
+func (*bindVisitor) VisitTrue() BooleanExpression { return AlwaysTrue{} }
+func (*bindVisitor) VisitFalse() BooleanExpression { return AlwaysFalse{} }
+func (*bindVisitor) VisitNot(child BooleanExpression) BooleanExpression {
+ return NewNot(child)
+}
+func (*bindVisitor) VisitAnd(left, right BooleanExpression) BooleanExpression {
+ return NewAnd(left, right)
+}
+func (*bindVisitor) VisitOr(left, right BooleanExpression) BooleanExpression {
+ return NewOr(left, right)
+}
+func (b *bindVisitor) VisitUnbound(pred UnboundPredicate) BooleanExpression {
+ expr, err := pred.Bind(b.schema, b.caseSensitive)
+ if err != nil {
+ panic(err)
+ }
+ return expr
+}
+func (*bindVisitor) VisitBound(pred BoundPredicate) BooleanExpression {
+ panic(fmt.Errorf("%w: found already bound predicate: %s",
ErrInvalidArgument, pred))
+}
+
+// ExpressionEvaluator returns a function which can be used to evaluate a
given expression
+// as long as a structlike value is passed which operates like and matches the
passed in
+// schema.
+func ExpressionEvaluator(s *Schema, unbound BooleanExpression, caseSensitive
bool) (func(structLike) (bool, error), error) {
+ bound, err := BindExpr(s, unbound, caseSensitive)
+ if err != nil {
+ return nil, err
+ }
+
+ return (&exprEvaluator{bound: bound}).Eval, nil
+}
+
+type exprEvaluator struct {
+ bound BooleanExpression
+ st structLike
+}
+
+func (e *exprEvaluator) Eval(st structLike) (bool, error) {
+ e.st = st
+ return VisitExpr(e.bound, e)
+}
+
+func (e *exprEvaluator) VisitUnbound(UnboundPredicate) bool {
+ panic("found unbound predicate when evaluating expression")
+}
+
+func (e *exprEvaluator) VisitBound(pred BoundPredicate) bool {
+ return VisitBoundPredicate(pred, e)
+}
+
+func (*exprEvaluator) VisitTrue() bool { return true }
+func (*exprEvaluator) VisitFalse() bool { return false }
+func (*exprEvaluator) VisitNot(child bool) bool { return !child }
+func (*exprEvaluator) VisitAnd(left, right bool) bool { return left && right }
+func (*exprEvaluator) VisitOr(left, right bool) bool { return left || right }
+
+func (e *exprEvaluator) VisitIn(term BoundTerm, literals Set[Literal]) bool {
+ v := term.evalToLiteral(e.st)
+ if !v.Valid {
+ return false
+ }
+
+ return literals.Contains(v.Val)
+}
+
+func (e *exprEvaluator) VisitNotIn(term BoundTerm, literals Set[Literal]) bool
{
+ return !e.VisitIn(term, literals)
+}
+
+func (e *exprEvaluator) VisitIsNan(term BoundTerm) bool {
+ switch term.Type().(type) {
+ case Float32Type:
+ v := term.(bound[float32]).eval(e.st)
+ if !v.Valid {
+ break
+ }
+ return math.IsNaN(float64(v.Val))
+ case Float64Type:
+ v := term.(bound[float64]).eval(e.st)
+ if !v.Valid {
+ break
+ }
+ return math.IsNaN(v.Val)
+ }
+
+ return false
+}
+
+func (e *exprEvaluator) VisitNotNan(term BoundTerm) bool {
+ return !e.VisitIsNan(term)
+}
+
+func (e *exprEvaluator) VisitIsNull(term BoundTerm) bool {
+ return term.evalIsNull(e.st)
+}
+
+func (e *exprEvaluator) VisitNotNull(term BoundTerm) bool {
+ return !term.evalIsNull(e.st)
+}
+
+func nullsFirstCmp[T LiteralType](cmp Comparator[T], v1, v2 Optional[T]) int {
+ if !v1.Valid {
+ if !v2.Valid {
+ // both are null
+ return 0
+ }
+ // v1 is null, v2 is not
+ return -1
+ }
+
+ if !v2.Valid {
+ return 1
+ }
+
+ return cmp(v1.Val, v2.Val)
+}
+
+func typedCmp[T LiteralType](st structLike, term BoundTerm, lit Literal) int {
+ v := term.(bound[T]).eval(st)
+ var l Optional[T]
+
+ rhs := lit.(TypedLiteral[T])
+ if lit != nil {
+ l.Valid = true
+ l.Val = rhs.Value()
+ }
+
+ return nullsFirstCmp(rhs.Comparator(), v, l)
+}
+
+func doCmp(st structLike, term BoundTerm, lit Literal) int {
+ // we already properly casted and converted everything during binding
+ // so we can type assert based on the term type
+ switch term.Type().(type) {
+ case BooleanType:
+ return typedCmp[bool](st, term, lit)
+ case Int32Type:
+ return typedCmp[int32](st, term, lit)
+ case Int64Type:
+ return typedCmp[int64](st, term, lit)
+ case Float32Type:
+ return typedCmp[float32](st, term, lit)
+ case Float64Type:
+ return typedCmp[float64](st, term, lit)
+ case DateType:
+ return typedCmp[Date](st, term, lit)
+ case TimeType:
+ return typedCmp[Time](st, term, lit)
+ case TimestampType, TimestampTzType:
+ return typedCmp[Timestamp](st, term, lit)
+ case BinaryType, FixedType:
+ return typedCmp[[]byte](st, term, lit)
+ case StringType:
+ return typedCmp[string](st, term, lit)
+ case UUIDType:
+ return typedCmp[uuid.UUID](st, term, lit)
+ case DecimalType:
+ return typedCmp[Decimal](st, term, lit)
+ }
+ panic(ErrType)
+}
+
+func (e *exprEvaluator) VisitEqual(term BoundTerm, lit Literal) bool {
+ return doCmp(e.st, term, lit) == 0
+}
+
+func (e *exprEvaluator) VisitNotEqual(term BoundTerm, lit Literal) bool {
+ return doCmp(e.st, term, lit) != 0
+}
+
+func (e *exprEvaluator) VisitGreater(term BoundTerm, lit Literal) bool {
+ return doCmp(e.st, term, lit) > 0
+}
+
+func (e *exprEvaluator) VisitGreaterEqual(term BoundTerm, lit Literal) bool {
+ return doCmp(e.st, term, lit) >= 0
+}
+
+func (e *exprEvaluator) VisitLess(term BoundTerm, lit Literal) bool {
+ return doCmp(e.st, term, lit) < 0
+}
+
+func (e *exprEvaluator) VisitLessEqual(term BoundTerm, lit Literal) bool {
+ return doCmp(e.st, term, lit) <= 0
+}
+
+func (e *exprEvaluator) VisitStartsWith(term BoundTerm, lit Literal) bool {
+ var value, prefix string
+
+ switch lit.(type) {
+ case TypedLiteral[string]:
+ val := term.(bound[string]).eval(e.st)
+ if !val.Valid {
+ return false
+ }
+ prefix, value = lit.(StringLiteral).Value(), val.Val
+ case TypedLiteral[[]byte]:
+ val := term.(bound[[]byte]).eval(e.st)
+ if !val.Valid {
+ return false
+ }
+ prefix, value = string(lit.(TypedLiteral[[]byte]).Value()),
string(val.Val)
+ }
+
+ return strings.HasPrefix(value, prefix)
+}
+
+func (e *exprEvaluator) VisitNotStartsWith(term BoundTerm, lit Literal) bool {
+ return !e.VisitStartsWith(term, lit)
+}
+
+// RewriteNotExpr rewrites a boolean expression to remove "Not" nodes from the
expression
+// tree. This is because Projections assume there are no "not" nodes.
+//
+// Not nodes will be replaced with simply calling `Negate` on the child in the
tree.
+func RewriteNotExpr(expr BooleanExpression) (BooleanExpression, error) {
+ return VisitExpr(expr, rewriteNotVisitor{})
+}
+
+type rewriteNotVisitor struct{}
+
+func (rewriteNotVisitor) VisitTrue() BooleanExpression { return AlwaysTrue{} }
+func (rewriteNotVisitor) VisitFalse() BooleanExpression { return AlwaysFalse{}
}
+func (rewriteNotVisitor) VisitNot(child BooleanExpression) BooleanExpression {
+ return child.Negate()
+}
+
+func (rewriteNotVisitor) VisitAnd(left, right BooleanExpression)
BooleanExpression {
+ return NewAnd(left, right)
+}
+
+func (rewriteNotVisitor) VisitOr(left, right BooleanExpression)
BooleanExpression {
+ return NewOr(left, right)
+}
+
+func (rewriteNotVisitor) VisitUnbound(pred UnboundPredicate) BooleanExpression
{
+ return pred
+}
+
+func (rewriteNotVisitor) VisitBound(pred BoundPredicate) BooleanExpression {
+ return pred
+}
+
+const (
+ rowsMightMatch, rowsMustMatch = true, true
+ rowsCannotMatch, rowsMightNotMatch = false, false
+ inPredicateLimit = 200
+)
+
+// NewManifestEvaluator returns a function that can be used to evaluate
whether a particular
+// manifest file has rows that might or might not match a given partition
filter by using
+// the stats provided in the partitions
(UpperBound/LowerBound/ContainsNull/ContainsNaN).
+func NewManifestEvaluator(spec PartitionSpec, schema *Schema, partitionFilter
BooleanExpression, caseSensitive bool) (func(ManifestFile) (bool, error),
error) {
+ partType := spec.PartitionType(schema)
+ partSchema := NewSchema(0, partType.FieldList...)
+ filter, err := RewriteNotExpr(partitionFilter)
+ if err != nil {
+ return nil, err
+ }
+
+ boundFilter, err := BindExpr(partSchema, filter, caseSensitive)
+ if err != nil {
+ return nil, err
+ }
+
+ return (&manifestEvalVisitor{partitionFilter: boundFilter}).Eval, nil
+}
+
+type manifestEvalVisitor struct {
+ partitionFields []FieldSummary
+ partitionFilter BooleanExpression
+}
+
+func (m *manifestEvalVisitor) Eval(manifest ManifestFile) (bool, error) {
+ if parts := manifest.Partitions(); len(parts) > 0 {
+ m.partitionFields = parts
+ return VisitExpr(m.partitionFilter, m)
+ }
+
+ return rowsMightMatch, nil
+}
+
+func allBoundCmp[T LiteralType](bound Literal, set Set[Literal], want int)
bool {
+ val := bound.(TypedLiteral[T])
+ cmp := val.Comparator()
+
+ return set.All(func(e Literal) bool {
+ return cmp(val.Value(), e.(TypedLiteral[T]).Value()) == want
+ })
+}
+
+func allBoundCheck(bound Literal, set Set[Literal], want int) bool {
+ switch bound.Type().(type) {
+ case BooleanType:
+ return allBoundCmp[bool](bound, set, want)
+ case Int32Type:
+ return allBoundCmp[int32](bound, set, want)
+ case Int64Type:
+ return allBoundCmp[int64](bound, set, want)
+ case Float32Type:
+ return allBoundCmp[float32](bound, set, want)
+ case Float64Type:
+ return allBoundCmp[float64](bound, set, want)
+ case DateType:
+ return allBoundCmp[Date](bound, set, want)
+ case TimeType:
+ return allBoundCmp[Time](bound, set, want)
+ case TimestampType, TimestampTzType:
+ return allBoundCmp[Timestamp](bound, set, want)
+ case BinaryType, FixedType:
+ return allBoundCmp[[]byte](bound, set, want)
+ case StringType:
+ return allBoundCmp[string](bound, set, want)
+ case UUIDType:
+ return allBoundCmp[uuid.UUID](bound, set, want)
+ case DecimalType:
+ return allBoundCmp[Decimal](bound, set, want)
+ }
+ panic(ErrType)
+}
+
+func (m *manifestEvalVisitor) VisitIn(term BoundTerm, literals Set[Literal])
bool {
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ if field.LowerBound == nil {
+ return rowsCannotMatch
+ }
+
+ if literals.Len() > inPredicateLimit {
+ return rowsMightMatch
+ }
+
+ lower, err := LiteralFromBytes(term.Type(), *field.LowerBound)
+ if err != nil {
+ panic(err)
+ }
+
+ if allBoundCheck(lower, literals, 1) {
+ return rowsCannotMatch
+ }
+
+ if field.UpperBound != nil {
+ upper, err := LiteralFromBytes(term.Type(), *field.UpperBound)
+ if err != nil {
+ panic(err)
+ }
+
+ if allBoundCheck(upper, literals, -1) {
+ return rowsCannotMatch
+ }
+ }
+
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitNotIn(term BoundTerm, literals
Set[Literal]) bool {
+ // because the bounds are not necessarily a min or max value, this
cannot be answered using them
+ // notIn(col, {X, ...}) with (X, Y) doesn't guarantee that X is a value
in col
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitIsNan(term BoundTerm) bool {
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ if field.ContainsNaN != nil && !*field.ContainsNaN {
+ return rowsCannotMatch
+ }
+
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitNotNan(term BoundTerm) bool {
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ if field.ContainsNaN != nil && *field.ContainsNaN &&
!field.ContainsNull && field.LowerBound == nil {
+ return rowsCannotMatch
+ }
+
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitIsNull(term BoundTerm) bool {
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ if !field.ContainsNull {
+ return rowsCannotMatch
+ }
+
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitNotNull(term BoundTerm) bool {
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ // ContainsNull encodes whether at least one partition value is null
+ // lowerBound is null if all partition values are null
+ allNull := field.ContainsNull && field.LowerBound == nil
+ if allNull && (term.Ref().Type().Equals(PrimitiveTypes.Float32) ||
term.Ref().Type().Equals(PrimitiveTypes.Float64)) {
+ // floating point types may include NaN values, which we check
separately
+ // in case bounds don't include NaN values, ContainsNaN needsz
to be checked
+ allNull = field.ContainsNaN != nil && !*field.ContainsNaN
+ }
+
+ if allNull {
+ return rowsCannotMatch
+ }
+
+ return rowsMightMatch
+}
+
+func getCmp[T LiteralType](b TypedLiteral[T]) func(Literal, Literal) int {
+ cmp := b.Comparator()
+ return func(l1, l2 Literal) int {
+ return cmp(l1.(TypedLiteral[T]).Value(),
l2.(TypedLiteral[T]).Value())
+ }
+}
+
+func getCmpLiteral(boundary Literal) func(Literal, Literal) int {
+ switch l := boundary.(type) {
+ case TypedLiteral[bool]:
+ return getCmp(l)
+ case TypedLiteral[int32]:
+ return getCmp(l)
+ case TypedLiteral[int64]:
+ return getCmp(l)
+ case TypedLiteral[float32]:
+ return getCmp(l)
+ case TypedLiteral[float64]:
+ return getCmp(l)
+ case TypedLiteral[Date]:
+ return getCmp(l)
+ case TypedLiteral[Time]:
+ return getCmp(l)
+ case TypedLiteral[Timestamp]:
+ return getCmp(l)
+ case TypedLiteral[[]byte]:
+ return getCmp(l)
+ case TypedLiteral[string]:
+ return getCmp(l)
+ case TypedLiteral[uuid.UUID]:
+ return getCmp(l)
+ case TypedLiteral[Decimal]:
+ return getCmp(l)
+ }
+ panic(ErrType)
+}
+
+func (m *manifestEvalVisitor) VisitEqual(term BoundTerm, lit Literal) bool {
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ if field.LowerBound == nil || field.UpperBound == nil {
+ // values are all null and literal cannot contain null
+ return rowsCannotMatch
+ }
+
+ lower, err := LiteralFromBytes(term.Ref().Type(), *field.LowerBound)
+ if err != nil {
+ panic(err)
+ }
+
+ cmp := getCmpLiteral(lower)
+ if cmp(lower, lit) == 1 {
+ return rowsCannotMatch
+ }
+
+ upper, err := LiteralFromBytes(term.Ref().Type(), *field.UpperBound)
+ if err != nil {
+ panic(err)
+ }
+
+ if cmp(lit, upper) == 1 {
+ return rowsCannotMatch
+ }
+
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitNotEqual(term BoundTerm, lit Literal) bool {
+ // because bounds are not necessarily a min or max, this cannot be
answered
+ // using them. notEq(col, X) with (X, Y) doesn't guarantee X is a value
in col
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitGreaterEqual(term BoundTerm, lit Literal)
bool {
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ if field.UpperBound == nil {
+ return rowsCannotMatch
+ }
+
+ upper, err := LiteralFromBytes(term.Ref().Type(), *field.UpperBound)
+ if err != nil {
+ panic(err)
+ }
+
+ if getCmpLiteral(upper)(lit, upper) == 1 {
+ return rowsCannotMatch
+ }
+
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitGreater(term BoundTerm, lit Literal) bool {
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ if field.UpperBound == nil {
+ return rowsCannotMatch
+ }
+
+ upper, err := LiteralFromBytes(term.Ref().Type(), *field.UpperBound)
+ if err != nil {
+ panic(err)
+ }
+
+ if getCmpLiteral(upper)(lit, upper) >= 0 {
+ return rowsCannotMatch
+ }
+
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitLessEqual(term BoundTerm, lit Literal) bool
{
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ if field.LowerBound == nil {
+ return rowsCannotMatch
+ }
+
+ lower, err := LiteralFromBytes(term.Ref().Type(), *field.LowerBound)
+ if err != nil {
+ panic(err)
+ }
+
+ if getCmpLiteral(lower)(lit, lower) == -1 {
+ return rowsCannotMatch
+ }
+
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitLess(term BoundTerm, lit Literal) bool {
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ if field.LowerBound == nil {
+ return rowsCannotMatch
+ }
+
+ lower, err := LiteralFromBytes(term.Ref().Type(), *field.LowerBound)
+ if err != nil {
+ panic(err)
+ }
+
+ if getCmpLiteral(lower)(lit, lower) <= 0 {
+ return rowsCannotMatch
+ }
+
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitStartsWith(term BoundTerm, lit Literal)
bool {
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ var prefix string
+ if val, ok := lit.(TypedLiteral[string]); ok {
+ prefix = val.Value()
+ } else {
+ prefix = string(lit.(TypedLiteral[[]byte]).Value())
+ }
+
+ lenPrefix := len(prefix)
+
+ if field.LowerBound == nil {
+ return rowsCannotMatch
+ }
+
+ lower, err := LiteralFromBytes(term.Ref().Type(), *field.LowerBound)
+ if err != nil {
+ panic(err)
+ }
+
+ // truncate lower bound so that it's length is not greater than the
length of prefix
+ var v string
+ switch l := lower.(type) {
+ case TypedLiteral[string]:
+ v = l.Value()
+ if len(v) > lenPrefix {
+ v = v[:lenPrefix]
+ }
+ case TypedLiteral[[]byte]:
+ v = string(l.Value())
+ if len(v) > lenPrefix {
+ v = v[:lenPrefix]
+ }
+ }
+
+ if v > prefix {
+ return rowsCannotMatch
+ }
+
+ if field.UpperBound == nil {
+ return rowsCannotMatch
+ }
+
+ upper, err := LiteralFromBytes(term.Ref().Type(), *field.UpperBound)
+ if err != nil {
+ panic(err)
+ }
+
+ switch u := upper.(type) {
+ case TypedLiteral[string]:
+ v = u.Value()
+ if len(v) > lenPrefix {
+ v = v[:lenPrefix]
+ }
+ case TypedLiteral[[]byte]:
+ v = string(u.Value())
+ if len(v) > lenPrefix {
+ v = v[:lenPrefix]
+ }
+ }
+
+ if v < prefix {
+ return rowsCannotMatch
+ }
+
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitNotStartsWith(term BoundTerm, lit Literal)
bool {
+ pos := term.Ref().Pos()
+ field := m.partitionFields[pos]
+
+ if field.ContainsNull || field.LowerBound == nil || field.UpperBound ==
nil {
+ return rowsMightMatch
+ }
+
+ // NotStartsWith will match unless ALL values must start with the
prefix.
+ // this happens when the lower and upper bounds BOTH start with the
prefix
+ lower, err := LiteralFromBytes(term.Ref().Type(), *field.LowerBound)
+ if err != nil {
+ panic(err)
+ }
+
+ upper, err := LiteralFromBytes(term.Ref().Type(), *field.UpperBound)
+ if err != nil {
+ panic(err)
+ }
+
+ var (
+ prefix, lowerBound, upperBound string
+ )
+ if val, ok := lit.(TypedLiteral[string]); ok {
+ prefix = val.Value()
+ lowerBound, upperBound = lower.(TypedLiteral[string]).Value(),
upper.(TypedLiteral[string]).Value()
+ } else {
+ prefix = string(lit.(TypedLiteral[[]byte]).Value())
+ lowerBound = string(lower.(TypedLiteral[[]byte]).Value())
+ upperBound = string(upper.(TypedLiteral[[]byte]).Value())
+ }
+
+ lenPrefix := len(prefix)
+ if len(lowerBound) < lenPrefix {
+ return rowsMightMatch
+ }
+
+ if lowerBound[:lenPrefix] == prefix {
+ // if upper is shorter then upper can't start with the prefix
+ if len(upperBound) < lenPrefix {
+ return rowsMightMatch
+ }
+
+ if upperBound[:lenPrefix] == prefix {
+ return rowsCannotMatch
+ }
+ }
+
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitTrue() bool {
+ return rowsMightMatch
+}
+
+func (m *manifestEvalVisitor) VisitFalse() bool {
+ return rowsCannotMatch
+}
+
+func (m *manifestEvalVisitor) VisitUnbound(UnboundPredicate) bool {
+ panic("need bound predicate")
+}
+
+func (m *manifestEvalVisitor) VisitBound(pred BoundPredicate) bool {
+ return VisitBoundPredicate(pred, m)
+}
+
+func (m *manifestEvalVisitor) VisitNot(child bool) bool { return !child }
+func (m *manifestEvalVisitor) VisitAnd(left, right bool) bool { return left &&
right }
+func (m *manifestEvalVisitor) VisitOr(left, right bool) bool { return left ||
right }
diff --git a/visitors_test.go b/visitors_test.go
new file mode 100644
index 0000000..688c1cc
--- /dev/null
+++ b/visitors_test.go
@@ -0,0 +1,1085 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package iceberg_test
+
+import (
+ "math"
+ "strings"
+ "testing"
+
+ "github.com/apache/arrow/go/v16/arrow/decimal128"
+ "github.com/apache/iceberg-go"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+type ExampleVisitor struct {
+ visitHistory []string
+}
+
+func (e *ExampleVisitor) VisitTrue() []string {
+ e.visitHistory = append(e.visitHistory, "TRUE")
+ return e.visitHistory
+}
+
+func (e *ExampleVisitor) VisitFalse() []string {
+ e.visitHistory = append(e.visitHistory, "FALSE")
+ return e.visitHistory
+}
+
+func (e *ExampleVisitor) VisitNot([]string) []string {
+ e.visitHistory = append(e.visitHistory, "NOT")
+ return e.visitHistory
+}
+
+func (e *ExampleVisitor) VisitAnd(_, _ []string) []string {
+ e.visitHistory = append(e.visitHistory, "AND")
+ return e.visitHistory
+}
+
+func (e *ExampleVisitor) VisitOr(_, _ []string) []string {
+ e.visitHistory = append(e.visitHistory, "OR")
+ return e.visitHistory
+}
+
+func (e *ExampleVisitor) VisitUnbound(pred iceberg.UnboundPredicate) []string {
+ e.visitHistory = append(e.visitHistory,
strings.ToUpper(pred.Op().String()))
+ return e.visitHistory
+}
+
+func (e *ExampleVisitor) VisitBound(pred iceberg.BoundPredicate) []string {
+ e.visitHistory = append(e.visitHistory,
strings.ToUpper(pred.Op().String()))
+ return e.visitHistory
+}
+
+type FooBoundExprVisitor struct {
+ ExampleVisitor
+}
+
+func (e *FooBoundExprVisitor) VisitBound(pred iceberg.BoundPredicate) []string
{
+ return iceberg.VisitBoundPredicate(pred, e)
+}
+
+func (e *FooBoundExprVisitor) VisitUnbound(pred iceberg.UnboundPredicate)
[]string {
+ panic("found unbound predicate when evaluating")
+}
+
+func (e *FooBoundExprVisitor) VisitIn(iceberg.BoundTerm,
iceberg.Set[iceberg.Literal]) []string {
+ e.visitHistory = append(e.visitHistory, "IN")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitNotIn(iceberg.BoundTerm,
iceberg.Set[iceberg.Literal]) []string {
+ e.visitHistory = append(e.visitHistory, "NOT_IN")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitIsNan(iceberg.BoundTerm) []string {
+ e.visitHistory = append(e.visitHistory, "IS_NAN")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitNotNan(iceberg.BoundTerm) []string {
+ e.visitHistory = append(e.visitHistory, "NOT_NAN")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitIsNull(iceberg.BoundTerm) []string {
+ e.visitHistory = append(e.visitHistory, "IS_NULL")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitNotNull(iceberg.BoundTerm) []string {
+ e.visitHistory = append(e.visitHistory, "NOT_NULL")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitEqual(iceberg.BoundTerm, iceberg.Literal)
[]string {
+ e.visitHistory = append(e.visitHistory, "EQUAL")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitNotEqual(iceberg.BoundTerm,
iceberg.Literal) []string {
+ e.visitHistory = append(e.visitHistory, "NOT_EQUAL")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitGreaterEqual(iceberg.BoundTerm,
iceberg.Literal) []string {
+ e.visitHistory = append(e.visitHistory, "GREATER_THAN_OR_EQUAL")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitGreater(iceberg.BoundTerm, iceberg.Literal)
[]string {
+ e.visitHistory = append(e.visitHistory, "GREATER_THAN")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitLessEqual(iceberg.BoundTerm,
iceberg.Literal) []string {
+ e.visitHistory = append(e.visitHistory, "LESS_THAN_OR_EQUAL")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitLess(iceberg.BoundTerm, iceberg.Literal)
[]string {
+ e.visitHistory = append(e.visitHistory, "LESS_THAN")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitStartsWith(iceberg.BoundTerm,
iceberg.Literal) []string {
+ e.visitHistory = append(e.visitHistory, "STARTS_WITH")
+ return e.visitHistory
+}
+
+func (e *FooBoundExprVisitor) VisitNotStartsWith(iceberg.BoundTerm,
iceberg.Literal) []string {
+ e.visitHistory = append(e.visitHistory, "NOT_STARTS_WITH")
+ return e.visitHistory
+}
+
+func TestBooleanExprVisitor(t *testing.T) {
+ expr := iceberg.NewAnd(
+ iceberg.NewOr(
+ iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("a"),
int32(1))),
+
iceberg.NewNot(iceberg.NotEqualTo(iceberg.Reference("b"), int32(0))),
+ iceberg.EqualTo(iceberg.Reference("a"), int32(1)),
+ iceberg.NotEqualTo(iceberg.Reference("b"), int32(0)),
+ ),
+ iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("a"),
int32(1))),
+ iceberg.NotEqualTo(iceberg.Reference("b"), int32(0)))
+
+ visitor := ExampleVisitor{visitHistory: make([]string, 0)}
+ result, err := iceberg.VisitExpr(expr, &visitor)
+ require.NoError(t, err)
+ assert.Equal(t, []string{
+ "EQUAL",
+ "NOT",
+ "NOTEQUAL",
+ "NOT",
+ "OR",
+ "EQUAL",
+ "OR",
+ "NOTEQUAL",
+ "OR",
+ "EQUAL",
+ "NOT",
+ "AND",
+ "NOTEQUAL",
+ "AND",
+ }, result)
+}
+
+func TestBindVisitorAlready(t *testing.T) {
+ bound, err := iceberg.EqualTo(iceberg.Reference("foo"), "hello").
+ Bind(tableSchemaSimple, false)
+ require.NoError(t, err)
+
+ _, err = iceberg.BindExpr(tableSchemaSimple, bound, true)
+ assert.ErrorIs(t, err, iceberg.ErrInvalidArgument)
+ assert.ErrorContains(t, err, "found already bound predicate:
BoundEqual(term=BoundReference(field=1: foo: optional string,
accessor=Accessor(position=0, inner=<nil>)), literal=hello)")
+}
+
+func TestAlwaysExprBinding(t *testing.T) {
+ tests := []struct {
+ expr iceberg.BooleanExpression
+ expected iceberg.BooleanExpression
+ }{
+ {iceberg.AlwaysTrue{}, iceberg.AlwaysTrue{}},
+ {iceberg.AlwaysFalse{}, iceberg.AlwaysFalse{}},
+ {iceberg.NewAnd(iceberg.AlwaysTrue{}, iceberg.AlwaysFalse{}),
iceberg.AlwaysFalse{}},
+ {iceberg.NewOr(iceberg.AlwaysTrue{}, iceberg.AlwaysFalse{}),
iceberg.AlwaysTrue{}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.expr.String(), func(t *testing.T) {
+ bound, err := iceberg.BindExpr(tableSchemaSimple,
tt.expr, true)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, bound)
+ })
+ }
+}
+
+func TestBoundBoolExprVisitor(t *testing.T) {
+ tests := []struct {
+ expr iceberg.BooleanExpression
+ expected []string
+ }{
+ {iceberg.NewAnd(iceberg.IsIn(iceberg.Reference("foo"), "foo",
"bar"),
+ iceberg.IsIn(iceberg.Reference("bar"), int32(1),
int32(2))), []string{"IN", "IN", "AND"}},
+
{iceberg.NewOr(iceberg.NewNot(iceberg.IsIn(iceberg.Reference("foo"), "foo",
"bar")),
+ iceberg.NewNot(iceberg.IsIn(iceberg.Reference("bar"),
int32(1), int32(2)))),
+ []string{"IN", "NOT", "IN", "NOT", "OR"}},
+ {iceberg.EqualTo(iceberg.Reference("bar"), int32(1)),
[]string{"EQUAL"}},
+ {iceberg.NotEqualTo(iceberg.Reference("foo"), "foo"),
[]string{"NOT_EQUAL"}},
+ {iceberg.AlwaysTrue{}, []string{"TRUE"}},
+ {iceberg.AlwaysFalse{}, []string{"FALSE"}},
+ {iceberg.NotIn(iceberg.Reference("foo"), "bar", "foo"),
[]string{"NOT_IN"}},
+ {iceberg.IsNull(iceberg.Reference("foo")), []string{"IS_NULL"}},
+ {iceberg.NotNull(iceberg.Reference("foo")),
[]string{"NOT_NULL"}},
+ {iceberg.GreaterThan(iceberg.Reference("foo"), "foo"),
[]string{"GREATER_THAN"}},
+ {iceberg.GreaterThanEqual(iceberg.Reference("foo"), "foo"),
[]string{"GREATER_THAN_OR_EQUAL"}},
+ {iceberg.LessThan(iceberg.Reference("foo"), "foo"),
[]string{"LESS_THAN"}},
+ {iceberg.LessThanEqual(iceberg.Reference("foo"), "foo"),
[]string{"LESS_THAN_OR_EQUAL"}},
+ {iceberg.StartsWith(iceberg.Reference("foo"), "foo"),
[]string{"STARTS_WITH"}},
+ {iceberg.NotStartsWith(iceberg.Reference("foo"), "foo"),
[]string{"NOT_STARTS_WITH"}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.expr.String(), func(t *testing.T) {
+ bound, err := iceberg.BindExpr(tableSchemaNested,
+ tt.expr,
+ true)
+ require.NoError(t, err)
+
+ visitor := FooBoundExprVisitor{ExampleVisitor:
ExampleVisitor{visitHistory: []string{}}}
+ result, err := iceberg.VisitExpr(bound, &visitor)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+type rowTester []any
+
+func (r rowTester) Size() int { return len(r) }
+func (r rowTester) Get(pos int) any { return r[pos] }
+func (r rowTester) Set(pos int, val any) {
+ r[pos] = val
+}
+
+func rowOf(vals ...any) rowTester {
+ return rowTester(vals)
+}
+
+var testSchema = iceberg.NewSchema(1,
+ iceberg.NestedField{ID: 13, Name: "x",
+ Type: iceberg.PrimitiveTypes.Int32, Required: true},
+ iceberg.NestedField{ID: 14, Name: "y",
+ Type: iceberg.PrimitiveTypes.Float64, Required: true},
+ iceberg.NestedField{ID: 15, Name: "z",
+ Type: iceberg.PrimitiveTypes.Int32},
+ iceberg.NestedField{ID: 16, Name: "s1",
+ Type: &iceberg.StructType{
+ FieldList: []iceberg.NestedField{{
+ ID: 17, Name: "s2", Required: true,
+ Type: &iceberg.StructType{
+ FieldList: []iceberg.NestedField{{
+ ID: 18, Name: "s3", Required:
true,
+ Type: &iceberg.StructType{
+ FieldList:
[]iceberg.NestedField{{
+ ID: 19, Name:
"s4", Required: true,
+ Type:
&iceberg.StructType{
+
FieldList: []iceberg.NestedField{{
+
ID: 20, Name: "i", Required: true,
+
Type: iceberg.PrimitiveTypes.Int32,
+ }},
+ },
+ }},
+ },
+ }},
+ },
+ }},
+ }},
+ iceberg.NestedField{ID: 21, Name: "s5", Type: &iceberg.StructType{
+ FieldList: []iceberg.NestedField{{
+ ID: 22, Name: "s6", Required: true, Type:
&iceberg.StructType{
+ FieldList: []iceberg.NestedField{{
+ ID: 23, Name: "f", Required: true,
Type: iceberg.PrimitiveTypes.Float32,
+ }},
+ },
+ }},
+ }},
+ iceberg.NestedField{ID: 24, Name: "s", Type:
iceberg.PrimitiveTypes.String})
+
+func TestExprEvaluator(t *testing.T) {
+ type testCase struct {
+ str string
+ row rowTester
+ result bool
+ }
+
+ tests := []struct {
+ exp iceberg.BooleanExpression
+ cases []testCase
+ }{
+ {iceberg.AlwaysTrue{}, []testCase{{"always true", rowOf(),
true}}},
+ {iceberg.AlwaysFalse{}, []testCase{{"always false", rowOf(),
false}}},
+ {iceberg.LessThan(iceberg.Reference("x"), int32(7)), []testCase{
+ {"7 < 7 => false", rowOf(7, 8, nil, nil), false},
+ {"6 < 7 => true", rowOf(6, 8, nil, nil), true},
+ }},
+ {iceberg.LessThan(iceberg.Reference("s1.s2.s3.s4.i"),
int32(7)), []testCase{
+ {"7 < 7 => false", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(7))))), false},
+ {"6 < 7 => true", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(6))))), true},
+ {"nil < 7 => true", rowOf(7, 8, nil, nil), true},
+ }},
+ {iceberg.LessThanEqual(iceberg.Reference("x"), int32(7)),
[]testCase{
+ {"7 <= 7 => true", rowOf(7, 8, nil), true},
+ {"6 <= 7 => true", rowOf(6, 8, nil), true},
+ {"8 <= 7 => false", rowOf(8, 8, nil), false},
+ }},
+ {iceberg.LessThanEqual(iceberg.Reference("s1.s2.s3.s4.i"),
int32(7)), []testCase{
+ {"7 <= 7 => true", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(7))))), true},
+ {"6 <= 7 => true", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(6))))), true},
+ {"8 <= 7 => false", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(8))))), false},
+ }},
+ {iceberg.GreaterThan(iceberg.Reference("x"), int32(7)),
[]testCase{
+ {"7 > 7 => false", rowOf(7, 8, nil), false},
+ {"6 > 7 => false", rowOf(6, 8, nil), false},
+ {"8 > 7 => true", rowOf(8, 8, nil), true},
+ }},
+ {iceberg.GreaterThan(iceberg.Reference("s1.s2.s3.s4.i"),
int32(7)), []testCase{
+ {"7 > 7 => false", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(7))))), false},
+ {"6 > 7 => false", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(6))))), false},
+ {"8 > 7 => true", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(8))))), true},
+ }},
+ {iceberg.GreaterThanEqual(iceberg.Reference("x"), int32(7)),
[]testCase{
+ {"7 >= 7 => true", rowOf(7, 8, nil), true},
+ {"6 >= 7 => false", rowOf(6, 8, nil), false},
+ {"8 >= 7 => true", rowOf(8, 8, nil), true},
+ }},
+ {iceberg.GreaterThanEqual(iceberg.Reference("s1.s2.s3.s4.i"),
int32(7)), []testCase{
+ {"7 >= 7 => true", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(7))))), true},
+ {"6 >= 7 => false", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(6))))), false},
+ {"8 >= 7 => true", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(8))))), true},
+ }},
+ {iceberg.EqualTo(iceberg.Reference("x"), int32(7)), []testCase{
+ {"7 == 7 => true", rowOf(7, 8, nil), true},
+ {"6 == 7 => false", rowOf(6, 8, nil), false},
+ }},
+ {iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)),
[]testCase{
+ {"7 == 7 => true", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(7))))), true},
+ {"6 == 7 => false", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(6))))), false},
+ }},
+ {iceberg.NotEqualTo(iceberg.Reference("x"), int32(7)),
[]testCase{
+ {"7 != 7 => false", rowOf(7, 8, nil), false},
+ {"6 != 7 => true", rowOf(6, 8, nil), true},
+ }},
+ {iceberg.NotEqualTo(iceberg.Reference("s1.s2.s3.s4.i"),
int32(7)), []testCase{
+ {"7 != 7 => false", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(7))))), false},
+ {"6 != 7 => true", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(6))))), true},
+ }},
+ {iceberg.IsNull(iceberg.Reference("z")), []testCase{
+ {"nil is null", rowOf(1, 2, nil), true},
+ {"3 is not null", rowOf(1, 2, 3), false},
+ }},
+ {iceberg.IsNull(iceberg.Reference("s1.s2.s3.s4.i")), []testCase{
+ {"3 is not null", rowOf(1, 2, 3,
rowOf(rowOf(rowOf(rowOf(3))))), false},
+ }},
+ {iceberg.NotNull(iceberg.Reference("z")), []testCase{
+ {"nil is null", rowOf(1, 2, nil), false},
+ {"3 is not null", rowOf(1, 2, 3), true},
+ }},
+ {iceberg.NotNull(iceberg.Reference("s1.s2.s3.s4.i")),
[]testCase{
+ {"3 is not null", rowOf(1, 2, 3,
rowOf(rowOf(rowOf(rowOf(3))))), true},
+ }},
+ {iceberg.IsNaN(iceberg.Reference("y")), []testCase{
+ {"NaN is NaN", rowOf(1, math.NaN(), 3), true},
+ {"2 is not NaN", rowOf(1, 2.0, 3), false},
+ }},
+ {iceberg.IsNaN(iceberg.Reference("s5.s6.f")), []testCase{
+ {"NaN is NaN", rowOf(1, 2, 3, nil,
rowOf(rowOf(math.NaN()))), true},
+ {"4 is not NaN", rowOf(1, 2, 3, nil,
rowOf(rowOf(4.0))), false},
+ {"nil is not NaN", rowOf(1, 2, 3, nil, nil), false},
+ }},
+ {iceberg.NotNaN(iceberg.Reference("y")), []testCase{
+ {"NaN is NaN", rowOf(1, math.NaN(), 3), false},
+ {"2 is not NaN", rowOf(1, 2.0, 3), true},
+ }},
+ {iceberg.NotNaN(iceberg.Reference("s5.s6.f")), []testCase{
+ {"NaN is NaN", rowOf(1, 2, 3, nil,
rowOf(rowOf(math.NaN()))), false},
+ {"4 is not NaN", rowOf(1, 2, 3, nil,
rowOf(rowOf(4.0))), true},
+ }},
+ {iceberg.NewAnd(iceberg.EqualTo(iceberg.Reference("x"),
int32(7)), iceberg.NotNull(iceberg.Reference("z"))), []testCase{
+ {"7, 3 => true", rowOf(7, 0, 3), true},
+ {"8, 3 => false", rowOf(8, 0, 3), false},
+ {"7, null => false", rowOf(7, 0, nil), false},
+ {"8, null => false", rowOf(8, 0, nil), false},
+ }},
+
{iceberg.NewAnd(iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)),
+ iceberg.NotNull(iceberg.Reference("s1.s2.s3.s4.i"))),
[]testCase{
+ {"7, 7 => true", rowOf(5, 0, 3,
rowOf(rowOf(rowOf(rowOf(7))))), true},
+ {"8, 8 => false", rowOf(7, 0, 3,
rowOf(rowOf(rowOf(rowOf(8))))), false},
+ {"7, null => false", rowOf(5, 0, 3, nil), false},
+ {"8, notnull => false", rowOf(7, 0, 3,
rowOf(rowOf(rowOf(rowOf(8))))), false},
+ }},
+ {iceberg.NewOr(iceberg.EqualTo(iceberg.Reference("x"),
int32(7)), iceberg.NotNull(iceberg.Reference("z"))), []testCase{
+ {"7, 3 => true", rowOf(7, 0, 3), true},
+ {"8, 3 => true", rowOf(8, 0, 3), true},
+ {"7, null => true", rowOf(7, 0, nil), true},
+ {"8, null => false", rowOf(8, 0, nil), false},
+ }},
+
{iceberg.NewOr(iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)),
+ iceberg.NotNull(iceberg.Reference("s1.s2.s3.s4.i"))),
[]testCase{
+ {"7, 7 => true", rowOf(5, 0, 3,
rowOf(rowOf(rowOf(rowOf(7))))), true},
+ {"8, notnull => true", rowOf(7, 0, 3,
rowOf(rowOf(rowOf(rowOf(8))))), true},
+ {"7, null => false", rowOf(5, 0, 3, nil), false},
+ {"8, notnull => true", rowOf(7, 0, 3,
rowOf(rowOf(rowOf(rowOf(8))))), true},
+ }},
+ {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("x"),
int32(7))), []testCase{
+ {"not(7 == 7) => false", rowOf(7), false},
+ {"not(8 == 7) => true", rowOf(8), true},
+ }},
+
{iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7))),
[]testCase{
+ {"not(7 == 7) => false", rowOf(7, nil, nil,
rowOf(rowOf(rowOf(rowOf(7))))), false},
+ {"not(8 == 7) => true", rowOf(7, nil, nil,
rowOf(rowOf(rowOf(rowOf(8))))), true},
+ }},
+ {iceberg.IsIn(iceberg.Reference("x"), int64(7), 8,
math.MaxInt64), []testCase{
+ {"7 in [7, 8, Int64Max] => true", rowOf(7, 8, nil),
true},
+ {"9 in [7, 8, Int64Max] => false", rowOf(9, 8, nil),
false},
+ {"8 in [7, 8, Int64Max] => true", rowOf(8, 8, nil),
true},
+ }},
+ {iceberg.IsIn(iceberg.Reference("x"), int64(math.MaxInt64),
math.MaxInt32, math.MinInt64), []testCase{
+ {"Int32Max in [Int64Max, Int32Max, Int64Min] => true",
rowOf(math.MaxInt32, 7.0, nil), true},
+ {"6 in [Int64Max, Int32Max, Int64Min] => false",
rowOf(6, 6.9, nil), false},
+ }},
+ {iceberg.IsIn(iceberg.Reference("y"), float64(7), 8, 9.1),
[]testCase{
+ {"7.0 in [7, 8, 9.1] => true", rowOf(0, 7.0, nil),
true},
+ {"9.1 in [7, 8, 9.1] => true", rowOf(7, 9.1, nil),
true},
+ {"6.8 in [7, 8, 9.1] => false", rowOf(7, 6.8, nil),
false},
+ }},
+ {iceberg.IsIn(iceberg.Reference("s1.s2.s3.s4.i"), int32(7), 8,
9), []testCase{
+ {"7 in [7, 8, 9] => true", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(7))))), true},
+ {"6 in [7, 8, 9] => true", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(6))))), false},
+ {"nil in [7, 8, 9] => false", rowOf(7, 8, nil, nil),
false},
+ }},
+ {iceberg.NotIn(iceberg.Reference("x"), int64(7), 8,
math.MaxInt64), []testCase{
+ {"7 not in [7, 8, Int64Max] => false", rowOf(7, 8,
nil), false},
+ {"9 not in [7, 8, Int64Max] => true", rowOf(9, 8, nil),
true},
+ {"8 not in [7, 8, Int64Max] => false", rowOf(8, 8,
nil), false},
+ }},
+ {iceberg.NotIn(iceberg.Reference("x"), int64(math.MaxInt64),
math.MaxInt32, math.MinInt64), []testCase{
+ {"Int32Max not in [Int64Max, Int32Max, Int64Min] =>
false", rowOf(math.MaxInt32, 7.0, nil), false},
+ {"6 not in [Int64Max, Int32Max, Int64Min] => true",
rowOf(6, 6.9, nil), true},
+ }},
+ {iceberg.NotIn(iceberg.Reference("y"), float64(7), 8, 9.1),
[]testCase{
+ {"7.0 not in [7, 8, 9.1] => false", rowOf(0, 7.0, nil),
false},
+ {"9.1 not in [7, 8, 9.1] => false", rowOf(7, 9.1, nil),
false},
+ {"6.8 not in [7, 8, 9.1] => true", rowOf(7, 6.8, nil),
true},
+ }},
+ {iceberg.NotIn(iceberg.Reference("s1.s2.s3.s4.i"), int32(7), 8,
9), []testCase{
+ {"7 not in [7, 8, 9] => false", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(7))))), false},
+ {"6 not in [7, 8, 9] => false", rowOf(7, 8, nil,
rowOf(rowOf(rowOf(rowOf(6))))), true},
+ }},
+ {iceberg.EqualTo(iceberg.Reference("s"), "abc"), []testCase{
+ {"abc == abc => true", rowOf(1, 2, nil, nil, nil,
"abc"), true},
+ {"abd == abc => false", rowOf(1, 2, nil, nil, nil,
"abd"), false},
+ }},
+ {iceberg.StartsWith(iceberg.Reference("s"), "abc"), []testCase{
+ {"abc startsWith abc => true", rowOf(1, 2, nil, nil,
nil, "abc"), true},
+ {"xabc startsWith abc => false", rowOf(1, 2, nil, nil,
nil, "xabc"), false},
+ {"Abc startsWith abc => false", rowOf(1, 2, nil, nil,
nil, "Abc"), false},
+ {"a startsWith abc => false", rowOf(1, 2, nil, nil,
nil, "a"), false},
+ {"abcd startsWith abc => true", rowOf(1, 2, nil, nil,
nil, "abcd"), true},
+ {"nil startsWith abc => false", rowOf(1, 2, nil, nil,
nil, nil), false},
+ }},
+ {iceberg.NotStartsWith(iceberg.Reference("s"), "abc"),
[]testCase{
+ {"abc not startsWith abc => false", rowOf(1, 2, nil,
nil, nil, "abc"), false},
+ {"xabc not startsWith abc => true", rowOf(1, 2, nil,
nil, nil, "xabc"), true},
+ {"Abc not startsWith abc => true", rowOf(1, 2, nil,
nil, nil, "Abc"), true},
+ {"a not startsWith abc => true", rowOf(1, 2, nil, nil,
nil, "a"), true},
+ {"abcd not startsWith abc => false", rowOf(1, 2, nil,
nil, nil, "abcd"), false},
+ {"nil not startsWith abc => true", rowOf(1, 2, nil,
nil, nil, nil), true},
+ }},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.exp.String(), func(t *testing.T) {
+ ev, err := iceberg.ExpressionEvaluator(testSchema,
tt.exp, true)
+ require.NoError(t, err)
+
+ for _, c := range tt.cases {
+ res, err := ev(c.row)
+ require.NoError(t, err)
+
+ assert.Equal(t, c.result, res, c.str)
+ }
+ })
+ }
+}
+
+func TestEvaluatorCmpTypes(t *testing.T) {
+ sc := iceberg.NewSchema(1,
+ iceberg.NestedField{ID: 1, Name: "a", Type:
iceberg.PrimitiveTypes.Bool},
+ iceberg.NestedField{ID: 2, Name: "b", Type:
iceberg.PrimitiveTypes.Int32},
+ iceberg.NestedField{ID: 3, Name: "c", Type:
iceberg.PrimitiveTypes.Int64},
+ iceberg.NestedField{ID: 4, Name: "d", Type:
iceberg.PrimitiveTypes.Float32},
+ iceberg.NestedField{ID: 5, Name: "e", Type:
iceberg.PrimitiveTypes.Float64},
+ iceberg.NestedField{ID: 6, Name: "f", Type:
iceberg.PrimitiveTypes.Date},
+ iceberg.NestedField{ID: 7, Name: "g", Type:
iceberg.PrimitiveTypes.Time},
+ iceberg.NestedField{ID: 8, Name: "h", Type:
iceberg.PrimitiveTypes.Timestamp},
+ iceberg.NestedField{ID: 9, Name: "i", Type:
iceberg.DecimalTypeOf(9, 2)},
+ iceberg.NestedField{ID: 10, Name: "j", Type:
iceberg.PrimitiveTypes.String},
+ iceberg.NestedField{ID: 11, Name: "k", Type:
iceberg.PrimitiveTypes.Binary},
+ iceberg.NestedField{ID: 12, Name: "l", Type:
iceberg.PrimitiveTypes.UUID},
+ iceberg.NestedField{ID: 13, Name: "m", Type:
iceberg.FixedTypeOf(5)})
+
+ rowData := rowOf(true,
+ 5, 5, float32(5.0), float64(5.0),
+ 29, 51661919000, 1503066061919234,
+ iceberg.Decimal{Scale: 2, Val: decimal128.FromI64(3456)},
+ "abcdef", []byte{0x01, 0x02, 0x03},
+ uuid.New(), []byte{0xDE, 0xAD, 0xBE, 0xEF, 0x0})
+
+ tests := []struct {
+ ref iceberg.BooleanExpression
+ exp bool
+ }{
+ {iceberg.EqualTo(iceberg.Reference("a"), true), true},
+ {iceberg.EqualTo(iceberg.Reference("a"), false), false},
+ {iceberg.EqualTo(iceberg.Reference("c"), int64(5)), true},
+ {iceberg.EqualTo(iceberg.Reference("c"), int64(6)), false},
+ {iceberg.EqualTo(iceberg.Reference("d"), int64(5)), true},
+ {iceberg.EqualTo(iceberg.Reference("d"), int64(6)), false},
+ {iceberg.EqualTo(iceberg.Reference("e"), int64(5)), true},
+ {iceberg.EqualTo(iceberg.Reference("e"), int64(6)), false},
+ {iceberg.EqualTo(iceberg.Reference("f"), "1970-01-30"), true},
+ {iceberg.EqualTo(iceberg.Reference("f"), "1970-01-31"), false},
+ {iceberg.EqualTo(iceberg.Reference("g"), "14:21:01.919"), true},
+ {iceberg.EqualTo(iceberg.Reference("g"), "14:21:02.919"),
false},
+ {iceberg.EqualTo(iceberg.Reference("h"),
"2017-08-18T14:21:01.919234"), true},
+ {iceberg.EqualTo(iceberg.Reference("h"),
"2017-08-19T14:21:01.919234"), false},
+ {iceberg.LessThan(iceberg.Reference("i"), "32.22"), false},
+ {iceberg.GreaterThan(iceberg.Reference("i"), "32.22"), true},
+ {iceberg.LessThanEqual(iceberg.Reference("j"), "abcd"), false},
+ {iceberg.GreaterThan(iceberg.Reference("j"), "abcde"), true},
+ {iceberg.GreaterThan(iceberg.Reference("k"), []byte{0x00}),
true},
+ {iceberg.LessThan(iceberg.Reference("k"), []byte{0x00}), false},
+ {iceberg.EqualTo(iceberg.Reference("l"), uuid.New().String()),
false},
+ {iceberg.EqualTo(iceberg.Reference("l"),
rowData[11].(uuid.UUID)), true},
+ {iceberg.EqualTo(iceberg.Reference("m"), []byte{0xDE, 0xAD,
0xBE, 0xEF, 0x1}), false},
+ {iceberg.EqualTo(iceberg.Reference("m"), []byte{0xDE, 0xAD,
0xBE, 0xEF, 0x0}), true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.ref.String(), func(t *testing.T) {
+ ev, err := iceberg.ExpressionEvaluator(sc, tt.ref, true)
+ require.NoError(t, err)
+
+ res, err := ev(rowData)
+ require.NoError(t, err)
+ assert.Equal(t, tt.exp, res)
+ })
+ }
+}
+
+func TestManifestEvaluator(t *testing.T) {
+ const (
+ IntMinValue, IntMaxValue = 30, 79
+ )
+
+ var (
+ IntMin, IntMax = []byte{byte(IntMinValue), 0x00, 0x00,
0x00}, []byte{byte(IntMaxValue), 0x00, 0x00, 0x00}
+ StringMin, StringMax = []byte("a"), []byte("z")
+ FloatMin, _ = iceberg.Float32Literal(0).MarshalBinary()
+ FloatMax, _ =
iceberg.Float32Literal(20).MarshalBinary()
+ DblMin, _ = iceberg.Float64Literal(0).MarshalBinary()
+ DblMax, _ =
iceberg.Float64Literal(20).MarshalBinary()
+ NanTrue, NanFalse = true, false
+
+ testSchema = iceberg.NewSchema(1,
+ iceberg.NestedField{ID: 1, Name: "id",
+ Type: iceberg.PrimitiveTypes.Int32, Required:
true},
+ iceberg.NestedField{ID: 2, Name:
"all_nulls_missing_nan",
+ Type: iceberg.PrimitiveTypes.String, Required:
false},
+ iceberg.NestedField{ID: 3, Name: "some_nulls",
+ Type: iceberg.PrimitiveTypes.String, Required:
false},
+ iceberg.NestedField{ID: 4, Name: "no_nulls",
+ Type: iceberg.PrimitiveTypes.String, Required:
false},
+ iceberg.NestedField{ID: 5, Name: "float",
+ Type: iceberg.PrimitiveTypes.Float32, Required:
false},
+ iceberg.NestedField{ID: 6, Name: "all_nulls_double",
+ Type: iceberg.PrimitiveTypes.Float64, Required:
false},
+ iceberg.NestedField{ID: 7, Name: "all_nulls_no_nans",
+ Type: iceberg.PrimitiveTypes.Float32, Required:
false},
+ iceberg.NestedField{ID: 8, Name: "all_nans",
+ Type: iceberg.PrimitiveTypes.Float64, Required:
false},
+ iceberg.NestedField{ID: 9, Name: "both_nan_and_null",
+ Type: iceberg.PrimitiveTypes.Float32, Required:
false},
+ iceberg.NestedField{ID: 10, Name: "no_nan_or_null",
+ Type: iceberg.PrimitiveTypes.Float64, Required:
false},
+ iceberg.NestedField{ID: 11, Name:
"all_nulls_missing_nan_float",
+ Type: iceberg.PrimitiveTypes.Float32, Required:
false},
+ iceberg.NestedField{ID: 12, Name:
"all_same_value_or_null",
+ Type: iceberg.PrimitiveTypes.String, Required:
false},
+ iceberg.NestedField{ID: 13, Name:
"no_nulls_same_value_a",
+ Type: iceberg.PrimitiveTypes.Binary, Required:
false},
+ )
+ )
+
+ partFields := make([]iceberg.PartitionField, 0, testSchema.NumFields())
+ for _, f := range testSchema.Fields() {
+ partFields = append(partFields, iceberg.PartitionField{
+ Name: f.Name,
+ SourceID: f.ID,
+ FieldID: f.ID,
+ Transform: iceberg.IdentityTransform{},
+ })
+ }
+
+ spec := iceberg.NewPartitionSpec(partFields...)
+ manifestNoStats := iceberg.NewManifestV1Builder("", 0, 0, 0).Build()
+ manifest := iceberg.NewManifestV1Builder("", 0, 0, 0).Partitions(
+ []iceberg.FieldSummary{
+ { // id
+ ContainsNull: false,
+ ContainsNaN: nil,
+ LowerBound: &IntMin,
+ UpperBound: &IntMax,
+ },
+ { // all_nulls_missing_nan
+ ContainsNull: true,
+ ContainsNaN: nil,
+ LowerBound: nil,
+ UpperBound: nil,
+ },
+ { // some_nulls
+ ContainsNull: true,
+ ContainsNaN: nil,
+ LowerBound: &StringMin,
+ UpperBound: &StringMax,
+ },
+ { // no_nulls
+ ContainsNull: false,
+ ContainsNaN: nil,
+ LowerBound: &StringMin,
+ UpperBound: &StringMax,
+ },
+ { // float
+ ContainsNull: true,
+ ContainsNaN: nil,
+ LowerBound: &FloatMin,
+ UpperBound: &FloatMax,
+ },
+ { // all_nulls_double
+ ContainsNull: true,
+ ContainsNaN: nil,
+ LowerBound: nil,
+ UpperBound: nil,
+ },
+ { // all_nulls_no_nans
+ ContainsNull: true,
+ ContainsNaN: &NanFalse,
+ LowerBound: nil,
+ UpperBound: nil,
+ },
+ { // all_nans
+ ContainsNull: false,
+ ContainsNaN: &NanTrue,
+ LowerBound: nil,
+ UpperBound: nil,
+ },
+ { // both_nan_and_null
+ ContainsNull: true,
+ ContainsNaN: &NanTrue,
+ LowerBound: nil,
+ UpperBound: nil,
+ },
+ { // no_nan_or_null
+ ContainsNull: false,
+ ContainsNaN: &NanFalse,
+ LowerBound: &DblMin,
+ UpperBound: &DblMax,
+ },
+ { // all_nulls_missing_nan_float
+ ContainsNull: true,
+ ContainsNaN: nil,
+ LowerBound: nil,
+ UpperBound: nil,
+ },
+ { // all_same_value_or_null
+ ContainsNull: true,
+ ContainsNaN: nil,
+ LowerBound: &StringMin,
+ UpperBound: &StringMin,
+ },
+ { // no_nulls_same_value_a
+ ContainsNull: false,
+ ContainsNaN: nil,
+ LowerBound: &StringMin,
+ UpperBound: &StringMin,
+ },
+ }).Build()
+
+ t.Run("all nulls", func(t *testing.T) {
+ tests := []struct {
+ field string
+ expected bool
+ msg string
+ }{
+ {"all_nulls_missing_nan", false, "should skip: all
nulls column with non-floating type contains all null"},
+ {"all_nulls_missing_nan_float", true, "should read: no
NaN information may indicate presence of NaN value"},
+ {"some_nulls", true, "should read: column with some
nulls contains a non-null value"},
+ {"no_nulls", true, "should read: non-null column
contains a non-null value"},
+ }
+
+ for _, tt := range tests {
+ eval, err := iceberg.NewManifestEvaluator(spec,
testSchema,
+ iceberg.NotNull(iceberg.Reference(tt.field)),
true)
+ require.NoError(t, err)
+
+ result, err := eval(manifest)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, result, tt.msg)
+ }
+ })
+
+ t.Run("no nulls", func(t *testing.T) {
+ tests := []struct {
+ field string
+ expected bool
+ msg string
+ }{
+ {"all_nulls_missing_nan", true, "should read: at least
one null value in all null column"},
+ {"some_nulls", true, "should read: column with some
nulls contains a null value"},
+ {"no_nulls", false, "should skip: non-null column
contains no null values"},
+ {"both_nan_and_null", true, "should read:
both_nan_and_null column contains no null values"},
+ }
+
+ for _, tt := range tests {
+ eval, err := iceberg.NewManifestEvaluator(spec,
testSchema,
+ iceberg.IsNull(iceberg.Reference(tt.field)),
true)
+ require.NoError(t, err)
+
+ result, err := eval(manifest)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, result, tt.msg)
+ }
+ })
+
+ t.Run("is nan", func(t *testing.T) {
+ tests := []struct {
+ field string
+ expected bool
+ msg string
+ }{
+ {"float", true, "should read: no information on if
there are nan values in float column"},
+ {"all_nulls_double", true, "should read: no NaN
information may indicate presence of NaN value"},
+ {"all_nulls_missing_nan_float", true, "should read: no
NaN information may indicate presence of NaN value"},
+ {"all_nulls_no_nans", false, "should skip: no nan
column doesn't contain nan value"},
+ {"all_nans", true, "should read: all_nans column
contains nan value"},
+ {"both_nan_and_null", true, "should read:
both_nan_and_null column contains nan value"},
+ {"no_nan_or_null", false, "should skip: no_nan_or_null
column doesn't contain nan value"},
+ }
+
+ for _, tt := range tests {
+ eval, err := iceberg.NewManifestEvaluator(spec,
testSchema,
+ iceberg.IsNaN(iceberg.Reference(tt.field)),
true)
+ require.NoError(t, err)
+
+ result, err := eval(manifest)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, result, tt.msg)
+ }
+ })
+
+ t.Run("not nan", func(t *testing.T) {
+ tests := []struct {
+ field string
+ expected bool
+ msg string
+ }{
+ {"float", true, "should read: no information on if
there are nan values in float column"},
+ {"all_nulls_double", true, "should read: all null
column contains non nan value"},
+ {"all_nulls_no_nans", true, "should read: no_nans
column contains non nan value"},
+ {"all_nans", false, "should skip: all nans
columndoesn't contain non nan value"},
+ {"both_nan_and_null", true, "should read:
both_nan_and_null nans column contains non nan value"},
+ {"no_nan_or_null", true, "should read: no_nan_or_null
column contains non nan value"},
+ }
+
+ for _, tt := range tests {
+ eval, err := iceberg.NewManifestEvaluator(spec,
testSchema,
+ iceberg.NotNaN(iceberg.Reference(tt.field)),
true)
+ require.NoError(t, err)
+
+ result, err := eval(manifest)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, result, tt.msg)
+ }
+ })
+
+ t.Run("test missing stats", func(t *testing.T) {
+ exprs := []iceberg.BooleanExpression{
+ iceberg.LessThan(iceberg.Reference("id"), int32(5)),
+ iceberg.LessThanEqual(iceberg.Reference("id"),
int32(30)),
+ iceberg.EqualTo(iceberg.Reference("id"), int32(70)),
+ iceberg.GreaterThan(iceberg.Reference("id"), int32(78)),
+ iceberg.GreaterThanEqual(iceberg.Reference("id"),
int32(90)),
+ iceberg.NotEqualTo(iceberg.Reference("id"), int32(101)),
+ iceberg.IsNull(iceberg.Reference("id")),
+ iceberg.NotNull(iceberg.Reference("id")),
+ iceberg.IsNaN(iceberg.Reference("float")),
+ iceberg.NotNaN(iceberg.Reference("float")),
+ }
+
+ for _, tt := range exprs {
+ eval, err := iceberg.NewManifestEvaluator(spec,
testSchema, tt, true)
+ require.NoError(t, err)
+
+ result, err := eval(manifestNoStats)
+ require.NoError(t, err)
+ assert.Truef(t, result, "should read when missing stats
for expr: %s", tt)
+ }
+ })
+
+ t.Run("test exprs", func(t *testing.T) {
+ tests := []struct {
+ expr iceberg.BooleanExpression
+ expect bool
+ msg string
+ }{
+
{iceberg.NewNot(iceberg.LessThan(iceberg.Reference("id"),
int32(IntMinValue-25))),
+ true, "should read: not(false)"},
+
{iceberg.NewNot(iceberg.GreaterThan(iceberg.Reference("id"),
int32(IntMinValue-25))),
+ false, "should skip: not(true)"},
+ {iceberg.NewAnd(
+ iceberg.LessThan(iceberg.Reference("id"),
int32(IntMinValue-25)),
+
iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMinValue-30))),
+ false, "should skip: and(false, true)"},
+ {iceberg.NewAnd(
+ iceberg.LessThan(iceberg.Reference("id"),
int32(IntMinValue-25)),
+
iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue+1))),
+ false, "should skip: and(false, false)"},
+ {iceberg.NewAnd(
+ iceberg.GreaterThan(iceberg.Reference("id"),
int32(IntMinValue-25)),
+ iceberg.LessThanEqual(iceberg.Reference("id"),
int32(IntMinValue))),
+ true, "should read: and(true, true)"},
+ {iceberg.NewOr(
+ iceberg.LessThan(iceberg.Reference("id"),
int32(IntMinValue-25)),
+
iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue+1))),
+ false, "should skip: or(false, false)"},
+ {iceberg.NewOr(
+ iceberg.LessThan(iceberg.Reference("id"),
int32(IntMinValue-25)),
+
iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue-19))),
+ true, "should read: or(false, true)"},
+ {iceberg.LessThan(iceberg.Reference("some_nulls"),
"1"), false,
+ "should not read: id range below lower bound"},
+ {iceberg.LessThan(iceberg.Reference("some_nulls"),
"b"), true,
+ "should read: lower bound in range"},
+ {iceberg.LessThan(iceberg.Reference("float"), 15.50),
true,
+ "should read: lower bound in range"},
+ {iceberg.LessThan(iceberg.Reference("no_nan_or_null"),
15.50), true,
+ "should read: lower bound in range"},
+
{iceberg.LessThanEqual(iceberg.Reference("no_nulls_same_value_a"), "a"), true,
+ "should read: lower bound in range"},
+ {iceberg.LessThan(iceberg.Reference("id"),
int32(IntMinValue-25)), false,
+ "should not read: id range below lower bound (5
< 30)"},
+ {iceberg.LessThan(iceberg.Reference("id"),
int32(IntMinValue)), false,
+ "should not read: id range below lower bound
(30 is not < 30)"},
+ {iceberg.LessThan(iceberg.Reference("id"),
int32(IntMinValue+1)), true,
+ "should read: one possible id"},
+ {iceberg.LessThan(iceberg.Reference("id"),
int32(IntMaxValue)), true,
+ "should read: many possible ids"},
+ {iceberg.LessThanEqual(iceberg.Reference("id"),
int32(IntMinValue-25)), false,
+ "should not read: id range below lower bound (5
< 30)"},
+ {iceberg.LessThanEqual(iceberg.Reference("id"),
int32(IntMinValue-1)), false,
+ "should not read: id range below lower bound 29
< 30"},
+ {iceberg.LessThanEqual(iceberg.Reference("id"),
int32(IntMinValue)), true,
+ "should read: one possible id"},
+ {iceberg.LessThanEqual(iceberg.Reference("id"),
int32(IntMaxValue)), true,
+ "should read: many possible ids"},
+ {iceberg.GreaterThan(iceberg.Reference("id"),
int32(IntMaxValue+6)), false,
+ "should not read: id range above upper bound
(85 < 79)"},
+ {iceberg.GreaterThan(iceberg.Reference("id"),
int32(IntMaxValue)), false,
+ "should not read: id range above upper bound
(79 is not > 79)"},
+ {iceberg.GreaterThan(iceberg.Reference("id"),
int32(IntMaxValue-1)), true,
+ "should read: one possible id"},
+ {iceberg.GreaterThan(iceberg.Reference("id"),
int32(IntMaxValue-4)), true,
+ "should read: many possible ids"},
+ {iceberg.GreaterThanEqual(iceberg.Reference("id"),
int32(IntMaxValue+6)), false,
+ "should not read: id range is above upper bound
(85 < 79)"},
+ {iceberg.GreaterThanEqual(iceberg.Reference("id"),
int32(IntMaxValue+1)), false,
+ "should not read: id range above upper bound
(80 > 79)"},
+ {iceberg.GreaterThanEqual(iceberg.Reference("id"),
int32(IntMaxValue)), true,
+ "should read: one possible id"},
+ {iceberg.GreaterThanEqual(iceberg.Reference("id"),
int32(IntMaxValue)), true,
+ "should read: many possible ids"},
+ {iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMinValue-25)), false,
+ "should not read: id below lower bound"},
+ {iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMinValue-1)), false,
+ "should not read: id below lower bound"},
+ {iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMinValue)), true,
+ "should read: id equal to lower bound"},
+ {iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMaxValue-4)), true,
+ "should read: id between lower and upper
bounds"},
+ {iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMaxValue)), true,
+ "should read: id equal to upper bound"},
+ {iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMaxValue+1)), false,
+ "should not read: id above upper bound"},
+ {iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMaxValue+6)), false,
+ "should not read: id above upper bound"},
+ {iceberg.NotEqualTo(iceberg.Reference("id"),
int32(IntMinValue-25)), true,
+ "should read: id below lower bound"},
+ {iceberg.NotEqualTo(iceberg.Reference("id"),
int32(IntMinValue-1)), true,
+ "should read: id below lower bound"},
+ {iceberg.NotEqualTo(iceberg.Reference("id"),
int32(IntMinValue)), true,
+ "should read: id equal to lower bound"},
+ {iceberg.NotEqualTo(iceberg.Reference("id"),
int32(IntMaxValue-4)), true,
+ "should read: id between lower and upper
bounds"},
+ {iceberg.NotEqualTo(iceberg.Reference("id"),
int32(IntMaxValue)), true,
+ "should read: id equal to upper bound"},
+ {iceberg.NotEqualTo(iceberg.Reference("id"),
int32(IntMaxValue+1)), true,
+ "should read: id above upper bound"},
+ {iceberg.NotEqualTo(iceberg.Reference("id"),
int32(IntMaxValue+6)), true,
+ "should read: id above upper bound"},
+
{iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMinValue-25))), true,
+ "should read: id below lower bound"},
+
{iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMinValue-1))), true,
+ "should read: id below lower bound"},
+
{iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue))),
true,
+ "should read: id equal to lower bound"},
+
{iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMaxValue-4))), true,
+ "should read: id between lower and upper
bounds"},
+
{iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue))),
true,
+ "should read: id equal to upper bound"},
+
{iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMaxValue+1))), true,
+ "should read: id above upper bound"},
+
{iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"),
int32(IntMaxValue+6))), true,
+ "should read: id above upper bound"},
+ {iceberg.IsIn(iceberg.Reference("id"),
int32(IntMinValue-25), IntMinValue-24), false,
+ "should not read: id below lower bound (5 < 30,
6 < 30)"},
+ {iceberg.IsIn(iceberg.Reference("id"),
int32(IntMinValue-2), IntMinValue-1), false,
+ "should not read: id below lower bound (28 <
30, 29 < 30)"},
+ {iceberg.IsIn(iceberg.Reference("id"),
int32(IntMinValue-1), IntMinValue), true,
+ "should read: id equal to lower bound (30 ==
30)"},
+ {iceberg.IsIn(iceberg.Reference("id"),
int32(IntMaxValue-4), IntMaxValue-3), true,
+ "should read: id between lower and upper bounds
(30 < 75 < 79, 30 < 76 < 79)"},
+ {iceberg.IsIn(iceberg.Reference("id"),
int32(IntMaxValue), IntMaxValue+1), true,
+ "should read: id equal to upper bound (79 ==
79)"},
+ {iceberg.IsIn(iceberg.Reference("id"),
int32(IntMaxValue+1), IntMaxValue+2), false,
+ "should not read: id above upper bound (80 >
79, 81 > 79)"},
+ {iceberg.IsIn(iceberg.Reference("id"),
int32(IntMaxValue+6), IntMaxValue+7), false,
+ "should not read: id above upper bound (85 >
79, 86 > 79)"},
+
{iceberg.IsIn(iceberg.Reference("all_nulls_missing_nan"), "abc", "def"), false,
+ "should skip: in on all nulls column"},
+ {iceberg.IsIn(iceberg.Reference("some_nulls"), "abc",
"def"), true,
+ "should read: in on some nulls column"},
+ {iceberg.IsIn(iceberg.Reference("no_nulls"), "abc",
"def"), true,
+ "should read: in on no nulls column"},
+
{iceberg.IsIn(iceberg.Reference("no_nulls_same_value_a"), "a", "b"), true,
+ "should read: in on no nulls column"},
+ {iceberg.IsIn(iceberg.Reference("float"), 0, -5.5),
true,
+ "should read: float equal to lower bound"},
+ {iceberg.IsIn(iceberg.Reference("no_nan_or_null"), 0,
-5.5), true,
+ "should read: float equal to lower bound"},
+ {iceberg.NotIn(iceberg.Reference("id"),
int32(IntMinValue-25), IntMinValue-24), true,
+ "should read: id below lower bound (5 < 30, 6 <
30)"},
+ {iceberg.NotIn(iceberg.Reference("id"),
int32(IntMinValue-2), IntMinValue-1), true,
+ "should read: id below lower bound (28 < 30, 29
< 30)"},
+ {iceberg.NotIn(iceberg.Reference("id"),
int32(IntMinValue-1), IntMinValue), true,
+ "should read: id equal to lower bound (30 ==
30)"},
+ {iceberg.NotIn(iceberg.Reference("id"),
int32(IntMaxValue-4), IntMaxValue-3), true,
+ "should read: id between lower and upper bounds
(30 < 75 < 79, 30 < 76 < 79)"},
+ {iceberg.NotIn(iceberg.Reference("id"),
int32(IntMaxValue), IntMaxValue+1), true,
+ "should read: id equal to upper bound (79 ==
79)"},
+ {iceberg.NotIn(iceberg.Reference("id"),
int32(IntMaxValue+1), IntMaxValue+2), true,
+ "should read: id above upper bound (80 > 79, 81
> 79)"},
+ {iceberg.NotIn(iceberg.Reference("id"),
int32(IntMaxValue+6), IntMaxValue+7), true,
+ "should read: id above upper bound (85 > 79, 86
> 79)"},
+
{iceberg.NotIn(iceberg.Reference("all_nulls_missing_nan"), "abc", "def"), true,
+ "should read: notIn on all nulls column"},
+ {iceberg.NotIn(iceberg.Reference("some_nulls"), "abc",
"def"), true,
+ "should read: notIn on some nulls column"},
+ {iceberg.NotIn(iceberg.Reference("no_nulls"), "abc",
"def"), true,
+ "should read: notIn on no nulls column"},
+ {iceberg.StartsWith(iceberg.Reference("some_nulls"),
"a"), true,
+ "should read: range matches"},
+ {iceberg.StartsWith(iceberg.Reference("some_nulls"),
"aa"), true,
+ "should read: range matches"},
+ {iceberg.StartsWith(iceberg.Reference("some_nulls"),
"dddd"), true,
+ "should read: range matches"},
+ {iceberg.StartsWith(iceberg.Reference("some_nulls"),
"z"), true,
+ "should read: range matches"},
+ {iceberg.StartsWith(iceberg.Reference("no_nulls"),
"a"), true,
+ "should read: range matches"},
+ {iceberg.StartsWith(iceberg.Reference("some_nulls"),
"zzzz"), false,
+ "should skip: range doesn't match"},
+ {iceberg.StartsWith(iceberg.Reference("some_nulls"),
"1"), false,
+ "should skip: range doesn't match"},
+
{iceberg.StartsWith(iceberg.Reference("no_nulls_same_value_a"), "a"), true,
+ "should read: all values start with the
prefix"},
+ {iceberg.NotStartsWith(iceberg.Reference("some_nulls"),
"a"), true,
+ "should read: range matches"},
+ {iceberg.NotStartsWith(iceberg.Reference("some_nulls"),
"aa"), true,
+ "should read: range matches"},
+ {iceberg.NotStartsWith(iceberg.Reference("some_nulls"),
"dddd"), true,
+ "should read: range matches"},
+ {iceberg.NotStartsWith(iceberg.Reference("some_nulls"),
"z"), true,
+ "should read: range matches"},
+ {iceberg.NotStartsWith(iceberg.Reference("no_nulls"),
"a"), true,
+ "should read: range matches"},
+ {iceberg.NotStartsWith(iceberg.Reference("some_nulls"),
"zzzz"), true,
+ "should read: range matches"},
+ {iceberg.NotStartsWith(iceberg.Reference("some_nulls"),
"1"), true,
+ "should read: range matches"},
+
{iceberg.NotStartsWith(iceberg.Reference("all_same_value_or_null"), "a"), true,
+ "should read: range matches"},
+
{iceberg.NotStartsWith(iceberg.Reference("all_same_value_or_null"), "aa"), true,
+ "should read: range matches"},
+
{iceberg.NotStartsWith(iceberg.Reference("all_same_value_or_null"), "A"), true,
+ "should read: range matches"},
+ // Iceberg does not implement SQL 3-way boolean logic,
so the choice of an
+ // all null column matching is by definition in order
to surface more values
+ // to the query engine to allow it to make its own
decision
+
{iceberg.NotStartsWith(iceberg.Reference("all_nulls_missing_nan"), "A"), true,
+ "should read: range matches"},
+
{iceberg.NotStartsWith(iceberg.Reference("no_nulls_same_value_a"), "a"), false,
+ "should not read: all values start with the
prefix"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.expr.String(), func(t *testing.T) {
+ eval, err := iceberg.NewManifestEvaluator(spec,
testSchema,
+ tt.expr, true)
+ require.NoError(t, err)
+
+ result, err := eval(manifest)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expect, result, tt.msg)
+ })
+ }
+ })
+}
+
+func TestRewriteNot(t *testing.T) {
+ tests := []struct {
+ expr, expected iceberg.BooleanExpression
+ }{
+ {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("x"), 34.56)),
+ iceberg.NotEqualTo(iceberg.Reference("x"), 34.56)},
+ {iceberg.NewNot(iceberg.NotEqualTo(iceberg.Reference("x"),
34.56)),
+ iceberg.EqualTo(iceberg.Reference("x"), 34.56)},
+ {iceberg.NewNot(iceberg.IsIn(iceberg.Reference("x"), 34.56,
23.45)),
+ iceberg.NotIn(iceberg.Reference("x"), 34.56, 23.45)},
+ {iceberg.NewNot(iceberg.NewAnd(
+ iceberg.EqualTo(iceberg.Reference("x"), 34.56),
iceberg.EqualTo(iceberg.Reference("y"), 34.56))),
+ iceberg.NewOr(
+ iceberg.NotEqualTo(iceberg.Reference("x"),
34.56), iceberg.NotEqualTo(iceberg.Reference("y"), 34.56))},
+ {iceberg.NewNot(iceberg.NewOr(
+ iceberg.EqualTo(iceberg.Reference("x"), 34.56),
iceberg.EqualTo(iceberg.Reference("y"), 34.56))),
+
iceberg.NewAnd(iceberg.NotEqualTo(iceberg.Reference("x"), 34.56),
iceberg.NotEqualTo(iceberg.Reference("y"), 34.56))},
+ {iceberg.NewNot(iceberg.AlwaysFalse{}), iceberg.AlwaysTrue{}},
+ {iceberg.NewNot(iceberg.AlwaysTrue{}), iceberg.AlwaysFalse{}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.expr.String(), func(t *testing.T) {
+ out, err := iceberg.RewriteNotExpr(tt.expr)
+ require.NoError(t, err)
+ assert.True(t, out.Equals(tt.expected))
+ })
+ }
+}