This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new d493460  feat(arrow/array): convert RecordReader and iterators (#314)
d493460 is described below

commit d49346023bc0bdb5006502a95a0edb9f77b3fb57
Author: Matt Topol <[email protected]>
AuthorDate: Mon Mar 17 10:23:59 2025 -0400

    feat(arrow/array): convert RecordReader and iterators (#314)
    
    ### Rationale for this change
    With the advent of Go iterators via the [iter](https://pkg.go.dev/iter)
    module, we should provide some easy compatibility and canonicity for
    handling streams of record batches.
    
    ### What changes are included in this PR?
    Two new functions: `ReaderFromIter` and `IterFromReader` for converting
    between the `RecordReader` interface and an iterator of Records and
    errors. This should make it easy to integrate in various packages
    without forcing refactors or boilerplate code.
    
    ### Are these changes tested?
    Yes, unit tests are added for them.
    
    ### Are there any user-facing changes?
    No.
---
 arrow/array/record.go      |  79 +++++++++++++++++++++++++++++++++
 arrow/array/record_test.go | 108 ++++++++++++++++++++++++++++++++++++---------
 2 files changed, 165 insertions(+), 22 deletions(-)

diff --git a/arrow/array/record.go b/arrow/array/record.go
index b8041e2..f29c896 100644
--- a/arrow/array/record.go
+++ b/arrow/array/record.go
@@ -19,6 +19,7 @@ package array
 import (
        "bytes"
        "fmt"
+       "iter"
        "strings"
        "sync/atomic"
 
@@ -405,6 +406,84 @@ func (b *RecordBuilder) UnmarshalJSON(data []byte) error {
        return nil
 }
 
+type iterReader struct {
+       refCount atomic.Int64
+
+       schema *arrow.Schema
+       cur    arrow.Record
+
+       next func() (arrow.Record, error, bool)
+       stop func()
+
+       err error
+}
+
+func (ir *iterReader) Schema() *arrow.Schema { return ir.schema }
+
+func (ir *iterReader) Retain() { ir.refCount.Add(1) }
+func (ir *iterReader) Release() {
+       debug.Assert(ir.refCount.Load() > 0, "too many releases")
+
+       if ir.refCount.Add(-1) == 0 {
+               ir.stop()
+               ir.schema, ir.next = nil, nil
+               if ir.cur != nil {
+                       ir.cur.Release()
+               }
+       }
+}
+
+func (ir *iterReader) Record() arrow.Record { return ir.cur }
+func (ir *iterReader) Err() error           { return ir.err }
+
+func (ir *iterReader) Next() bool {
+       if ir.cur != nil {
+               ir.cur.Release()
+       }
+
+       var ok bool
+       ir.cur, ir.err, ok = ir.next()
+       if ir.err != nil {
+               ir.stop()
+               return false
+       }
+
+       return ok
+}
+
+// ReaderFromIter wraps a go iterator for arrow.Record + error into a 
RecordReader
+// interface object for ease of use.
+func ReaderFromIter(schema *arrow.Schema, itr iter.Seq2[arrow.Record, error]) 
RecordReader {
+       next, stop := iter.Pull2(itr)
+       rdr := &iterReader{
+               schema: schema,
+               next:   next,
+               stop:   stop,
+       }
+       rdr.refCount.Add(1)
+       return rdr
+}
+
+// IterFromReader converts a RecordReader interface into an iterator that
+// you can use range on. The semantics are still important, if a record
+// that is returned is desired to be utilized beyond the scope of an iteration
+// then Retain must be called on it.
+func IterFromReader(rdr RecordReader) iter.Seq2[arrow.Record, error] {
+       rdr.Retain()
+       return func(yield func(arrow.Record, error) bool) {
+               defer rdr.Release()
+               for rdr.Next() {
+                       if !yield(rdr.Record(), nil) {
+                               return
+                       }
+               }
+
+               if rdr.Err() != nil {
+                       yield(nil, rdr.Err())
+               }
+       }
+}
+
 var (
        _ arrow.Record = (*simpleRecord)(nil)
        _ RecordReader = (*simpleRecords)(nil)
diff --git a/arrow/array/record_test.go b/arrow/array/record_test.go
index 91a31cb..2a61bdd 100644
--- a/arrow/array/record_test.go
+++ b/arrow/array/record_test.go
@@ -301,33 +301,97 @@ func TestRecordReader(t *testing.T) {
        defer rec2.Release()
 
        recs := []arrow.Record{rec1, rec2}
-       itr, err := array.NewRecordReader(schema, recs)
-       if err != nil {
-               t.Fatal(err)
-       }
-       defer itr.Release()
+       t.Run("simple reader", func(t *testing.T) {
+               itr, err := array.NewRecordReader(schema, recs)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer itr.Release()
 
-       itr.Retain()
-       itr.Release()
+               itr.Retain()
+               itr.Release()
 
-       if got, want := itr.Schema(), schema; !got.Equal(want) {
-               t.Fatalf("invalid schema. got=%#v, want=%#v", got, want)
-       }
+               if got, want := itr.Schema(), schema; !got.Equal(want) {
+                       t.Fatalf("invalid schema. got=%#v, want=%#v", got, want)
+               }
 
-       n := 0
-       for itr.Next() {
-               n++
-               if got, want := itr.Record(), recs[n-1]; 
!reflect.DeepEqual(got, want) {
-                       t.Fatalf("itr[%d], invalid record. got=%#v, want=%#v", 
n-1, got, want)
+               n := 0
+               for itr.Next() {
+                       n++
+                       if got, want := itr.Record(), recs[n-1]; 
!reflect.DeepEqual(got, want) {
+                               t.Fatalf("itr[%d], invalid record. got=%#v, 
want=%#v", n-1, got, want)
+                       }
+               }
+               if err := itr.Err(); err != nil {
+                       t.Fatalf("itr error: %#v", err)
                }
-       }
-       if err := itr.Err(); err != nil {
-               t.Fatalf("itr error: %#v", err)
-       }
 
-       if n != len(recs) {
-               t.Fatalf("invalid number of iterations. got=%d, want=%d", n, 
len(recs))
-       }
+               if n != len(recs) {
+                       t.Fatalf("invalid number of iterations. got=%d, 
want=%d", n, len(recs))
+               }
+       })
+
+       t.Run("iter to reader", func(t *testing.T) {
+               itr := func(yield func(arrow.Record, error) bool) {
+                       for _, r := range recs {
+                               if !yield(r, nil) {
+                                       return
+                               }
+                       }
+               }
+
+               rdr := array.ReaderFromIter(schema, itr)
+               defer rdr.Release()
+
+               rdr.Retain()
+               rdr.Release()
+
+               if got, want := rdr.Schema(), schema; !got.Equal(want) {
+                       t.Fatalf("invalid schema. got=%#v, want=%#v", got, want)
+               }
+
+               n := 0
+               for rdr.Next() {
+                       n++
+                       // facet of using the simple record reader with a slice
+                       // by default it will release records when the reader 
is released
+                       // leading to too many releases on the original record
+                       // so we retain it to keep it from going away while the 
test runs
+                       rdr.Record().Retain()
+                       if got, want := rdr.Record(), recs[n-1]; 
!reflect.DeepEqual(got, want) {
+                               t.Fatalf("itr[%d], invalid record. got=%#v, 
want=%#v", n-1, got, want)
+                       }
+               }
+               if err := rdr.Err(); err != nil {
+                       t.Fatalf("itr error: %#v", err)
+               }
+
+               if n != len(recs) {
+                       t.Fatalf("invalid number of iterations. got=%d, 
want=%d", n, len(recs))
+               }
+       })
+
+       t.Run("reader to iter", func(t *testing.T) {
+               rdr, err := array.NewRecordReader(schema, recs)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               itr := array.IterFromReader(rdr)
+               rdr.Release()
+
+               n := 0
+               for rec, err := range itr {
+                       if err != nil {
+                               t.Fatalf("itr error: %#v", err)
+                       }
+
+                       n++
+                       if got, want := rec, recs[n-1]; !reflect.DeepEqual(got, 
want) {
+                               t.Fatalf("itr[%d], invalid record. got=%#v, 
want=%#v", n-1, got, want)
+                       }
+               }
+       })
 
        for _, tc := range []struct {
                name   string

Reply via email to