This is an automated email from the ASF dual-hosted git repository.

etudenhoefner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-go.git


The following commit(s) were added to refs/heads/main by this push:
     new d8e7cce  feat(table/scanner): Implement Arrow type promotion and 
conversion (#174)
d8e7cce is described below

commit d8e7cce3e6e4d214d62b322f6548d41cc44b6d9b
Author: Matt Topol <[email protected]>
AuthorDate: Wed Oct 23 03:42:43 2024 -0400

    feat(table/scanner): Implement Arrow type promotion and conversion (#174)
---
 .github/workflows/go-ci.yml |   7 +-
 errors.go                   |   1 +
 exprs.go                    |  10 ++
 go.mod                      |  15 ++-
 go.sum                      |  18 +--
 partitions.go               |   3 +-
 schema.go                   |   3 +-
 table/arrow_utils.go        | 311 ++++++++++++++++++++++++++++++++++++++++++++
 table/arrow_utils_test.go   | 116 +++++++++++++++++
 table/table.go              |   3 +-
 types.go                    |  44 ++++++-
 visitors.go                 |  49 +++++++
 12 files changed, 555 insertions(+), 25 deletions(-)

diff --git a/.github/workflows/go-ci.yml b/.github/workflows/go-ci.yml
index 3d0c0c0..60bcb48 100644
--- a/.github/workflows/go-ci.yml
+++ b/.github/workflows/go-ci.yml
@@ -39,7 +39,7 @@ jobs:
     strategy:
       fail-fast: false
       matrix:
-        go: [ '1.22', '1.23' ]
+        go: [ '1.23' ]
         os: [ 'ubuntu-latest', 'windows-latest', 'macos-latest' ]
     steps:
     - uses: actions/checkout@v4
@@ -48,10 +48,7 @@ jobs:
       with:
         go-version: ${{ matrix.go }}
         cache: true
-        cache-dependency-path: go.sum
-    - name: Install staticcheck
-      if: matrix.go == '1.22'
-      run: go install honnef.co/go/tools/cmd/[email protected]
+        cache-dependency-path: go.sum    
     - name: Install staticcheck
       if: matrix.go == '1.23'
       run: go install honnef.co/go/tools/cmd/[email protected]
diff --git a/errors.go b/errors.go
index f4fc986..9ecc26d 100644
--- a/errors.go
+++ b/errors.go
@@ -29,4 +29,5 @@ var (
        ErrBadCast                 = errors.New("could not cast value")
        ErrBadLiteral              = errors.New("invalid literal value")
        ErrInvalidBinSerialization = errors.New("invalid binary serialization")
+       ErrResolve                 = errors.New("cannot resolve type")
 )
diff --git a/exprs.go b/exprs.go
index 4bdb8c1..32ecc62 100644
--- a/exprs.go
+++ b/exprs.go
@@ -411,6 +411,7 @@ type BoundReference interface {
 
        Field() NestedField
        Pos() int
+       PosPath() []int
 }
 
 type boundRef[T LiteralType] struct {
@@ -450,6 +451,15 @@ func createBoundRef(field NestedField, acc accessor) 
BoundReference {
 
 func (b *boundRef[T]) Pos() int { return b.acc.pos }
 
+func (b *boundRef[T]) PosPath() []int {
+       out, inner := []int{b.acc.pos}, &b.acc
+       for inner.inner != nil {
+               inner = inner.inner
+               out = append(out, inner.pos)
+       }
+       return out
+}
+
 func (*boundRef[T]) isTerm() {}
 
 func (b *boundRef[T]) String() string {
diff --git a/go.mod b/go.mod
index fc300c9..a9ab8ed 100644
--- a/go.mod
+++ b/go.mod
@@ -17,10 +17,12 @@
 
 module github.com/apache/iceberg-go
 
-go 1.22.7
+go 1.23
+
+toolchain go1.23.2
 
 require (
-       github.com/apache/arrow-go/v18 v18.0.0-20240924011512-14844aea3205
+       github.com/apache/arrow-go/v18 v18.0.1-0.20241022184425-56b794f52a9b
        github.com/aws/aws-sdk-go-v2 v1.32.2
        github.com/aws/aws-sdk-go-v2/config v1.28.0
        github.com/aws/aws-sdk-go-v2/credentials v1.17.41
@@ -34,7 +36,6 @@ require (
        github.com/stretchr/testify v1.9.0
        github.com/twmb/murmur3 v1.1.8
        github.com/wolfeidau/s3iofs v1.5.2
-       golang.org/x/exp v0.0.0-20240909161429-701f63a606c0
        gopkg.in/yaml.v3 v3.0.1
 )
 
@@ -42,8 +43,9 @@ require (
        atomicgo.dev/cursor v0.2.0 // indirect
        atomicgo.dev/keyboard v0.2.9 // indirect
        atomicgo.dev/schedule v0.1.0 // indirect
-       github.com/andybalholm/brotli v1.1.0 // indirect
-       github.com/apache/thrift v0.20.0 // indirect
+       github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // 
indirect
+       github.com/andybalholm/brotli v1.1.1 // indirect
+       github.com/apache/thrift v0.21.0 // indirect
        github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect
        github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 // indirect
        github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 // indirect
@@ -65,7 +67,7 @@ require (
        github.com/gookit/color v1.5.4 // indirect
        github.com/json-iterator/go v1.1.12 // indirect
        github.com/klauspost/asmfmt v1.3.2 // indirect
-       github.com/klauspost/compress v1.17.10 // indirect
+       github.com/klauspost/compress v1.17.11 // indirect
        github.com/klauspost/cpuid/v2 v2.2.8 // indirect
        github.com/lithammer/fuzzysearch v1.1.8 // indirect
        github.com/mattn/go-runewidth v0.0.15 // indirect
@@ -80,6 +82,7 @@ require (
        github.com/stretchr/objx v0.5.2 // indirect
        github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
        github.com/zeebo/xxh3 v1.0.2 // indirect
+       golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect
        golang.org/x/mod v0.21.0 // indirect
        golang.org/x/net v0.30.0 // indirect
        golang.org/x/sync v0.8.0 // indirect
diff --git a/go.sum b/go.sum
index 02cb752..cc43dd1 100644
--- a/go.sum
+++ b/go.sum
@@ -17,12 +17,12 @@ github.com/MarvinJWendt/testza v0.3.0/go.mod 
h1:eFcL4I0idjtIx8P9C6KkAuLgATNKpX4/
 github.com/MarvinJWendt/testza v0.4.2/go.mod 
h1:mSdhXiKH8sg/gQehJ63bINcCKp7RtYewEjXsvsVUPbE=
 github.com/MarvinJWendt/testza v0.5.2 
h1:53KDo64C1z/h/d/stCYCPY69bt/OSwjq5KpFNwi+zB4=
 github.com/MarvinJWendt/testza v0.5.2/go.mod 
h1:xu53QFE5sCdjtMCKk8YMQ2MnymimEctc4n3EjyIYvEY=
-github.com/andybalholm/brotli v1.1.0 
h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
-github.com/andybalholm/brotli v1.1.0/go.mod 
h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
-github.com/apache/arrow-go/v18 v18.0.0-20240924011512-14844aea3205 
h1:/tq9JMJI+i/MO016cGVdKn9c7od1/Ui2uwF78vojPW4=
-github.com/apache/arrow-go/v18 v18.0.0-20240924011512-14844aea3205/go.mod 
h1:MXqyiBhPPITRK1sWzJeXiPh8S+xSCAJVlmzTeMY7l1M=
-github.com/apache/thrift v0.20.0 
h1:631+KvYbsBZxmuJjYwhezVsrfc/TbqtZV4QcxOX1fOI=
-github.com/apache/thrift v0.20.0/go.mod 
h1:hOk1BQqcp2OLzGsyVXdfMk7YFlMxK3aoEVhjD06QhB8=
+github.com/andybalholm/brotli v1.1.1 
h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
+github.com/andybalholm/brotli v1.1.1/go.mod 
h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
+github.com/apache/arrow-go/v18 v18.0.1-0.20241022184425-56b794f52a9b 
h1:E79+ggEd/cv9g4iLe1B8XyBhpqlvqKbCGDmf+RPSwbA=
+github.com/apache/arrow-go/v18 v18.0.1-0.20241022184425-56b794f52a9b/go.mod 
h1:kVPeNv6eFSRhkfWZx1BIRXU6EZnp5g2NqKsuJmKXsO8=
+github.com/apache/thrift v0.21.0 
h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE=
+github.com/apache/thrift v0.21.0/go.mod 
h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw=
 github.com/atomicgo/cursor v0.0.1/go.mod 
h1:cBON2QmmrysudxNBFthvMtN32r3jxVRIvzkUiF/RuIk=
 github.com/aws/aws-sdk-go-v2 v1.32.2 
h1:AkNLZEyYMLnx/Q/mSKkcMqwNFXMAvFto9bNsHqcTduI=
 github.com/aws/aws-sdk-go-v2 v1.32.2/go.mod 
h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo=
@@ -88,8 +88,8 @@ github.com/json-iterator/go v1.1.12 
h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr
 github.com/json-iterator/go v1.1.12/go.mod 
h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
 github.com/klauspost/asmfmt v1.3.2 
h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4=
 github.com/klauspost/asmfmt v1.3.2/go.mod 
h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE=
-github.com/klauspost/compress v1.17.10 
h1:oXAz+Vh0PMUvJczoi+flxpnBEPxoER1IaAnU/NMPtT0=
-github.com/klauspost/compress v1.17.10/go.mod 
h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
+github.com/klauspost/compress v1.17.11 
h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
+github.com/klauspost/compress v1.17.11/go.mod 
h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
 github.com/klauspost/cpuid/v2 v2.0.9/go.mod 
h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
 github.com/klauspost/cpuid/v2 v2.0.10/go.mod 
h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
 github.com/klauspost/cpuid/v2 v2.0.12/go.mod 
h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
@@ -151,6 +151,8 @@ github.com/wolfeidau/s3iofs v1.5.2/go.mod 
h1:fPAKzdWmZ1Z2L9vnqL6d1eb7pVsUgkUstxQ
 github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod 
h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs=
 github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e 
h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
 github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod 
h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
+github.com/xyproto/randomstring v1.0.5 
h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
+github.com/xyproto/randomstring v1.0.5/go.mod 
h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
 github.com/yuin/goldmark v1.4.13/go.mod 
h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
 github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
 github.com/zeebo/assert v1.3.0/go.mod 
h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
diff --git a/partitions.go b/partitions.go
index 321af2e..48b730c 100644
--- a/partitions.go
+++ b/partitions.go
@@ -20,9 +20,8 @@ package iceberg
 import (
        "encoding/json"
        "fmt"
+       "slices"
        "strings"
-
-       "golang.org/x/exp/slices"
 )
 
 const (
diff --git a/schema.go b/schema.go
index f6d88d3..18014dc 100644
--- a/schema.go
+++ b/schema.go
@@ -21,11 +21,10 @@ import (
        "encoding/json"
        "fmt"
        "maps"
+       "slices"
        "strings"
        "sync"
        "sync/atomic"
-
-       "golang.org/x/exp/slices"
 )
 
 // Schema is an Iceberg table schema, represented as a struct with
diff --git a/table/arrow_utils.go b/table/arrow_utils.go
index 8f3890a..b44d06f 100644
--- a/table/arrow_utils.go
+++ b/table/arrow_utils.go
@@ -18,12 +18,16 @@
 package table
 
 import (
+       "context"
        "fmt"
        "slices"
        "strconv"
 
        "github.com/apache/arrow-go/v18/arrow"
+       "github.com/apache/arrow-go/v18/arrow/array"
+       "github.com/apache/arrow-go/v18/arrow/compute"
        "github.com/apache/arrow-go/v18/arrow/extensions"
+       "github.com/apache/arrow-go/v18/arrow/memory"
        "github.com/apache/iceberg-go"
 )
 
@@ -412,6 +416,52 @@ func ArrowSchemaToIceberg(sc *arrow.Schema, 
downcastNsTimestamp bool, nameMappin
        }
 }
 
+type convertToSmallTypes struct{}
+
+func (convertToSmallTypes) Schema(_ *arrow.Schema, structResult arrow.Field) 
arrow.Field {
+       return structResult
+}
+
+func (convertToSmallTypes) Struct(_ *arrow.StructType, results []arrow.Field) 
arrow.Field {
+       return arrow.Field{Type: arrow.StructOf(results...)}
+}
+
+func (convertToSmallTypes) Field(field arrow.Field, fieldResult arrow.Field) 
arrow.Field {
+       field.Type = fieldResult.Type
+       return field
+}
+
+func (convertToSmallTypes) List(_ arrow.ListLikeType, elemResult arrow.Field) 
arrow.Field {
+       return arrow.Field{Type: arrow.ListOfField(elemResult)}
+}
+
+func (convertToSmallTypes) Map(_ *arrow.MapType, keyResult, valueResult 
arrow.Field) arrow.Field {
+       return arrow.Field{
+               Type: arrow.MapOfWithMetadata(keyResult.Type, 
keyResult.Metadata,
+                       valueResult.Type, valueResult.Metadata),
+       }
+}
+
+func (convertToSmallTypes) Primitive(dt arrow.DataType) arrow.Field {
+       switch dt.ID() {
+       case arrow.LARGE_STRING:
+               dt = arrow.BinaryTypes.String
+       case arrow.LARGE_BINARY:
+               dt = arrow.BinaryTypes.Binary
+       }
+
+       return arrow.Field{Type: dt}
+}
+
+func ensureSmallArrowTypes(dt arrow.DataType) (arrow.DataType, error) {
+       top, err := VisitArrowSchema(arrow.NewSchema([]arrow.Field{{Type: dt}}, 
nil), convertToSmallTypes{})
+       if err != nil {
+               return nil, err
+       }
+
+       return top.Type.(*arrow.StructType).Field(0).Type, nil
+}
+
 type convertToArrow struct {
        metadata        map[string]string
        includeFieldIDs bool
@@ -540,3 +590,264 @@ func TypeToArrowType(t iceberg.Type, includeFieldIDs 
bool) (arrow.DataType, erro
 
        return top.Type.(*arrow.StructType).Field(0).Type, nil
 }
+
+type arrowAccessor struct {
+       fileSchema *iceberg.Schema
+}
+
+func (a arrowAccessor) SchemaPartner(partner arrow.Array) arrow.Array {
+       return partner
+}
+
+func (a arrowAccessor) FieldPartner(partnerStruct arrow.Array, fieldID int, _ 
string) arrow.Array {
+       if partnerStruct == nil {
+               return nil
+       }
+
+       field, ok := a.fileSchema.FindFieldByID(fieldID)
+       if !ok {
+               return nil
+       }
+
+       if st, ok := partnerStruct.(*array.Struct); ok {
+               if idx, ok := 
st.DataType().(*arrow.StructType).FieldIdx(field.Name); ok {
+                       return st.Field(idx)
+               }
+       }
+
+       panic(fmt.Errorf("cannot find %s in expected partner_struct type %s",
+               field.Name, partnerStruct.DataType()))
+}
+
+func (a arrowAccessor) ListElementPartner(partnerList arrow.Array) arrow.Array 
{
+       if l, ok := partnerList.(array.ListLike); ok {
+               return l.ListValues()
+       }
+       return nil
+}
+
+func (a arrowAccessor) MapKeyPartner(partnerMap arrow.Array) arrow.Array {
+       if m, ok := partnerMap.(*array.Map); ok {
+               return m.Keys()
+       }
+       return nil
+}
+
+func (a arrowAccessor) MapValuePartner(partnerMap arrow.Array) arrow.Array {
+       if m, ok := partnerMap.(*array.Map); ok {
+               return m.Items()
+       }
+       return nil
+}
+
+func retOrPanic[T any](v T, err error) T {
+       if err != nil {
+               panic(err)
+       }
+       return v
+}
+
+type arrowProjectionVisitor struct {
+       ctx                 context.Context
+       fileSchema          *iceberg.Schema
+       includeFieldIDs     bool
+       downcastNsTimestamp bool
+       useLargeTypes       bool
+}
+
+func (a *arrowProjectionVisitor) castIfNeeded(field iceberg.NestedField, vals 
arrow.Array) arrow.Array {
+       fileField, ok := a.fileSchema.FindFieldByID(field.ID)
+       if !ok {
+               panic(fmt.Errorf("could not find field id %d in schema", 
field.ID))
+       }
+
+       typ, ok := fileField.Type.(iceberg.PrimitiveType)
+       if !ok {
+               vals.Retain()
+               return vals
+       }
+
+       if !field.Type.Equals(typ) {
+               promoted := retOrPanic(iceberg.PromoteType(fileField.Type, 
field.Type))
+               targetType := retOrPanic(TypeToArrowType(promoted, 
a.includeFieldIDs))
+               if !a.useLargeTypes {
+                       targetType = 
retOrPanic(ensureSmallArrowTypes(targetType))
+               }
+
+               return retOrPanic(compute.CastArray(a.ctx, vals,
+                       compute.SafeCastOptions(targetType)))
+       }
+
+       targetType := retOrPanic(TypeToArrowType(field.Type, a.includeFieldIDs))
+       if !arrow.TypeEqual(targetType, vals.DataType()) {
+               switch field.Type.(type) {
+               case iceberg.TimestampType:
+                       tt, tgtok := targetType.(*arrow.TimestampType)
+                       vt, valok := vals.DataType().(*arrow.TimestampType)
+
+                       if tgtok && valok && tt.TimeZone == "" && vt.TimeZone 
== "" && tt.Unit == arrow.Microsecond {
+                               if vt.Unit == arrow.Nanosecond && 
a.downcastNsTimestamp {
+                                       return 
retOrPanic(compute.CastArray(a.ctx, vals, compute.UnsafeCastOptions(tt)))
+                               } else if vt.Unit == arrow.Second || vt.Unit == 
arrow.Millisecond {
+                                       return 
retOrPanic(compute.CastArray(a.ctx, vals, compute.SafeCastOptions(tt)))
+                               }
+                       }
+
+                       panic(fmt.Errorf("unsupported schema projection from %s 
to %s",
+                               vals.DataType(), targetType))
+               case iceberg.TimestampTzType:
+                       tt, tgtok := targetType.(*arrow.TimestampType)
+                       vt, valok := vals.DataType().(*arrow.TimestampType)
+
+                       if tgtok && valok && tt.TimeZone == "UTC" &&
+                               slices.Contains(utcAliases, vt.TimeZone) && 
tt.Unit == arrow.Microsecond {
+                               if vt.Unit == arrow.Nanosecond && 
a.downcastNsTimestamp {
+                                       return 
retOrPanic(compute.CastArray(a.ctx, vals, compute.UnsafeCastOptions(tt)))
+                               } else if vt.Unit != arrow.Nanosecond {
+                                       return 
retOrPanic(compute.CastArray(a.ctx, vals, compute.SafeCastOptions(tt)))
+                               }
+                       }
+
+                       panic(fmt.Errorf("unsupported schema projection from %s 
to %s",
+                               vals.DataType(), targetType))
+               }
+       }
+       vals.Retain()
+       return vals
+}
+
+func (a *arrowProjectionVisitor) constructField(field iceberg.NestedField, 
arrowType arrow.DataType) arrow.Field {
+       metadata := map[string]string{}
+       if field.Doc != "" {
+               metadata[ArrowFieldDocKey] = field.Doc
+       }
+
+       if a.includeFieldIDs {
+               metadata[ArrowParquetFieldIDKey] = strconv.Itoa(field.ID)
+       }
+
+       return arrow.Field{
+               Name:     field.Name,
+               Type:     arrowType,
+               Nullable: !field.Required,
+               Metadata: arrow.MetadataFrom(metadata),
+       }
+}
+
+func (a *arrowProjectionVisitor) Schema(_ *iceberg.Schema, _ arrow.Array, 
result arrow.Array) arrow.Array {
+       return result
+}
+
+func (a *arrowProjectionVisitor) Struct(st iceberg.StructType, structArr 
arrow.Array, fieldResults []arrow.Array) arrow.Array {
+       if structArr == nil {
+               return nil
+       }
+
+       fieldArrs := make([]arrow.Array, len(st.FieldList))
+       fields := make([]arrow.Field, len(st.FieldList))
+       for i, field := range st.FieldList {
+               arr := fieldResults[i]
+               if arr != nil {
+                       arr = a.castIfNeeded(field, arr)
+                       defer arr.Release()
+                       fieldArrs[i] = arr
+                       fields[i] = a.constructField(field, arr.DataType())
+               } else if !field.Required {
+                       dt := retOrPanic(TypeToArrowType(field.Type, false))
+
+                       arr = 
array.MakeArrayOfNull(compute.GetAllocator(a.ctx), dt, structArr.Len())
+                       defer arr.Release()
+                       fieldArrs[i] = arr
+                       fields[i] = a.constructField(field, arr.DataType())
+               } else {
+                       panic(fmt.Errorf("%w: field is required, but could not 
be found in file: %s",
+                               iceberg.ErrInvalidSchema, field))
+               }
+       }
+
+       return retOrPanic(array.NewStructArrayWithFields(fieldArrs, fields))
+}
+
+func (a *arrowProjectionVisitor) Field(_ iceberg.NestedField, _ arrow.Array, 
fieldArr arrow.Array) arrow.Array {
+       return fieldArr
+}
+
+func (a *arrowProjectionVisitor) List(listType iceberg.ListType, listArr 
arrow.Array, valArr arrow.Array) arrow.Array {
+       arr, ok := listArr.(array.ListLike)
+       if !ok || valArr == nil {
+               return nil
+       }
+
+       valueArr := a.castIfNeeded(listType.ElementField(), valArr)
+       defer valueArr.Release()
+
+       var outType arrow.ListLikeType
+       elemField := a.constructField(listType.ElementField(), 
valueArr.DataType())
+       switch arr.DataType().ID() {
+       case arrow.LIST:
+               outType = arrow.ListOfField(elemField)
+       case arrow.LARGE_LIST:
+               outType = arrow.LargeListOfField(elemField)
+       case arrow.LIST_VIEW:
+               outType = arrow.LargeListViewOfField(elemField)
+       }
+
+       data := array.NewData(outType, arr.Len(), arr.Data().Buffers(),
+               []arrow.ArrayData{valueArr.Data()}, arr.NullN(), 
arr.Data().Offset())
+       defer data.Release()
+       return array.MakeFromData(data)
+}
+
+func (a *arrowProjectionVisitor) Map(m iceberg.MapType, mapArray, keyResult, 
valResult arrow.Array) arrow.Array {
+       if keyResult == nil || valResult == nil {
+               return nil
+       }
+
+       arr, ok := mapArray.(*array.Map)
+       if !ok {
+               return nil
+       }
+
+       keys := a.castIfNeeded(m.KeyField(), keyResult)
+       defer keys.Release()
+       vals := a.castIfNeeded(m.ValueField(), valResult)
+       defer vals.Release()
+
+       keyField := a.constructField(m.KeyField(), keys.DataType())
+       valField := a.constructField(m.ValueField(), vals.DataType())
+
+       mapType := arrow.MapOfWithMetadata(keyField.Type, keyField.Metadata, 
valField.Type, valField.Metadata)
+       childData := array.NewData(mapType.Elem(), arr.Len(), 
[]*memory.Buffer{nil},
+               []arrow.ArrayData{keys.Data(), vals.Data()}, 0, 0)
+       defer childData.Release()
+       newData := array.NewData(mapType, arr.Len(), arr.Data().Buffers(),
+               []arrow.ArrayData{childData}, arr.NullN(), arr.Offset())
+       defer newData.Release()
+       return array.NewMapData(newData)
+}
+
+func (a *arrowProjectionVisitor) Primitive(_ iceberg.PrimitiveType, arr 
arrow.Array) arrow.Array {
+       return arr
+}
+
+// ToRequestedSchema will construct a new record batch matching the requested 
iceberg schema
+// casting columns if necessary as appropriate.
+func ToRequestedSchema(requested, fileSchema *iceberg.Schema, batch 
arrow.Record, downcastTimestamp, includeFieldIDs, useLargeTypes bool) 
(arrow.Record, error) {
+       st := array.RecordToStructArray(batch)
+       defer st.Release()
+
+       result, err := iceberg.VisitSchemaWithPartner[arrow.Array, 
arrow.Array](requested, st,
+               &arrowProjectionVisitor{
+                       ctx:                 context.Background(),
+                       fileSchema:          fileSchema,
+                       includeFieldIDs:     includeFieldIDs,
+                       downcastNsTimestamp: downcastTimestamp,
+                       useLargeTypes:       useLargeTypes,
+               }, arrowAccessor{fileSchema: fileSchema})
+       if err != nil {
+               return nil, err
+       }
+       defer result.Release()
+
+       return array.RecordFromStructArray(result.(*array.Struct), nil), nil
+}
diff --git a/table/arrow_utils_test.go b/table/arrow_utils_test.go
index 76c40fd..9c15985 100644
--- a/table/arrow_utils_test.go
+++ b/table/arrow_utils_test.go
@@ -18,10 +18,15 @@
 package table_test
 
 import (
+       "context"
+       "strings"
        "testing"
 
        "github.com/apache/arrow-go/v18/arrow"
+       "github.com/apache/arrow-go/v18/arrow/array"
+       "github.com/apache/arrow-go/v18/arrow/compute"
        "github.com/apache/arrow-go/v18/arrow/extensions"
+       "github.com/apache/arrow-go/v18/arrow/memory"
        "github.com/apache/iceberg-go"
        "github.com/apache/iceberg-go/table"
        "github.com/stretchr/testify/assert"
@@ -393,3 +398,114 @@ func TestArrowSchemaWithNameMapping(t *testing.T) {
                })
        }
 }
+
+var (
+       ArrowSchemaWithAllTimestampPrec = arrow.NewSchema([]arrow.Field{
+               {Name: "timestamp_s", Type: &arrow.TimestampType{Unit: 
arrow.Second}, Nullable: true},
+               {Name: "timestamptz_s", Type: 
arrow.FixedWidthTypes.Timestamp_s, Nullable: true},
+               {Name: "timestamp_ms", Type: &arrow.TimestampType{Unit: 
arrow.Millisecond}, Nullable: true},
+               {Name: "timestamptz_ms", Type: 
arrow.FixedWidthTypes.Timestamp_ms, Nullable: true},
+               {Name: "timestamp_us", Type: &arrow.TimestampType{Unit: 
arrow.Microsecond}, Nullable: true},
+               {Name: "timestamptz_us", Type: 
arrow.FixedWidthTypes.Timestamp_us, Nullable: true},
+               {Name: "timestamp_ns", Type: &arrow.TimestampType{Unit: 
arrow.Nanosecond}, Nullable: true},
+               {Name: "timestamptz_ns", Type: 
arrow.FixedWidthTypes.Timestamp_ns, Nullable: true},
+               {Name: "timestamptz_us_etc_utc", Type: 
&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: "Etc/UTC"}, Nullable: 
true},
+               {Name: "timestamptz_ns_z", Type: &arrow.TimestampType{Unit: 
arrow.Nanosecond, TimeZone: "Z"}, Nullable: true},
+               {Name: "timestamptz_s_0000", Type: &arrow.TimestampType{Unit: 
arrow.Second, TimeZone: "+00:00"}, Nullable: true},
+       }, nil)
+
+       ArrowSchemaWithAllMicrosecondsTimestampPrec = 
arrow.NewSchema([]arrow.Field{
+               {Name: "timestamp_s", Type: &arrow.TimestampType{Unit: 
arrow.Microsecond}, Nullable: true},
+               {Name: "timestamptz_s", Type: 
arrow.FixedWidthTypes.Timestamp_us, Nullable: true},
+               {Name: "timestamp_ms", Type: &arrow.TimestampType{Unit: 
arrow.Microsecond}, Nullable: true},
+               {Name: "timestamptz_ms", Type: 
arrow.FixedWidthTypes.Timestamp_us, Nullable: true},
+               {Name: "timestamp_us", Type: &arrow.TimestampType{Unit: 
arrow.Microsecond}, Nullable: true},
+               {Name: "timestamptz_us", Type: 
arrow.FixedWidthTypes.Timestamp_us, Nullable: true},
+               {Name: "timestamp_ns", Type: &arrow.TimestampType{Unit: 
arrow.Microsecond}, Nullable: true},
+               {Name: "timestamptz_ns", Type: 
arrow.FixedWidthTypes.Timestamp_us, Nullable: true},
+               {Name: "timestamptz_us_etc_utc", Type: 
&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: "UTC"}, Nullable: true},
+               {Name: "timestamptz_ns_z", Type: &arrow.TimestampType{Unit: 
arrow.Microsecond, TimeZone: "UTC"}, Nullable: true},
+               {Name: "timestamptz_s_0000", Type: &arrow.TimestampType{Unit: 
arrow.Microsecond, TimeZone: "UTC"}, Nullable: true},
+       }, nil)
+
+       TableSchemaWithAllMicrosecondsTimestampPrec = iceberg.NewSchema(0,
+               iceberg.NestedField{ID: 1, Name: "timestamp_s", Type: 
iceberg.PrimitiveTypes.Timestamp},
+               iceberg.NestedField{ID: 2, Name: "timestamptz_s", Type: 
iceberg.PrimitiveTypes.TimestampTz},
+               iceberg.NestedField{ID: 3, Name: "timestamp_ms", Type: 
iceberg.PrimitiveTypes.Timestamp},
+               iceberg.NestedField{ID: 4, Name: "timestamptz_ms", Type: 
iceberg.PrimitiveTypes.TimestampTz},
+               iceberg.NestedField{ID: 5, Name: "timestamp_us", Type: 
iceberg.PrimitiveTypes.Timestamp},
+               iceberg.NestedField{ID: 6, Name: "timestamptz_us", Type: 
iceberg.PrimitiveTypes.TimestampTz},
+               iceberg.NestedField{ID: 7, Name: "timestamp_ns", Type: 
iceberg.PrimitiveTypes.Timestamp},
+               iceberg.NestedField{ID: 8, Name: "timestamptz_ns", Type: 
iceberg.PrimitiveTypes.TimestampTz},
+               iceberg.NestedField{ID: 9, Name: "timestamptz_us_etc_utc", 
Type: iceberg.PrimitiveTypes.TimestampTz},
+               iceberg.NestedField{ID: 10, Name: "timestamptz_ns_z", Type: 
iceberg.PrimitiveTypes.TimestampTz},
+               iceberg.NestedField{ID: 11, Name: "timestamptz_s_0000", Type: 
iceberg.PrimitiveTypes.TimestampTz},
+       )
+)
+
+func ArrowRecordWithAllTimestampPrec(mem memory.Allocator) arrow.Record {
+       batch, _, err := array.RecordFromJSON(mem, 
ArrowSchemaWithAllTimestampPrec,
+               strings.NewReader(`[
+               {
+                       "timestamp_s": "2023-01-01T19:25:00-05:00",
+                       "timestamptz_s": "2023-01-01T19:25:00Z",
+                       "timestamp_ms": "2023-01-01T19:25:00.123-05:00",
+                       "timestamptz_ms": "2023-01-01T19:25:00.123Z",
+                       "timestamp_us": "2023-01-01T19:25:00.123456-05:00",
+                       "timestamptz_us": "2023-01-01T19:25:00.123456Z",
+                       "timestamp_ns": "2024-07-11T03:30:00.123456789-05:00",
+                       "timestamptz_ns": "2023-01-01T19:25:00.123456789Z",
+                       "timestamptz_us_etc_utc": "2023-01-01T19:25:00.123456Z",
+                       "timestamptz_ns_z": "2024-07-11T03:30:00.123456789Z",
+                       "timestamptz_s_0000": "2023-01-01T19:25:00Z"
+               }, {}, {
+                       "timestamp_s": "2023-03-01T19:25:00-05:00",
+                       "timestamptz_s": "2023-03-01T19:25:00Z",
+                       "timestamp_ms": "2023-03-01T19:25:00.123-05:00",
+                       "timestamptz_ms": "2023-03-01T19:25:00.123Z",
+                       "timestamp_us": "2023-03-01T19:25:00.123456-05:00",
+                       "timestamptz_us": "2023-03-01T19:25:00.123456Z",
+                       "timestamp_ns": "2024-07-11T03:30:00.9876543210-05:00",
+                       "timestamptz_ns": "2023-03-01T19:25:00.9876543210Z",
+                       "timestamptz_us_etc_utc": "2023-03-01T19:25:00.123456Z",
+                       "timestamptz_ns_z": "2024-07-11T03:30:00.9876543210Z",
+                       "timestamptz_s_0000": "2023-03-01T19:25:00Z"
+               }
+       ]`))
+
+       if err != nil {
+               panic(err)
+       }
+
+       return batch
+}
+
+func TestToRequestedSchemaTimestamps(t *testing.T) {
+       mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
+       defer mem.AssertSize(t, 0)
+
+       batch := ArrowRecordWithAllTimestampPrec(mem)
+       defer batch.Release()
+
+       requestedSchema := TableSchemaWithAllMicrosecondsTimestampPrec
+       fileSchema := requestedSchema
+
+       converted, err := table.ToRequestedSchema(requestedSchema, fileSchema, 
batch, true, false, false)
+       require.NoError(t, err)
+       defer converted.Release()
+
+       assert.True(t, 
converted.Schema().Equal(ArrowSchemaWithAllMicrosecondsTimestampPrec), 
"expected: %s\ngot: %s",
+               ArrowSchemaWithAllMicrosecondsTimestampPrec, converted.Schema())
+
+       for i, col := range batch.Columns() {
+               convertedCol := converted.Column(i)
+               if arrow.TypeEqual(col.DataType(), convertedCol.DataType()) {
+                       assert.True(t, array.Equal(col, convertedCol), 
"expected: %s\ngot: %s", col, convertedCol)
+               } else {
+                       expected, err := 
compute.CastArray(context.Background(), col, 
compute.UnsafeCastOptions(convertedCol.DataType()))
+                       require.NoError(t, err)
+                       assert.True(t, array.Equal(expected, convertedCol), 
"expected: %s\ngot: %s", expected, convertedCol)
+                       expected.Release()
+               }
+       }
+}
diff --git a/table/table.go b/table/table.go
index b4cf92c..7064add 100644
--- a/table/table.go
+++ b/table/table.go
@@ -18,9 +18,10 @@
 package table
 
 import (
+       "slices"
+
        "github.com/apache/iceberg-go"
        "github.com/apache/iceberg-go/io"
-       "golang.org/x/exp/slices"
 )
 
 type Identifier = []string
diff --git a/types.go b/types.go
index 6729964..ebc8849 100644
--- a/types.go
+++ b/types.go
@@ -21,12 +21,12 @@ import (
        "encoding/json"
        "fmt"
        "regexp"
+       "slices"
        "strconv"
        "strings"
        "time"
 
        "github.com/apache/arrow-go/v18/arrow/decimal128"
-       "golang.org/x/exp/slices"
 )
 
 var (
@@ -637,3 +637,45 @@ var PrimitiveTypes = struct {
        Binary:      BinaryType{},
        UUID:        UUIDType{},
 }
+
+// PromoteType promotes the type being read from a file to a requested read 
type.
+// fileType is the type from the file being read
+// readType is the requested readType
+func PromoteType(fileType, readType Type) (Type, error) {
+       switch t := fileType.(type) {
+       case Int32Type:
+               if _, ok := readType.(Int64Type); ok {
+                       return readType, nil
+               }
+       case Float32Type:
+               if _, ok := readType.(Float64Type); ok {
+                       return readType, nil
+               }
+       case StringType:
+               if _, ok := readType.(BinaryType); ok {
+                       return readType, nil
+               }
+       case BinaryType:
+               if _, ok := readType.(StringType); ok {
+                       return readType, nil
+               }
+       case DecimalType:
+               if rt, ok := readType.(DecimalType); ok {
+                       if t.precision <= rt.precision && t.scale <= rt.scale {
+                               return readType, nil
+                       }
+                       return nil, fmt.Errorf("%w: cannot reduce precision 
from %s to %s",
+                               ErrResolve, fileType, readType)
+               }
+       case FixedType:
+               if _, ok := readType.(UUIDType); ok && t.len == 16 {
+                       return readType, nil
+               }
+       default:
+               if fileType.Equals(readType) {
+                       return fileType, nil
+               }
+       }
+
+       return nil, fmt.Errorf("%w: cannot promote %s to %s", ErrResolve, 
fileType, readType)
+}
diff --git a/visitors.go b/visitors.go
index 7525026..bb0caab 100644
--- a/visitors.go
+++ b/visitors.go
@@ -19,7 +19,9 @@ package iceberg
 
 import (
        "fmt"
+       "maps"
        "math"
+       "slices"
        "strings"
 
        "github.com/google/uuid"
@@ -395,3 +397,50 @@ func (rewriteNotVisitor) VisitUnbound(pred 
UnboundPredicate) BooleanExpression {
 func (rewriteNotVisitor) VisitBound(pred BoundPredicate) BooleanExpression {
        return pred
 }
+
+// ExtractFieldIDs returns a slice containing the field IDs which are 
referenced
+// by any terms in the given expression. This enables retrieving exactly which
+// fields are needed for an expression.
+func ExtractFieldIDs(expr BooleanExpression) ([]int, error) {
+       res, err := VisitExpr(expr, expressionFieldIDs{})
+       if err != nil {
+               return nil, err
+       }
+
+       out := make([]int, 0, len(res))
+       return slices.AppendSeq(out, maps.Keys(res)), nil
+}
+
+type expressionFieldIDs struct{}
+
+func (expressionFieldIDs) VisitTrue() map[int]struct{} {
+       return map[int]struct{}{}
+}
+
+func (expressionFieldIDs) VisitFalse() map[int]struct{} {
+       return map[int]struct{}{}
+}
+
+func (expressionFieldIDs) VisitNot(child map[int]struct{}) map[int]struct{} {
+       return child
+}
+
+func (expressionFieldIDs) VisitAnd(left, right map[int]struct{}) 
map[int]struct{} {
+       maps.Insert(left, maps.All(right))
+       return left
+}
+
+func (expressionFieldIDs) VisitOr(left, right map[int]struct{}) 
map[int]struct{} {
+       maps.Insert(left, maps.All(right))
+       return left
+}
+
+func (expressionFieldIDs) VisitUnbound(UnboundPredicate) map[int]struct{} {
+       panic("expression field IDs only works for bound expressions")
+}
+
+func (expressionFieldIDs) VisitBound(pred BoundPredicate) map[int]struct{} {
+       return map[int]struct{}{
+               pred.Ref().Field().ID: {},
+       }
+}

Reply via email to