This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new b419b81  feat(parquet/pqarrow): Add SeekToRow for RecordReader (#321)
b419b81 is described below

commit b419b81866945bde9e9a5d0ec036acc5ebaba2bb
Author: Matt Topol <[email protected]>
AuthorDate: Fri Mar 21 11:36:59 2025 -0400

    feat(parquet/pqarrow): Add SeekToRow for RecordReader (#321)
    
    ### Rationale for this change
    As suggested by
    https://github.com/apache/arrow-go/issues/278#issuecomment-2741136394
    allowing the RecordReader from the `pqarrow` package to also leverage
    the `SeekToRow` functionality to skip records from a parquet file while
    respected skipping particular row groups.
    
    ### What changes are included in this PR?
    Implementing a `SeekToRow` method for `pqarrow.RecordReader` to seek the
    record reader to a specific row where the next read will start from.
    
    ### Are these changes tested?
    Yes, unit tests are added for this.
    
    ### Are there any user-facing changes?
    Just the new functions.
---
 parquet/file/column_reader.go       |   1 +
 parquet/file/record_reader.go       |  24 +++++-
 parquet/pqarrow/column_readers.go   |  31 ++++++-
 parquet/pqarrow/file_reader.go      |  79 +++++++++++++++--
 parquet/pqarrow/file_reader_test.go | 167 ++++++++++++++++++++++++++++++++++++
 5 files changed, 290 insertions(+), 12 deletions(-)

diff --git a/parquet/file/column_reader.go b/parquet/file/column_reader.go
index 5faf8bc..03ca5a8 100644
--- a/parquet/file/column_reader.go
+++ b/parquet/file/column_reader.go
@@ -223,6 +223,7 @@ func (c *columnChunkReader) pager() PageReader             
{ return c.rdr }
 func (c *columnChunkReader) setPageReader(rdr PageReader) {
        c.rdr, c.err = rdr, nil
        c.decoders = make(map[format.Encoding]encoding.TypedDecoder)
+       c.newDictionary = false
        c.numBuffered, c.numDecoded = 0, 0
 }
 
diff --git a/parquet/file/record_reader.go b/parquet/file/record_reader.go
index 20d68f4..e8b0dee 100644
--- a/parquet/file/record_reader.go
+++ b/parquet/file/record_reader.go
@@ -63,6 +63,7 @@ type RecordReader interface {
        // ReleaseValues transfers the buffer of data with the values to the 
caller,
        // a new buffer will be allocated on subsequent calls.
        ReleaseValues() *memory.Buffer
+       ResetValues()
        // NullCount returns the number of nulls decoded
        NullCount() int64
        // Type returns the parquet physical type of the column
@@ -78,6 +79,10 @@ type RecordReader interface {
        // Release decrements the ref count by one, releasing the internal 
buffers when
        // the ref count is 0.
        Release()
+       // SeekToRow will shift the record reader so that subsequent reads will
+       // start at the desired row. It will utilize Offset Indexes if they 
exist
+       // to skip pages and seek.
+       SeekToRow(int64) error
 }
 
 // BinaryRecordReader provides an extra GetBuilderChunks function above and 
beyond
@@ -440,12 +445,27 @@ func (rr *recordReader) reserveValues(extra int64) error {
        return rr.recordReaderImpl.ReserveValues(extra, 
rr.leafInfo.HasNullableValues())
 }
 
-func (rr *recordReader) resetValues() {
+func (rr *recordReader) ResetValues() {
        rr.recordReaderImpl.ResetValues()
 }
 
+func (rr *recordReader) SeekToRow(recordIdx int64) error {
+       if err := rr.recordReaderImpl.SeekToRow(recordIdx); err != nil {
+               return err
+       }
+
+       rr.atRecStart = true
+       rr.recordsRead = 0
+       // force re-reading the definition/repetition levels
+       // calling SeekToRow on the underlying column reader will ensure that
+       // the next reads will pull from the correct row
+       rr.levelsPos, rr.levelsWritten = 0, 0
+
+       return nil
+}
+
 func (rr *recordReader) Reset() {
-       rr.resetValues()
+       rr.ResetValues()
 
        if rr.levelsWritten > 0 {
                remain := int(rr.levelsWritten - rr.levelsPos)
diff --git a/parquet/pqarrow/column_readers.go 
b/parquet/pqarrow/column_readers.go
index 921774d..5047d88 100644
--- a/parquet/pqarrow/column_readers.go
+++ b/parquet/pqarrow/column_readers.go
@@ -92,7 +92,7 @@ func (lr *leafReader) IsOrHasRepeatedChild() bool { return 
false }
 
 func (lr *leafReader) LoadBatch(nrecords int64) (err error) {
        lr.releaseOut()
-       lr.recordRdr.Reset()
+       lr.recordRdr.ResetValues()
 
        if err := lr.recordRdr.Reserve(nrecords); err != nil {
                return err
@@ -135,6 +135,16 @@ func (lr *leafReader) clearOut() (out *arrow.Chunked) {
 
 func (lr *leafReader) Field() *arrow.Field { return lr.field }
 
+func (lr *leafReader) SeekToRow(rowIdx int64) error {
+       pr, offset, err := lr.input.FindChunkForRow(rowIdx)
+       if err != nil {
+               return err
+       }
+
+       lr.recordRdr.SetPageReader(pr)
+       return lr.recordRdr.SeekToRow(offset)
+}
+
 func (lr *leafReader) nextRowGroup() error {
        pr, err := lr.input.NextChunk()
        if err != nil {
@@ -227,6 +237,21 @@ func (sr *structReader) GetRepLevels() ([]int16, error) {
        return sr.defRepLevelChild.GetRepLevels()
 }
 
+func (sr *structReader) SeekToRow(rowIdx int64) error {
+       var g errgroup.Group
+       if !sr.props.Parallel {
+               g.SetLimit(1)
+       }
+
+       for _, rdr := range sr.children {
+               g.Go(func() error {
+                       return rdr.SeekToRow(rowIdx)
+               })
+       }
+
+       return g.Wait()
+}
+
 func (sr *structReader) LoadBatch(nrecords int64) error {
        // Load batches in parallel
        // When reading structs with large numbers of columns, the serial load 
is very slow.
@@ -356,6 +381,10 @@ func (lr *listReader) Field() *arrow.Field { return 
lr.field }
 
 func (lr *listReader) IsOrHasRepeatedChild() bool { return true }
 
+func (lr *listReader) SeekToRow(rowIdx int64) error {
+       return lr.itemRdr.SeekToRow(rowIdx)
+}
+
 func (lr *listReader) LoadBatch(nrecords int64) error {
        return lr.itemRdr.LoadBatch(nrecords)
 }
diff --git a/parquet/pqarrow/file_reader.go b/parquet/pqarrow/file_reader.go
index d6eae17..ab7dcec 100644
--- a/parquet/pqarrow/file_reader.go
+++ b/parquet/pqarrow/file_reader.go
@@ -21,6 +21,7 @@ import (
        "errors"
        "fmt"
        "io"
+       "slices"
        "sync"
        "sync/atomic"
 
@@ -116,6 +117,7 @@ type colReaderImpl interface {
        GetDefLevels() ([]int16, error)
        GetRepLevels() ([]int16, error)
        Field() *arrow.Field
+       SeekToRow(int64) error
        IsOrHasRepeatedChild() bool
        Retain()
        Release()
@@ -427,6 +429,20 @@ func (fr *FileReader) getColumnReader(ctx context.Context, 
i int, colFactory itr
 type RecordReader interface {
        array.RecordReader
        arrio.Reader
+       // SeekToRow will shift the record reader so that subsequent calls to 
Read
+       // or Next will begin from the specified row.
+       //
+       // If the record reader was constructed with a request for a subset of 
row
+       // groups, then rows are counted across the requested row groups, not 
the
+       // entire file. This prevents reading row groups that were requested to 
be
+       // skipped, and allows treating the subset of row groups as a single 
collection
+       // of rows.
+       //
+       // If the file contains Offset indexes for a given column, then it will 
be
+       // utilized to skip pages as needed to find the requested row. 
Otherwise page
+       // headers will have to still be read to find the right page to being 
reading
+       // from.
+       SeekToRow(int64) error
 }
 
 // GetRecordReader returns a record reader that reads only the requested 
column indexes and row groups.
@@ -537,12 +553,8 @@ func (fr *FileReader) getReader(ctx context.Context, field 
*SchemaField, arrowFi
                }
 
                // because we performed getReader concurrently, we need to 
prune out any empty readers
-               for n := len(childReaders) - 1; n >= 0; n-- {
-                       if childReaders[n] == nil {
-                               childReaders = append(childReaders[:n], 
childReaders[n+1:]...)
-                               childFields = append(childFields[:n], 
childFields[n+1:]...)
-                       }
-               }
+               childReaders = slices.DeleteFunc(childReaders,
+                       func(r *ColumnReader) bool { return r == nil })
                if len(childFields) == 0 {
                        return nil, nil
                }
@@ -615,15 +627,45 @@ type columnIterator struct {
        rdr       *file.Reader
        schema    *schema.Schema
        rowGroups []int
+
+       rgIdx int
 }
 
-func (c *columnIterator) NextChunk() (file.PageReader, error) {
+func (c *columnIterator) FindChunkForRow(rowIdx int64) (file.PageReader, 
int64, error) {
        if len(c.rowGroups) == 0 {
+               return nil, 0, nil
+       }
+
+       if rowIdx < 0 || rowIdx > c.rdr.NumRows() {
+               return nil, 0, fmt.Errorf("invalid row index %d, file only has 
%d rows", rowIdx, c.rdr.NumRows())
+       }
+
+       idx := int64(0)
+       for i, rg := range c.rowGroups {
+               rgr := c.rdr.RowGroup(rg)
+               if idx+rgr.NumRows() > rowIdx {
+                       c.rgIdx = i + 1
+                       pr, err := rgr.GetColumnPageReader(c.index)
+                       if err != nil {
+                               return nil, 0, err
+                       }
+
+                       return pr, rowIdx - idx, nil
+               }
+               idx += rgr.NumRows()
+       }
+
+       return nil, 0, fmt.Errorf("%w: invalid row index %d, row group subset 
only has %d total rows",
+               arrow.ErrInvalid, rowIdx, idx)
+}
+
+func (c *columnIterator) NextChunk() (file.PageReader, error) {
+       if len(c.rowGroups) == 0 || c.rgIdx >= len(c.rowGroups) {
                return nil, nil
        }
 
-       rgr := c.rdr.RowGroup(c.rowGroups[0])
-       c.rowGroups = c.rowGroups[1:]
+       rgr := c.rdr.RowGroup(c.rowGroups[c.rgIdx])
+       c.rgIdx++
        return rgr.GetColumnPageReader(c.index)
 }
 
@@ -643,6 +685,25 @@ type recordReader struct {
        refCount int64
 }
 
+func (r *recordReader) SeekToRow(row int64) error {
+       if r.cur != nil {
+               r.cur.Release()
+               r.cur = nil
+       }
+
+       if row < 0 || row >= r.numRows {
+               return fmt.Errorf("invalid row index %d, file only has %d 
rows", row, r.numRows)
+       }
+
+       for _, fr := range r.fieldReaders {
+               if err := fr.SeekToRow(row); err != nil {
+                       return err
+               }
+       }
+
+       return nil
+}
+
 func (r *recordReader) Retain() {
        atomic.AddInt64(&r.refCount, 1)
 }
diff --git a/parquet/pqarrow/file_reader_test.go 
b/parquet/pqarrow/file_reader_test.go
index 9010927..bca5164 100644
--- a/parquet/pqarrow/file_reader_test.go
+++ b/parquet/pqarrow/file_reader_test.go
@@ -285,6 +285,173 @@ func TestRecordReaderSerial(t *testing.T) {
        assert.Nil(t, rec)
 }
 
+func TestRecordReaderSeekToRow(t *testing.T) {
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       defer mem.AssertSize(t, 0)
+
+       tbl := makeDateTimeTypesTable(mem, true, true)
+       defer tbl.Release()
+
+       var buf bytes.Buffer
+       require.NoError(t, pqarrow.WriteTable(tbl, &buf, tbl.NumRows(), nil, 
pqarrow.NewArrowWriterProperties(pqarrow.WithAllocator(mem))))
+
+       pf, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()), 
file.WithReadProps(parquet.NewReaderProperties(mem)))
+       require.NoError(t, err)
+
+       reader, err := pqarrow.NewFileReader(pf, 
pqarrow.ArrowReadProperties{BatchSize: 2}, mem)
+       require.NoError(t, err)
+
+       sc, err := reader.Schema()
+       assert.NoError(t, err)
+       assert.Truef(t, tbl.Schema().Equal(sc), "expected: %s\ngot: %s", 
tbl.Schema(), sc)
+
+       rr, err := reader.GetRecordReader(context.Background(), nil, nil)
+       assert.NoError(t, err)
+       assert.NotNil(t, rr)
+       defer rr.Release()
+
+       tr := array.NewTableReader(tbl, 2)
+       defer tr.Release()
+
+       rec, err := rr.Read()
+       assert.NoError(t, err)
+       tr.Next()
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       require.NoError(t, rr.SeekToRow(0))
+       rec, err = rr.Read()
+       assert.NoError(t, err)
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       rec, err = rr.Read()
+       assert.NoError(t, err)
+       tr.Next()
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       require.NoError(t, rr.SeekToRow(2))
+       rec, err = rr.Read()
+       assert.NoError(t, err)
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       require.NoError(t, rr.SeekToRow(4))
+       rec, err = rr.Read()
+       tr.Next()
+       assert.NoError(t, err)
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+}
+
+func TestRecordReaderMultiRowGroup(t *testing.T) {
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       defer mem.AssertSize(t, 0)
+
+       tbl := makeDateTimeTypesTable(mem, true, true)
+       defer tbl.Release()
+
+       var buf bytes.Buffer
+       require.NoError(t, pqarrow.WriteTable(tbl, &buf, 2, nil, 
pqarrow.NewArrowWriterProperties(pqarrow.WithAllocator(mem))))
+
+       pf, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()), 
file.WithReadProps(parquet.NewReaderProperties(mem)))
+       require.NoError(t, err)
+
+       reader, err := pqarrow.NewFileReader(pf, 
pqarrow.ArrowReadProperties{BatchSize: 2}, mem)
+       require.NoError(t, err)
+
+       sc, err := reader.Schema()
+       assert.NoError(t, err)
+       assert.Truef(t, tbl.Schema().Equal(sc), "expected: %s\ngot: %s", 
tbl.Schema(), sc)
+
+       rr, err := reader.GetRecordReader(context.Background(), nil, nil)
+       assert.NoError(t, err)
+       assert.NotNil(t, rr)
+       defer rr.Release()
+
+       tr := array.NewTableReader(tbl, 2)
+       defer tr.Release()
+
+       rec, err := rr.Read()
+       assert.NoError(t, err)
+       tr.Next()
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       rec, err = rr.Read()
+       assert.NoError(t, err)
+       tr.Next()
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       rec, err = rr.Read()
+       assert.NoError(t, err)
+       tr.Next()
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       rec, err = rr.Read()
+       assert.Same(t, io.EOF, err)
+       assert.Nil(t, rec)
+}
+
+func TestRecordReaderSeekToRowMultiRowGroup(t *testing.T) {
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       defer mem.AssertSize(t, 0)
+
+       tbl := makeDateTimeTypesTable(mem, true, true)
+       defer tbl.Release()
+
+       var buf bytes.Buffer
+       require.NoError(t, pqarrow.WriteTable(tbl, &buf, 2, nil, 
pqarrow.NewArrowWriterProperties(pqarrow.WithAllocator(mem))))
+
+       pf, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()), 
file.WithReadProps(parquet.NewReaderProperties(mem)))
+       require.NoError(t, err)
+
+       reader, err := pqarrow.NewFileReader(pf, 
pqarrow.ArrowReadProperties{BatchSize: 2}, mem)
+       require.NoError(t, err)
+
+       sc, err := reader.Schema()
+       assert.NoError(t, err)
+       assert.Truef(t, tbl.Schema().Equal(sc), "expected: %s\ngot: %s", 
tbl.Schema(), sc)
+
+       rr, err := reader.GetRecordReader(context.Background(), nil, nil)
+       assert.NoError(t, err)
+       assert.NotNil(t, rr)
+       defer rr.Release()
+
+       tr := array.NewTableReader(tbl, 2)
+       defer tr.Release()
+
+       rec, err := rr.Read()
+       assert.NoError(t, err)
+       tr.Next()
+       first := tr.Record()
+       first.Retain()
+       defer first.Release()
+
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       require.NoError(t, rr.SeekToRow(0))
+       rec, err = rr.Read()
+       assert.NoError(t, err)
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       rec, err = rr.Read()
+       assert.NoError(t, err)
+       tr.Next()
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       require.NoError(t, rr.SeekToRow(2))
+       rec, err = rr.Read()
+       assert.NoError(t, err)
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       require.NoError(t, rr.SeekToRow(4))
+       rec, err = rr.Read()
+       tr.Next()
+       assert.NoError(t, err)
+       assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: 
%s\ngot: %s", tr.Record(), rec)
+
+       require.NoError(t, rr.SeekToRow(0))
+       rec, err = rr.Read()
+       assert.NoError(t, err)
+       assert.Truef(t, array.RecordEqual(first, rec), "expected: %s\ngot: %s", 
first, rec)
+}
+
 func TestFileReaderWriterMetadata(t *testing.T) {
        mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
        defer mem.AssertSize(t, 0)

Reply via email to